## Setup ##

In [1]:
import numpy as np
import os
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import torch
from torch import nn
import sys
import torch.nn.functional as F
import argparse
import re
import pickle as pk
from tqdm import tqdm
from scipy.special import softmax
from sklearn.decomposition import PCA
from shutil import copyfile
from sklearn.mixture import GaussianMixture
from sklearn.mixture._gaussian_mixture import _estimate_gaussian_parameters
from nltk.tokenize import sent_tokenize, word_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import confusion_matrix, f1_score
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args(object):
    pass

args = Args()
args.dataset_name = "20News"
args.gpu = 7
args.pca = 128
args.random_state = 42
args.emb_dim = 768
args.num_heads = 2
args.batch_size = 64
args.temp = 0.2
args.lr = 1e-3
args.epochs = 5
args.accum_steps = 1
args.max_sent = 150

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

In [3]:
data_path = os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name)
new_data_path = os.path.join("/home/pk36/MEGClass/intermediate_data", args.dataset_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# this is where we save all representations
if not os.path.exists(new_data_path):
    os.makedirs(new_data_path)

## Helper Functions ##

In [4]:
DATA_FOLDER_PATH = os.path.join("/shared/data2/pk36/multidim/multigran")
INTERMEDIATE_DATA_FOLDER_PATH = os.path.join("/home/pk36/MEGClass/intermediate_data")

def tensor_to_numpy(tensor):
    return tensor.clone().detach().cpu().numpy()

# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def sentenceToClass(sent_repr, class_repr, weights):
    # sent_repr: N x S x E
    # class_repr: C x E
    # weights: N x S # equals 0 for masked sentences

    #cos-sim between (N x S) x E and (C x E) = N x S x C
    m, n = sent_repr.shape[:2]
    sentcos = cosine_similarity(sent_repr.reshape(m*n,-1), class_repr).reshape(m,n,-1)
    sent_to_class = np.argmax(sentcos, axis=2) # N x S
    sent_to_doc_class = np.sum(np.multiply(sent_to_class, weights), axis=1) # N x 1
    return sent_to_doc_class

def docToClass(doc_repr, class_repr):
    # doc_repr: N x E
    # class_repr: C x E

    #cos-sim between N x E and C x E = N x C
    doccos = cosine_similarity(doc_repr, class_repr)
    doc_to_class = np.argmax(doccos, axis=1) # N x 1
    return doc_to_class

def evaluate_predictions(true_class, predicted_class, output_to_console=True, return_tuple=False, return_confusion=False):
    confusion = confusion_matrix(true_class, predicted_class)
    if return_confusion and output_to_console:
        print("-" * 80 + "Evaluating" + "-" * 80)
        print(confusion)
    
    f1_micro = f1_score(true_class, predicted_class, average='micro')
    f1_macro = f1_score(true_class, predicted_class, average='macro')
    if output_to_console:
        print("F1 micro: " + str(f1_micro))
        print("F1 macro: " + str(f1_macro))
    if return_tuple:
        return confusion, f1_macro, f1_micro
    else:
        return {
            "confusion": confusion.tolist(),
            "f1_micro": f1_micro,
            "f1_macro": f1_macro
        }
    
def getSentClassRepr(args):
    with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, f"document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
        dictionary = pk.load(f)
        class_repr = dictionary["class_representations"]
        sent_repr = dictionary["sent_representations"]
    return sent_repr, class_repr


def getDSMapAndGold(args, sent_dict):
    # get the ground truth labels for all documents and assign a "ground truth" label to each sentence based on its parent document
    gold_labels = list(map(int, open(os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name, "labels.txt"), "r").read().splitlines()))
    gold_sent_labels = []
    # get all sent ids for each doc
    doc_to_sent = []
    sent_id = 0
    for doc_id, doc in enumerate(sent_dict.values()):
        sent_ids = []
        for sent in doc:
            sent_ids.append(sent_id)
            gold_sent_labels.append(gold_labels[doc_id])
            sent_id += 1
        doc_to_sent.append(sent_ids)
            
    return gold_labels, gold_sent_labels, doc_to_sent

## BERT-Based Sentence Embeddings, Initial Doc Embeddings, & Class Representations ##

