In [12]:
from representations import sentenceEmb, contextualizedEmb
import numpy as np
import os
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import torch
from torch import nn
import sys
import torch.nn.functional as F
import argparse
import re
import pickle as pk
from tqdm import tqdm
from nltk.tokenize import sent_tokenize, word_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import confusion_matrix, f1_score
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

In [2]:
class Args(object):
    pass

args = Args()
args.dataset_name = "agnews"
args.gpu = 2
args.pca = 64
args.emb_dim = 768
args.num_heads = 2
args.batch_size = 64
args.epochs = 5
args.accum_steps = 1
args.max_sent = 150

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
import torch

In [3]:
data_path = os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name, "dataset.txt")
new_data_path = os.path.join("/home/pk36/MEGClass/intermediate_data", args.dataset_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# this is where we save all representations
if not os.path.exists(new_data_path):
    os.makedirs(new_data_path)

### Generate Sentence Embeddings ###

In [5]:
padded_sent_representations, sentence_mask, class_repr = sentenceEmb(args, data_path, new_data_path, device)

 51%|█████▏    | 61780/120000 [15:47<14:55, 65.04it/s] 

Right mark without Left:  $2,137,000 and Counting. David Gardner launches our Foolanthropy campaign for 2004." /&gt;                               Fool.com:  $2,137,000 and Counting Foolanthropy          &lt;META Name="ArticleDate" Content="2004/10/15



 53%|█████▎    | 63805/120000 [16:18<14:34, 64.27it/s]

Right mark without Left: Biotech's 5 Baggers. How can yesterday's biotech winners lead you to today's top performers?" /&gt;                                 Fool.com: Biotech's 5 Baggers Commentary October 18, 2004   &lt;script language="JavaScript



100%|██████████| 120000/120000 [30:48<00:00, 64.90it/s]


Trimmed Documents: 0
Retrieved Class Representations!


In [6]:
padded_sent_representations.dtype, padded_sent_representations.shape

(dtype('float64'), (120000, 150, 768))

In [7]:
sentence_mask.dtype, sentence_mask.shape, class_repr.dtype, class_repr.shape

(dtype('bool'), (120000, 150), dtype('float64'), (4, 768))

### Helper Functions ###

In [14]:
def tensor_to_numpy(tensor):
    return tensor.clone().detach().cpu().numpy()

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def getClassRepr(args, load=True):
    with open(os.path.join("/home/pk36/XClass/data/intermediate_data", args.dataset_name, f"document_repr_lm-bbu-12-mixture-plm.pk"), "rb") as f:
        dictionary = pk.load(f)
        class_representations = dictionary["class_representations"]
    return class_representations

# Given tensor representations for classes and documents, identify the top class
def findMaxClass(class_repr, doc_repr, labels=None, confident=False):
    # labels (N x C): each row has either one class with a value = 1 OR all rows are zero meaning not confident enough

    class_repr = F.normalize(class_repr, dim=1) # C x emb_dim
    doc_repr = F.normalize(doc_repr, dim=1) # N x emb_dim

    # cosine similarity between doc_repr and class_repr
    cos_sim = torch.mm(doc_repr, class_repr.transpose(0,1)) # N x C

    if labels is None:
        # identify closest class i to doc
        i_sim = torch.max(cos_sim, dim=1)[0] # N x 1
    elif (not confident) and (labels is not None):
        i_sim = cos_sim[labels]
    else:
        # get the confident class cos-sim OR get the max cos-sim (1 if no confident class, 0 if yes)
        i_sim = torch.max(cos_sim * labels, dim=1)[0] + (1 - torch.sum(labels, dim=1)) * torch.max(cos_sim, dim=1)[0]
    
    return i_sim

def contrastive_loss(class_repr, doc_repr, i_sim, temp=0.2):
    class_repr = F.normalize(class_repr, dim=1) # C x emb_dim
    doc_repr = F.normalize(doc_repr, dim=1) # N x emb_dim

    # cosine similarity between doc_repr and class_repr
    cos_sim = torch.mm(doc_repr, class_repr.transpose(0,1)) # N x C

    # compute loss
    loss = -torch.log((torch.exp(i_sim)/temp)/torch.sum(torch.exp(cos_sim/temp), dim=1))

    return torch.mean(loss)

def cosine_similarity_embeddings(emb_a, emb_b):
    return np.dot(emb_a, np.transpose(emb_b)) / np.outer(np.linalg.norm(emb_a, axis=1), np.linalg.norm(emb_b, axis=1))

def sentenceToClass(sent_repr, class_repr, weights):
    # sent_repr: N x S x E
    # class_repr: C x E
    # weights: N x S # equals 0 for masked sentences

    #cos-sim between (N x S) x E and (C x E) = N x S x C
    m, n = sent_repr.shape[:2]
    sentcos = cosine_similarity(sent_repr.reshape(m*n,-1), class_repr).reshape(m,n,-1)
    sent_to_class = np.argmax(sentcos, axis=2) # N x S
    sent_to_doc_class = np.sum(np.multiply(sent_to_class, weights), axis=1) # N x 1
    return sent_to_doc_class

def docToClass(doc_repr, class_repr):
    # doc_repr: N x E
    # class_repr: C x E

    #cos-sim between N x E and C x E = N x C
    doccos = cosine_similarity(doc_repr, class_repr)
    doc_to_class = np.argmax(doccos, axis=1) # N x 1
    return doc_to_class

def evaluate_predictions(true_class, predicted_class, output_to_console=True, return_tuple=False):
    confusion = confusion_matrix(true_class, predicted_class)
    if output_to_console:
        print("-" * 80 + "Evaluating" + "-" * 80)
        print(confusion)
    f1_macro = f1_score(true_class, predicted_class, average='macro')
    f1_micro = f1_score(true_class, predicted_class, average='micro')
    if output_to_console:
        print("F1 macro: " + str(f1_macro))
        print("F1 micro: " + str(f1_micro))
    if return_tuple:
        return confusion, f1_macro, f1_micro
    else:
        return {
            "confusion": confusion.tolist(),
            "f1_macro": f1_macro,
            "f1_micro": f1_micro
        }

### MEGClass Model ###

In [35]:
class MEGClassModel(nn.Module):

    def __init__(self, emb_dim, num_heads, dropout=0.1):
        #super().__init__(config)
        super(MEGClassModel, self).__init__()

        self.attention  = torch.nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
        # Two-layer MLP
        self.ffn1 = nn.Sequential(
            nn.Linear(emb_dim, 2*emb_dim),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(2*emb_dim, emb_dim)
        )
        self.ffn2 = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.ReLU(inplace=True),
            nn.Linear(emb_dim, emb_dim)
        )

        self.norm1 = nn.LayerNorm(emb_dim)

        self.sent_attention = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.Tanh(),
            nn.Linear(emb_dim, emb_dim, bias=False)
            )

        self.scalar_sent_attention = nn.Sequential(
            nn.Linear(emb_dim, 1),
            nn.Tanh(),
            nn.Linear(1, 1, bias=False)
        )
    
    def forward(self, input_emb, mask=None):
        # input_emb: batch size x sequence length x emb_dim
        X, _ = self.attention(input_emb, input_emb, input_emb, key_padding_mask=mask)
        X = X + input_emb
        X = self.ffn2(X)
        contextualized_sent = self.norm1(X) #[~mask] N x S x E

        # scalar attention weight for each sentence
        exp_sent = torch.exp(self.scalar_sent_attention(contextualized_sent)) # N x S x 1
        exp_sent = torch.squeeze(exp_sent, dim=2) * (~mask).int().float() # N x S x 1 but all masked items are 0
        denom = torch.unsqueeze(torch.sum(exp_sent, dim=1), dim=1) # N x 1
        alpha = torch.unsqueeze(torch.div(exp_sent, denom), dim=2) # N x S x 1
        contextualized_doc = torch.sum(alpha.expand_as(contextualized_sent) * contextualized_sent, dim=1) # N x 1 x E

        # convert mask from N x S to N x S x E
        # full_mask = (~mask).unsqueeze(-1).expand(X.size())
        # exp_sent = torch.exp(self.sent_attention(contextualized_sent)) # N x S x E
        # denom = torch.unsqueeze(torch.sum(exp_sent * (full_mask).int().float(), dim=1), dim=1) # N x 1 x E
        # contextualized_doc = torch.sum((torch.div(exp_sent, denom) * contextualized_sent) * (full_mask), dim=1) # N x 1 x E

        return contextualized_sent, contextualized_doc, alpha

