## Setup ##

In [53]:
import numpy as np
import os
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
from torch import nn
import sys
import torch.nn.functional as F
import argparse
import re
import pickle as pk
from tqdm import tqdm
from scipy.special import softmax
from sklearn.decomposition import PCA
from shutil import copyfile
from sklearn.mixture import GaussianMixture
from sklearn.mixture._gaussian_mixture import _estimate_gaussian_parameters
from nltk.tokenize import sent_tokenize, word_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import confusion_matrix, f1_score
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

In [54]:
class Args(object):
    pass

args = Args()
args.dataset_name = "yelp"
args.gpu = 1
args.pca = 128
args.random_state = 42
args.emb_dim = 768
args.num_heads = 2
args.batch_size = 64
args.temp = 0.2
args.lr = 1e-3
args.epochs = 5
args.accum_steps = 1
args.max_sent = 150

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

In [55]:
data_path = os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name)
new_data_path = os.path.join("/home/pk36/MEGClass/intermediate_data", args.dataset_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# this is where we save all representations
if not os.path.exists(new_data_path):
    os.makedirs(new_data_path)

## Helper Functions ##

In [56]:
DATA_FOLDER_PATH = os.path.join("/shared/data2/pk36/multidim/multigran")
INTERMEDIATE_DATA_FOLDER_PATH = os.path.join("/home/pk36/MEGClass/intermediate_data")

def tensor_to_numpy(tensor):
    return tensor.clone().detach().cpu().numpy()

def sentenceToClass(sent_repr, class_repr, weights):
    # sent_repr: N x S x E
    # class_repr: C x E
    # weights: N x S # equals 0 for masked sentences

    #cos-sim between (N x S) x E and (C x E) = N x S x C
    m, n = sent_repr.shape[:2]
    sentcos = cosine_similarity(sent_repr.reshape(m*n,-1), class_repr).reshape(m,n,-1)
    sent_to_class = np.argmax(sentcos, axis=2) # N x S
    sent_to_doc_class = np.sum(np.multiply(sent_to_class, weights), axis=1) # N x 1
    return sent_to_doc_class

def docToClass(doc_repr, class_repr):
    # doc_repr: N x E
    # class_repr: C x E

    #cos-sim between N x E and C x E = N x C
    doccos = cosine_similarity(doc_repr, class_repr)
    doc_to_class = np.argmax(doccos, axis=1) # N x 1
    return doc_to_class

def evaluate_predictions(true_class, predicted_class, output_to_console=True, return_tuple=False, return_confusion=False):
    confusion = confusion_matrix(true_class, predicted_class)
    if return_confusion and output_to_console:
        print("-" * 80 + "Evaluating" + "-" * 80)
        print(confusion)
    
    f1_micro = f1_score(true_class, predicted_class, average='micro')
    f1_macro = f1_score(true_class, predicted_class, average='macro')
    if output_to_console:
        print("F1 micro: " + str(f1_micro))
        print("F1 macro: " + str(f1_macro))
    if return_tuple:
        return confusion, f1_macro, f1_micro
    else:
        return {
            "confusion": confusion.tolist(),
            "f1_micro": f1_micro,
            "f1_macro": f1_macro
        }
    
def getSentClassRepr(args):
    with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, f"document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
        dictionary = pk.load(f)
        class_repr = dictionary["class_representations"]
        sent_repr = dictionary["sent_representations"]
    return sent_repr, class_repr


def getDSMapAndGold(args, sent_dict):
    # get the ground truth labels for all documents and assign a "ground truth" label to each sentence based on its parent document
    gold_labels = list(map(int, open(os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name, "labels.txt"), "r").read().splitlines()))
    gold_sent_labels = []
    # get all sent ids for each doc
    doc_to_sent = []
    sent_id = 0
    for doc_id, doc in enumerate(sent_dict.values()):
        sent_ids = []
        for sent in doc:
            sent_ids.append(sent_id)
            gold_sent_labels.append(gold_labels[doc_id])
            sent_id += 1
        doc_to_sent.append(sent_ids)
            
    return gold_labels, gold_sent_labels, doc_to_sent