In [21]:
def sentenceEmb(args, sent_dict, doc_to_sent, class_words, device, classonly=False):
    # Load model from HuggingFace Hub
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    model = model.to(device)
    
    num_docs = len(doc_to_sent)
    padded_sent_repr = np.zeros((num_docs, args.max_sent, args.emb_dim))
    sentence_mask = np.ones((num_docs, args.max_sent))
    doc_lengths = np.zeros(num_docs, dtype=int)
    trimmed = 0

    if not classonly:
        for doc_id in tqdm(np.arange(num_docs)):
            sents = sent_dict[str(doc_id)]
            num_sent = len(sents)
            if num_sent > args.max_sent:
                trimmed += 1
                sents = sents[:args.max_sent]
            encoded_input = tokenizer(sents, padding=True, truncation=True, return_tensors='pt')
            encoded_input = encoded_input.to(device)

            # Compute token embeddings
            with torch.no_grad():
                model_output = model(**encoded_input)

            # Perform pooling
            embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

            # Normalize embeddings
            embeddings = tensor_to_numpy(F.normalize(embeddings, p=2, dim=1))

            # save the number of sentences in each document
            doc_lengths[doc_id] = int(num_sent)

            padded_sent_repr[doc_id, :embeddings.shape[0], :] = embeddings
            # Update mask so that padded sentences are not included in attention computation
            sentence_mask[doc_id, :num_sent] = 0
        print(f"Trimmed Documents: {trimmed}")

    # construct class representations
    class_repr = np.zeros((len(class_words), args.emb_dim))
    intro_map = {"20News":"This article is about ", "agnews": "This article is about ",
                "yelp": "This is ", "nyt-coarse":"This article is about ", 
                "nyt-fine": "This article is about "}

    print("Constructing Class Representations...")
    for class_id in tqdm(np.arange(len(class_words))):
        # class_sent = intro_map[args.dataset_name] + ", ".join(class_words[class_id])
        class_sent = intro_map[args.dataset_name] + class_words[class_id][0]
        print(class_sent)
        encoded_input = tokenizer(class_sent, truncation=True, return_tensors='pt')
        encoded_input = encoded_input.to(device)
         # Compute token embeddings
        with torch.no_grad():
            model_output = model(**encoded_input)
        # Perform pooling
        embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        # Normalize embeddings
        embeddings = tensor_to_numpy(F.normalize(embeddings, p=2, dim=1))
        class_repr[class_id, :] = embeddings
    
    if classonly:
        return class_repr
    else:
        return padded_sent_repr, class_repr, doc_lengths, sentence_mask

In [6]:
def bertSentenceEmb(args, doc_to_sent, sent_repr):
    num_docs = len(doc_to_sent)
    doc_lengths = np.zeros(num_docs, dtype=int)
    # init_doc_repr = np.zeros((num_docs, args.emb_dim))
    padded_sent_repr = np.zeros((num_docs, args.max_sent, args.emb_dim))
    sentence_mask = np.ones((num_docs, args.max_sent))
    trimmed = 0


    for doc_id in tqdm(np.arange(num_docs)):
        start_sent = doc_to_sent[doc_id][0]
        end_sent = doc_to_sent[doc_id][-1]
        num_sent = end_sent - start_sent + 1
        if num_sent > args.max_sent:
            end_sent = start_sent + args.max_sent - 1
            num_sent = args.max_sent
            trimmed += 1
        embeddings = sent_repr[start_sent:end_sent+1]

        # save the number of sentences in each document
        doc_lengths[doc_id] = int(num_sent)

        # Add initial doc representation
        # init_doc_repr[doc_id, :] = np.mean(embeddings, axis=0)
        # Add padded sentences
        padded_sent_repr[doc_id, :embeddings.shape[0], :] = embeddings
        # Update mask so that padded sentences are not included in attention computation
        sentence_mask[doc_id, :num_sent] = 0

    
    print(f"Trimmed Documents: {trimmed}")

    return padded_sent_repr, doc_lengths, sentence_mask

### Get Class Weights ###

In [7]:
def getTargetClasses(padded_sent_repr, doc_lengths, class_repr, weights=None):
    # weights: N x 150
    class_weights = np.zeros((padded_sent_repr.shape[0], class_repr.shape[0])) # N x C
    sent_weights = np.zeros(padded_sent_repr.shape[:2])

    for doc_id in tqdm(np.arange(padded_sent_repr.shape[0])):
        l = doc_lengths[doc_id]
        sent_emb = padded_sent_repr[doc_id, :l, :] # S x E
        sentcos = cosine_similarity(sent_emb, class_repr) # S x C
        sent_to_class = np.argmax(sentcos, axis=1) # S
        
        # default: equal vote weight between all sentences
        if weights is None:
            # w = np.ones(doc_lengths[doc_id])/doc_lengths[doc_id]
            # top cos-sim - second cos-sim
            toptwo = np.partition(sentcos, -2)[:, -2:] # S x 2
            toptwo = toptwo[:, 1] - toptwo[:, 0] # S
            w = toptwo / np.sum(toptwo)
            sent_weights[doc_id, :l] = w
        else:
            w = weights[doc_id, :l]
        
        class_weights[doc_id, :] = np.bincount(sent_to_class, weights=w, minlength=class_repr.shape[0])

    if weights is None:
        return class_weights, sent_weights
    else:
        return class_weights