## Train ##

In [39]:
def contextEmb(args, sent_representations, mask, class_repr, new_data_path, device):
    sent_representations = torch.from_numpy(sent_representations)
    mask = torch.from_numpy(mask)
    dataset = TensorDataset(sent_representations, mask)
    sampler = SequentialSampler(dataset)
    dataset_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, shuffle=False)
    # sent_representations: N docs x L sentences x 1024 emb (L with padding is always max_sents=50)
    model = MEGClassModel(args.emb_dim, args.num_heads).to(device)

    total_steps = len(dataset_loader) * args.epochs / args.accum_steps
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)

    print("Starting to train!")

    for i in tqdm(range(args.epochs)):
        model.train()
        total_train_loss = 0
        
        model.zero_grad()
        for j, batch in enumerate(tqdm(dataset_loader)):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)
            
            c_sent, c_doc, _ = model(input_emb, mask=input_mask)

            i_sim = findMaxClass(torch.from_numpy(class_repr).float().to(device), c_doc)

            loss = contrastive_loss(torch.from_numpy(class_repr).float().to(device), c_doc, i_sim) / args.accum_steps

            total_train_loss += loss
            loss.backward()
            optimizer.step()
            model.zero_grad()

        scheduler.step()

        avg_train_loss = torch.tensor([total_train_loss / len(dataset_loader) * args.accum_steps])
        print(f"Average training loss: {avg_train_loss.mean()}")

    model.eval()

    torch.save(model.state_dict(), os.path.join(new_data_path, f"{args.dataset_name}_model_weights.pth"))

    print("Starting to evaluate!")

    sentence_predictions = None
    doc_predictions = None

    with torch.no_grad(), open(os.path.join(new_data_path, "contextualized_sent.txt"), 'w') as fs, open(os.path.join(new_data_path, "contextualized_docs.txt"), 'w') as fd:
        for batch in tqdm(dataset_loader):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)

            c_sent, c_doc, alpha = model(input_emb, mask=input_mask)
            c_sent = tensor_to_numpy(c_sent)
            c_doc = tensor_to_numpy(c_doc)
            alpha = tensor_to_numpy(torch.squeeze(alpha, dim=2))

            # for row in c_doc:
            #     fd.write(' '.join(map(str, row)) + '\n')

            # fs.write(str(c_sent))
            # fs.write("\n")

            # fd.write(str(c_doc))
            # fd.write("\n")



            sent_class = sentenceToClass(c_sent, class_repr, alpha)
            doc_class = docToClass(c_doc, class_repr)
            if sentence_predictions is None:
                sentence_predictions = sent_class
                doc_predictions = doc_class
            else:
                sentence_predictions = np.append(sentence_predictions, sent_class)
                doc_predictions = np.append(doc_predictions, doc_class)

    # return tensor_to_numpy(context_sent), tensor_to_numpy(context_doc), class_repr
    return sentence_predictions, doc_predictions

