## Setup ##

In [1]:
import numpy as np
import os
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import torch
from torch import nn
import sys
import torch.nn.functional as F
import argparse
import re
import pickle as pk
from tqdm import tqdm
from scipy.special import softmax
from sklearn.decomposition import PCA
from shutil import copyfile
from sklearn.mixture import GaussianMixture
from sklearn.mixture._gaussian_mixture import _estimate_gaussian_parameters
from nltk.tokenize import sent_tokenize, word_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import confusion_matrix, f1_score
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

  from .autonotebook import tqdm as notebook_tqdm


In [106]:
class Args(object):
    pass

args = Args()
args.dataset_name = "books"
args.gpu = 1
args.pca = 64
args.random_state = 42
args.emb_dim = 768
args.num_heads = 2
args.batch_size = 64
args.temp = 0.2
args.lr = 1e-3
args.epochs = 4
args.accum_steps = 1
args.max_sent = 150

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

In [107]:
data_path = os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name)
new_data_path = os.path.join("/home/pk36/MEGClass/intermediate_data", args.dataset_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# this is where we save all representations
if not os.path.exists(new_data_path):
    os.makedirs(new_data_path)

## Helper Functions ##

In [108]:
DATA_FOLDER_PATH = os.path.join("/shared/data2/pk36/multidim/multigran")
INTERMEDIATE_DATA_FOLDER_PATH = os.path.join("/home/pk36/MEGClass/intermediate_data")

def tensor_to_numpy(tensor):
    return tensor.clone().detach().cpu().numpy()

# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def sentenceToClass(sent_repr, class_repr, weights):
    # sent_repr: N x S x E
    # class_repr: C x E
    # weights: N x S # equals 0 for masked sentences

    #cos-sim between (N x S) x E and (C x E) = N x S x C
    m, n = sent_repr.shape[:2]
    sentcos = cosine_similarity(sent_repr.reshape(m*n,-1), class_repr).reshape(m,n,-1)
    sent_to_class = np.argmax(sentcos, axis=2) # N x S
    sent_to_doc_class = np.sum(np.multiply(sent_to_class, weights), axis=1) # N x 1
    return sent_to_doc_class

def docToClass(doc_repr, class_repr):
    # doc_repr: N x E
    # class_repr: C x E

    #cos-sim between N x E and C x E = N x C
    doccos = cosine_similarity(doc_repr, class_repr)
    doc_to_class = np.argmax(doccos, axis=1) # N x 1
    return doc_to_class

def evaluate_predictions(true_class, predicted_class, output_to_console=True, return_tuple=False, return_confusion=False):
    confusion = confusion_matrix(true_class, predicted_class)
    if return_confusion and output_to_console:
        print("-" * 80 + "Evaluating" + "-" * 80)
        print(confusion)
    
    f1_micro = f1_score(true_class, predicted_class, average='micro')
    f1_macro = f1_score(true_class, predicted_class, average='macro')
    if output_to_console:
        print("F1 micro: " + str(f1_micro))
        print("F1 macro: " + str(f1_macro))
    if return_tuple:
        return f1_macro, f1_micro
    else:
        return {
            "f1_micro": f1_micro,
            "f1_macro": f1_macro
        }
    
def getSentClassRepr(args):
    with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, f"document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
        dictionary = pk.load(f)
        class_repr = dictionary["class_representations"]
        sent_repr = dictionary["sent_representations"]
        if (args.dataset_name in ["nyt-location", "books", "nyt-fine-updated", "nyt-coarse-nophrase"]):
         return sent_repr, class_repr
        else:
            doc_repr = dictionary["document_representations"]
    return doc_repr, sent_repr, class_repr


def getDSMapAndGold(args, sent_dict):
    # get the ground truth labels for all documents and assign a "ground truth" label to each sentence based on its parent document
    gold_labels = list(map(int, open(os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name, "labels.txt"), "r").read().splitlines()))
    gold_sent_labels = []
    # get all sent ids for each doc
    doc_to_sent = []
    sent_id = 0
    for doc_id, doc in enumerate(sent_dict.values()):
        sent_ids = []
        for sent in doc:
            sent_ids.append(sent_id)
            gold_sent_labels.append(gold_labels[doc_id])
            sent_id += 1
        doc_to_sent.append(sent_ids)
            
    return gold_labels, gold_sent_labels, doc_to_sent

## BERT-Based Sentence Embeddings, Initial Doc Embeddings, & Class Representations ##

In [109]:
def sentenceEmb(args, sent_dict, doc_to_sent, class_words, device, classonly=False):
    # Load model from HuggingFace Hub
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    model = model.to(device)
    
    num_docs = len(doc_to_sent)
    padded_sent_repr = np.zeros((num_docs, args.max_sent, args.emb_dim))
    sentence_mask = np.ones((num_docs, args.max_sent))
    doc_lengths = np.zeros(num_docs, dtype=int)
    trimmed = 0

    if not classonly:
        for doc_id in tqdm(np.arange(num_docs)):
            sents = sent_dict[str(doc_id)]
            num_sent = len(sents)
            if num_sent > args.max_sent:
                trimmed += 1
                sents = sents[:args.max_sent]
                num_sent = args.max_sent
            encoded_input = tokenizer(sents, padding=True, truncation=True, return_tensors='pt')
            encoded_input = encoded_input.to(device)

            # Compute token embeddings
            with torch.no_grad():
                model_output = model(**encoded_input)

            # Perform pooling
            embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

            # Normalize embeddings
            embeddings = tensor_to_numpy(F.normalize(embeddings, p=2, dim=1))

            # save the number of sentences in each document
            doc_lengths[doc_id] = int(num_sent)

            padded_sent_repr[doc_id, :embeddings.shape[0], :] = embeddings
            # Update mask so that padded sentences are not included in attention computation
            sentence_mask[doc_id, :num_sent] = 0
        print(f"Trimmed Documents: {trimmed}")

    # construct class representations
    class_repr = np.zeros((len(class_words), args.emb_dim))
    intro_map = {"20News":"This article is about ", "agnews": "This article is about ",
                "yelp": "This is ", "nyt-coarse":"This article is about ", 
                "nyt-fine": "This article is about "}

    print("Constructing Class Representations...")
    for class_id in tqdm(np.arange(len(class_words))):
        # class_sent = intro_map[args.dataset_name] + ", ".join(class_words[class_id])
        class_sent = intro_map[args.dataset_name] + class_words[class_id][0]
        
        print(class_sent)
        encoded_input = tokenizer(class_sent, truncation=True, return_tensors='pt')
        encoded_input = encoded_input.to(device)
         # Compute token embeddings
        with torch.no_grad():
            model_output = model(**encoded_input)
        # Perform pooling
        embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        # Normalize embeddings
        embeddings = tensor_to_numpy(F.normalize(embeddings, p=2, dim=1))
        class_repr[class_id, :] = embeddings
    
    if classonly:
        return class_repr
    else:
        return padded_sent_repr, class_repr, doc_lengths, sentence_mask