### Get initial embeddings and class representations + gold labels ###

In [8]:
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "dataset.pk"), "rb") as f:
    dataset = pk.load(f)
    sent_dict = dataset["sent_data"]
    cleaned_text = dataset["cleaned_text"]
    class_names = np.array(dataset["class_names"])
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
    reprpickle = pk.load(f)
    class_words = reprpickle["class_words"]
gold_labels, gold_sent_labels, doc_to_sent = getDSMapAndGold(args, sent_dict)
num_classes = len(class_words)

In [22]:
sent_repr, class_repr = getSentClassRepr(args)
padded_sent_repr, doc_lengths, sentence_mask = bertSentenceEmb(args, doc_to_sent, sent_repr)
# plm_padded_sent_repr, plm_class_repr, doc_lengths, plm_sentence_mask = sentenceEmb(args, sent_dict, doc_to_sent, class_words, device)
plm_class_repr = sentenceEmb(args, sent_dict, doc_to_sent, class_words, device, True)
init_class_weights, init_sent_weights = getTargetClasses(padded_sent_repr, doc_lengths, class_repr, None)
init_plm_class_weights, init_plm_sent_weights = getTargetClasses(plm_padded_sent_repr, doc_lengths, plm_class_repr, None)

padded_sent_repr.shape, doc_lengths.shape, sentence_mask.shape, init_class_weights.shape, init_sent_weights.shape

100%|██████████| 17871/17871 [00:00<00:00, 27721.91it/s]


Trimmed Documents: 127
Constructing Class Representations...


100%|██████████| 5/5 [00:00<00:00, 132.66it/s]


This article is about computer
This article is about sports
This article is about science
This article is about politics
This article is about religion


100%|██████████| 17871/17871 [00:11<00:00, 1547.60it/s]
100%|██████████| 17871/17871 [00:07<00:00, 2485.23it/s]


((17871, 150, 768), (17871,), (17871, 150), (17871, 5), (17871, 150))

In [25]:
cleaned_text[1]

"Re Lyme vaccine Nntp-Posting-Host ucrengr X-Newsreader TIN version 1.1 PL8 Jeff, If you have time to type it in I'd love to have the reference for that paper! thanks! -- kathleen richards email Sometimes you're the windshield, sometimes you're the bug! -dire straits"

In [26]:
init_plm_class_weights[1], gold_labels[1]

(array([0.        , 0.22120812, 0.77205911, 0.00673277, 0.        ]), 2)

### Analyze and Evaluate Class Weights ###

In [27]:
evaluate_predictions(gold_labels, np.argmax(init_plm_class_weights, axis=1))

F1 micro: 0.6967153488892619
F1 macro: 0.6734028008214983


{'confusion': [[4646, 71, 86, 40, 48],
  [540, 2558, 243, 460, 178],
  [1550, 183, 1179, 784, 256],
  [51, 55, 44, 1886, 589],
  [22, 16, 48, 156, 2182]],
 'f1_micro': 0.6967153488892619,
 'f1_macro': 0.6734028008214983}

## Model ##

In [28]:
class MEGClassModel(nn.Module):
    def __init__(self, D_in, D_hidden, head, dropout=0.0):
        super(MEGClassModel, self).__init__()
        self.mha = nn.MultiheadAttention(embed_dim=D_in, num_heads=head, dropout=dropout, batch_first=True)
        self.layernorm = nn.LayerNorm(D_in)
        # self.embd = nn.Sequential(
        #     nn.Linear(D_in, 2*D_in),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(2*D_in, D_in))
        self.embd = nn.Linear(D_in,D_hidden)
        self.attention = nn.Linear(D_hidden,1)
        
    def forward(self, x_org, mask=None):
        x, mha_w = self.mha(x_org,x_org,x_org,key_padding_mask=mask)
        x = self.layernorm(x_org+x)
        
        x = self.embd(x)
        x = torch.tanh(x) # contextualized sentences
        a = self.attention(x)
        if mask is not None:
            a = a.masked_fill_((mask == 1).unsqueeze(-1), float('-inf'))
        w = torch.softmax(a, dim=1) # alpha_k
        o = torch.matmul(w.permute(0,2,1), x) #doc 
        return o, mha_w, w, x # contextualized doc, multi-head attention weights, alpha_k, contextualized sent


## Contrastive Loss ##

In [29]:
def contrastive_loss(sample_outputs, class_indices, class_embds, temp=0.2):
    k = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/temp)
    loss = -1*(torch.log(k[np.arange(len(class_indices)),class_indices]/k.sum(1))).sum()
    return loss/len(sample_outputs)

