In [1]:
from representations import sentenceEmb, contextualizedEmb
import numpy as np
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Args(object):
    pass

args = Args()
args.dataset_name = "agnews"
args.gpu = 2
args.pca = 64
args.emb_dim = 768
args.num_heads = 2
args.batch_size = 64
args.epochs = 5
args.accum_steps = 1
args.max_sent = 150

os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
import torch

In [3]:
data_path = os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name, "dataset.txt")
new_data_path = os.path.join("/home/pk36/MEGClass/intermediate_data", args.dataset_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# this is where we save all representations
if not os.path.exists(new_data_path):
    os.makedirs(new_data_path)

### Generate Sentence Embeddings ###

In [5]:
padded_sent_representations, sentence_mask, class_repr = sentenceEmb(args, data_path, new_data_path, device)

 51%|█████▏    | 61779/120000 [09:32<09:07, 106.30it/s]

Right mark without Left:  $2,137,000 and Counting. David Gardner launches our Foolanthropy campaign for 2004." /&gt;                               Fool.com:  $2,137,000 and Counting Foolanthropy          &lt;META Name="ArticleDate" Content="2004/10/15



 53%|█████▎    | 63818/120000 [09:52<08:30, 109.99it/s]

Right mark without Left: Biotech's 5 Baggers. How can yesterday's biotech winners lead you to today's top performers?" /&gt;                                 Fool.com: Biotech's 5 Baggers Commentary October 18, 2004   &lt;script language="JavaScript



100%|██████████| 120000/120000 [18:29<00:00, 108.16it/s]


Trimmed Documents: 0
Retrieved Class Representations!


In [13]:
padded_sent_representations.dtype, padded_sent_representations.shape

(dtype('float64'), (120000, 150, 768))

In [12]:
sentence_mask.dtype, sentence_mask.shape, class_repr.dtype, class_repr.shape

(dtype('bool'), (120000, 150), dtype('float64'), (4, 768))

### Cosine Similarities between Sentences and Docs ###

In [75]:
def cosine_similarity_embeddings(emb_a, emb_b):
    return np.dot(emb_a, np.transpose(emb_b)) / np.outer(np.linalg.norm(emb_a, axis=1), np.linalg.norm(emb_b, axis=1))

### Train ###

In [41]:
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch import nn
import torch.nn.functional as F
import sys


### MEGClass Model ###

In [71]:
class MEGClassModel(nn.Module):

    def __init__(self, emb_dim, num_heads, dropout=0.1):
        #super().__init__(config)
        super(MEGClassModel, self).__init__()

        self.attention  = torch.nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
        # Two-layer MLP
        self.ffn1 = nn.Sequential(
            nn.Linear(emb_dim, 2*emb_dim),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(2*emb_dim, emb_dim)
        )
        self.ffn2 = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.ReLU(inplace=True),
            nn.Linear(emb_dim, emb_dim)
        )

        self.norm1 = nn.LayerNorm(emb_dim)

        self.sent_attention = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.Tanh(),
            nn.Linear(emb_dim, emb_dim, bias=False)
            )
    
    def forward(self, input_emb, mask=None):
        # input_emb: batch size x sequence length x emb_dim
        X, _ = self.attention(input_emb, input_emb, input_emb, key_padding_mask=mask)
        X = X + input_emb
        X = self.ffn2(X)
        contextualized_sent = self.norm1(X) #[~mask] N x S x E

        # convert mask from N x S to N x S x E
        full_mask = (~mask).unsqueeze(-1).expand(X.size())
        exp_sent = torch.exp(self.sent_attention(contextualized_sent)) # N x S x E
        denom = torch.unsqueeze(torch.sum(exp_sent * (full_mask).int().float(), dim=1), dim=1) # N x 1 x E
        contextualized_doc = torch.sum((torch.div(exp_sent, denom) * contextualized_sent) * (full_mask), dim=1) # N x 1 x E

        return contextualized_sent, contextualized_doc

### Loss Function ###

In [72]:
def findMaxClass(class_repr, doc_repr, labels=None, confident=False):
    # labels (N x C): each row has either one class with a value = 1 OR all rows are zero meaning not confident enough

    class_repr = F.normalize(class_repr, dim=1) # C x emb_dim
    doc_repr = F.normalize(doc_repr, dim=1) # N x emb_dim

    # cosine similarity between doc_repr and class_repr
    cos_sim = torch.mm(doc_repr, class_repr.transpose(0,1)) # N x C

    if labels is None:
        # identify closest class i to doc
        i_sim = torch.max(cos_sim, dim=1)[0] # N x 1
    elif (not confident) and (labels is not None):
        i_sim = cos_sim[labels]
    else:
        # get the confident class cos-sim OR get the max cos-sim (1 if no confident class, 0 if yes)
        i_sim = torch.max(cos_sim * labels, dim=1)[0] + (1 - torch.sum(labels, dim=1)) * torch.max(cos_sim, dim=1)[0]
    
    return i_sim
    