## BERT-Based Sentence Embeddings, Initial Doc Embeddings, & Class Representations ##

In [57]:
def bertSentenceEmb(args, doc_to_sent, sent_repr):
    num_docs = len(doc_to_sent)
    doc_lengths = np.zeros(num_docs, dtype=int)
    # init_doc_repr = np.zeros((num_docs, args.emb_dim))
    padded_sent_repr = np.zeros((num_docs, args.max_sent, args.emb_dim))
    sentence_mask = np.ones((num_docs, args.max_sent))
    trimmed = 0


    for doc_id in tqdm(np.arange(num_docs)):
        start_sent = doc_to_sent[doc_id][0]
        end_sent = doc_to_sent[doc_id][-1]
        num_sent = end_sent - start_sent + 1
        if num_sent > args.max_sent:
            end_sent = start_sent + args.max_sent - 1
            num_sent = args.max_sent
            trimmed += 1
        embeddings = sent_repr[start_sent:end_sent+1]

        # save the number of sentences in each document
        doc_lengths[doc_id] = int(num_sent)

        # Add initial doc representation
        # init_doc_repr[doc_id, :] = np.mean(embeddings, axis=0)
        # Add padded sentences
        padded_sent_repr[doc_id, :embeddings.shape[0], :] = embeddings
        # Update mask so that padded sentences are not included in attention computation
        sentence_mask[doc_id, :num_sent] = 0

    
    print(f"Trimmed Documents: {trimmed}")

    return padded_sent_repr, doc_lengths, sentence_mask

### Get Class Weights ###

In [58]:
def getTargetClasses(padded_sent_repr, doc_lengths, class_repr, weights=None):
    # weights: N x 150
    class_weights = np.zeros((padded_sent_repr.shape[0], class_repr.shape[0])) # N x C
    sent_weights = np.zeros(padded_sent_repr.shape[:2])

    for doc_id in tqdm(np.arange(padded_sent_repr.shape[0])):
        l = doc_lengths[doc_id]
        sent_emb = padded_sent_repr[doc_id, :l, :] # S x E
        sentcos = cosine_similarity(sent_emb, class_repr) # S x C
        sent_to_class = np.argmax(sentcos, axis=1) # S
        
        # default: equal vote weight between all sentences
        if weights is None:
            # w = np.ones(doc_lengths[doc_id])/doc_lengths[doc_id]
            # top cos-sim - second cos-sim
            toptwo = np.partition(sentcos, -2)[:, -2:] # S x 2
            toptwo = toptwo[:, 1] - toptwo[:, 0] # S
            w = toptwo / np.sum(toptwo)
            sent_weights[doc_id, :l] = w
        else:
            w = weights[doc_id, :l]
        
        class_weights[doc_id, :] = np.bincount(sent_to_class, weights=w, minlength=class_repr.shape[0])

    if weights is None:
        return class_weights, sent_weights
    else:
        return class_weights

### Get initial embeddings and class representations + gold labels ###

In [59]:
sent_repr, class_repr = getSentClassRepr(args)
num_classes = class_repr.shape[0]
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "dataset.pk"), "rb") as f:
    dataset = pk.load(f)
    sent_dict = dataset["sent_data"]
    cleaned_text = dataset["cleaned_text"]
    class_names = np.array(dataset["class_names"])
gold_labels, gold_sent_labels, doc_to_sent = getDSMapAndGold(args, sent_dict)

In [60]:

padded_sent_repr, doc_lengths, sentence_mask = bertSentenceEmb(args, doc_to_sent, sent_repr)
init_class_weights, init_sent_weights = getTargetClasses(padded_sent_repr, doc_lengths, class_repr, None)
padded_sent_repr.shape, doc_lengths.shape, sentence_mask.shape, init_class_weights.shape, init_sent_weights.shape

100%|██████████| 38000/38000 [00:00<00:00, 48620.76it/s]


Trimmed Documents: 0


100%|██████████| 38000/38000 [00:09<00:00, 4009.55it/s]


((38000, 150, 768), (38000,), (38000, 150), (38000, 2), (38000, 150))

