## Setup ##

In [1]:
import numpy as np
import os
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import torch
from torch import nn
import sys
import torch.nn.functional as F
import argparse
import re
import pickle as pk
from tqdm import tqdm
from scipy.special import softmax
from sklearn.decomposition import PCA
from shutil import copyfile
from sklearn.mixture import GaussianMixture
from sklearn.mixture._gaussian_mixture import _estimate_gaussian_parameters
from nltk.tokenize import sent_tokenize, word_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import confusion_matrix, f1_score
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args(object):
    pass

args = Args()
args.dataset_name = "books"
args.gpu = 7
args.pca = 128
args.random_state = 42
args.emb_dim = 768
args.num_heads = 2
args.batch_size = 64
args.temp = 0.2
args.lr = 1e-3
args.epochs = 5
args.accum_steps = 1
args.max_sent = 150

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

In [3]:
data_path = os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name)
new_data_path = os.path.join("/home/pk36/MEGClass/intermediate_data", args.dataset_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# this is where we save all representations
if not os.path.exists(new_data_path):
    os.makedirs(new_data_path)

## Helper Functions ##

In [12]:
DATA_FOLDER_PATH = os.path.join("/shared/data2/pk36/multidim/multigran")
INTERMEDIATE_DATA_FOLDER_PATH = os.path.join("/home/pk36/MEGClass/intermediate_data")

def tensor_to_numpy(tensor):
    return tensor.clone().detach().cpu().numpy()

# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def sentenceToClass(sent_repr, class_repr, weights):
    # sent_repr: N x S x E
    # class_repr: C x E
    # weights: N x S # equals 0 for masked sentences

    #cos-sim between (N x S) x E and (C x E) = N x S x C
    m, n = sent_repr.shape[:2]
    sentcos = cosine_similarity(sent_repr.reshape(m*n,-1), class_repr).reshape(m,n,-1)
    sent_to_class = np.argmax(sentcos, axis=2) # N x S
    sent_to_doc_class = np.sum(np.multiply(sent_to_class, weights), axis=1) # N x 1
    return sent_to_doc_class

def docToClass(doc_repr, class_repr):
    # doc_repr: N x E
    # class_repr: C x E

    #cos-sim between N x E and C x E = N x C
    doccos = cosine_similarity(doc_repr, class_repr)
    doc_to_class = np.argmax(doccos, axis=1) # N x 1
    return doc_to_class

def evaluate_predictions(true_class, predicted_class, output_to_console=True, return_tuple=False, return_confusion=False):
    confusion = confusion_matrix(true_class, predicted_class)
    if return_confusion and output_to_console:
        print("-" * 80 + "Evaluating" + "-" * 80)
        print(confusion)
    
    f1_micro = f1_score(true_class, predicted_class, average='micro')
    f1_macro = f1_score(true_class, predicted_class, average='macro')
    if output_to_console:
        print("F1 micro: " + str(f1_micro))
        print("F1 macro: " + str(f1_macro))
    if return_tuple:
        return f1_macro, f1_micro
    else:
        return {
            "f1_micro": f1_micro,
            "f1_macro": f1_macro
        }
    
def getSentClassRepr(args):
    with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, f"document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
        dictionary = pk.load(f)
        class_repr = dictionary["class_representations"]
        sent_repr = dictionary["sent_representations"]
        if (args.dataset_name in ["nyt-location", "books"]):
         return sent_repr, class_repr
        else:
            doc_repr = dictionary["document_representations"]
    return doc_repr, sent_repr, class_repr


def getDSMapAndGold(args, sent_dict):
    # get the ground truth labels for all documents and assign a "ground truth" label to each sentence based on its parent document
    gold_labels = list(map(int, open(os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name, "labels.txt"), "r").read().splitlines()))
    gold_sent_labels = []
    # get all sent ids for each doc
    doc_to_sent = []
    sent_id = 0
    for doc_id, doc in enumerate(sent_dict.values()):
        sent_ids = []
        for sent in doc:
            sent_ids.append(sent_id)
            gold_sent_labels.append(gold_labels[doc_id])
            sent_id += 1
        doc_to_sent.append(sent_ids)
            
    return gold_labels, gold_sent_labels, doc_to_sent

## BERT-Based Sentence Embeddings, Initial Doc Embeddings, & Class Representations ##