def weighted_contrastive_loss(args, sample_outputs, class_weights, class_embds):
    # k: B x C, class_weights: B x C
    numerator = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/args.temp)
    denom = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/args.temp)
    weighted_loss = -1 * (torch.log(numerator/(denom.sum(dim=1).unsqueeze(-1))) * class_weights).sum() # B x C -> B
    # print(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)[:5])
    # print(numerator[:5])
    # print(denom[:5])
    # print((torch.log(numerator/(denom.sum(dim=1).unsqueeze(-1))) * class_weights)[:5])
    # 1/0
    return weighted_loss/len(sample_outputs)

def weighted_class_contrastive_loss(args, sample_outputs, class_weights, class_embds):
    # numerator: B x C, class_weights: B x C
    numerator = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/args.temp)
    denom = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/args.temp) # * (1 - class_weights)
    weighted_loss = -1 * (torch.log(numerator/(denom.sum(dim=1).unsqueeze(-1))) * class_weights).sum() # B x C -> B
    # print(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)[:5])
    # print(numerator[:5])
    # print(denom.sum(dim=1)[:5])
    # print((-1 * (torch.log(numerator/(denom.sum(dim=1).unsqueeze(-1))) * class_weights))[:5])
    # 1/0
    return weighted_loss/len(sample_outputs)


## Train ##

In [30]:
def contextEmb(args, sent_representations, mask, class_repr, class_weights, 
                doc_lengths, new_data_path, device):
    sent_representations = torch.from_numpy(sent_representations)
    mask = torch.from_numpy(mask).to(torch.bool)
    class_weights = torch.from_numpy(class_weights)
    dataset = TensorDataset(sent_representations, mask, class_weights)
    sampler = SequentialSampler(dataset)
    dataset_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, shuffle=False)
    # sent_representations: N docs x L sentences x 768 emb (L with padding is always max_sents=50)
    model = MEGClassModel(args.emb_dim, args.emb_dim, args.num_heads).to(device)

    total_steps = len(dataset_loader) * args.epochs / args.accum_steps
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)

    print("Starting to train!")

    for i in tqdm(range(args.epochs)):
        total_train_loss = 0

        for batch in tqdm(dataset_loader):
            model.train()
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)
            input_weights = batch[2].to(device).float()
            
            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)

            loss = weighted_class_contrastive_loss(args, c_doc, input_weights, torch.from_numpy(class_repr).float().to(device)) / args.accum_steps

            total_train_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        avg_train_loss = total_train_loss / len(dataset_loader)
        print(f"Average training loss: {avg_train_loss}")

    model.eval()

    torch.save(model.state_dict(), os.path.join(new_data_path, f"{args.dataset_name}_model_e{args.epochs}.pth"))

    print("Starting to evaluate!")

    evalsampler = SequentialSampler(dataset)
    eval_loader = DataLoader(dataset, sampler=evalsampler, batch_size=args.batch_size, shuffle=False)

    doc_predictions = None
    attention_weights = np.zeros_like(mask, dtype=float)
    updated_sent_repr = np.zeros_like(sent_representations)
    final_doc_emb = np.zeros((len(class_weights), args.emb_dim))
    idx = 0

    with torch.no_grad():
        for batch in tqdm(eval_loader):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)

            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)
            c_sent, c_doc, alpha = tensor_to_numpy(c_sent), tensor_to_numpy(c_doc), tensor_to_numpy(torch.squeeze(alpha, dim=2))

            final_doc_emb[idx:idx+c_doc.shape[0], :] = c_doc
            attention_weights[idx:idx+c_doc.shape[0], :] = alpha
            updated_sent_repr[idx:idx+c_doc.shape[0], :, :] = c_sent

            idx += c_doc.shape[0]

            doc_class = docToClass(c_doc, class_repr)
            if doc_predictions is None:
                doc_predictions = doc_class
            else:
                doc_predictions = np.append(doc_predictions, doc_class)
    
    updated_class_weights = getTargetClasses(updated_sent_repr, doc_lengths, class_repr, attention_weights)

    return doc_predictions, final_doc_emb, updated_sent_repr, attention_weights, updated_class_weights

## Evaluate ##

In [39]:
# Sentence Transformer Embeddings + Sentence-Transformer Initial Class Distributions
args.epochs = 3
args.lr = 1e-3
args.temp = 0.1
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, plm_padded_sent_repr, plm_sentence_mask, plm_class_repr, init_plm_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 280/280 [00:09<00:00, 28.12it/s]
 33%|███▎      | 1/3 [00:09<00:19,  9.96s/it]

Average training loss: 1.5781413316726685


100%|██████████| 280/280 [00:08<00:00, 34.05it/s]
 67%|██████▋   | 2/3 [00:18<00:08,  8.95s/it]

Average training loss: 1.086656928062439


100%|██████████| 280/280 [00:09<00:00, 29.36it/s]
100%|██████████| 3/3 [00:27<00:00,  9.25s/it]