In [110]:
def bertSentenceEmb(args, doc_to_sent, sent_repr):
    num_docs = len(doc_to_sent)
    doc_lengths = np.zeros(num_docs, dtype=int)
    # init_doc_repr = np.zeros((num_docs, args.emb_dim))
    padded_sent_repr = np.zeros((num_docs, args.max_sent, args.emb_dim))
    sentence_mask = np.ones((num_docs, args.max_sent))
    trimmed = 0


    for doc_id in tqdm(np.arange(num_docs)):
        start_sent = doc_to_sent[doc_id][0]
        end_sent = doc_to_sent[doc_id][-1]
        num_sent = end_sent - start_sent + 1
        if num_sent > args.max_sent:
            end_sent = start_sent + args.max_sent - 1
            num_sent = args.max_sent
            trimmed += 1
        embeddings = sent_repr[start_sent:end_sent+1]

        # save the number of sentences in each document
        doc_lengths[doc_id] = int(num_sent)

        # Add initial doc representation
        # init_doc_repr[doc_id, :] = np.mean(embeddings, axis=0)
        # Add padded sentences
        padded_sent_repr[doc_id, :embeddings.shape[0], :] = embeddings
        # Update mask so that padded sentences are not included in attention computation
        sentence_mask[doc_id, :num_sent] = 0

    
    print(f"Trimmed Documents: {trimmed}")

    return padded_sent_repr, doc_lengths, sentence_mask

### Get Class Weights ###

In [111]:
def getTargetClasses(padded_sent_repr, doc_lengths, class_repr, weights=None):
    # weights: N x 150
    class_weights = np.zeros((padded_sent_repr.shape[0], class_repr.shape[0])) # N x C
    sent_weights = np.zeros(padded_sent_repr.shape[:2])

    for doc_id in tqdm(np.arange(padded_sent_repr.shape[0])):
        l = doc_lengths[doc_id]
        sent_emb = padded_sent_repr[doc_id, :l, :] # S x E
        sentcos = cosine_similarity(sent_emb, class_repr) # S x C
        sent_to_class = np.argmax(sentcos, axis=1) # S

        # add sentence centrality to weight
        # cent = np.zeros(l)
        # sent_to_sent = cosine_similarity(sent_emb, sent_emb) # S x S
        # sim = np.triu(sent_to_sent)
        # np.fill_diagonal(sim, 0)
        # sim = sim[:l-1, :]
        # simsum = np.divide(sim.sum(axis=1), np.arange(l - 1, 0, -1)) # S x 1
        # cent[:l-1] = simsum/(simsum.sum())
        
        # default: equal vote weight between all sentences
        if weights is None:
            # w = np.ones(doc_lengths[doc_id])/doc_lengths[doc_id]
            # top cos-sim - second cos-sim
            toptwo = np.partition(sentcos, -2)[:, -2:] # S x 2
            toptwo = toptwo[:, 1] - toptwo[:, 0] # S
            w = toptwo / np.sum(toptwo)
            # w = (w + cent)/2
            sent_weights[doc_id, :l] = w
        else:
            w = weights[doc_id, :l]
        
        class_weights[doc_id, :] = np.bincount(sent_to_class, weights=w, minlength=class_repr.shape[0])

    if weights is None:
        return class_weights, sent_weights
    else:
        return class_weights

In [112]:
def getTargetClassSet(padded_sent_repr, doc_lengths, class_set, alpha=None, set_weights=None):
    # sent_weights: N x 150 -> weigh each sentence based on its contribution to the document
    # set_weights: C x CD -> how confident each class-indicative document is

    class_weights = np.zeros((padded_sent_repr.shape[0], len(class_set))) # N x C
    if alpha is None:
        sent_weights = np.zeros(padded_sent_repr.shape[:2]) # N x 150

    for doc_id in tqdm(np.arange(padded_sent_repr.shape[0])):
        l = doc_lengths[doc_id]
        sent_emb = padded_sent_repr[doc_id, :l, :] # S x E
        sentclass_dist = np.zeros((l, len(class_set))) # S x C

        for class_id in np.arange(len(class_set)):
            sentcos = cosine_similarity(sent_emb, class_set[class_id]) # S x CD
            if set_weights is None:
                sentclass_sim = np.mean(sentcos, axis=1) # on average, how similar each sentence is to the class set
            else:
                sentclass_sim = np.average(sentcos, axis=1, weights=set_weights[class_id]) # same but weighted average based on class-indicativeness
            
            sentclass_dist[:, class_id] = sentclass_sim


        sent_to_class = np.argmax(sentclass_dist, axis=1) # S

        # add sentence centrality to weight
        # cent = np.zeros(l)
        # sent_to_sent = cosine_similarity(sent_emb, sent_emb) # S x S
        # sim = np.triu(sent_to_sent)
        # np.fill_diagonal(sim, 0)
        # sim = sim[:l-1, :]
        # simsum = np.divide(sim.sum(axis=1), np.arange(l - 1, 0, -1)) # S x 1
        # cent[:l-1] = simsum/(simsum.sum())
        
        # default: equal vote weight between all sentences
        if alpha is None:
            # top cos-sim - second cos-sim
            toptwo = np.partition(sentclass_dist, -2)[:, -2:] # S x 2
            toptwo = toptwo[:, 1] - toptwo[:, 0] # S
            w = toptwo / np.sum(toptwo)
            # w = (w + cent)/2
            sent_weights[doc_id, :l] = w
        else:
            w = alpha[doc_id, :l]
        
        class_weights[doc_id, :] = np.bincount(sent_to_class, weights=w, minlength=len(class_set))
        # class_weights[doc_id, :] = np.bincount(sent_to_class, minlength=len(class_set))

    if alpha is None:
        return class_weights, sent_weights
    else:
        return class_weights
    
def docToClassSet(doc_emb, class_set, set_weights=None):
    # set_weights: C x CD -> how confident each class-indicative document is

    class_dist = np.zeros((doc_emb.shape[0], len(class_set))) # D x C
    for class_id in np.arange(len(class_set)):
        doccos = cosine_similarity(doc_emb, class_set[class_id]) # D x CD
        if set_weights is None:
            cos_sim = np.mean(doccos, axis=1) # on average, how similar each document is to the class set
        else:
            cos_sim = np.average(doccos, axis=1, weights=set_weights[class_id]) # same but weighted average based on class-indicativeness
        
        class_dist[:, class_id] = cos_sim
    
    return np.argmax(class_dist, axis=1), class_dist

### Get initial embeddings and class representations + gold labels ###

In [113]:
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "dataset.pk"), "rb") as f:
    dataset = pk.load(f)
    sent_dict = dataset["sent_data"]
    cleaned_text = dataset["cleaned_text"]
    class_names = np.array(dataset["class_names"])
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
    reprpickle = pk.load(f)
    class_words = reprpickle["class_words"]
gold_labels, gold_sent_labels, doc_to_sent = getDSMapAndGold(args, sent_dict)
num_classes = len(class_words)

### Get Representations ###

In [114]:
sent_repr, class_repr = getSentClassRepr(args)
init_class_set = [np.array([class_repr[i]]) for i in np.arange(len(class_repr))] # C x CD x E

In [115]:
class_set = [np.array([class_repr[i]]) for i in np.arange(len(class_repr))] # C x CD x E

padded_sent_repr, doc_lengths, sentence_mask = bertSentenceEmb(args, doc_to_sent, sent_repr)
init_class_weights, init_sent_weights = getTargetClassSet(padded_sent_repr, doc_lengths, class_set, alpha=None, set_weights=None)