In [5]:
def sentenceEmb(args, sent_dict, doc_to_sent, class_words, device, classonly=False):
    # Load model from HuggingFace Hub
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    model = model.to(device)
    
    num_docs = len(doc_to_sent)
    padded_sent_repr = np.zeros((num_docs, args.max_sent, args.emb_dim))
    sentence_mask = np.ones((num_docs, args.max_sent))
    doc_lengths = np.zeros(num_docs, dtype=int)
    trimmed = 0

    if not classonly:
        for doc_id in tqdm(np.arange(num_docs)):
            sents = sent_dict[str(doc_id)]
            num_sent = len(sents)
            if num_sent > args.max_sent:
                trimmed += 1
                sents = sents[:args.max_sent]
                num_sent = args.max_sent
            encoded_input = tokenizer(sents, padding=True, truncation=True, return_tensors='pt')
            encoded_input = encoded_input.to(device)

            # Compute token embeddings
            with torch.no_grad():
                model_output = model(**encoded_input)

            # Perform pooling
            embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

            # Normalize embeddings
            embeddings = tensor_to_numpy(F.normalize(embeddings, p=2, dim=1))

            # save the number of sentences in each document
            doc_lengths[doc_id] = int(num_sent)

            padded_sent_repr[doc_id, :embeddings.shape[0], :] = embeddings
            # Update mask so that padded sentences are not included in attention computation
            sentence_mask[doc_id, :num_sent] = 0
        print(f"Trimmed Documents: {trimmed}")

    # construct class representations
    class_repr = np.zeros((len(class_words), args.emb_dim))
    intro_map = {"20News":"This article is about ", "agnews": "This article is about ",
                "yelp": "This is ", "nyt-coarse":"This article is about ", 
                "nyt-fine": "This article is about "}

    print("Constructing Class Representations...")
    for class_id in tqdm(np.arange(len(class_words))):
        # class_sent = intro_map[args.dataset_name] + ", ".join(class_words[class_id])
        class_sent = intro_map[args.dataset_name] + class_words[class_id][0]
        
        print(class_sent)
        encoded_input = tokenizer(class_sent, truncation=True, return_tensors='pt')
        encoded_input = encoded_input.to(device)
         # Compute token embeddings
        with torch.no_grad():
            model_output = model(**encoded_input)
        # Perform pooling
        embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        # Normalize embeddings
        embeddings = tensor_to_numpy(F.normalize(embeddings, p=2, dim=1))
        class_repr[class_id, :] = embeddings
    
    if classonly:
        return class_repr
    else:
        return padded_sent_repr, class_repr, doc_lengths, sentence_mask

In [6]:
def bertSentenceEmb(args, doc_to_sent, sent_repr):
    num_docs = len(doc_to_sent)
    doc_lengths = np.zeros(num_docs, dtype=int)
    # init_doc_repr = np.zeros((num_docs, args.emb_dim))
    padded_sent_repr = np.zeros((num_docs, args.max_sent, args.emb_dim))
    sentence_mask = np.ones((num_docs, args.max_sent))
    trimmed = 0


    for doc_id in tqdm(np.arange(num_docs)):
        start_sent = doc_to_sent[doc_id][0]
        end_sent = doc_to_sent[doc_id][-1]
        num_sent = end_sent - start_sent + 1
        if num_sent > args.max_sent:
            end_sent = start_sent + args.max_sent - 1
            num_sent = args.max_sent
            trimmed += 1
        embeddings = sent_repr[start_sent:end_sent+1]

        # save the number of sentences in each document
        doc_lengths[doc_id] = int(num_sent)

        # Add initial doc representation
        # init_doc_repr[doc_id, :] = np.mean(embeddings, axis=0)
        # Add padded sentences
        padded_sent_repr[doc_id, :embeddings.shape[0], :] = embeddings
        # Update mask so that padded sentences are not included in attention computation
        sentence_mask[doc_id, :num_sent] = 0

    
    print(f"Trimmed Documents: {trimmed}")

    return padded_sent_repr, doc_lengths, sentence_mask

### Get Class Weights ###

In [7]:
def getTargetClasses(padded_sent_repr, doc_lengths, class_repr, weights=None):
    # weights: N x 150
    class_weights = np.zeros((padded_sent_repr.shape[0], class_repr.shape[0])) # N x C
    sent_weights = np.zeros(padded_sent_repr.shape[:2])

    for doc_id in tqdm(np.arange(padded_sent_repr.shape[0])):
        l = doc_lengths[doc_id]
        sent_emb = padded_sent_repr[doc_id, :l, :] # S x E
        sentcos = cosine_similarity(sent_emb, class_repr) # S x C
        sent_to_class = np.argmax(sentcos, axis=1) # S
        
        # default: equal vote weight between all sentences
        if weights is None:
            # w = np.ones(doc_lengths[doc_id])/doc_lengths[doc_id]
            # top cos-sim - second cos-sim
            toptwo = np.partition(sentcos, -2)[:, -2:] # S x 2
            toptwo = toptwo[:, 1] - toptwo[:, 0] # S
            w = toptwo / np.sum(toptwo)
            sent_weights[doc_id, :l] = w
        else:
            w = weights[doc_id, :l]
        
        class_weights[doc_id, :] = np.bincount(sent_to_class, weights=w, minlength=class_repr.shape[0])

    if weights is None:
        return class_weights, sent_weights
    else:
        return class_weights