### Analyze and Evaluate Class Weights ###

In [61]:
evaluate_predictions(gold_labels, np.argmax(init_class_weights, axis=1))

F1 micro: 0.7816578947368421
F1 macro: 0.7744238318317597


{'confusion': [[11449, 7551], [746, 18254]],
 'f1_micro': 0.7816578947368421,
 'f1_macro': 0.7744238318317597}

## Model ##

In [62]:
class MEGClassModel(nn.Module):
    def __init__(self, D_in, D_hidden, head, dropout=0.0):
        super(MEGClassModel, self).__init__()
        self.mha = nn.MultiheadAttention(embed_dim=D_in, num_heads=head, dropout=dropout, batch_first=True)
        self.layernorm = nn.LayerNorm(D_in)
        # self.embd = nn.Sequential(
        #     nn.Linear(D_in, 2*D_in),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(2*D_in, D_in))
        self.embd = nn.Linear(D_in,D_hidden)
        self.attention = nn.Linear(D_hidden,1)
        
    def forward(self, x_org, mask=None):
        x, mha_w = self.mha(x_org,x_org,x_org,key_padding_mask=mask)
        x = self.layernorm(x_org+x)
        
        x = self.embd(x)
        x = torch.tanh(x) # contextualized sentences
        a = self.attention(x)
        if mask is not None:
            a = a.masked_fill_((mask == 1).unsqueeze(-1), float('-inf'))
        w = torch.softmax(a, dim=1) # alpha_k
        o = torch.matmul(w.permute(0,2,1), x) #doc 
        return o, mha_w, w, x # contextualized doc, multi-head attention weights, alpha_k, contextualized sent


## Contrastive Loss ##

In [63]:
def contrastive_loss(sample_outputs, class_indices, class_embds, temp=0.2):
    k = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/temp)
    loss = -1*(torch.log(k[np.arange(len(class_indices)),class_indices]/k.sum(1))).sum()
    return loss/len(sample_outputs)

def weighted_contrastive_loss(args, sample_outputs, class_weights, class_embds):
    # k: B x C, class_weights: B x C
    k = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/args.temp)
    weighted_loss = -1 * (torch.log(k/(k.sum(dim=1).unsqueeze(-1))) * class_weights).sum() # B x C -> B
    return weighted_loss/len(sample_outputs)

def weighted_class_contrastive_loss(sample_outputs, class_weights, class_embds, temp=0.2):
    # k: B x C, class_weights: B x C, class_embds: C x E
    target = class_weights @ class_embds # B x E
    target_cos = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], target).diagonal(offset=0)/temp)
    k = torch.exp(torch.nn.functional.cosine_similarity(sample_outputs[:,None], class_embds, axis=2)/temp)
    loss = -1*(torch.log(target_cos/k.sum(1))).sum()
    return loss/len(sample_outputs)


## Train ##