100%|██████████| 33594/33594 [00:00<00:00, 49067.67it/s]


Trimmed Documents: 1


100%|██████████| 33594/33594 [01:02<00:00, 536.08it/s]


In [116]:
print("Class-Oriented:")
evaluate_predictions(gold_labels, np.argmax(init_class_weights, axis=1))

Class-Oriented:
F1 micro: 0.5027981187116747
F1 macro: 0.5051026475634897


{'f1_micro': 0.5027981187116747, 'f1_macro': 0.5051026475634897}

## Check Class Imbalance ##

In [34]:
print(class_words[list(class_names).index("health")])

['health', 'medical', 'healthcare', 'medicine', 'hospital', 'physician', 'care', 'hospitalization', 'medically', 'physicians', 'hospitals', 'doctor', 'clinical', 'nonmedical', 'surgeon', 'doctors', 'hospitalizations', 'cardiologists', 'obstetrician', 'radiology', 'cardiologist', 'pathologist', 'patient', 'pediatrician', 'pediatricians', 'obstetricians', 'neurologists', 'pathologists', 'clinicians', 'anesthesiologists', 'cardiology', 'medics', 'medic', 'radiologists', 'neurosurgery', 'anesthesiologist', 'gynecologist', 'neurologist', 'radiologist', 'paramedic', 'oncologist', 'neurosurgeon', 'immunologist', 'pharmacist', 'neurosurgeons', 'gynecologists', 'urologist', 'gastroenterologist', 'surgical', 'obstetric', 'obstetrics', 'pediatrics', 'biomedical', 'neurology', 'obstetrical', 'inpatient', 'outpatient', 'orthopedic', 'sanitarium', 'nurse', 'epidemiologists', 'gynecological', 'epidemiologist', 'ophthalmologist', 'diabetics', 'pharmacists', 'dermatology', 'diagnostic', 'dialysis', 'ou

### Precision & Recall ###

In [28]:
init_doc_preds = np.argmax(init_class_weights, axis=1)
curr_doc_preds = np.argmax(curr_class_weights, axis=1)

def computeConfusion(preds):
    class_imb = np.zeros((num_classes, 3)) # count FP, TP, FN
    miscls = np.zeros((num_classes, num_classes))

    for i in np.arange(len(gold_labels)):
        if gold_labels[i] == preds[i]:
            class_imb[gold_labels[i]][1] += 1
        else:
            class_imb[gold_labels[i]][2] += 1
            class_imb[preds[i]][0] += 1
            
        miscls[gold_labels[i]][preds[i]] += 1

    return class_imb, miscls

init_class_imb, init_miscls = computeConfusion(init_doc_preds)
curr_class_imb, curr_miscls = computeConfusion(curr_doc_preds)

for i in np.arange(num_classes):
    gold_prop = np.sum(gold_labels == i)/len(gold_labels)
    init_prop = np.sum(init_doc_preds == i)/len(init_doc_preds)
    curr_prop = np.sum(curr_doc_preds == i)/len(curr_doc_preds)
    # Precision: out of all of the docs labeled abortion, which ones were correct?
    # Recall: out of all of the true abortion docs, how many were correctly identified?
    print(f'{class_names[i]} (Init Total: {init_class_imb[i][1] + init_class_imb[i][2]}, {round(gold_prop,4)}, {init_class_imb[i]}, {round(init_prop, 4)}):')
    print(f'Precision: {init_class_imb[i][1]/(init_class_imb[i][0] + init_class_imb[i][1])}, Recall: {init_class_imb[i][1]/(init_class_imb[i][1] + init_class_imb[i][2])}')
    print(f'{class_names[i]} (Curr Total: {curr_class_imb[i][1] + curr_class_imb[i][2]}, {round(gold_prop,4)}, {curr_class_imb[i]}, {round(curr_prop, 4)}):')
    print(f'Precision: {curr_class_imb[i][1]/(curr_class_imb[i][0] + curr_class_imb[i][1])}, Recall: {curr_class_imb[i][1]/(curr_class_imb[i][1] + curr_class_imb[i][2])}')
    print("\n")

business (Init Total: 10891.0, 0.3404, [ 651. 9018. 1873.], 0.3022):
Precision: 0.932671424139001, Recall: 0.8280231383711322
business (Curr Total: 10891.0, 0.3404, [ 477. 9158. 1733.], 0.3011):
Precision: 0.950492994291645, Recall: 0.8408777890000918


politics (Init Total: 10221.0, 0.3194, [3791. 9216. 1005.], 0.4065):
Precision: 0.7085415545475513, Recall: 0.9016730261226886
politics (Curr Total: 10221.0, 0.3194, [2827. 8862. 1359.], 0.3653):
Precision: 0.7581486867995552, Recall: 0.8670384502494863


sports (Init Total: 2943.0, 0.092, [ 131. 2562.  381.], 0.0842):
Precision: 0.9513553657630895, Recall: 0.8705402650356778
sports (Curr Total: 2943.0, 0.092, [ 364. 2664.  279.], 0.0946):
Precision: 0.8797886393659181, Recall: 0.9051987767584098


health (Init Total: 2383.0, 0.0745, [ 435. 1170. 1213.], 0.0502):
Precision: 0.7289719626168224, Recall: 0.49097775912715064
health (Curr Total: 2383.0, 0.0745, [ 392. 1400.  983.], 0.056):
Precision: 0.78125, Recall: 0.5874947545111204


edu

### Check the most common misclassifications for each class ###

In [29]:
for i in np.arange(num_classes):
    print(f'Initial {class_names[i]}:')
    for c in np.arange(num_classes):
        if init_miscls[i][c] > 0:
            print(f'{class_names[c]}: {init_miscls[i][c]}', end=" ")
    print(f'\nCurrent {class_names[i]}:')
    for c in np.arange(num_classes):
        if curr_miscls[i][c] > 0:
            print(f'{class_names[c]}: {curr_miscls[i][c]}', end=" ")
    print("\n")

Initial business:
business: 9018.0 politics: 1459.0 sports: 30.0 health: 21.0 education: 21.0 estate: 161.0 arts: 7.0 science: 11.0 technology: 163.0 
Current business:
business: 9158.0 politics: 1206.0 sports: 60.0 health: 20.0 education: 22.0 estate: 190.0 arts: 16.0 science: 13.0 technology: 206.0 

Initial politics:
business: 221.0 politics: 9216.0 sports: 53.0 health: 361.0 education: 35.0 estate: 35.0 arts: 10.0 science: 19.0 technology: 271.0 
Current politics:
business: 200.0 politics: 8862.0 sports: 266.0 health: 298.0 education: 65.0 estate: 78.0 arts: 14.0 science: 10.0 technology: 428.0 

Initial sports:
business: 49.0 politics: 258.0 sports: 2562.0 health: 16.0 education: 22.0 estate: 31.0 arts: 4.0 technology: 1.0 
Current sports:
business: 29.0 politics: 155.0 sports: 2664.0 health: 17.0 education: 28.0 estate: 36.0 arts: 8.0 technology: 6.0 

Initial health:
business: 117.0 politics: 841.0 sports: 13.0 health: 1170.0 education: 60.0 estate: 51.0 arts: 8.0 science: 114.0

In [32]:
minority_cls = []
for i in np.arange(num_classes):
    prop = np.sum(gold_labels == i)/len(gold_labels)
    print(f'{class_names[i]}: ({np.sum(gold_labels == i)} docs), {prop}')
    if prop < (1/num_classes):
        minority_cls.append(i)
print("Minority Classes:")
print([class_names[c] for c in minority_cls])

business: (10891 docs), 0.34037566021814547
politics: (10221 docs), 0.3194361971434822
sports: (2943 docs), 0.09197737287870737
health: (2383 docs), 0.07447573209988437
education: (1783 docs), 0.05572397412257399
estate: (1648 docs), 0.05150482857767916
arts: (1214 docs), 0.03794105697409132
science: (512 docs), 0.016001500140638183
technology: (402 docs), 0.01256367784479795
Minority Classes:
['sports', 'health', 'education', 'estate', 'arts', 'science', 'technology']


In [112]:
class_size = np.array([np.sum(gold_labels == i)/len(gold_labels) for i in np.arange(len(class_names))])
class_size

array([0.34037566, 0.3194362 , 0.09197737, 0.07447573, 0.05572397,
       0.05150483, 0.03794106, 0.0160015 , 0.01256368])

In [126]:
np.log(class_size)/np.sum(np.log(class_size))

array([0.04338665, 0.04594275, 0.09606501, 0.10456231, 0.11623976,
       0.1194095 , 0.13171414, 0.16647125, 0.17620864])

In [10]:
plm_padded_sent_repr, plm_class_repr, doc_lengths, plm_sentence_mask = sentenceEmb(args, sent_dict, doc_to_sent, class_words, device)

init_plm_class_set = [np.array([plm_class_repr[i]]) for i in np.arange(len(plm_class_repr))] # C x CD x E
plm_class_set = [np.array([plm_class_repr[i]]) for i in np.arange(len(plm_class_repr))] # C x CD x E

init_plm_class_weights, init_plm_sent_weights = getTargetClassSet(plm_padded_sent_repr, doc_lengths, init_plm_class_set, alpha=None, set_weights=None)

100%|██████████| 31997/31997 [19:43<00:00, 27.04it/s]


Trimmed Documents: 155
Constructing Class Representations...


100%|██████████| 9/9 [00:00<00:00, 132.89it/s]


This article is about business
This article is about politics
This article is about sports
This article is about health
This article is about education
This article is about estate
This article is about arts
This article is about science
This article is about technology


100%|██████████| 31997/31997 [01:39<00:00, 320.20it/s]


## Model ##

In [117]:
class MEGClassModel(nn.Module):
    def __init__(self, D_in, D_hidden, head, dropout=0.0):
        super(MEGClassModel, self).__init__()
        self.mha = nn.MultiheadAttention(embed_dim=D_in, num_heads=head, dropout=dropout, batch_first=True)
        self.layernorm = nn.LayerNorm(D_in)
        self.embd = nn.Linear(D_in,D_hidden)
        self.attention = nn.Linear(D_hidden,1)
        
    def forward(self, x_org, mask=None):
        x, mha_w = self.mha(x_org,x_org,x_org,key_padding_mask=mask)
        x = self.layernorm(x_org+x)
        
        x = self.embd(x)
        x = torch.tanh(x) # contextualized sentences
        a = self.attention(x)
        if mask is not None:
            a = a.masked_fill_((mask == 1).unsqueeze(-1), float('-inf'))
        w = torch.softmax(a, dim=1) # alpha_k
        o = torch.matmul(w.permute(0,2,1), x) #doc 
        return o, mha_w, w, x # contextualized doc, multi-head attention weights, alpha_k, contextualized sent
    
    def new_forward(self, x_org, classes=None, mask=None):
        x, mha_w = self.mha(x_org,x_org,x_org,key_padding_mask=mask)
        x = self.layernorm(x_org+x)
        
        x = self.embd(x)
        x = torch.tanh(x) # contextualized sentences b x s x emb-dim
        a = self.attention(x) # b x s x 1

        cent = torch.zeros_like(a)
        disc = torch.zeros_like(a)
        for i in torch.arange(a.size(dim=0)):
            # sentence centrality
            if 1 in mask[i]:
                numsent = torch.argmax(mask[i].to(torch.long))
            else:
                numsent = mask[i].size(dim=0)
            xi_norm = x[i] / x[i].norm(dim=1)[:, None]
            sim = torch.mm(xi_norm[:numsent], xi_norm[:numsent].transpose(0,1)) # s x s
            sim = torch.triu(sim).fill_diagonal_(0)
            sim = sim[:numsent-1] # numsent-1 x numsent
            simsum = sim.sum(dim=1) # (numsent - 1) x 1
            simlen = torch.arange(numsent - 1, 0, -1).to(device)
            cent[i, :numsent-1] = nn.functional.normalize(torch.div(simsum, simlen).unsqueeze(-1), dim=0)

            # sentence class contrastive
            if classes is not None:
                classnorm = classes / classes.norm(dim=1)[:,  None]
                class_sim = torch.mm(xi_norm[:numsent], classnorm.transpose(0,1)) # s x c
                toptwo = torch.topk(class_sim, 2, dim=1).values # S x 2
                toptwo = torch.sub(toptwo[:, 0], toptwo[:, 1]) # s
                disc[i, :numsent] = torch.div(toptwo, torch.sum(toptwo)).unsqueeze(-1)

        # print(cent[5].flatten())
        # print(disc[5].flatten())
        # print(a[5].flatten())
        # 1/0
        a = torch.add(a, cent)
        a = torch.add(a, disc)

        if mask is not None:
            a = a.masked_fill_((mask == 1).unsqueeze(-1), float('-inf'))
        w = torch.softmax(a, dim=1) # alpha_k
        o = torch.matmul(w.permute(0,2,1), x) #doc 
        return o, mha_w, w, x # contextualized doc, multi-head attention weights, alpha_k, contextualized sent


## Contrastive Loss ##

In [118]:
def contrastive_loss(sample_outputs, class_indices, class_embds, temp=0.2):
    k = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/temp)
    loss = -1*(torch.log(k[np.arange(len(class_indices)),class_indices]/k.sum(1))).sum()
    return loss/len(sample_outputs)