In [8]:
def getTargetClassSet(padded_sent_repr, doc_lengths, class_set, alpha=None, set_weights=None):
    # sent_weights: N x 150 -> weigh each sentence based on its contribution to the document
    # set_weights: C x CD -> how confident each class-indicative document is

    class_weights = np.zeros((padded_sent_repr.shape[0], len(class_set))) # N x C
    if alpha is None:
        sent_weights = np.zeros(padded_sent_repr.shape[:2]) # N x 150

    for doc_id in tqdm(np.arange(padded_sent_repr.shape[0])):
        l = doc_lengths[doc_id]
        sent_emb = padded_sent_repr[doc_id, :l, :] # S x E
        sentclass_dist = np.zeros((l, len(class_set))) # S x C

        for class_id in np.arange(len(class_set)):
            sentcos = cosine_similarity(sent_emb, class_set[class_id]) # S x CD
            if set_weights is None:
                sentclass_sim = np.mean(sentcos, axis=1) # on average, how similar each sentence is to the class set
            else:
                sentclass_sim = np.average(sentcos, axis=1, weights=set_weights[class_id]) # same but weighted average based on class-indicativeness
            
            sentclass_dist[:, class_id] = sentclass_sim


        sent_to_class = np.argmax(sentclass_dist, axis=1) # S
        
        # default: equal vote weight between all sentences
        if alpha is None:
            # top cos-sim - second cos-sim
            toptwo = np.partition(sentclass_dist, -2)[:, -2:] # S x 2
            toptwo = toptwo[:, 1] - toptwo[:, 0] # S
            w = toptwo / np.sum(toptwo)
            sent_weights[doc_id, :l] = w
        else:
            w = alpha[doc_id, :l]
        
        class_weights[doc_id, :] = np.bincount(sent_to_class, weights=w, minlength=len(class_set))

    if alpha is None:
        return class_weights, sent_weights
    else:
        return class_weights
    
def docToClassSet(doc_emb, class_set, set_weights=None):
    # set_weights: C x CD -> how confident each class-indicative document is

    class_dist = np.zeros((doc_emb.shape[0], len(class_set))) # D x C
    for class_id in np.arange(len(class_set)):
        doccos = cosine_similarity(doc_emb, class_set[class_id]) # D x CD
        if set_weights is None:
            cos_sim = np.mean(doccos, axis=1) # on average, how similar each document is to the class set
        else:
            cos_sim = np.average(doccos, axis=1, weights=set_weights[class_id]) # same but weighted average based on class-indicativeness
        
        class_dist[:, class_id] = cos_sim
    
    return np.argmax(class_dist, axis=1), class_dist

### Get initial embeddings and class representations + gold labels ###

In [9]:
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "dataset.pk"), "rb") as f:
    dataset = pk.load(f)
    sent_dict = dataset["sent_data"]
    cleaned_text = dataset["cleaned_text"]
    class_names = np.array(dataset["class_names"])
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
    reprpickle = pk.load(f)
    class_words = reprpickle["class_words"]
gold_labels, gold_sent_labels, doc_to_sent = getDSMapAndGold(args, sent_dict)
num_classes = len(class_words)

### Check Class Imbalance ###

In [10]:
for i in np.arange(num_classes):
    print(f'{class_names[i]}: ({np.sum(gold_labels == i)} docs), {np.sum(gold_labels == i)/len(gold_labels)}')

children: (4686 docs), 0.13948919449901767
comics_graphic: (4934 docs), 0.14687146514258498
fantasy_paranormal: (4351 docs), 0.12951717568613444
history_biography: (4567 docs), 0.13594689527891884
mystery_thriller_crime: (4888 docs), 0.1455021730070846
poetry: (1226 docs), 0.03649461213311901
romance: (4440 docs), 0.1321664582961243
young_adult: (4502 docs), 0.13401202595701614


In [112]:
class_size = np.array([np.sum(gold_labels == i)/len(gold_labels) for i in np.arange(len(class_names))])
class_size

array([0.34037566, 0.3194362 , 0.09197737, 0.07447573, 0.05572397,
       0.05150483, 0.03794106, 0.0160015 , 0.01256368])

In [126]:
np.log(class_size)/np.sum(np.log(class_size))

array([0.04338665, 0.04594275, 0.09606501, 0.10456231, 0.11623976,
       0.1194095 , 0.13171414, 0.16647125, 0.17620864])

In [10]:
plm_padded_sent_repr, plm_class_repr, doc_lengths, plm_sentence_mask = sentenceEmb(args, sent_dict, doc_to_sent, class_words, device)

init_plm_class_set = [np.array([plm_class_repr[i]]) for i in np.arange(len(plm_class_repr))] # C x CD x E
plm_class_set = [np.array([plm_class_repr[i]]) for i in np.arange(len(plm_class_repr))] # C x CD x E