### Run Contextualized Embedding Training ###

In [40]:
sent_to_doc_class, doc_to_class = contextEmb(args, padded_sent_representations, sentence_mask, class_repr, new_data_path, device)



Starting to train!


100%|██████████| 1875/1875 [01:35<00:00, 19.73it/s]
 20%|██        | 1/5 [01:35<06:20, 95.07s/it]

Average training loss: -0.35764655470848083


100%|██████████| 1875/1875 [01:35<00:00, 19.64it/s]
 40%|████      | 2/5 [03:10<04:46, 95.38s/it]

Average training loss: -0.3715905249118805


100%|██████████| 1875/1875 [01:34<00:00, 19.77it/s]
 60%|██████    | 3/5 [04:45<03:10, 95.18s/it]

Average training loss: -0.4131544828414917


100%|██████████| 1875/1875 [01:35<00:00, 19.68it/s]
 80%|████████  | 4/5 [06:20<01:35, 95.26s/it]

Average training loss: -0.48218175768852234


100%|██████████| 1875/1875 [01:35<00:00, 19.53it/s]
100%|██████████| 5/5 [07:57<00:00, 95.42s/it]


Average training loss: -0.5789074897766113
Starting to evaluate!


100%|██████████| 1875/1875 [07:07<00:00,  4.38it/s]


In [41]:
sent_to_doc_class.shape, doc_to_class.shape

((120000,), (120000,))

In [44]:
sent_to_doc_class[2]

2.0

### Evaluate ###

In [46]:
with open(os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name, "labels.txt"), "r") as l:
        gold_labels = l.read().splitlines()

sent_pred = np.rint(sent_to_doc_class)
doc_pred = np.rint(doc_to_class)

print("Evaluate Predictions (Sentence-Based): ")
evaluate_predictions(gold_labels, sent_pred)
print("Evaluate Predictions (Document-Based): ")
evaluate_predictions(gold_labels, doc_pred)

Evaluate Predictions (Sentence-Based): 


ValueError: Mix of label input types (string and number)