def weighted_class_contrastive_loss(args, sample_outputs, class_weights, class_embds):
    # k: B x C, class_weights: B x C
    numerator = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/args.temp)
    denom = numerator.sum(dim=1).unsqueeze(-1)
    weighted_loss = -1 * (torch.log(numerator/(denom)) * class_weights).sum() # B x C -> B
    return weighted_loss/len(sample_outputs)

def weighted_contrastive_doc_sent_loss(args, sample_outputs, sent_output, sent_weights, sent_mask, class_weights, class_embds):
    # k: B x C, class_weights: B x C
    # weighted loss for each document repr and class repr

    numerator = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/args.temp)
    denom = numerator.sum(dim=1).unsqueeze(-1)
    weighted_loss = -1 * (torch.log(numerator/(denom)) * class_weights).sum() # B x C -> B

    return weighted_loss/len(sample_outputs)


def classSetContrastiveLoss(args, sample_outputs, class_weights, class_set, device, set_weights=None):
    # sample_outputs: B x E, class_weights: B x C, class_set: C x CD x E, set_weights: C x CD
    # flatten_class_set = np.array([class_set[i][j] for i in np.arange(len(class_set)) for j in np.arange(len(class_set[i]))])
    # flatten_class_set = torch.cat(class_set, dim=0)
    # denom = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:, None], flatten_class_set, dim=2)/args.temp) # B x (C*CD)
    # if set_weights is None:
    #     denom = torch.sum(denom, dim=1).unsqueeze(-1) # B
    # else:
    #     denom = torch.sum(denom * set_weights.reshape(-1,), dim=1).unsqueeze(-1) # B
    weighted_loss = torch.zeros((sample_outputs.size(dim=0), len(class_set))).to(device) # B x C
    for i in np.arange(len(class_set)):
        # similarity between contextualized docs and class-indicative docs
        numerator = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:, None], class_set[i], dim=2)/args.temp) # B x CD
        weighted_loss[:, i] = numerator.mean(dim=1) # B x 1
        # if set_weights is None:
        #     numerator = torch.sum(numerator, dim=1) # B
        # else:
        #     numerator = torch.sum(torch.mul(numerator, set_weights[i]), dim=1) # B
        # if set_weights is None:
        
    return -1 * (torch.log(weighted_loss/weighted_loss.sum(dim=1).unsqueeze(-1)) * class_weights).sum() / len(sample_outputs)