Average training loss: 0.7707670331001282
Starting to evaluate!


100%|██████████| 280/280 [00:16<00:00, 16.50it/s]
100%|██████████| 17871/17871 [00:06<00:00, 2644.86it/s]


In [40]:
# epochs: 3
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.718146718146718
F1 macro: 0.6975556787502244


{'confusion': [[4735, 58, 38, 31, 29],
  [542, 2697, 157, 474, 109],
  [1551, 140, 1285, 816, 160],
  [42, 27, 29, 2002, 525],
  [33, 16, 29, 231, 2115]],
 'f1_micro': 0.718146718146718,
 'f1_macro': 0.6975556787502244}

In [49]:
# Sentence Transformer Embeddings + Class-Oriented Initial Class Distributions
args.epochs = 3
args.lr = 1e-3
args.temp = 0.1
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, plm_padded_sent_repr, plm_sentence_mask, plm_class_repr, init_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 280/280 [00:06<00:00, 43.69it/s]
 33%|███▎      | 1/3 [00:06<00:12,  6.41s/it]

Average training loss: 1.6032929420471191


100%|██████████| 280/280 [00:06<00:00, 42.62it/s]
 67%|██████▋   | 2/3 [00:12<00:06,  6.51s/it]

Average training loss: 1.1564592123031616


100%|██████████| 280/280 [00:06<00:00, 43.52it/s]
100%|██████████| 3/3 [00:19<00:00,  6.48s/it]


Average training loss: 0.8360235095024109
Starting to evaluate!


100%|██████████| 280/280 [00:10<00:00, 26.27it/s]
100%|██████████| 17871/17871 [00:06<00:00, 2863.74it/s]


In [50]:
# epochs: 3
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7877007442224834
F1 macro: 0.7758593030310905


{'confusion': [[4753, 24, 85, 25, 4],
  [228, 3499, 85, 162, 5],
  [1450, 116, 1846, 527, 13],
  [50, 84, 154, 2046, 291],
  [58, 24, 160, 249, 1933]],
 'f1_micro': 0.7877007442224834,
 'f1_macro': 0.7758593030310905}

In [238]:
doc_id = 7
print("Initial Class Weights: ", init_class_weights[doc_id])
print("Updated Class Weights: ", updated_class_weights[doc_id])
print("True Class: ", gold_labels[doc_id])
tok_sents = sent_dict[str(doc_id)]
chosen_class = np.argmax(cosine_similarity(updated_sent_repr[doc_id, :len(tok_sents)], class_repr), axis=1)
sent_classes = cosine_similarity(updated_sent_repr[doc_id, :len(tok_sents)], class_repr)

for s,c,cw,w in zip(tok_sents, class_names[chosen_class.astype(int)], sent_classes, updated_sent_weights[doc_id, :len(tok_sents)]):
    print(f'{c} ({cw}) {w}: {s}')

Initial Class Weights:  [0.35916079 0.64083921]
Updated Class Weights:  [0.09076648 0.90923354]
True Class:  0
good ([-0.067625    0.19165149]) 0.029811829328536987: ok!
bad ([-0.0163843  -0.03797249]) 0.09076648205518723: let me tell you about my bad experience first.
good ([-0.17282255  0.13541248]) 0.021152541041374207: i went to d b last night for a post wedding party - which, side note, is a great idea!
good ([-0.00954034  0.01940415]) 0.07051163166761398: it was around midnight and the bar wasn't really populated.
good ([-0.01856739  0.03315905]) 0.04901992529630661: there were three bartenders and only one was actually making rounds to see if anyone needed anything.
good ([-0.11018234  0.01576056]) 0.03178786486387253: the two other bartenders were chatting on the far side of the bar that no one was sitting at.
good ([0.00839888 0.05479983]) 0.07281240820884705: kind of counter productive if you ask me.
good ([-0.02061332  0.01664088]) 0.036122485995292664: i stood there for abo

## Generate Pseudo-Training Dataset ##

In [None]:
_pca = PCA(n_components=args.pca, random_state=args.random_state)
pca_doc_repr = _pca.fit_transform(final_doc_emb)
pca_class_repr = _pca.transform(class_repr)
print(f"Explained document variance: {sum(_pca.explained_variance_ratio_)}")
cosine_similarities = cosine_similarity(pca_doc_repr, pca_class_repr)
doc_class_assignment = np.argmax(cosine_similarities, axis=1)
doc_class_probs = cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

# get cleaned text
cleaned_text = dataset["cleaned_text"]

# generate pseudo training dataset
generateDataset(doc_class_assignment, doc_class_probs, num_classes, cleaned_text, gold_labels, data_path, new_data_path)

In [104]:
topsentids = np.argmax(updated_sent_weights, axis=1)
for idx in np.arange(10):
    print(sent_dict[str(idx)][topsentids[idx]])