init_plm_class_weights, init_plm_sent_weights = getTargetClassSet(plm_padded_sent_repr, doc_lengths, init_plm_class_set, alpha=None, set_weights=None)

100%|██████████| 31997/31997 [19:43<00:00, 27.04it/s]


Trimmed Documents: 155
Constructing Class Representations...


100%|██████████| 9/9 [00:00<00:00, 132.89it/s]


This article is about business
This article is about politics
This article is about sports
This article is about health
This article is about education
This article is about estate
This article is about arts
This article is about science
This article is about technology


100%|██████████| 31997/31997 [01:39<00:00, 320.20it/s]


In [13]:
sent_repr, class_repr = getSentClassRepr(args)
init_class_set = [np.array([class_repr[i]]) for i in np.arange(len(class_repr))] # C x CD x E

In [14]:
class_set = [np.array([class_repr[i]]) for i in np.arange(len(class_repr))] # C x CD x E

padded_sent_repr, doc_lengths, sentence_mask = bertSentenceEmb(args, doc_to_sent, sent_repr)
init_class_weights, init_sent_weights = getTargetClassSet(padded_sent_repr, doc_lengths, class_set, alpha=None, set_weights=None)

100%|██████████| 33594/33594 [00:00<00:00, 56697.44it/s]


Trimmed Documents: 1


100%|██████████| 33594/33594 [01:01<00:00, 549.46it/s]


### Analyze and Evaluate Class Weights ###

In [15]:
print("Class-Oriented:")
evaluate_predictions(gold_labels, np.argmax(init_class_weights, axis=1))

Class-Oriented:
F1 micro: 0.5027981187116747
F1 macro: 0.5051026475634897


{'f1_micro': 0.5027981187116747, 'f1_macro': 0.5051026475634897}

## Model ##

In [16]:
class MEGClassModel(nn.Module):
    def __init__(self, D_in, D_hidden, head, dropout=0.0):
        super(MEGClassModel, self).__init__()
        self.mha = nn.MultiheadAttention(embed_dim=D_in, num_heads=head, dropout=dropout, batch_first=True)
        self.layernorm = nn.LayerNorm(D_in)
        self.embd = nn.Linear(D_in,D_hidden)
        self.attention = nn.Linear(D_hidden,1)
        
    def forward(self, x_org, mask=None):
        x, mha_w = self.mha(x_org,x_org,x_org,key_padding_mask=mask)
        x = self.layernorm(x_org+x)
        
        x = self.embd(x)
        x = torch.tanh(x) # contextualized sentences
        a = self.attention(x)
        if mask is not None:
            a = a.masked_fill_((mask == 1).unsqueeze(-1), float('-inf'))
        w = torch.softmax(a, dim=1) # alpha_k
        o = torch.matmul(w.permute(0,2,1), x) #doc 
        return o, mha_w, w, x # contextualized doc, multi-head attention weights, alpha_k, contextualized sent


## Contrastive Loss ##

In [17]:
def contrastive_loss(sample_outputs, class_indices, class_embds, temp=0.2):
    k = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/temp)
    loss = -1*(torch.log(k[np.arange(len(class_indices)),class_indices]/k.sum(1))).sum()
    return loss/len(sample_outputs)

def weighted_class_contrastive_loss(args, sample_outputs, class_weights, class_embds):
    # k: B x C, class_weights: B x C
    numerator = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/args.temp)
    denom = numerator.sum(dim=1).unsqueeze(-1)
    weighted_loss = -1 * (torch.log(numerator/(denom)) * class_weights).sum() # B x C -> B
    return weighted_loss/len(sample_outputs)

def classSetContrastiveLoss(args, sample_outputs, class_weights, class_set, device, set_weights=None):
    # sample_outputs: B x E, class_weights: B x C, class_set: C x CD x E, set_weights: C x CD
    # flatten_class_set = np.array([class_set[i][j] for i in np.arange(len(class_set)) for j in np.arange(len(class_set[i]))])
    # flatten_class_set = torch.cat(class_set, dim=0)
    # denom = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:, None], flatten_class_set, dim=2)/args.temp) # B x (C*CD)
    # if set_weights is None:
    #     denom = torch.sum(denom, dim=1).unsqueeze(-1) # B
    # else:
    #     denom = torch.sum(denom * set_weights.reshape(-1,), dim=1).unsqueeze(-1) # B
    weighted_loss = torch.zeros((sample_outputs.size(dim=0), len(class_set))).to(device) # B x C
    for i in np.arange(len(class_set)):
        # similarity between contextualized docs and class-indicative docs
        numerator = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:, None], class_set[i], dim=2)/args.temp) # B x CD
        weighted_loss[:, i] = numerator.mean(dim=1) # B x 1
        # if set_weights is None:
        #     numerator = torch.sum(numerator, dim=1) # B
        # else:
        #     numerator = torch.sum(torch.mul(numerator, set_weights[i]), dim=1) # B
        # if set_weights is None:
        
    return -1 * (torch.log(weighted_loss/weighted_loss.sum(dim=1).unsqueeze(-1)) * class_weights).sum() / len(sample_outputs)