## Train ##

In [119]:
def contextEmb(args, sent_representations, mask, class_repr, class_weights, 
                doc_lengths, new_data_path, device):
    sent_representations = torch.from_numpy(sent_representations)
    mask = torch.from_numpy(mask).to(torch.bool)
    class_weights = torch.from_numpy(class_weights)
    classes = torch.from_numpy(class_repr).float().to(device)
    dataset = TensorDataset(sent_representations, mask, class_weights)
    sampler = SequentialSampler(dataset)
    dataset_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, shuffle=False)
    # sent_representations: N docs x L sentences x 768 emb (L with padding is always max_sents=50)
    model = MEGClassModel(args.emb_dim, args.emb_dim, args.num_heads).to(device)

    total_steps = len(dataset_loader) * args.epochs / args.accum_steps
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)

    print("Starting to train!")

    for i in tqdm(range(args.epochs)):
        total_train_loss = 0

        for batch in tqdm(dataset_loader):
            model.train()
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)
            input_weights = batch[2].to(device).float()
            
            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)

            loss = weighted_class_contrastive_loss(args, c_doc, input_weights, classes) / args.accum_steps

            total_train_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        avg_train_loss = total_train_loss / len(dataset_loader)
        print(f"Average training loss: {avg_train_loss}")

    model.eval()

    torch.save(model.state_dict(), os.path.join(new_data_path, f"{args.dataset_name}_model_e{args.epochs}.pth"))

    print("Starting to evaluate!")

    evalsampler = SequentialSampler(dataset)
    eval_loader = DataLoader(dataset, sampler=evalsampler, batch_size=args.batch_size, shuffle=False)

    doc_predictions = None
    attention_weights = np.zeros_like(mask, dtype=float)
    updated_sent_repr = np.zeros_like(sent_representations)
    final_doc_emb = np.zeros((len(class_weights), args.emb_dim))
    idx = 0

    with torch.no_grad():
        for batch in tqdm(eval_loader):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)

            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)
            c_sent, c_doc, alpha = tensor_to_numpy(c_sent), tensor_to_numpy(c_doc), tensor_to_numpy(torch.squeeze(alpha, dim=2))

            final_doc_emb[idx:idx+c_doc.shape[0], :] = c_doc
            attention_weights[idx:idx+c_doc.shape[0], :] = alpha
            updated_sent_repr[idx:idx+c_doc.shape[0], :, :] = c_sent

            idx += c_doc.shape[0]

            doc_class = docToClass(c_doc, class_repr)
            if doc_predictions is None:
                doc_predictions = doc_class
            else:
                doc_predictions = np.append(doc_predictions, doc_class)
    
    updated_class_weights = getTargetClasses(updated_sent_repr, doc_lengths, class_repr, attention_weights)

    return doc_predictions, final_doc_emb, updated_sent_repr, attention_weights, updated_class_weights

In [167]:
def contextEmb(args, sent_representations, mask, class_set, class_weights, 
                doc_lengths, new_data_path, device):
    sent_representations = torch.from_numpy(sent_representations)
    mask = torch.from_numpy(mask).to(torch.bool)
    class_weights = torch.from_numpy(class_weights)
    dataset = TensorDataset(sent_representations, mask, class_weights)
    sampler = SequentialSampler(dataset)
    dataset_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, shuffle=False)
    # sent_representations: N docs x L sentences x 768 emb (L with padding is always max_sents=50)
    model = MEGClassModel(args.emb_dim, args.emb_dim, args.num_heads).to(device)

    total_steps = len(dataset_loader) * args.epochs / args.accum_steps
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)

    print("Starting to train!")

    for i in tqdm(range(args.epochs)):
        total_train_loss = 0

        for batch in tqdm(dataset_loader):
            model.train()
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)
            input_weights = batch[2].to(device).float()
            
            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)
            torch_class_set = [torch.from_numpy(class_set[i]).float().to(device) for i in np.arange(len(class_set))]
            loss = classSetContrastiveLoss(args, c_doc, input_weights, torch_class_set, device)/ args.accum_steps
            # loss = weighted_class_contrastive_loss(args, c_doc, input_weights, torch.from_numpy(class_repr).float().to(device)) / args.accum_steps

            total_train_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        avg_train_loss = total_train_loss / len(dataset_loader)
        print(f"Average training loss: {avg_train_loss}")

    model.eval()

    torch.save(model.state_dict(), os.path.join(new_data_path, f"{args.dataset_name}_model_e{args.epochs}.pth"))

    print("Starting to evaluate!")

    evalsampler = SequentialSampler(dataset)
    eval_loader = DataLoader(dataset, sampler=evalsampler, batch_size=args.batch_size, shuffle=False)

    doc_predictions = None
    attention_weights = np.zeros_like(mask, dtype=float)
    updated_sent_repr = np.zeros_like(sent_representations)
    final_doc_emb = np.zeros((len(class_weights), args.emb_dim))
    idx = 0

    with torch.no_grad():
        for batch in tqdm(eval_loader):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)

            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)
            c_sent, c_doc, alpha = tensor_to_numpy(c_sent), tensor_to_numpy(c_doc), tensor_to_numpy(torch.squeeze(alpha, dim=2))

            final_doc_emb[idx:idx+c_doc.shape[0], :] = c_doc
            attention_weights[idx:idx+c_doc.shape[0], :] = alpha
            updated_sent_repr[idx:idx+c_doc.shape[0], :, :] = c_sent

            idx += c_doc.shape[0]

            doc_class, _ = docToClassSet(c_doc, class_set)
            if doc_predictions is None:
                doc_predictions = doc_class
            else:
                doc_predictions = np.append(doc_predictions, doc_class)
    
    updated_class_weights = getTargetClassSet(updated_sent_repr, doc_lengths, class_set, attention_weights)

    return doc_predictions, final_doc_emb, updated_sent_repr, attention_weights, updated_class_weights