secretary_general_kofi_annan said today that iraq has cooperated up to now with united_nations weapons inspectors , so it would be ''premature'' to take military action before inspectors report back on their investigations to the security_council on jan . 27 .
of all those expelled , mr . hogan is putting up the biggest fight against the restrictions , asking a federal_judge to overrule the board and mounting a public_relations battle .
the school plans to offer scholarships for families who cannot afford tuition , ms . friedman said .
''we 've had to deal with very serious issues , '' mr . mccartney said .
it also came as china has worked hard to strengthen ties with the bush_administration .
''the american administration is trying to create some pretexts to attack iraq , to exercise their aggression against iraq , '' he said .
an adviser to mr . edwards said that campaign had deliberately chosen the day after new year 's to go on ''today , '' because the senator 's aides figured , co

In [121]:
# Find really confident sentences
def getConfidentSent(weights, sent_emb, class_repr):
    topsentids = np.argmax(weights, axis=1)
    topembs = np.zeros((sent_emb.shape[0], sent_emb.shape[2]))
    for doc_id in np.arange(sent_emb.shape[0]):
        topembs[doc_id, :] = sent_emb[doc_id, topsentids[doc_id], :]

    topcossim = cosine_similarity(topembs, class_repr) # N x C
    toppreds = np.argmax(topcossim, axis=1)
    class_sents = [[] for c in range(class_repr.shape[0])]

    for id, p in enumerate(toppreds):
        # doc id, sent id, sent weight, sent emb
        class_sents[p].append((id, topsentids[id], weights[id, topsentids[id]], topembs[id]))
    for c in np.arange(len(class_sents)):
        total = int(len(class_sents[c])*0.05)
        class_sents[c] = sorted(class_sents[c], key=lambda x: x[2], reverse=True)[:total]

    return class_sents

In [133]:
class_sents = getConfidentSent(init_class_weights, padded_sent_repr, class_repr)

In [134]:
for i in range(len(class_sents)):
    print(len(class_sents[i]))
    top = class_sents[i][1]
    print(class_sents[i][0][2], class_sents[i][1][2])
    print(sent_dict[str(top[0])][top[1]])

518
1.0000000000000002 1.0000000000000002
the wpp group 's acquisition of the struggling british advertising company cordiant communications group may not be a done deal after all .
625
1.0000000000000007 1.0000000000000004
he told them , aides said , that he wanted to hold a full scale news conference a few hours later .
111
1.0000000000000002 1.0000000000000002
kiplagat , born in kenya , finished third in new york last year .
69
1.0000000000000002 1.0000000000000002
if a scheduled appointment was missed , the health visitor would make a home visit to make sure all was well and to ensure that appointments for checkups and immunizations were kept .
75
1.0000000000000002 1.0000000000000002
more than half of private_school students attend catholic schools , while about a third are enrolled in other religious schools .
85
1.0000000000000002 1.0000000000000002
ft . co op in a loft building elevator , dining_room , den , office , high ceilings , renovated_kitchen , 3 exposures maintenance 1

In [62]:
# epochs: 4, class-oriented + new loss
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7704210526315789
F1 macro: 0.7610170669333791


{'confusion': [[10869, 8131], [593, 18407]],
 'f1_micro': 0.7704210526315789,
 'f1_macro': 0.7610170669333791}