## Train ##

In [18]:
def contextEmb(args, sent_representations, mask, class_set, class_weights, 
                doc_lengths, new_data_path, device):
    sent_representations = torch.from_numpy(sent_representations)
    mask = torch.from_numpy(mask).to(torch.bool)
    class_weights = torch.from_numpy(class_weights)
    dataset = TensorDataset(sent_representations, mask, class_weights)
    sampler = SequentialSampler(dataset)
    dataset_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, shuffle=False)
    # sent_representations: N docs x L sentences x 768 emb (L with padding is always max_sents=50)
    model = MEGClassModel(args.emb_dim, args.emb_dim, args.num_heads).to(device)

    total_steps = len(dataset_loader) * args.epochs / args.accum_steps
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)

    print("Starting to train!")

    for i in tqdm(range(args.epochs)):
        total_train_loss = 0

        for batch in tqdm(dataset_loader):
            model.train()
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)
            input_weights = batch[2].to(device).float()
            
            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)
            torch_class_set = [torch.from_numpy(class_set[i]).float().to(device) for i in np.arange(len(class_set))]
            loss = classSetContrastiveLoss(args, c_doc, input_weights, torch_class_set, device)/ args.accum_steps
            # loss = weighted_class_contrastive_loss(args, c_doc, input_weights, torch.from_numpy(class_repr).float().to(device)) / args.accum_steps

            total_train_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        avg_train_loss = total_train_loss / len(dataset_loader)
        print(f"Average training loss: {avg_train_loss}")

    model.eval()

    torch.save(model.state_dict(), os.path.join(new_data_path, f"{args.dataset_name}_model_e{args.epochs}.pth"))

    print("Starting to evaluate!")

    evalsampler = SequentialSampler(dataset)
    eval_loader = DataLoader(dataset, sampler=evalsampler, batch_size=args.batch_size, shuffle=False)

    doc_predictions = None
    attention_weights = np.zeros_like(mask, dtype=float)
    updated_sent_repr = np.zeros_like(sent_representations)
    final_doc_emb = np.zeros((len(class_weights), args.emb_dim))
    idx = 0

    with torch.no_grad():
        for batch in tqdm(eval_loader):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)

            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)
            c_sent, c_doc, alpha = tensor_to_numpy(c_sent), tensor_to_numpy(c_doc), tensor_to_numpy(torch.squeeze(alpha, dim=2))

            final_doc_emb[idx:idx+c_doc.shape[0], :] = c_doc
            attention_weights[idx:idx+c_doc.shape[0], :] = alpha
            updated_sent_repr[idx:idx+c_doc.shape[0], :, :] = c_sent

            idx += c_doc.shape[0]

            doc_class, _ = docToClassSet(c_doc, class_set)
            if doc_predictions is None:
                doc_predictions = doc_class
            else:
                doc_predictions = np.append(doc_predictions, doc_class)
    
    updated_class_weights = getTargetClassSet(updated_sent_repr, doc_lengths, class_set, attention_weights)

    return doc_predictions, final_doc_emb, updated_sent_repr, attention_weights, updated_class_weights

In [19]:
args.epochs = 4
args.lr = 1e-3
args.temp = 0.1
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, padded_sent_repr, sentence_mask, init_class_set, init_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 525/525 [00:22<00:00, 23.69it/s]
 25%|██▌       | 1/4 [00:22<01:06, 22.16s/it]

Average training loss: 2.118612766265869


100%|██████████| 525/525 [00:18<00:00, 27.87it/s]
 50%|█████     | 2/4 [00:41<00:40, 20.23s/it]

Average training loss: 1.8127846717834473


100%|██████████| 525/525 [00:19<00:00, 27.25it/s]
 75%|███████▌  | 3/4 [01:00<00:19, 19.81s/it]

Average training loss: 1.3866466283798218


100%|██████████| 525/525 [00:18<00:00, 27.81it/s]
100%|██████████| 4/4 [01:19<00:00, 19.82s/it]


Average training loss: 1.1912039518356323
Starting to evaluate!


100%|██████████| 525/525 [00:20<00:00, 25.78it/s]
100%|██████████| 33594/33594 [01:02<00:00, 537.99it/s]


In [20]:
# epochs: 4
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.5133357147109603
F1 macro: 0.5136440314075857


{'f1_micro': 0.5133357147109603, 'f1_macro': 0.5136440314075857}

### Second Iteration ###

In [59]:
curr_class_weights = getTargetClassSet(updated_sent_repr, doc_lengths, class_set, alpha=updated_sent_weights, set_weights=None)
evaluate_predictions(gold_labels, np.argmax(curr_class_weights, axis=1))