In [64]:
def contextEmb(args, sent_representations, mask, class_repr, class_weights, 
                doc_lengths, new_data_path, device):
    sent_representations = torch.from_numpy(sent_representations)
    mask = torch.from_numpy(mask).to(torch.bool)
    class_weights = torch.from_numpy(class_weights)
    dataset = TensorDataset(sent_representations, mask, class_weights)
    sampler = SequentialSampler(dataset)
    dataset_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, shuffle=False)
    # sent_representations: N docs x L sentences x 768 emb (L with padding is always max_sents=50)
    model = MEGClassModel(args.emb_dim, args.emb_dim, args.num_heads).to(device)

    total_steps = len(dataset_loader) * args.epochs / args.accum_steps
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)

    print("Starting to train!")

    for i in tqdm(range(args.epochs)):
        total_train_loss = 0

        for batch in tqdm(dataset_loader):
            model.train()
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)
            input_weights = batch[2].to(device).float()
            
            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)

            loss = weighted_contrastive_loss(args, c_doc, input_weights, torch.from_numpy(class_repr).float().to(device)) / args.accum_steps

            total_train_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()

        avg_train_loss = total_train_loss / len(dataset_loader)
        print(f"Average training loss: {avg_train_loss}")

    model.eval()

    torch.save(model.state_dict(), os.path.join(new_data_path, f"{args.dataset_name}_model_e{args.epochs}.pth"))

    print("Starting to evaluate!")

    evalsampler = SequentialSampler(dataset)
    eval_loader = DataLoader(dataset, sampler=evalsampler, batch_size=args.batch_size, shuffle=False)

    doc_predictions = None
    attention_weights = np.zeros_like(mask, dtype=float)
    updated_sent_repr = np.zeros_like(sent_representations)
    final_doc_emb = np.zeros((len(class_weights), args.emb_dim))
    idx = 0

    with torch.no_grad():
        for batch in tqdm(eval_loader):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)

            c_doc, _, alpha, c_sent = model(input_emb, mask=input_mask)
            c_doc = c_doc.squeeze(1)
            c_sent, c_doc, alpha = tensor_to_numpy(c_sent), tensor_to_numpy(c_doc), tensor_to_numpy(torch.squeeze(alpha, dim=2))

            final_doc_emb[idx:idx+c_doc.shape[0], :] = c_doc
            attention_weights[idx:idx+c_doc.shape[0], :] = alpha
            updated_sent_repr[idx:idx+c_doc.shape[0], :, :] = c_sent

            idx += c_doc.shape[0]

            doc_class = docToClass(c_doc, class_repr)
            if doc_predictions is None:
                doc_predictions = doc_class
            else:
                doc_predictions = np.append(doc_predictions, doc_class)
    
    updated_class_weights = getTargetClasses(updated_sent_repr, doc_lengths, class_repr, attention_weights)

    return doc_predictions, final_doc_emb, updated_sent_repr, attention_weights, updated_class_weights

## Evaluate ##

In [32]:
updated_sent_weights[0]