In [None]:
tensor([[-0.7307, -0.7865],
        [-1.4249, -0.0000],
        [-1.0388, -0.4371],
        [-0.7478, -0.8558],
        [-0.3418, -1.1477]]

In [21]:
init_class_weights

array([[0.47846081, 0.52153919],
       [1.        , 0.        ],
       [0.7036531 , 0.2963469 ],
       ...,
       [0.16955682, 0.83044318],
       [1.        , 0.        ],
       [0.71302977, 0.28697023]])

In [68]:
(doc_pred == gold_labels)[:10], (gold_labels == np.argmax(init_class_weights, axis=1))[:10]

(array([False,  True, False, False,  True,  True,  True,  True,  True,
         True]),
 array([ True,  True, False, False,  True,  True,  True,  True,  True,
         True]))

In [75]:
init_sent_weights.shape

(38000, 150)

In [147]:
# Second Iteration:
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, updated_sent_repr, sentence_mask, class_repr, updated_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 500/500 [00:17<00:00, 28.60it/s]
 20%|██        | 1/5 [00:17<01:09, 17.49s/it]

Average training loss: 2.2325527667999268


100%|██████████| 500/500 [00:17<00:00, 27.78it/s]
 40%|████      | 2/5 [00:35<00:53, 17.80s/it]

Average training loss: 0.7424776554107666


100%|██████████| 500/500 [00:18<00:00, 27.63it/s]
 60%|██████    | 3/5 [00:53<00:35, 17.95s/it]

Average training loss: 0.3872785270214081


100%|██████████| 500/500 [00:17<00:00, 28.08it/s]
 80%|████████  | 4/5 [01:11<00:17, 17.90s/it]

Average training loss: 0.3500744700431824


100%|██████████| 500/500 [00:18<00:00, 27.44it/s]
100%|██████████| 5/5 [01:29<00:00, 17.94s/it]


Average training loss: 0.3379160463809967
Starting to evaluate!


100%|██████████| 500/500 [00:19<00:00, 25.61it/s]
100%|██████████| 31997/31997 [00:15<00:00, 2073.08it/s]


Evaluate Predictions (Document-Based): 
F1 micro: 0.8020439416195269
F1 macro: 0.6456836188505892


### Analyze Results ###

In [10]:
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
    reprpickle = pk.load(f)
    class_words = reprpickle["class_words"]

In [12]:
class_words[1]

['good',
 'great',
 'excellent',
 'fantastic',
 'terrific',
 'wonderful',
 'superb',
 'amazing',
 'incredible',
 'fabulous',
 'awesome',
 'outstanding',
 'stellar',
 'phenomenal',
 'marvelous',
 'exceptional',
 'excellently',
 'splendid',
 'tremendous',
 'delicious',
 'superbly',
 'wonderfully',
 'perfect',
 'fantastically',
 'delightful',
 'lovely',
 'fabulously',
 'awesomeness',
 'greatness',
 'gorgeous',
 'magnificent',
 'spectacular',
 'extraordinary',
 'beautifully',
 'brilliant',
 'beautiful',
 'sublime',
 'stunning',
 'amazingly',
 'sensational',
 'masterful',
 'perfectly',
 'exquisite',
 'divine',
 'best',
 'deliciously',
 'fine',
 'decent',
 'superlative',
 'glorious',
 'amazingness',
 'awesomely',
 'impeccable',
 'heavenly',
 'adequate',
 'tasty',
 'deliciousness',
 'yummy',
 'pleasing',
 'delectable',
 'nice',
 'passable',
 'unique',
 'tasteful',
 'kickass',
 'delightfully',
 'nicely',
 'perfection',
 'wholesome',
 'interesting',
 'impressive',
 'abundant',
 'decently',
 'se

In [38]:
correct_class = np.argmax(updated_class_weights, axis=1) == gold_labels
correct_class[20:40]

array([ True, False,  True,  True,  True,  True, False, False,  True,
        True, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True])

In [150]:
init_class_weights[26], gold_labels[26], updated_class_weights[26]

(array([0.10631772, 0.52813451, 0.        , 0.29934825, 0.        ,
        0.06619952, 0.        , 0.        , 0.        ]),
 3,
 array([0.04851887, 0.66976335, 0.        , 0.28171776, 0.        ,
        0.        , 0.        , 0.        , 0.        ]))

## Run PCA on Document Embeddings & Fit GMM ##

In [51]:
_pca = PCA(n_components=args.pca, random_state=args.random_state)
pca_doc_repr = _pca.fit_transform(final_doc_emb)
# pca_doc_repr = _pca.fit_transform(np.sum(padded_sent_repr, axis=1)/doc_lengths.reshape((-1, 1)))
pca_class_repr = _pca.transform(plm_class_repr)
print(f"Explained document variance: {sum(_pca.explained_variance_ratio_)}")

Explained document variance: 0.8698887026923976


In [52]:
# class-oriented final doc
cosine_similarities = cosine_similarity(pca_doc_repr, pca_class_repr)
doc_class_assignment = np.argmax(cosine_similarities, axis=1)
doc_class_probs = cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.8009624531363662
F1 macro: 0.7875096605942187


{'confusion': [[4586, 61, 162, 60, 22],
  [78, 3645, 66, 174, 16],
  [1063, 194, 1959, 701, 35],
  [28, 73, 100, 2087, 337],
  [23, 19, 118, 227, 2037]],
 'f1_micro': 0.8009624531363662,
 'f1_macro': 0.7875096605942187}

In [146]:
# class-oriented average
cosine_similarities = cosine_similarity(pca_doc_repr, pca_class_repr)
doc_class_assignment = np.argmax(cosine_similarities, axis=1)
doc_class_probs = cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.8893947368421051
F1 macro: 0.8893947349271942


{'confusion': [[16901, 2099], [2104, 16896]],
 'f1_micro': 0.8893947368421051,
 'f1_macro': 0.8893947349271942}

## Save Pseudo Training Dataset ##

In [53]:
def write_to_dir(text, labels, prob, data_path, new_data_path):
    assert len(text) == len(labels)
    print("Saving files in:", new_data_path)
    
    with open(os.path.join(new_data_path, "dataset.txt"), "w") as f:
        for i, line in enumerate(text):
            f.write(line)
            f.write("\n")

    with open(os.path.join(new_data_path, "labels.txt"), "w") as f:
        for i, line in enumerate(labels):
            f.write(str(line))
            f.write("\n")

    with open(os.path.join(new_data_path, "probs.txt"), "w") as f:
        for i, line in enumerate(prob):
            f.write(str(line))
            #f.write(",".join(map(str, line)))
            f.write("\n")

    copyfile(os.path.join(data_path, "classes.txt"),
             os.path.join(new_data_path, "classes.txt"))

In [54]:
def generateDataset(documents_to_class, prob, num_classes, cleaned_text, gold_labels, data_path, new_data_path):
    pseudo_document_class_with_confidence = [[] for _ in range(num_classes)]
    for i in range(documents_to_class.shape[0]):
        pseudo_document_class_with_confidence[documents_to_class[i]].append((prob[i], i))

    selected = []
    confidence_threshold = 0.5
    confident_documents = [[] for _ in range(num_classes)]

    for i in range(num_classes):
        pseudo_document_class_with_confidence[i] = sorted(pseudo_document_class_with_confidence[i], key=lambda x: x[0], reverse=True)
        num_docs_to_take = int(len(pseudo_document_class_with_confidence[i]) * confidence_threshold)
        confident_documents[i] = pseudo_document_class_with_confidence[i][:num_docs_to_take]
        selected.extend([x[1] for x in confident_documents[i]])
    
    selected = sorted(selected)
    text = [cleaned_text[i] for i in selected]
    classes = [documents_to_class[i] for i in selected]
    probs = [prob[i] for i in selected]
    ###
    gold_classes = [gold_labels[i] for i in selected]
    evaluate_predictions(gold_classes, classes)
    ###
    write_to_dir(text, classes, probs, data_path, new_data_path)
    return confident_documents

In [55]:
finalconf = generateDataset(doc_class_assignment, doc_class_probs, num_classes, cleaned_text, gold_labels, data_path, new_data_path)

F1 micro: 0.8843743004253414
F1 macro: 0.8703795616370998
Saving files in: /home/pk36/MEGClass/intermediate_data/20News


In [216]:
cleaned_text[23973], doc_class_assignment[23973], gold_labels[23973]

("I've been waiting for an order for 6 weeks. That would be fine if customer service let me know there would be a delay, but I've contacted them 5 times for updates and so far nothing, other than passing the buck and telling me that the product was to be drop shipped by another company. Since I'm in the business of customer service, I found their non-response unacceptable. I'll update the review if there are new developments. I do not like to trash any business so am hoping they redeem themselves. Update from 10 17. Ally from Muscle Driver just called. The product finally shipped from Fisher. I have a tracking number. Still no excuse for the long wait, but good damage control. Update on 10 22. The lack of communication between Fisher and Muscle Driver is astounding. Fisher sent our package through some private delivery service (not UPS, Fedex, or USPS) and they only do curbside delivery and won't give us a window so now I need to take a day off to make sure I can get the package that I

In [196]:
finalprobs[0]

0.6838244497723575

In [23]:
def doc_gmm(args, document_representations, doc_class_representations, gold_labels):
    num_classes = len(doc_class_representations)
    cosine_similarities = cosine_similarity(document_representations, doc_class_representations)
    doc_class_assignment = np.argmax(cosine_similarities, axis=1)

    print("Evaluate Document Cosine Similarity Predictions: ")
    evaluate_predictions(gold_labels, doc_class_assignment)

    # initialize gmm based on these selected documents
    document_class_assignment_matrix = np.zeros((document_representations.shape[0], num_classes))
    for i in np.arange(len(document_representations)): # iterate through docs and sents
        document_class_assignment_matrix[i][doc_class_assignment[i]] = 1.0

    gmm = GaussianMixture(n_components=num_classes, covariance_type='tied',
                          random_state=args.random_state,
                          n_init=999, warm_start=True)

    gmm._initialize(document_representations, document_class_assignment_matrix)
    gmm.lower_bound_ = -np.infty

    gmm.converged_ = True #HACK FOR NOT RANDOMLY INITIALIZING PARAMS DURING FIT
    gmm.fit(document_representations)

    documents_to_class = gmm.predict(document_representations)
    confidence = gmm.predict_proba(document_representations)

    print("Evaluate Document GMM Predictions: ")
    evaluate_predictions(gold_labels, documents_to_class)

    return documents_to_class, confidence

In [33]:
# 128
doc_preds, doc_prob = doc_gmm(args, pca_doc_repr, pca_class_repr, gold_labels)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.782384869341391
F1 macro: 0.7704668150661399
Evaluate Document GMM Predictions: 
F1 micro: 0.7800346930781712
F1 macro: 0.7686244714799803