100%|██████████| 33594/33594 [06:20<00:00, 88.19it/s] 

F1 micro: 0.5201524081681252
F1 macro: 0.5128400801910938





{'f1_micro': 0.5201524081681252, 'f1_macro': 0.5128400801910938}

In [60]:
args.epochs = 4
args.lr = 1e-3
args.temp = 0.1
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, padded_sent_repr, sentence_mask, init_class_set, curr_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 525/525 [00:20<00:00, 25.46it/s]
 25%|██▌       | 1/4 [00:20<01:01, 20.62s/it]

Average training loss: 2.145280122756958


100%|██████████| 525/525 [00:20<00:00, 25.59it/s]
 50%|█████     | 2/4 [00:41<00:41, 20.60s/it]

Average training loss: 1.7258243560791016


100%|██████████| 525/525 [00:18<00:00, 28.20it/s]
 75%|███████▌  | 3/4 [00:59<00:19, 19.72s/it]

Average training loss: 1.0600605010986328


100%|██████████| 525/525 [00:18<00:00, 28.95it/s]
100%|██████████| 4/4 [01:18<00:00, 19.53s/it]


Average training loss: 0.7097876667976379
Starting to evaluate!


100%|██████████| 525/525 [00:23<00:00, 22.64it/s]
100%|██████████| 33594/33594 [01:03<00:00, 529.02it/s]


In [61]:
# epochs: 4; FOURTH ITER (USING INITIAL REPRESENTATIONS)
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.48871822349229027
F1 macro: 0.47107252805139477


{'f1_micro': 0.48871822349229027, 'f1_macro': 0.47107252805139477}

In [48]:
# epochs: 4; THIRD ITER (USING INITIAL REPRESENTATIONS)
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.5086324938977198
F1 macro: 0.49307389487809583


{'f1_micro': 0.5086324938977198, 'f1_macro': 0.49307389487809583}

In [34]:
# epochs: 4 SECOND ITER (USING INITIAL REPRESENTATIONS)
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.5250937667440614
F1 macro: 0.5161383273632094


{'f1_micro': 0.5250937667440614, 'f1_macro': 0.5161383273632094}

In [98]:
# epochs: 4 SECOND ITER (USING UPDATED REPRESENTATIONS)
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7717531195792066
F1 macro: 0.7618630503676733


{'confusion': [[4713, 30, 116, 26, 6],
  [344, 3281, 144, 198, 12],
  [1509, 111, 1770, 536, 26],
  [64, 70, 109, 2056, 326],
  [44, 28, 133, 247, 1972]],
 'f1_micro': 0.7717531195792066,
 'f1_macro': 0.7618630503676733}

## Evaluate ##

In [24]:
# epochs: 3
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7652061999888087
F1 macro: 0.7372566265293617


{'confusion': [[4798, 56, 4, 31, 2],
  [176, 3698, 1, 100, 4],
  [1763, 312, 1021, 820, 36],
  [49, 162, 3, 2253, 158],
  [41, 44, 10, 424, 1905]],
 'f1_micro': 0.7652061999888087,
 'f1_macro': 0.7372566265293617}

In [26]:
# epochs: 5
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7949191427452297
F1 macro: 0.7772314826599722


{'confusion': [[4830, 30, 6, 18, 7],
  [163, 3715, 5, 91, 5],
  [1683, 182, 1439, 614, 34],
  [42, 121, 13, 2255, 194],
  [47, 31, 15, 364, 1967]],
 'f1_micro': 0.7949191427452297,
 'f1_macro': 0.7772314826599722}

In [64]:
class_names

array(['computer', 'sports', 'science', 'politics', 'religion'],
      dtype='<U8')

In [151]:
doc_id = 0
print("Initial CO Class Weights: ", init_class_weights[doc_id])
print("Initial PLM Class Weights: ", init_plm_class_weights[doc_id])

print("Updated Class Weights: ", updated_class_weights[doc_id])
print("True Class: ", gold_labels[doc_id])
tok_sents = sent_dict[str(doc_id)]
chosen_class = np.argmax(cosine_similarity(updated_sent_repr[doc_id, :len(tok_sents)], plm_class_repr), axis=1)
sent_classes = cosine_similarity(updated_sent_repr[doc_id, :len(tok_sents)], plm_class_repr)

for s,c,cw,w in zip(tok_sents, class_names[chosen_class.astype(int)], sent_classes, updated_sent_weights[doc_id, :len(tok_sents)]):
    print(f'{c} ({cw}) {w}: {s}')