array([ True,  True,  True,  True,  True,  True,  True, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False,

In [104]:
topsentids = np.argmax(updated_sent_weights, axis=1)
for idx in np.arange(10):
    print(sent_dict[str(idx)][topsentids[idx]])

secretary_general_kofi_annan said today that iraq has cooperated up to now with united_nations weapons inspectors , so it would be ''premature'' to take military action before inspectors report back on their investigations to the security_council on jan . 27 .
of all those expelled , mr . hogan is putting up the biggest fight against the restrictions , asking a federal_judge to overrule the board and mounting a public_relations battle .
the school plans to offer scholarships for families who cannot afford tuition , ms . friedman said .
''we 've had to deal with very serious issues , '' mr . mccartney said .
it also came as china has worked hard to strengthen ties with the bush_administration .
''the american administration is trying to create some pretexts to attack iraq , to exercise their aggression against iraq , '' he said .
an adviser to mr . edwards said that campaign had deliberately chosen the day after new year 's to go on ''today , '' because the senator 's aides figured , co

In [121]:
# Find really confident sentences
def getConfidentSent(weights, sent_emb, class_repr):
    topsentids = np.argmax(weights, axis=1)
    topembs = np.zeros((sent_emb.shape[0], sent_emb.shape[2]))
    for doc_id in np.arange(sent_emb.shape[0]):
        topembs[doc_id, :] = sent_emb[doc_id, topsentids[doc_id], :]

    topcossim = cosine_similarity(topembs, class_repr) # N x C
    toppreds = np.argmax(topcossim, axis=1)
    class_sents = [[] for c in range(class_repr.shape[0])]

    for id, p in enumerate(toppreds):
        # doc id, sent id, sent weight, sent emb
        class_sents[p].append((id, topsentids[id], weights[id, topsentids[id]], topembs[id]))
    for c in np.arange(len(class_sents)):
        total = int(len(class_sents[c])*0.05)
        class_sents[c] = sorted(class_sents[c], key=lambda x: x[2], reverse=True)[:total]

    return class_sents

In [133]:
class_sents = getConfidentSent(init_class_weights, padded_sent_repr, class_repr)

In [134]:
for i in range(len(class_sents)):
    print(len(class_sents[i]))
    top = class_sents[i][1]
    print(class_sents[i][0][2], class_sents[i][1][2])
    print(sent_dict[str(top[0])][top[1]])

518
1.0000000000000002 1.0000000000000002
the wpp group 's acquisition of the struggling british advertising company cordiant communications group may not be a done deal after all .
625
1.0000000000000007 1.0000000000000004
he told them , aides said , that he wanted to hold a full scale news conference a few hours later .
111
1.0000000000000002 1.0000000000000002
kiplagat , born in kenya , finished third in new york last year .
69
1.0000000000000002 1.0000000000000002
if a scheduled appointment was missed , the health visitor would make a home visit to make sure all was well and to ensure that appointments for checkups and immunizations were kept .
75
1.0000000000000002 1.0000000000000002
more than half of private_school students attend catholic schools , while about a third are enrolled in other religious schools .
85
1.0000000000000002 1.0000000000000002
ft . co op in a loft building elevator , dining_room , den , office , high ceilings , renovated_kitchen , 3 exposures maintenance 1

In [65]:
args.epochs = 5
args.lr = 1e-3
args.temp = 0.2
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, padded_sent_repr, sentence_mask, class_repr, init_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 594/594 [00:21<00:00, 27.93it/s]
 20%|██        | 1/5 [00:21<01:25, 21.28s/it]

Average training loss: 0.7060002684593201


100%|██████████| 594/594 [00:18<00:00, 32.37it/s]
 40%|████      | 2/5 [00:39<00:58, 19.56s/it]

Average training loss: 0.5426245331764221


100%|██████████| 594/594 [00:17<00:00, 33.21it/s]
 60%|██████    | 3/5 [00:57<00:37, 18.81s/it]

Average training loss: 0.40575507283210754


100%|██████████| 594/594 [00:18<00:00, 31.95it/s]
 80%|████████  | 4/5 [01:16<00:18, 18.73s/it]

Average training loss: 0.3907454311847687


100%|██████████| 594/594 [00:18<00:00, 31.72it/s]
100%|██████████| 5/5 [01:34<00:00, 18.98s/it]


Average training loss: 0.38512274622917175
Starting to evaluate!


100%|██████████| 594/594 [00:19<00:00, 30.08it/s]
100%|██████████| 38000/38000 [00:08<00:00, 4368.39it/s]


In [67]:
doc_pred = np.rint(doc_to_class)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Document-Based): 
F1 micro: 0.7907894736842105
F1 macro: 0.7853926828426849


{'confusion': [[12012, 6988], [962, 18038]],
 'f1_micro': 0.7907894736842105,
 'f1_macro': 0.7853926828426849}

In [68]:
(doc_pred == gold_labels)[:10], (gold_labels == np.argmax(init_class_weights, axis=1))[:10]

(array([False,  True, False, False,  True,  True,  True,  True,  True,
         True]),
 array([ True,  True, False, False,  True,  True,  True,  True,  True,
         True]))

In [75]:
init_sent_weights.shape

(38000, 150)

In [82]:
doc_id = 0
print("Class Weights: ", init_class_weights[doc_id])
tok_sents = sent_dict[str(doc_id)]
chosen_class = np.argmax(cosine_similarity(updated_sent_repr[doc_id, :len(tok_sents)], class_repr), axis=1)
sent_classes = cosine_similarity(updated_sent_repr[doc_id, :len(tok_sents)], class_repr)

for s,c,cw,w in zip(tok_sents, class_names[chosen_class.astype(int)], sent_classes, updated_sent_weights[doc_id, :len(tok_sents)]):
    print(f'{c} ({cw}) {w}: {s}')

Class Weights:  [0.47846081 0.52153919]
bad ([-0.04573185 -0.1081845 ]) 0.15926334261894226: contrary to other reviews, i have zero complaints about the service or the prices.
good ([-0.07114672 -0.03382205]) 0.16140052676200867: i have been getting tire service here for the past 5 years now, and compared to my experience with places like pep boys, these guys are experienced and know what they're doing.
bad ([-0.01363032 -0.11388189]) 0.1908542513847351: also, this is one place that i do not feel like i am being taken advantage of, just because of my gender.
bad ([ 0.0168174  -0.11937073]) 0.16015806794166565: other auto mechanics have been notorious for capitalizing on my ignorance of cars, and have sucked my bank account dry.
bad ([-0.00713463 -0.07132438]) 0.1108829602599144: but here, my service and road coverage has all been well explained - and let up to me to decide.
bad ([0.02093715 0.00960349]) 0.09704500436782837: and they just renovated the waiting room.
good ([-0.06087189 -

In [147]:
# Second Iteration:
doc_to_class, final_doc_emb, updated_sent_repr, updated_sent_weights, updated_class_weights = contextEmb(args, updated_sent_repr, sentence_mask, class_repr, updated_class_weights, doc_lengths, new_data_path, device)



Starting to train!


100%|██████████| 500/500 [00:17<00:00, 28.60it/s]
 20%|██        | 1/5 [00:17<01:09, 17.49s/it]

Average training loss: 2.2325527667999268


100%|██████████| 500/500 [00:17<00:00, 27.78it/s]
 40%|████      | 2/5 [00:35<00:53, 17.80s/it]

Average training loss: 0.7424776554107666


100%|██████████| 500/500 [00:18<00:00, 27.63it/s]
 60%|██████    | 3/5 [00:53<00:35, 17.95s/it]

Average training loss: 0.3872785270214081


100%|██████████| 500/500 [00:17<00:00, 28.08it/s]
 80%|████████  | 4/5 [01:11<00:17, 17.90s/it]

Average training loss: 0.3500744700431824


100%|██████████| 500/500 [00:18<00:00, 27.44it/s]
100%|██████████| 5/5 [01:29<00:00, 17.94s/it]


Average training loss: 0.3379160463809967
Starting to evaluate!


100%|██████████| 500/500 [00:19<00:00, 25.61it/s]
100%|██████████| 31997/31997 [00:15<00:00, 2073.08it/s]


Evaluate Predictions (Document-Based): 
F1 micro: 0.8020439416195269
F1 macro: 0.6456836188505892


### Analyze Results ###

In [130]:
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
    reprpickle = pk.load(f)
    class_words = reprpickle["class_words"]
with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, "dataset.pk"), "rb") as f:
    datapk = pk.load(f)
    sent_dict = datapk["sent_data"]
    class_names = np.array(datapk["class_names"])

In [135]:
class_words[6]

['arts',
 'artists',
 'art',
 'artistic',
 'artist',
 'artworks',
 'artwork',
 'painters',
 'paintings',
 'artistically',
 'painting',
 'masterworks',
 'impressionists',
 'masterpieces',
 'watercolors',
 'painter',
 'canvases',
 'compositions',
 'impressionist',
 'abstractions',
 'collages',
 'frescos',
 'painterly',
 'watercolor',
 'portraiture',
 'printmaking',
 'sculpture',
 'sculptures',
 'drawings',
 'nudes',
 'murals',
 'sculptors',
 'seascapes',
 'woodcuts',
 'expressionist',
 'expressionistic',
 'expressionism',
 'expressionists',
 'mosaics',
 'pictorial',
 'artsy',
 'prints',
 'impressionism',
 'visual',
 'collage',
 'sculptural',
 'impressionistic',
 'abstract',
 'creative',
 'sculptor',
 'painted',
 'muses',
 'surrealism',
 'conceptualism',
 'surrealist',
 'cubism',
 'etching',
 'artisan',
 'portraits',
 'montages',
 'cubist',
 'lifes',
 'renderings',
 'artistry',
 'landscapes',
 'gallery',
 'cartoonists',
 'experimentalist',
 'allegorical',
 'triptych',
 'dada',
 'decorator

In [38]:
correct_class = np.argmax(updated_class_weights, axis=1) == gold_labels
correct_class[20:40]

array([ True, False,  True,  True,  True,  True, False, False,  True,
        True, False,  True,  True,  True,  True,  True,  True,  True,
        True,  True])

In [150]:
init_class_weights[26], gold_labels[26], updated_class_weights[26]

(array([0.10631772, 0.52813451, 0.        , 0.29934825, 0.        ,
        0.06619952, 0.        , 0.        , 0.        ]),
 3,
 array([0.04851887, 0.66976335, 0.        , 0.28171776, 0.        ,
        0.        , 0.        , 0.        , 0.        ]))

In [None]:
args.epochs = 5
doc_to_class, final_doc_emb, updated_sent_repr, updated_class_weights = contextEmb(args, padded_sent_repr, sentence_mask, class_repr, init_class_weights, doc_lengths, new_data_path, device)

In [None]:
doc_to_class.shape, final_doc_emb.shape, updated_sent_repr.shape, updated_class_weights.shape

In [None]:
init_class_weights[0], updated_class_weights[0], gold_labels[0]

In [None]:
init_class_weights[1], updated_class_weights[1], gold_labels[1]

In [None]:
init_class_weights[2], updated_class_weights[2], gold_labels[2]

In [None]:
init_class_weights[3], updated_class_weights[3], gold_labels[3]

In [None]:
args.epochs = 5
print(padded_sent_repr.shape, sentence_mask.shape, class_repr.shape, init_doc_preds.shape)
# padded_sent_repr = padded_sent_repr[:2]
# sentence_mask = sentence_mask[:2]
# init_doc_preds = init_doc_preds[:2]
doc_to_class, doc_emb, weights = contextEmb(args, padded_sent_repr, sentence_mask, class_repr, init_doc_preds, new_data_path, device)

In [None]:
with open(data_path, "r") as f:
    text = f.read().splitlines()

In [None]:
sentence_mask[1]

In [None]:
pkdoc_repr = reprpickle["document_representations"]
pkdoc_repr.shape

In [None]:
pkdoc_cossim = cosine_similarity(pkdoc_repr, class_repr)
pkdoc_classes = np.argmax(pkdoc_cossim, axis=1)
evaluate_predictions(gold_labels, pkdoc_classes)
evaluate_predictions(gold_labels, init_doc_preds)

In [None]:
updated_class_weights.shape

In [None]:
weighted_votes = np.zeros(9)
for i in init_sent_preds[781:829]:
    weighted_votes[i] += weights[0][19][i]
weighted_votes

In [None]:
init_doc_preds[19], doc_to_class[19], gold_labels[19]

In [None]:
doc_to_sent[19][0], doc_to_sent[19][-1]

In [None]:
np.unique(init_sent_preds[781:829], return_counts=True)

In [None]:
(doc_to_class == gold_labels)[:20]

In [None]:
args.epochs = 5
print(padded_sent_repr.shape, sentence_mask.shape, class_repr.shape, init_doc_preds.shape)
doc_to_class, doc_emb = contextEmb(args, padded_sent_repr, sentence_mask, class_repr, init_doc_preds, new_data_path, device)

## Run PCA on Document Embeddings & Fit GMM ##

In [83]:
_pca = PCA(n_components=args.pca, random_state=args.random_state)
pca_doc_repr = _pca.fit_transform(final_doc_emb)
pca_class_repr = _pca.transform(class_repr)
print(f"Explained document variance: {sum(_pca.explained_variance_ratio_)}")

Explained document variance: 0.9488055329297386


In [87]:
cosine_similarities = cosine_similarity(pca_doc_repr, pca_class_repr)
doc_class_assignment = np.argmax(cosine_similarities, axis=1)
doc_class_probs = cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment]

print("Evaluate Document Cosine Similarity Predictions: ")
evaluate_predictions(gold_labels, doc_class_assignment)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.8396315789473684
F1 macro: 0.8396297593448594


{'confusion': [[15889, 3111], [2983, 16017]],
 'f1_micro': 0.8396315789473684,
 'f1_macro': 0.8396297593448594}

In [42]:
cosine_similarities[np.arange(pca_doc_repr.shape[0]), doc_class_assignment][:20]

array([0.71955292, 0.40952983, 0.71538326, 0.72139932, 0.16384949,
       0.66803061, 0.58444152, 0.51261272, 0.75677112, 0.18657242,
       0.75161684, 0.73295012, 0.62532666, 0.76548129, 0.60488691,
       0.34830207, 0.67427778, 0.71308677, 0.67438752, 0.73837452])

## Save Pseudo Training Dataset ##

In [85]:
def write_to_dir(text, labels, prob, data_path, new_data_path):
    assert len(text) == len(labels)
    print("Saving files in:", new_data_path)
    
    with open(os.path.join(new_data_path, "dataset.txt"), "w") as f:
        for i, line in enumerate(text):
            f.write(line)
            f.write("\n")

    with open(os.path.join(new_data_path, "labels.txt"), "w") as f:
        for i, line in enumerate(labels):
            f.write(str(line))
            f.write("\n")

    with open(os.path.join(new_data_path, "probs.txt"), "w") as f:
        for i, line in enumerate(prob):
            f.write(str(line))
            #f.write(",".join(map(str, line)))
            f.write("\n")

    copyfile(os.path.join(data_path, "classes.txt"),
             os.path.join(new_data_path, "classes.txt"))

In [86]:
def generateDataset(documents_to_class, prob, num_classes, cleaned_text, gold_labels, data_path, new_data_path):
    pseudo_document_class_with_confidence = [[] for _ in range(num_classes)]
    for i in range(documents_to_class.shape[0]):
        pseudo_document_class_with_confidence[documents_to_class[i]].append((prob[i], i))

    selected = []
    confidence_threshold = 0.5
    for i in range(num_classes):
        pseudo_document_class_with_confidence[i] = sorted(pseudo_document_class_with_confidence[i], key=lambda x: x[0], reverse=True)
        num_docs_to_take = int(len(pseudo_document_class_with_confidence[i]) * confidence_threshold)
        confident_documents = pseudo_document_class_with_confidence[i][:num_docs_to_take]
        confident_documents = [x[1] for x in confident_documents]
        selected.extend(confident_documents)
    
    selected = sorted(selected)
    text = [cleaned_text[i] for i in selected]
    classes = [documents_to_class[i] for i in selected]
    ###
    gold_classes = [gold_labels[i] for i in selected]
    evaluate_predictions(gold_classes, classes)
    ###
    # write_to_dir(text, classes, prob, data_path, new_data_path)
    return pseudo_document_class_with_confidence

In [None]:
generateDataset(doc_class_assignment, doc_class_probs, num_classes, cleaned_text, gold_labels, data_path, new_data_path)

In [23]:
def doc_gmm(args, document_representations, doc_class_representations, gold_labels):
    num_classes = len(doc_class_representations)
    cosine_similarities = cosine_similarity(document_representations, doc_class_representations)
    doc_class_assignment = np.argmax(cosine_similarities, axis=1)

    print("Evaluate Document Cosine Similarity Predictions: ")
    evaluate_predictions(gold_labels, doc_class_assignment)

    # initialize gmm based on these selected documents
    document_class_assignment_matrix = np.zeros((document_representations.shape[0], num_classes))
    for i in np.arange(len(document_representations)): # iterate through docs and sents
        document_class_assignment_matrix[i][doc_class_assignment[i]] = 1.0

    gmm = GaussianMixture(n_components=num_classes, covariance_type='tied',
                          random_state=args.random_state,
                          n_init=999, warm_start=True)

    gmm._initialize(document_representations, document_class_assignment_matrix)
    gmm.lower_bound_ = -np.infty

    gmm.converged_ = True #HACK FOR NOT RANDOMLY INITIALIZING PARAMS DURING FIT
    gmm.fit(document_representations)

    documents_to_class = gmm.predict(document_representations)
    confidence = gmm.predict_proba(document_representations)

    print("Evaluate Document GMM Predictions: ")
    evaluate_predictions(gold_labels, documents_to_class)

    return documents_to_class, confidence

In [33]:
# 128
doc_preds, doc_prob = doc_gmm(args, pca_doc_repr, pca_class_repr, gold_labels)

Evaluate Document Cosine Similarity Predictions: 
F1 micro: 0.782384869341391
F1 macro: 0.7704668150661399
Evaluate Document GMM Predictions: 
F1 micro: 0.7800346930781712
F1 macro: 0.7686244714799803