In [120]:
args.epochs = 4 # NYT-Fine IS 5
args.lr = 1e-3
args.temp = 0.1
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, padded_sent_repr, sentence_mask, class_repr, init_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 525/525 [00:17<00:00, 29.19it/s]
 25%|██▌       | 1/4 [00:17<00:53, 17.99s/it]

Average training loss: 2.0641584396362305


100%|██████████| 525/525 [00:15<00:00, 33.59it/s]
 50%|█████     | 2/4 [00:33<00:33, 16.62s/it]

Average training loss: 1.7929142713546753


100%|██████████| 525/525 [00:14<00:00, 35.20it/s]
 75%|███████▌  | 3/4 [00:48<00:15, 15.85s/it]

Average training loss: 1.3813471794128418


100%|██████████| 525/525 [00:15<00:00, 34.00it/s]
100%|██████████| 4/4 [01:04<00:00, 16.01s/it]


Average training loss: 1.1902844905853271
Starting to evaluate!


100%|██████████| 525/525 [00:20<00:00, 25.78it/s]
100%|██████████| 33594/33594 [00:08<00:00, 3886.51it/s]


In [121]:
# epochs: 4
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.5150026790498303
F1 macro: 0.518173483949235


{'f1_micro': 0.5150026790498303, 'f1_macro': 0.518173483949235}

### Second Iteration ###

In [139]:
curr_class_weights = getTargetClassSet(updated_sent_repr, doc_lengths, class_set, alpha=updated_sent_weights, set_weights=None)
evaluate_predictions(gold_labels, np.argmax(curr_class_weights, axis=1))

100%|██████████| 33594/33594 [13:01<00:00, 43.01it/s]


F1 micro: 0.5281300232184318
F1 macro: 0.5258843552347254


{'f1_micro': 0.5281300232184318, 'f1_macro': 0.5258843552347254}

In [33]:
args.epochs = 4
args.lr = 1e-3
args.temp = 0.1
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, padded_sent_repr, sentence_mask, class_repr, curr_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 280/280 [00:07<00:00, 37.56it/s]
 25%|██▌       | 1/4 [00:07<00:22,  7.46s/it]

Average training loss: 1.624432921409607


100%|██████████| 280/280 [00:07<00:00, 35.93it/s]
 50%|█████     | 2/4 [00:15<00:15,  7.66s/it]

Average training loss: 1.3301043510437012


100%|██████████| 280/280 [00:07<00:00, 35.44it/s]
 75%|███████▌  | 3/4 [00:23<00:07,  7.78s/it]

Average training loss: 0.7405486106872559


100%|██████████| 280/280 [00:07<00:00, 36.34it/s]
100%|██████████| 4/4 [00:30<00:00,  7.73s/it]


Average training loss: 0.4643602669239044
Starting to evaluate!


100%|██████████| 280/280 [00:10<00:00, 27.59it/s]
100%|██████████| 17871/17871 [00:06<00:00, 2844.55it/s]


In [355]:
# epochs: 4; FOURTH ITER (USING INITIAL REPRESENTATIONS)
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7865256560908733
F1 macro: 0.7732321095083541


{'f1_micro': 0.7865256560908733, 'f1_macro': 0.7732321095083541}

In [345]:
# epochs: 4; THIRD ITER (USING INITIAL REPRESENTATIONS)
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7814895640982596
F1 macro: 0.7692942005311851


{'f1_micro': 0.7814895640982596, 'f1_macro': 0.7692942005311851}

In [34]:
# epochs: 4 SECOND ITER (USING INITIAL REPRESENTATIONS)
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7736556432208607
F1 macro: 0.7602823259977058


{'f1_micro': 0.7736556432208607, 'f1_macro': 0.7602823259977058}

In [25]:
def printClasses(x):
    for i in np.arange(num_classes):
        if x[i] > 0:
            print(f'{class_names[i]}: {x[i]}', end=" ")

In [34]:
doc_id = 1
print("Initial CO Class Weights:") 
printClasses(init_class_weights[doc_id])
# print("Initial PLM Class Weights: ", init_plm_class_weights[doc_id])

print("Updated Class Weights:")
printClasses(updated_class_weights[doc_id])

print("True Class: ", class_names[gold_labels[doc_id]])
tok_sents = sent_dict[str(doc_id)]
chosen_class = np.argmax(cosine_similarity(updated_sent_repr[doc_id, :len(tok_sents)], class_repr), axis=1)
sent_classes = cosine_similarity(updated_sent_repr[doc_id, :len(tok_sents)], class_repr)

for s,c,cw,w in zip(tok_sents, class_names[chosen_class.astype(int)], sent_classes, updated_sent_weights[doc_id, :len(tok_sents)]):
    print(f'{c} ({cw}) {w}: {s}')

Initial CO Class Weights:
computer: 0.6559191414637636 sports: 0.09092395911607942 politics: 0.25315689942015696 Updated Class Weights:
computer: 0.8252974152565002 politics: 0.17470255494117737 True Class:  science
computer ([ 0.22602685  0.04150834  0.0242996  -0.06390868 -0.11698047]) 0.2958510220050812: re lyme vaccine nntp-posting-host ucrengr x-newsreader tin version 1.1 pl8 jeff, if you have time to type it in i'd love to have the reference for that paper!
computer ([ 0.03197761 -0.06128826 -0.04751204 -0.04008535 -0.12241898]) 0.20030459761619568: thanks!
computer ([ 0.12055883  0.05709751 -0.00104686 -0.02052284 -0.10099893]) 0.3291417956352234: -- kathleen richards email sometimes you're the windshield, sometimes you're the bug!
politics ([-0.0189478  -0.06962198 -0.06771603 -0.00329312 -0.06759918]) 0.17470255494117737: -dire straits


In [34]:
# sentence centrality
doc_id = 46
print("Initial CO Class Weights:") 
printClasses(init_class_weights[doc_id])
# print("Initial PLM Class Weights: ", init_plm_class_weights[doc_id])

print("Updated Class Weights:")
printClasses(curr_class_weights[doc_id])

print("True Class: ", class_names[gold_labels[doc_id]])
tok_sents = sent_dict[str(doc_id)]
r = padded_sent_repr
chosen_class = np.argmax(cosine_similarity(r[doc_id, :len(tok_sents)], class_repr), axis=1)
sent_classes = cosine_similarity(r[doc_id, :len(tok_sents)], class_repr)

for s,c,cw,w in zip(tok_sents, class_names[chosen_class.astype(int)], sent_classes, init_sent_weights[doc_id, :len(tok_sents)]):
    print(f'{c} ({cw}) {w}: {s}')

Initial CO Class Weights:
business: 0.30759900965269593 politics: 0.41130068009186277 sports: 0.0035528031148018913 education: 0.17317906541533593 estate: 0.03225024527217161 science: 0.044217051177306024 technology: 0.027901145275826073 Updated Class Weights:
business: 0.28279584739357233 politics: 0.4073226824402809 education: 0.27804237697273493 estate: 0.017167776823043823 arts: 0.014671259559690952 True Class:  education
education ([0.57640734 0.57135406 0.55826009 0.46610187 0.66778041 0.5121223
 0.53071282 0.58463734 0.59064805]) 0.029201084216249156: in the first year of the new russia , a group of students enrolled in a new western style graduate school to learn a long forbidden academic discipline market economics .