def contrastive_loss(class_repr, doc_repr, i_sim, temp=0.2):
    class_repr = F.normalize(class_repr, dim=1) # C x emb_dim
    doc_repr = F.normalize(doc_repr, dim=1) # N x emb_dim

    # cosine similarity between doc_repr and class_repr
    cos_sim = torch.mm(doc_repr, class_repr.transpose(0,1)) # N x C

    # compute loss
    loss = -torch.log((torch.exp(i_sim)/temp)/torch.sum(torch.exp(cos_sim/temp), dim=1))

    return torch.mean(loss)

def tensor_to_numpy(tensor):
    return tensor.clone().detach().cpu().numpy()

In [76]:
def contextEmb(args, sent_representations, mask, class_repr, new_data_path, device, gold_labels):
    sent_representations = torch.from_numpy(sent_representations)
    mask = torch.from_numpy(mask)
    dataset = TensorDataset(sent_representations, mask)
    sampler = SequentialSampler(dataset)
    dataset_loader = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, shuffle=False)
    # sent_representations: N docs x L sentences x 1024 emb (L with padding is always max_sents=50)
    model = MEGClassModel(args.emb_dim, args.num_heads).to(device)

    total_steps = len(dataset_loader) * args.epochs / args.accum_steps
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.1*total_steps, num_training_steps=total_steps)

    print("Starting to train!")

    for i in tqdm(range(args.epochs)):
        model.train()
        total_train_loss = 0
        
        model.zero_grad()
        for j, batch in enumerate(tqdm(dataset_loader)):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)
            
            c_sent, c_doc = model(input_emb, mask=input_mask)

            i_sim = findMaxClass(torch.from_numpy(class_repr).float().to(device), c_doc)

            loss = contrastive_loss(torch.from_numpy(class_repr).float().to(device), c_doc, i_sim) / args.accum_steps

            total_train_loss += loss
            loss.backward()
            optimizer.step()
            model.zero_grad()

        scheduler.step()

        avg_train_loss = torch.tensor([total_train_loss / len(dataset_loader) * args.accum_steps])
        print(f"Average training loss: {avg_train_loss.mean()}")

    model.eval()
    context_sent = None
    context_doc = None

    torch.save(model.state_dict(), os.path.join(new_data_path, f"{args.dataset_name}_model_weights.pth"))

    print("Starting to evaluate!")

    #sentence_predictions = []
    #doc_predictions = []

    with torch.no_grad(), open(os.path.join(new_data_path, "contextualized_sent.txt"), 'w') as fs, open(os.path.join(new_data_path, "contextualized_docs.txt"), 'w') as fd:
        for batch in tqdm(dataset_loader):
            input_emb = batch[0].to(device).float()
            input_mask = batch[1].to(device)

            c_sent, c_doc = model(input_emb, mask=input_mask)
            c_sent = tensor_to_numpy(c_sent)
            c_doc = tensor_to_numpy(c_doc)

            # fs.write(str(c_sent))
            # fs.write("\n")

            # fd.write(str(c_doc))
            # fd.write("\n")

            #sent_class = cosine_similarity_embeddings(c_sent, class_repr)
            #doc_class = 

            # if context_sent is None:
            #     context_sent = c_sent
            #     context_doc = c_doc
            # else:
            #     context_sent = torch.cat((context_sent, c_sent), dim=0)
            #     context_doc = torch.cat((context_doc, c_doc), dim=0)

    # return tensor_to_numpy(context_sent), tensor_to_numpy(context_doc), class_repr


### Run Contextualized Embedding Training ###

In [77]:
with open(os.path.join("/shared/data2/pk36/multidim/multigran", args.dataset_name, "labels.txt"), "r") as l:
    gold_labels = l.read().splitlines()

csent, cdoc, class_repr = contextEmb(args, padded_sent_representations, sentence_mask, class_repr, new_data_path, device, gold_labels)



Starting to train!


100%|██████████| 1875/1875 [01:09<00:00, 26.89it/s]
 20%|██        | 1/5 [01:09<04:38, 69.73s/it]

Average training loss: -0.32817596197128296


100%|██████████| 1875/1875 [01:08<00:00, 27.25it/s]
 40%|████      | 2/5 [02:18<03:27, 69.25s/it]

Average training loss: -0.3425796329975128


100%|██████████| 1875/1875 [01:11<00:00, 26.06it/s]
 60%|██████    | 3/5 [03:30<02:21, 70.52s/it]

Average training loss: -0.38553935289382935


100%|██████████| 1875/1875 [01:10<00:00, 26.72it/s]
 80%|████████  | 4/5 [04:40<01:10, 70.42s/it]

Average training loss: -0.45694342255592346


100%|██████████| 1875/1875 [01:11<00:00, 26.37it/s]
100%|██████████| 5/5 [05:52<00:00, 70.43s/it]


Average training loss: -0.5571542978286743
Starting to evaluate!


100%|██████████| 1875/1875 [01:09<00:00, 26.84it/s]


TypeError: cannot unpack non-iterable NoneType object

### Evaluate ###

In [None]:
cosine_similarity_embedding(cdoc, class_re)