Initial CO Class Weights:  [0.01458973 0.00454681 0.88410401 0.04704762 0.04971184]
Initial PLM Class Weights:  [0.12874045 0.12084435 0.44067873 0.17724041 0.13249605]
Updated Class Weights:  [0.         0.         0.99999993 0.         0.        ]
True Class:  2
science ([-0.02891412  0.05887069  0.21833655 -0.00544262 -0.02173259]) 0.052025001496076584: re how to diagnose lyme... really x-newsreader tin version 1.1 pl9 gordon banks wrote in article ( marc gabriel) writes now, i'm not saying that culturing is the best way to diagnose it's very hard to culture bb in most cases.
science ([-0.04523567  0.04875416  0.19182204 -0.02786728 -0.04778766]) 0.06097852438688278: the point is that dr. n has developed a "feel" for what is and what isn't ld.
science ([-0.07668495  0.00333512  0.12947233  0.00652326 -0.12845935]) 0.029841937124729156: this comes from years of experience.
science ([-0.06803242  0.03180561  0.14810391 -0.01367903 -0.08580509]) 0.037689365446567535: no serology can ma

## Run PCA on Document Embeddings & Fit GMM ##

In [62]:
_pca = PCA(n_components=64, random_state=args.random_state)
pca_doc_repr = _pca.fit_transform(final_doc_emb)
# pca_doc_repr = _pca.fit_transform(np.sum(padded_sent_repr, axis=1)/doc_lengths.reshape((-1, 1)))
# pca_doc_repr = _pca.fit_transform(doc_repr)
# pca_class_repr = [_pca.transform(class_set[i]) for i in np.arange(len(class_set))]
pca_class_repr = [_pca.transform(init_class_set[i]) for i in np.arange(len(init_class_set))]
# pca_class_repr = _pca.transform(plm_class_repr)
print(f"Explained document variance: {sum(_pca.explained_variance_ratio_)}")

Explained document variance: 0.8804910204318771


In [63]:
doc_class_assignment, doc_class_probs = docToClassSet(pca_doc_repr, pca_class_repr, None)
doc_class_probs = doc_class_probs[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.5275942132523664
F1 macro: 0.513313956403899


{'f1_micro': 0.5275942132523664, 'f1_macro': 0.513313956403899}

In [29]:
# class-oriented final doc
# cosine_similarities = cosine_similarity(pca_doc_repr, pca_class_repr)
# doc_class_assignment = np.argmax(cosine_similarities, axis=1)
# doc_class_probs = cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]
# doc_class_probs = updated_class_weights[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]
doc_class_assignment = docToClassSet(pca_doc_repr, pca_class_repr, None)

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.7898270941749203
F1 macro: 0.7651374709330063


{'confusion': [[4713, 47, 73, 24, 34],
  [71, 3703, 55, 105, 45],
  [1309, 136, 1771, 587, 149],
  [24, 81, 34, 1678, 808],
  [12, 17, 18, 127, 2250]],
 'f1_micro': 0.7898270941749203,
 'f1_macro': 0.7651374709330063}

In [146]:
# class-oriented average
cosine_similarities = cosine_similarity(pca_doc_repr, pca_class_repr)
doc_class_assignment = np.argmax(cosine_similarities, axis=1)
doc_class_probs = cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.8893947368421051
F1 macro: 0.8893947349271942


{'confusion': [[16901, 2099], [2104, 16896]],
 'f1_micro': 0.8893947368421051,
 'f1_macro': 0.8893947349271942}

## Save Pseudo Training Dataset ##

In [23]:
def write_to_dir(text, labels, prob, data_path, new_data_path):
    assert len(text) == len(labels)
    print("Saving files in:", new_data_path)
    
    with open(os.path.join(new_data_path, "dataset.txt"), "w") as f:
        for i, line in enumerate(text):
            f.write(line)
            f.write("\n")

    with open(os.path.join(new_data_path, "labels.txt"), "w") as f:
        for i, line in enumerate(labels):
            f.write(str(line))
            f.write("\n")

    with open(os.path.join(new_data_path, "probs.txt"), "w") as f:
        for i, line in enumerate(prob):
            f.write(str(line))
            #f.write(",".join(map(str, line)))
            f.write("\n")

    copyfile(os.path.join(data_path, "classes.txt"),
             os.path.join(new_data_path, "classes.txt"))

In [57]:
def generateDataset(documents_to_class, prob, num_classes, cleaned_text, gold_labels, data_path, new_data_path, thresh=0.5):
    pseudo_document_class_with_confidence = [[] for _ in range(num_classes)]
    for i in range(documents_to_class.shape[0]):
        pseudo_document_class_with_confidence[documents_to_class[i]].append((prob[i], i))

    selected = []
    confident_documents = [[] for _ in range(num_classes)]

    for i in range(num_classes):
        pseudo_document_class_with_confidence[i] = sorted(pseudo_document_class_with_confidence[i], key=lambda x: x[0], reverse=True)
        num_docs_to_take = int(len(pseudo_document_class_with_confidence[i]) * thresh)
        confident_documents[i] = pseudo_document_class_with_confidence[i][:num_docs_to_take]
        selected.extend([x[1] for x in confident_documents[i]])
    
    selected = sorted(selected)
    text = [cleaned_text[i] for i in selected]
    classes = [documents_to_class[i] for i in selected]
    probs = [prob[i] for i in selected]
    ###
    gold_classes = [gold_labels[i] for i in selected]
    evaluate_predictions(gold_classes, classes)
    ###
    # write_to_dir(text, classes, probs, data_path, new_data_path)
    return confident_documents

def updateClassSet(confident_docs, doc_emb, class_set):
    for i in np.arange(len(class_set)):
        doc_ind = [d[1] for d in confident_docs[i]]
        class_set[i] = np.concatenate((np.expand_dims(class_set[i][0], axis=0), doc_emb[doc_ind, :]), axis=0)
    return
        

In [54]:
doc_class_assignment, doc_class_probs = docToClassSet(final_doc_emb, init_class_set, None)
doc_class_probs = doc_class_probs[np.arange(final_doc_emb.shape[0]), doc_class_assignment]

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.7743271221532091
F1 macro: 0.7638439775669024


{'f1_micro': 0.7743271221532091, 'f1_macro': 0.7638439775669024}

In [65]:
finalconf = generateDataset(doc_class_assignment, doc_class_probs, num_classes, cleaned_text, gold_labels, data_path, new_data_path, 0.5)
# updateClassSet(finalconf, final_doc_emb, class_set)

F1 micro: 0.6315350720495415
F1 macro: 0.620464459861108


In [153]:
for i in np.arange(len(finalconf)):
    print(f'{class_names[i]}: {len(finalconf[i])}')

business: 4850
politics: 4525
sports: 1571
health: 1182
education: 993
estate: 1229
arts: 604
science: 168
technology: 875


In [224]:
class_names

array(['business', 'politics', 'sports', 'health', 'education', 'estate',
       'arts', 'science', 'technology'], dtype='<U10')

In [66]:
# class_id = 2
for class_id in np.arange(len(class_names)):
    print("Class: ", class_names[class_id])
    # results = [doc_class_assignment[i[1]] == gold_labels[i[1]] for i in finalconf[class_id][:50]]
    # print(np.sum(results)/len(results))
    all_results = [doc_class_assignment[i[1]] == gold_labels[i[1]] for i in finalconf[class_id]]
    print(np.sum(all_results)/len(all_results))

Class:  children
0.9348837209302325
Class:  comics_graphic
0.9668556476232011
Class:  fantasy_paranormal
0.43478260869565216
Class:  history_biography
0.9775561097256857
Class:  mystery_thriller_crime
0.8326235093696763
Class:  poetry
0.5741728922091782
Class:  romance
0.4310094408133624
Class:  young_adult
0.40095238095238095


In [97]:
doc_id = finalconf[4][0][1]
cleaned_text[doc_id], doc_class_assignment[doc_id], gold_labels[doc_id]

 4,
 4)