business ([0.3393767  0.32187175 0.31154182 0.27653444 0.31493925 0.29866308
 0.30016655 0.30211372 0.32391569]) 0.005853290864809215: they were heady days .
politics ([0.47754115 0.54335834 0.45258943 0.44714381 0.43593186 0.42433767
 0.46249914 0.43247215 0.4578

## Run PCA on Document Embeddings & Fit GMM ##

In [134]:
_pca = PCA(n_components=16, random_state=args.random_state)
pca_doc_repr = _pca.fit_transform(final_doc_emb)
# pca_sent_repr = _pca.fit_transform(updated_sent_repr.reshape((-1, 768)))
# pca_sent_repr = pca_sent_repr.reshape((-1, args.max_sent, 64))
# pca_doc_repr = _pca.fit_transform(np.sum(padded_sent_repr, axis=1)/doc_lengths.reshape((-1, 1)))
# pca_doc_repr = _pca.fit_transform(doc_repr)
# pca_class_repr = [_pca.transform(class_set[i]) for i in np.arange(len(class_set))]
pca_class_repr = [_pca.transform(init_class_set[i]) for i in np.arange(len(init_class_set))]
# pca_class_repr = _pca.transform(plm_class_repr)
print(f"Explained document variance: {sum(_pca.explained_variance_ratio_)}")

Explained document variance: 0.6766643806335045


In [135]:
doc_class_assignment, doc_class_cos = docToClassSet(pca_doc_repr, pca_class_repr, None)
print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)
# doc_class_assignment, doc_class_cos = docToClassSet(final_doc_emb, class_repr.reshape((26, 1, -1)), None)
# NEW: Use regularized softmax to transform cosine-similarities into probabilities and rank confident docs based on top-two ratio
reg_temp = 0.2
doc_class_dist = np.exp(doc_class_cos/reg_temp)/np.sum(np.exp(doc_class_cos/reg_temp), axis=1).reshape(-1,1)

# doc_assignment_probs = doc_class_dist[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]
# doc_toptwo = np.partition(doc_class_dist, -2)[:, -2:] # N x 2
# doc_toptwo = np.partition(doc_class_cos, -2)[:, -2:] # N x 2
# doc_toptwo = doc_toptwo[:, 1] - doc_toptwo[:, 0] # N3

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.5294993153539322
F1 macro: 0.5217685750130246


In [103]:
doc_id = 5
print(class_names[gold_labels[doc_id]],": ", cleaned_text[doc_id])
for idx, c in enumerate(class_names):
    print(c, ": ", doc_class_dist[doc_id][idx])

science :  Re Migraines and Estrogen In article (Peggy Wageman) writes I read that hormonal fluctuations can contribute to migraines, could taking supplemental estrogen (ERT) cause migraines? Any information I'm not sure it is the fluctuation so much as the estrogen level. Taking Premarin can certainly cause migraines in some women. -- ---------------------------------------------------------------------------- Gordon Banks N3JXP "Skepticism is the chastity of the intellect, and it is shameful to surrender it too soon." ----------------------------------------------------------------------------
computer :  0.00477684940169282
sports :  0.06798844065948896
science :  0.7312896166292748
politics :  0.0748134569428074
religion :  0.12113163636673609


In [29]:
# class-oriented final doc
# cosine_similarities = cosine_similarity(pca_doc_repr, pca_class_repr)
# doc_class_assignment = np.argmax(cosine_similarities, axis=1)
# doc_class_probs = cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]
# doc_class_probs = updated_class_weights[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]
doc_class_assignment = docToClassSet(pca_doc_repr, pca_class_repr, None)

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.7898270941749203
F1 macro: 0.7651374709330063


{'confusion': [[4713, 47, 73, 24, 34],
  [71, 3703, 55, 105, 45],
  [1309, 136, 1771, 587, 149],
  [24, 81, 34, 1678, 808],
  [12, 17, 18, 127, 2250]],
 'f1_micro': 0.7898270941749203,
 'f1_macro': 0.7651374709330063}

In [146]:
# class-oriented average
cosine_similarities = cosine_similarity(pca_doc_repr, pca_class_repr)
doc_class_assignment = np.argmax(cosine_similarities, axis=1)
doc_class_probs = cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.8893947368421051
F1 macro: 0.8893947349271942


{'confusion': [[16901, 2099], [2104, 16896]],
 'f1_micro': 0.8893947368421051,
 'f1_macro': 0.8893947349271942}

## Save Pseudo Training Dataset ##

In [136]:
def write_to_dir(text, labels, prob, data_path, new_data_path):
    assert len(text) == len(labels)
    print("Saving files in:", new_data_path)
    
    with open(os.path.join(new_data_path, "dataset.txt"), "w") as f:
        for i, line in enumerate(text):
            f.write(line)
            f.write("\n")

    with open(os.path.join(new_data_path, "labels.txt"), "w") as f:
        for i, line in enumerate(labels):
            f.write(str(line))
            f.write("\n")

    with open(os.path.join(new_data_path, "probs.txt"), "w") as f:
        for i, line in enumerate(prob):
            # f.write(str(line))
            f.write(",".join(map(str, line)))
            f.write("\n")

    copyfile(os.path.join(data_path, "classes.txt"),
             os.path.join(new_data_path, "classes.txt"))

In [137]:
def generateDataset(documents_to_class, ranks, class_dist, num_classes, cleaned_text, gold_labels, data_path, new_data_path, thresh=0.5):
    pseudo_document_class_with_confidence = [[] for _ in range(num_classes)]
    for i in range(documents_to_class.shape[0]):
        pseudo_document_class_with_confidence[documents_to_class[i]].append((ranks[i], i))

    selected = []
    confident_documents = [[] for _ in range(num_classes)]

    for i in range(num_classes):
        pseudo_document_class_with_confidence[i] = sorted(pseudo_document_class_with_confidence[i], key=lambda x: x[0], reverse=True)
        num_docs_to_take = int(len(pseudo_document_class_with_confidence[i]) * thresh)
        confident_documents[i] = pseudo_document_class_with_confidence[i][:num_docs_to_take]
        selected.extend([x[1] for x in confident_documents[i]])
    
    selected = sorted(selected)
    text = [cleaned_text[i] for i in selected]
    classes = [documents_to_class[i] for i in selected]
    probs = [class_dist[i] for i in selected]
    ###
    gold_classes = [gold_labels[i] for i in selected]
    evaluate_predictions(gold_classes, classes)
    ###
    # write_to_dir(text, classes, probs, data_path, new_data_path)
    return confident_documents

def updateClassSet(confident_docs, doc_emb, class_set):
    for i in np.arange(len(class_set)):
        doc_ind = [d[1] for d in confident_docs[i]]
        class_set[i] = np.concatenate((np.expand_dims(class_set[i][0], axis=0), doc_emb[doc_ind, :]), axis=0)
    return
        

In [138]:
thresh = 0.2
reg_temp = 0.2
doc_class_dist = np.exp(doc_class_cos/reg_temp)/np.sum(np.exp(doc_class_cos/reg_temp), axis=1).reshape(-1,1)
doc_rank = np.max(doc_class_cos, axis=1)

finalconf = generateDataset(doc_class_assignment, doc_rank, doc_class_dist, num_classes, cleaned_text, gold_labels, data_path, new_data_path, thresh)
updateClassSet(finalconf, final_doc_emb, class_set)

F1 micro: 0.7200714711137582
F1 macro: 0.7085271151734461


In [23]:
for i in np.arange(len(finalconf)):
    print(f'{class_names[i]}: {len(finalconf[i])}')

business: 708
politics: 673
sports: 231
health: 205
education: 162
estate: 190
arts: 91
science: 23
technology: 111


In [26]:
class_names

array(['computer', 'sports', 'science', 'politics', 'religion'],
      dtype='<U8')

In [27]:
# class_id = 2
for class_id in np.arange(len(class_names)):
    # results = [doc_class_assignment[i[1]] == gold_labels[i[1]] for i in finalconf[class_id][:50]]
    # print(np.sum(results)/len(results))
    all_results = [doc_class_assignment[i[1]] == gold_labels[i[1]] for i in finalconf[class_id]]
    print(class_id, class_names[class_id], ": ", np.sum(all_results)/len(all_results))

0 computer :  0.9274725274725275
1 sports :  1.0
2 science :  0.9886363636363636
3 politics :  0.8904109589041096
4 religion :  0.9903381642512077


In [43]:
doc_id = finalconf[4][0][1]
doc_id, cleaned_text[doc_id], class_names[doc_class_assignment[doc_id]], doc_class_dist[doc_id], class_names[gold_labels[doc_id]]

(5058,
 'Re Monophysites and Mike Walker Nabil Ayoub writes As a final note, the Oriental Orthodox and Eastren Orthodox did sign a common statement of Christology, in which the heresey of Monophysitism was condemned. So the Coptic Orthodox Church does not believe in Monophysitism. Sorry! What does the Coptic Church believe about the will and energy of Christ? Were there one or were there two (i.e. Human and Divine) wills and energies in Him. Also, what is the objection ot the Copts with the Pope of Rome (i.e. why is there a Coptic Catholic Church)? Do you reject the supreme jurisdiction of the 263rd sucessor of St. Peter (who blessed St. John Mark, Bishop of Alexandria was translator for) and his predecessors? Or his infallibility? Or what other things perhaps? Andy Byler',
 'religion',
 array([0.00452341, 0.01535639, 0.01955483, 0.02291949, 0.93764589]),
 'religion')

In [57]:
cleaned_text[doc_id]

"medical journals are the prime source of information about scientific advances that can change how doctors treat patients in offices and in hospitals . and to ensure the quality of what journals publish , their editors , beginning 200 years ago , have increasingly called on scientific peers to review new findings from research in test tubes and on animals and humans . the system , known as peer_review , is now considered a linchpin of science . editors of the journals and many scientists consider the system 's expense and time consumption worthwhile in the belief that it weeds out shoddy work and methodological errors and blunts possible biases by scientific investigators . another main aim is to prevent authors from making claims that cannot be supported by the evidence they report . yet for all its acclaim , the system has long been controversial . despite its system of checks and balances , a number of errors , plagiarism and even outright fraud have slipped through it . at the sam

In [64]:
print("Initial CO Class Weights: ", init_class_weights[doc_id])

print("Updated Class Weights: ", updated_class_weights[doc_id])
print("True Class: ", gold_labels[doc_id])
tok_sents = sent_dict[str(doc_id)]

sent_class_cos = cosine_similarity(padded_sent_repr[doc_id, :len(tok_sents)], class_repr)
reg_temp = 0.5
sent_class_dist = np.exp(sent_class_cos/reg_temp)/np.sum(np.exp(sent_class_cos/reg_temp), axis=1).reshape(-1,1)

chosen_class = np.argmax(sent_class_cos, axis=1)
# sent_classes = cosine_similarity(padded_sent_repr[doc_id, :len(tok_sents)], class_repr)

for c,cw,w,s in zip(class_names[chosen_class.astype(int)], sent_class_dist, updated_sent_weights[doc_id, :len(tok_sents)], tok_sents):
    print(f'{c} ({cw}) {w}: {s}')

Initial CO Class Weights:  [0.0036531  0.         0.         0.05012046 0.         0.
 0.         0.94373815 0.00248829]
Updated Class Weights:  [0.         0.         0.         0.         0.         0.
 0.         0.99999996 0.        ]
True Class:  3
science ([0.11058517 0.1084116  0.10210798 0.11702134 0.10733373 0.09862915
 0.10505607 0.12718393 0.12367104]) 0.009585914145700859: a new study in which mouse brains were examined says prolonged stress and drugs like those given to troops in the persian_gulf_war can have long lasting effects on the brain .
science ([0.11192069 0.11428407 0.10495972 0.11785007 0.10390442 0.10154774
 0.10696098 0.12862321 0.1099491 ]) 0.029937306097804286: the results , the researchers say , could help explain two controversial afflictions of some service members post_traumatic_stress syndrome and gulf_war_syndrome .
science ([0.1090294  0.10522888 0.10683364 0.11248761 0.10348834 0.09249359
 0.10123969 0.15034217 0.11885668]) 0.08042533166769485: the r

In [96]:
finalconf[4][2000]

(0.7317021510492036, 9181)

In [27]:
def doc_gmm(args, document_representations, doc_class_representations, gold_labels):
    num_classes = len(doc_class_representations)
    cosine_similarities = cosine_similarity(document_representations, doc_class_representations)
    doc_class_assignment = np.argmax(cosine_similarities, axis=1)

    print("Evaluate Document Cosine Similarity Predictions: ")
    evaluate_predictions(gold_labels, doc_class_assignment)

    # initialize gmm based on these selected documents
    document_class_assignment_matrix = np.zeros((document_representations.shape[0], num_classes))
    for i in np.arange(len(document_representations)): # iterate through docs and sents
        document_class_assignment_matrix[i][doc_class_assignment[i]] = 1.0

    gmm = GaussianMixture(n_components=num_classes, covariance_type='tied',
                          random_state=args.random_state,
                          n_init=999, warm_start=True)

    gmm._initialize(document_representations, document_class_assignment_matrix)
    gmm.lower_bound_ = -np.infty

    gmm.converged_ = True #HACK FOR NOT RANDOMLY INITIALIZING PARAMS DURING FIT
    gmm.fit(document_representations)

    documents_to_class = gmm.predict(document_representations)
    confidence = gmm.predict_proba(document_representations)

    print("Evaluate Document GMM Predictions: ")
    evaluate_predictions(gold_labels, documents_to_class)

    return documents_to_class, confidence

In [32]:
pca_class_repr[0].shape

(1, 64)

In [34]:
# 128
doc_preds, doc_prob = doc_gmm(args, pca_doc_repr, np.array(pca_class_repr).squeeze(1), gold_labels)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.7002218958027315
F1 macro: 0.5808469084409915
Evaluate Document GMM Predictions: 
F1 micro: 0.7940431915492078
F1 macro: 0.6569828627562839