In [96]:
finalconf[4][2000]

(0.7317021510492036, 9181)

In [27]:
def doc_gmm(args, document_representations, doc_class_representations, gold_labels):
    num_classes = len(doc_class_representations)
    cosine_similarities = cosine_similarity(document_representations, doc_class_representations)
    doc_class_assignment = np.argmax(cosine_similarities, axis=1)

    print("Evaluate Document Cosine Similarity Predictions: ")
    evaluate_predictions(gold_labels, doc_class_assignment)

    # initialize gmm based on these selected documents
    document_class_assignment_matrix = np.zeros((document_representations.shape[0], num_classes))
    for i in np.arange(len(document_representations)): # iterate through docs and sents
        document_class_assignment_matrix[i][doc_class_assignment[i]] = 1.0

    gmm = GaussianMixture(n_components=num_classes, covariance_type='tied',
                          random_state=args.random_state,
                          n_init=999, warm_start=True)

    gmm._initialize(document_representations, document_class_assignment_matrix)
    gmm.lower_bound_ = -np.infty

    gmm.converged_ = True #HACK FOR NOT RANDOMLY INITIALIZING PARAMS DURING FIT
    gmm.fit(document_representations)

    documents_to_class = gmm.predict(document_representations)
    confidence = gmm.predict_proba(document_representations)

    print("Evaluate Document GMM Predictions: ")
    evaluate_predictions(gold_labels, documents_to_class)

    return documents_to_class, confidence

In [32]:
pca_class_repr[0].shape

(1, 64)

In [34]:
# 128
doc_preds, doc_prob = doc_gmm(args, pca_doc_repr, np.array(pca_class_repr).squeeze(1), gold_labels)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.7002218958027315
F1 macro: 0.5808469084409915
Evaluate Document GMM Predictions: 
F1 micro: 0.7940431915492078
F1 macro: 0.6569828627562839
