In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# import requi9red module
import sys
  
# append the path of the
# parent directory
sys.path.append("..")

In [3]:
import json
import torch 
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CrossEntropyLoss, NLLLoss
import torch.nn.functional as F 
from torch import autograd

import numpy as np 

from transformers.models.luke.configuration_luke import LukeConfig
from transformers import AutoTokenizer, AutoModel, AutoConfig

from entitylinker.lukemodel import LukeModel 
from entitylinker.model_utils import *

from entitylinker.dataset import MedMentionsDataset, Collater
from entitylinker.model import EntityLinker 

In [4]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

if torch.cuda.is_available():
    
    print(torch.cuda.current_device())
    print(torch.cuda.device(0))
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(0))


0
<torch.cuda.device object at 0x2b048a5a5ee0>
1
Quadro RTX 5000


In [5]:
input_data_file = "/home/vs428/project/MedMentions/full/data/corpus_pubtator.txt"
train_pmids_file = "/home/vs428/project/MedMentions/full/data/corpus_pubtator_pmids_trng.txt"
test_pmids_file = "/home/vs428/project/MedMentions/full/data/corpus_pubtator_pmids_test.txt"

entity_vocab_file = "/home/vs428/project/MedMentions/full/pretraining5/entity_vocab.jsonl"
entity_embedding_model = "/home/vs428/project/MedMentions/full/pretraining5/model_epoch20.bin"
entity_embedding_metadata = "/home/vs428/project/MedMentions/full/pretraining5/metadata.json"

MAX_LENGTH = 128
MODEL_NAME = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'

In [6]:
# Create Tokenizer and DataLoader
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

collater = Collater(tokenizer, max_length = MAX_LENGTH)
train_dataset = MedMentionsDataset(input_data_file, train_pmids_file, tokenizer, entity_vocab_file, 
                             max_length = MAX_LENGTH, stride=0, first_token_ent_only=True)
test_dataset = MedMentionsDataset(input_data_file, test_pmids_file, tokenizer, entity_vocab_file, 
                             max_length = MAX_LENGTH, stride=0, first_token_ent_only=True)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collater)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=collater)
# tokens, bio_tags, entity_ids = next(iter(train_dataloader))


In [7]:

# define optimizer, model, loss etc. 


# with open(entity_embedding_metadata) as f:
#     luke_metadata = json.loads(f.read())

# luke_config = luke_metadata['model_config']
# bert_config = AutoConfig.from_pretrained(MODEL_NAME)

# config = LukeConfig(
#     entity_vocab_size=luke_config['entity_vocab_size'],
#     bert_model_name=luke_config['bert_model_name'],
#     entity_emb_size=luke_config['entity_emb_size'],
#     **bert_config.to_dict(),
# )

# luke_model = LukeModel(config)
# pretrained_entity_embeddings = luke_model.load_state_dict(torch.load(entity_embedding_model, map_location=torch.device('cpu')))
# print(pretrained_entity_embeddings)


In [8]:
# Retrieve pretrained entity embeddings

model_archive = ModelArchive.load(entity_embedding_model)

model_archive.tokenizer
model_archive.entity_vocab
model_archive.bert_model_name
model_archive.config
model_archive.max_mention_length
pretrained_entity_embeddings = model_archive.state_dict['entity_embeddings.entity_embeddings.weight']
pretrained_entity_embeddings = pretrained_entity_embeddings.to(device)

In [9]:
def sim_matrix(a, b, eps=1e-8):
    """
    works on a 3D and 2D array
    """
    a_n, b_n = torch.linalg.norm(a, dim=-1)[:, :, None], torch.linalg.norm(b, dim=-1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.matmul(a_norm, b_norm.transpose(0, 1))        
    return sim_mt        


In [10]:
# define loss function
def mention_entity_loss(mention_pred, entity_pred, bio_tags, entity_ids, 
                        attention_mask, pretrained_entity_embedding, device, lm=0.1):
    # attention_mask torch.Size([27,256])
    # mention_pred torch.Size([27, 256, 3])
    # entity_pred torch.Size([27, 256, 256])
    # bio_tags torch.Size([27, 256])
    # entity_ids torch.Size([27, 256])
    
    # unpadded_men_pred torch.Size([25, 3, 256, 256])
    # unpadded_bio_tags torch.Size([25, 256, 256])    
    
    ### MENTION LOSS

    # first compute all loss
    mention_loss = F.nll_loss(mention_pred.permute(0,2,1).contiguous(), bio_tags, reduction="none")

    # get only the unpadded losses: batch size
    num_unpadded = torch.sum(attention_mask, dim=1)
    masked_mention_loss = torch.where(attention_mask == 1, mention_loss, torch.tensor(0.).float().to(device))
    
    # compute average of the masked loss
    avg_mention_loss = torch.sum(masked_mention_loss, dim=1) / num_unpadded
    avg_mention_loss = torch.sum(avg_mention_loss)/len(avg_mention_loss)

#     print("MENTION LOSS: ", avg_mention_loss)

    
    ### ENTITY LOSS
#     cos_sim = sim_matrix(entity_pred, pretrained_entity_embedding)
#     print("cos_sim", cos_sim.shape)
    # cos_sim torch.Size([27, 256, 34727])
    
    # get all the IDs in entity_ids that aren't -1
    nonzero_ent_ids = entity_ids[entity_ids!=-1]
    # get the true entity embeddings
    # shape: [nonzero_ent_ids, embed_dim]
    true_ent_embeddings = pretrained_entity_embedding[nonzero_ent_ids, :]
    
    # get non-neg1 idxs 
    ent_idxs = (entity_ids != -1).nonzero()
    real_entity_pred = entity_pred[ent_idxs[:, 0], ent_idxs[:, 1], :]    
#     print("true_ent_embeddings", true_ent_embeddings.shape)
#     print(real_entity_pred.shape)
    entity_loss = torch.sum(1 - F.cosine_similarity(true_ent_embeddings, real_entity_pred, dim=1))/true_ent_embeddings.shape[0]
#     print("ENTITY LOSS", entity_loss)
#     max_ents = torch.max(cos_sim, dim=2)
    
    # target_ent_embeddings = pretrained_entity_embedding[entity_ids, :]
    # target_ent_embeddings torch.Size([27, 256, 256])
    # max_ent_idxs = torch.argmax(cos_sim, dim=2)
#     max_ents = self.entity_embedding[max_ent_idxs, :]
    
    
#     entity_pred[ent_idxs]
#     entity_loss = torch.mean((output - target)**2)
    
    

    loss = (lm * avg_mention_loss) + ((1-lm) * entity_loss)
    return loss

In [11]:
# define model, etc. 
model = EntityLinker(model_name='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
                     pretrained_entity_embeddings=pretrained_entity_embeddings )

model.to(device)


optimizer = optim.Adam(model.parameters(), lr=0.00002)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
# for name, param in model.named_parameters():
#     if param.requires_grad:
#     print(name, param.requires_grad)

In [13]:
def get_bi_spans(batch_tags):
    """Returns a list of N x 2 numpy arrays"""
    batch_tags = batch_tags.cpu().detach().numpy()
    bi_spans = []
    for batch_idx in range(batch_tags.shape[0]):
        b_tags = (batch_tags[batch_idx, :] == 0).astype(int)
        i_tags = (batch_tags[batch_idx, :] == 1).astype(int)
        
        # if there aren't any beginning tags, just return an empty array
        if not b_tags.any():
            bi_spans.append(np.array([]))
            continue
        
        # get all B tag locations
        b_idx, = b_tags.nonzero() 

        # get idxs for I tags
        d = np.diff(i_tags)
        i_idx, = d.nonzero()

        # We need to start things after the change in "condition". Therefore, 
        # we'll shift the index by 1 to the right.
        i_idx += 1

        # add end idxs for all B tags (just interweave +1 of b_idxs)
        b_idx = np.vstack((b_idx,b_idx+1)).reshape((-1,),order='F')    

        if i_tags[0]:
            # If the start of condition is True prepend a 0
            i_idx = np.r_[0, i_idx]

        if i_tags[-1]:
            # If the end of condition is True, append the length of the array
            i_idx = np.r_[i_idx, i_tags.size] # Edit

        # reshape to idxs
        b_idx.shape = (-1,2)
        i_idx.shape = (-1,2)    

        bi_idx = []
        # combine the b and i tags
        for start, stop in b_idx:
            # get idx where I tags start for each B tag, if exists
            i, = np.where(i_idx[:, 0] == stop) 
            if i.size > 0:
                bi_idx.append([start, int(i_idx[i, 1])])
            else:
                bi_idx.append([start, stop])

        bi_spans.append(np.array(bi_idx))

    assert batch_tags.shape[0] == len(bi_spans), f"{batch_tags.shape[0]},{len(bi_spans)},{b_idx}"
    return bi_spans


In [14]:
def compute_conf_matrices(mention_preds, entity_preds, bio_tags, entity_ids, pretrained_entity_embedding):
    '''Computes both macro and micro confusion matrices given torch tensor outputs from our model
    '''
    # mention_pred Size([27, 256, 3])
    # entity_pred torch.Size([27, 256, 256])

    nb_classes = 2
    micro_confusion_matrix = torch.zeros(nb_classes, nb_classes)
    # micro_confusion_matrix [[TP, FN], [FP, TN]]
        
    macro_confusion_matrix = []        
    
    # first we need to find the max of the BIO tagging we have
    pred_bio_tags = torch.argmax(mention_preds, dim=2)
    # pred_bio_tags torch.Size([28, 256])
    
    # create two lists of spans
    pred_span_list = get_bi_spans(pred_bio_tags)
    target_span_list = get_bi_spans(bio_tags)    

    for batch_idx, (pred_spans, target_spans) in enumerate(zip(pred_span_list, target_span_list)):
        tp_matches = 0
        
        batch_confusion_matrix = torch.zeros(nb_classes, nb_classes)

        # do the STRONG MATCHING here

        # first check if there are any predicted spans, if not, we just add a bunch of FNs
        if pred_spans.size <= 0 or pred_spans.shape[0] == 0:
            micro_confusion_matrix[0,1] += target_spans.shape[0]
            batch_confusion_matrix[0,1] += target_spans.shape[0]
            
        # conversely, if the target spans are empty, we add a bunch of FPs instead
        elif target_spans.size <= 0:
            micro_confusion_matrix[1,0] += pred_spans.shape[0]
            batch_confusion_matrix[1,0] += pred_spans.shape[0]            
            
        else:
            for span_idx in range(pred_spans.shape[0]):
                
                if torch.all(torch.tensor(pred_spans[span_idx, :] == target_spans), dim=1).any():
                    seq_idx = pred_spans[span_idx, 0]
                    # we use the start span idx to get the entity embedding and compare it to the pretrained embeds

                    cos_sim = sim_matrix(entity_preds[batch_idx, seq_idx, :].unsqueeze(0).unsqueeze(0), 
                                        pretrained_entity_embedding)
                    # this should return something of size [1, 1, sim_vals]
                    pred_ent_embed_id = torch.argmax(torch.squeeze(cos_sim))
                    target_ent_embed_id = entity_ids[batch_idx, seq_idx]
                    # just to test: 
                    # TODO: REMOVE
    #                 if batch_idx == 1 and seq_idx == 0:
    #                     target_ent_embed_id = 2
                    # TODO: ENDREMOVE
                    if pred_ent_embed_id == target_ent_embed_id:
                        # TP
                        micro_confusion_matrix[0,0] += 1
                        batch_confusion_matrix[0,0] += 1 
                        tp_matches += 1
                    else:
                        # FP
                        micro_confusion_matrix[1,0] += 1
                        batch_confusion_matrix[1,0] += 1
                else:
                    # FP 
                    micro_confusion_matrix[1,0] += 1
                    batch_confusion_matrix[1,0] += 1
        
        # all FNs 
        micro_confusion_matrix[0,1] += target_spans.shape[0] - tp_matches
        batch_confusion_matrix[0,1] += target_spans.shape[0] - tp_matches
        
        # we don't do TNs since we don't need them for P,R,F1
        
        # add on the batch confusion matrix to the batch
        macro_confusion_matrix.append(batch_confusion_matrix)

    return micro_confusion_matrix, macro_confusion_matrix


In [15]:
def compute_metrics(metrics, micro_conf, macro_confs):
    '''Given both the micro and macro confusion matrices, we compute a given set of metrics
    
    Matrix orientation is: 
    |  TP  |  FN |
    |-------------
    |  FP  |  TN |
    '''
    micro_metrics = {}
    macro_metrics = {}
    
    if "precision" in metrics:
        micro_metrics['precision'] = micro_conf[0,0] / (micro_conf[0,0] + micro_conf[1,0])
        macro_metrics['precision'] = torch.mean(torch.tensor([batch_conf[0,0] / (batch_conf[0,0] + batch_conf[1,0]) for batch_conf in macro_confs]))
    if "recall" in metrics:
        micro_metrics['recall'] = micro_conf[0,0] / (micro_conf[0,0] + micro_conf[0,1])
        macro_metrics['recall'] = torch.mean(torch.tensor([batch_conf[0,0] / (batch_conf[0,0] + batch_conf[0,1]) for batch_conf in macro_confs]))
    if "f1" in metrics:
        micro_metrics['f1'] = micro_conf[0,0] / (micro_conf[0,0] + ((1/2) * (micro_conf[0,1] + micro_conf[1,0])))
        macro_metrics['f1'] = torch.mean(torch.tensor([batch_conf[0,0] / (batch_conf[0,0] + ((1/2) * (batch_conf[0,1] + batch_conf[1,0]))) for batch_conf in macro_confs]))
        
    return micro_metrics, macro_metrics


In [16]:
# Train!!
def train(model, train_loader, optimizer, criterion, epoch):

    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        tokenized_text, bio_tags, entity_ids = data
        bio_tags = bio_tags.to(device)
        entity_ids = entity_ids.to(device)
#         print("bio_tags", bio_tags.shape)
#         print("entity_ids", entity_ids.shape)
        # zero the parameter gradients
        optimizer.zero_grad()
#         with autograd.detect_anomaly():

        # forward + backward + optimize
        input_ids = tokenized_text['input_ids'].to(device)
        token_type_ids = tokenized_text['token_type_ids'].to(device)
        attention_mask = tokenized_text['attention_mask'].to(device)
                          
        mention_pred, entity_pred = model(input_ids,token_type_ids,attention_mask)
#         print("ent pred", entity_pred.shape)
        loss = criterion(mention_pred, entity_pred, bio_tags, entity_ids, attention_mask, 
                         pretrained_entity_embeddings, device)
#         print(loss)
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if i % 5 == 4:    # print every 2000 mini-batches
            print('[%d, %5d] train loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 5))
            running_loss = 0.0
#         if i == 2:
#             break
    print('Finished Training')

In [17]:
def evaluate(model, test_loader, optimizer, criterion, epoch):
    
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        running_loss = 0.0
        
        for i, data in enumerate(test_loader, 0):
            
            tokenized_text, bio_tags, entity_ids = data
            bio_tags = bio_tags.to(device)
            entity_ids = entity_ids.to(device)
            
    #         with autograd.detect_anomaly():

            # forward + backward + optimize
            input_ids = tokenized_text['input_ids'].to(device)
            token_type_ids = tokenized_text['token_type_ids'].to(device)
            attention_mask = tokenized_text['attention_mask'].to(device)

            mention_pred, entity_pred = model(input_ids,token_type_ids,attention_mask)
    #         print("ent pred", entity_pred.shape)
            loss = criterion(mention_pred, entity_pred, bio_tags, entity_ids, 
                            attention_mask, pretrained_entity_embeddings, device)
    #         print(loss)


            micro_conf_mat, macro_conf_mat = compute_conf_matrices(mention_pred, entity_pred, bio_tags, entity_ids, pretrained_entity_embeddings)
            micro_metrics, macro_metrics = compute_metrics(["precision", "recall","f1"], micro_conf_mat, macro_conf_mat)
            print('-' * 89)
            print(f"Micro-Precision: {micro_metrics['precision']}, Macro-Precision: {macro_metrics['precision']}\n")
            print(f"Micro-Recall: {micro_metrics['recall']}, Macro-Recall: {macro_metrics['recall']}\n")
            print(f"Micro-F1: {micro_metrics['f1']}, Macro-F1: {macro_metrics['f1']}\n")    
            print('-' * 89)


            running_loss += loss.item()
            total_loss += loss.item()
            if i % 5 == 4:    # print every 2000 mini-batches
                print('[%d, %5d] test loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 5))
                running_loss = 0.0
        # whatever i is at is the len of test_loader
        return total_loss / i

In [18]:
num_epochs = 25
import time
import math 

# try:
for epoch in range(num_epochs):  # loop over the dataset multiple times
    epoch_start_time = time.time()
    train(model, train_dataloader, optimizer, mention_entity_loss, epoch)
    val_loss = evaluate(model, test_dataloader, optimizer, mention_entity_loss, epoch)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(epoch, 
                                      (time.time() - epoch_start_time),
                                       val_loss, math.exp(val_loss)))
    print('-' * 89)        
# except Exception as e:
#     print("An exception occurred", e) 
    

[1,     5] train loss: 1.014
[1,    10] train loss: 1.004
[1,    15] train loss: 1.000
[1,    20] train loss: 0.994
[1,    25] train loss: 0.992
[1,    30] train loss: 0.981
[1,    35] train loss: 0.976
[1,    40] train loss: 0.977
[1,    45] train loss: 0.967
[1,    50] train loss: 0.966
[1,    55] train loss: 0.957
[1,    60] train loss: 0.953
[1,    65] train loss: 0.946
[1,    70] train loss: 0.938
[1,    75] train loss: 0.940
[1,    80] train loss: 0.935
[1,    85] train loss: 0.938
[1,    90] train loss: 0.926
[1,    95] train loss: 0.927
[1,   100] train loss: 0.920
[1,   105] train loss: 0.917
[1,   110] train loss: 0.919
[1,   115] train loss: 0.910
[1,   120] train loss: 0.904
[1,   125] train loss: 0.903
[1,   130] train loss: 0.898
[1,   135] train loss: 0.899
[1,   140] train loss: 0.903
[1,   145] train loss: 0.895
[1,   150] train loss: 0.892
[1,   155] train loss: 0.891
[1,   160] train loss: 0.890
[1,   165] train loss: 0.889
Finished Training
-------------------------

-----------------------------------------------------------------------------------------
Micro-Precision: 0.04222378507256508, Macro-Precision: 0.04913223534822464

Micro-Recall: 0.046475600451231, Macro-Recall: nan

Micro-F1: 0.044247787445783615, Macro-F1: 0.05058671906590462

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.05124653875827789, Macro-Precision: 0.05434131622314453

Micro-Recall: 0.053740013390779495, Macro-Recall: 0.05872342362999916

Micro-F1: 0.05246366560459137, Macro-F1: 0.056032031774520874

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.0560000017285347, Macro-Precision: 0.05208045616745949

Micro-Recall: 0.05944798141717911, Macro-Recall: nan

Micro-F1: 0.05767250

-----------------------------------------------------------------------------------------
Micro-Precision: 0.052945561707019806, Macro-Precision: 0.05247754231095314

Micro-Recall: 0.059117402881383896, Macro-Recall: nan

Micro-F1: 0.05586152523756027, Macro-F1: 0.050713278353214264

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.07126076519489288, Macro-Precision: 0.0769544467329979

Micro-Recall: 0.07698815315961838, Macro-Recall: 0.08303510397672653

Micro-F1: 0.07401382923126221, Macro-F1: 0.07934790849685669

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.05968928709626198, Macro-Precision: 0.06351488828659058

Micro-Recall: 0.06552962213754654, Macro-Recall: nan

Micro-F1: 0.062473

-----------------------------------------------------------------------------------------
Micro-Precision: 0.13096773624420166, Macro-Precision: 0.1308744251728058

Micro-Recall: 0.15048183500766754, Macro-Recall: 0.14916056394577026

Micro-F1: 0.1400482952594757, Macro-F1: 0.1382024735212326

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.1478779911994934, Macro-Precision: 0.13662081956863403

Micro-Recall: 0.173540860414505, Macro-Recall: nan

Micro-F1: 0.1596849262714386, Macro-F1: 0.13950298726558685

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.1679029017686844, Macro-Precision: 0.164861261844635

Micro-Recall: 0.1928737461566925, Macro-Recall: nan

Micro-F1: 0.17952415347099304, 

-----------------------------------------------------------------------------------------
Micro-Precision: 0.14573539793491364, Macro-Precision: 0.13951537013053894

Micro-Recall: 0.16426949203014374, Macro-Recall: 0.1495068371295929

Micro-F1: 0.15444840490818024, Macro-F1: 0.1433221995830536

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.12145748734474182, Macro-Precision: 0.13505952060222626

Micro-Recall: 0.135869562625885, Macro-Recall: nan

Micro-F1: 0.12825994193553925, Macro-F1: 0.13470740616321564

-----------------------------------------------------------------------------------------
[2,    30] test loss: 0.815
-----------------------------------------------------------------------------------------
Micro-Precision: 0.16983695328235626, Macro-Precision: 0.16312408447265625

Micro-Recall: 0.20764119923114777, Macro-Recall: n

-----------------------------------------------------------------------------------------
Micro-Precision: 0.17709065973758698, Macro-Precision: 0.17497889697551727

Micro-Recall: 0.19192688167095184, Macro-Recall: 0.18803298473358154

Micro-F1: 0.18421052396297455, Macro-F1: 0.1796400099992752

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.14794732630252838, Macro-Precision: 0.14460699260234833

Micro-Recall: 0.17750929296016693, Macro-Recall: nan

Micro-F1: 0.16138571500778198, Macro-F1: 0.1435750424861908

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.1409727931022644, Macro-Precision: nan

Micro-Recall: 0.14615385234355927, Macro-Recall: 0.161091610789299

Micro-F1: 0.1435165703296

-----------------------------------------------------------------------------------------
Micro-Precision: 0.19889502227306366, Macro-Precision: nan

Micro-Recall: 0.22682268917560577, Macro-Recall: nan

Micro-F1: 0.21194280683994293, Macro-F1: 0.17919869720935822

-----------------------------------------------------------------------------------------
[3,    15] test loss: 0.757
-----------------------------------------------------------------------------------------
Micro-Precision: 0.24436090886592865, Macro-Precision: nan

Micro-Recall: 0.24667932093143463, Macro-Recall: nan

Micro-F1: 0.2455146312713623, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2120765894651413, Macro-Precision: 0.2034972757101059

Micro-Recall: 0.24827586114406586, Macro-Recall: nan

Micro-F1: 0.2287529855966568, Macro-F1: 0.20853468775749207

-----------------------------------------------------------------------------------------
Micro-Precision: 0.19879062473773956, Macro-Precision: nan

Micro-Recall: 0.23822464048862457, Macro-Recall: nan

Micro-F1: 0.21672847867012024, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.20349344611167908, Macro-Precision: 0.18514017760753632

Micro-Recall: 0.22843137383460999, Macro-Recall: nan

Micro-F1: 0.21524249017238617, Macro-F1: 0.18662090599536896

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.21954397857189178, Macro-Precision: 0.2216389924287796

Micro-Recall: 0.2237715870141983, Macro-Recall: 0.22979974746704102

Micro-F1: 0.22163762152194977, Macro-F1: 0.22259892523

[4,   135] train loss: 0.770
[4,   140] train loss: 0.766
[4,   145] train loss: 0.767
[4,   150] train loss: 0.756
[4,   155] train loss: 0.777
[4,   160] train loss: 0.775
[4,   165] train loss: 0.768
Finished Training
-----------------------------------------------------------------------------------------
Micro-Precision: 0.20735573768615723, Macro-Precision: 0.2159823179244995

Micro-Recall: 0.21001926064491272, Macro-Recall: nan

Micro-F1: 0.20867900550365448, Macro-F1: 0.20317131280899048

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2613636255264282, Macro-Precision: 0.243504598736763

Micro-Recall: 0.2697419822216034, Macro-Recall: nan

Micro-F1: 0.2654867172241211, Macro-F1: 0.24338094890117645

-----------------------------------------------------------------------------------------
-----------------------------------------

-----------------------------------------------------------------------------------------
Micro-Precision: 0.2829716205596924, Macro-Precision: 0.2773948609828949

Micro-Recall: 0.27405011653900146, Macro-Recall: 0.27340593934059143

Micro-F1: 0.2784394323825836, Macro-F1: 0.2722485363483429

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.230158731341362, Macro-Precision: 0.2181260734796524

Micro-Recall: 0.23895131051540375, Macro-Recall: nan

Micro-F1: 0.2344726175069809, Macro-F1: 0.21614299714565277

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2445193976163864, Macro-Precision: 0.22220492362976074

Micro-Recall: 0.2398676574230194, Macro-Recall: 0.21787813305854797

Micro-F1: 0.24

-----------------------------------------------------------------------------------------
Micro-Precision: 0.24056939780712128, Macro-Precision: nan

Micro-Recall: 0.24707601964473724, Macro-Recall: nan

Micro-F1: 0.24377930164337158, Macro-F1: nan

-----------------------------------------------------------------------------------------
[4,    45] test loss: 0.753
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2112676054239273, Macro-Precision: 0.19920486211776733

Micro-Recall: 0.23096664249897003, Macro-Recall: nan

Micro-F1: 0.2206783890724182, Macro-F1: 0.20166733860969543

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.22857142984867096, Macro-Precision: 0.2173343151807785

Micro-Recall: 0.2591361999511719, Macro-Recall: nan

Micro-F1: 0.24289606511592865, Macro-F1: 0.2

-----------------------------------------------------------------------------------------
Micro-Precision: 0.2301710695028305, Macro-Precision: 0.243543341755867

Micro-Recall: 0.24749164283275604, Macro-Recall: nan

Micro-F1: 0.23851732909679413, Macro-F1: 0.23785264790058136

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.25183016061782837, Macro-Precision: 0.23899680376052856

Micro-Recall: 0.25424981117248535, Macro-Recall: nan

Micro-F1: 0.25303420424461365, Macro-F1: 0.22820930182933807

-----------------------------------------------------------------------------------------
[5,    10] test loss: 0.732
-----------------------------------------------------------------------------------------
Micro-Precision: 0.26752617955207825, Macro-Precision: nan

Micro-Recall: 0.2653876841068268, Macro-Recall: nan

Micro-F1: 0.2664526402950287

-----------------------------------------------------------------------------------------
Micro-Precision: 0.24580536782741547, Macro-Precision: nan

Micro-Recall: 0.2730661630630493, Macro-Recall: nan

Micro-F1: 0.25871965289115906, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.29204893112182617, Macro-Precision: 0.27209943532943726

Micro-Recall: 0.27053824067115784, Macro-Recall: 0.25546538829803467

Micro-F1: 0.2808823585510254, Macro-F1: 0.2614005208015442

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3079526126384735, Macro-Precision: 0.31218698620796204

Micro-Recall: 0.27471697330474854, Macro-Recall: 0.2873723804950714

Micro-F1: 0.2903869152069092, Macro-F1: 0

-----------------------------------------------------------------------------------------
Micro-Precision: 0.260155588388443, Macro-Precision: 0.26341959834098816

Micro-Recall: 0.2872137427330017, Macro-Recall: nan

Micro-F1: 0.2730158865451813, Macro-F1: 0.2679191529750824

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2888730466365814, Macro-Precision: 0.3027568459510803

Micro-Recall: 0.30705079436302185, Macro-Recall: nan

Micro-F1: 0.2976846694946289, Macro-F1: 0.30054184794425964

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.25671643018722534, Macro-Precision: 0.24977177381515503

Micro-Recall: 0.26791277527809143, Macro-Recall: nan

Micro-F1: 0.2621951103210449, Macro-F1: 0.24

-----------------------------------------------------------------------------------------
Micro-Precision: 0.2926470637321472, Macro-Precision: 0.2876233458518982

Micro-Recall: 0.2926470637321472, Macro-Recall: 0.2918437719345093

Micro-F1: 0.2926470637321472, Macro-F1: 0.28773006796836853

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.20923520624637604, Macro-Precision: 0.19907458126544952

Micro-Recall: 0.23927392065525055, Macro-Recall: nan

Micro-F1: 0.2232486456632614, Macro-F1: 0.2010040432214737

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2645778954029083, Macro-Precision: 0.266370564699173

Micro-Recall: 0.2886989414691925, Macro-Recall: 0.2971557378768921

Micro-F1: 0.2761

-----------------------------------------------------------------------------------------
Micro-Precision: 0.2644444406032562, Macro-Precision: 0.26108500361442566

Micro-Recall: 0.2748267948627472, Macro-Recall: nan

Micro-F1: 0.2695356607437134, Macro-F1: 0.25743192434310913

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.28738147020339966, Macro-Precision: nan

Micro-Recall: 0.31344470381736755, Macro-Recall: nan

Micro-F1: 0.2998477816581726, Macro-F1: nan

-----------------------------------------------------------------------------------------
[6,    40] test loss: 0.707
-----------------------------------------------------------------------------------------
Micro-Precision: 0.27331188321113586, Macro-Precision: 0.27053123712539673

Micro-Recall: 0.2823920249938965, Macro-Recall: 0.2712445557117462

Micro-F1: 0.2777777910232544, 

-----------------------------------------------------------------------------------------
Micro-Precision: 0.30898022651672363, Macro-Precision: 0.3078153431415558

Micro-Recall: 0.33115825057029724, Macro-Recall: nan

Micro-F1: 0.31968504190444946, Macro-F1: 0.30614522099494934

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2383720874786377, Macro-Precision: 0.22700589895248413

Micro-Recall: 0.2976406514644623, Macro-Recall: nan

Micro-F1: 0.2647296190261841, Macro-F1: 0.24294714629650116

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2151394486427307, Macro-Precision: 0.21487250924110413

Micro-Recall: 0.25194400548934937, Macro-Recall: 0.24289365112781525

Micro-F1: 0.2320916950702

-----------------------------------------------------------------------------------------
Micro-Precision: 0.30469372868537903, Macro-Precision: nan

Micro-Recall: 0.2948421835899353, Macro-Recall: 0.28940579295158386

Micro-F1: 0.299686998128891, Macro-F1: 0.29460418224334717

-----------------------------------------------------------------------------------------
[7,    25] test loss: 0.699
-----------------------------------------------------------------------------------------
Micro-Precision: 0.31139418482780457, Macro-Precision: 0.29869526624679565

Micro-Recall: 0.3651452362537384, Macro-Recall: nan

Micro-F1: 0.3361344635486603, Macro-F1: 0.3074894845485687

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.35497236251831055, Macro-Precision: nan

Micro-Recall: 0.37058398127555847, Macro-Recall: nan

Micro-F1: 0.36261022090911865,

-----------------------------------------------------------------------------------------
Micro-Precision: 0.2883755564689636, Macro-Precision: 0.27717074751853943

Micro-Recall: 0.32768839597702026, Macro-Recall: nan

Micro-F1: 0.3067776560783386, Macro-F1: 0.27939870953559875

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.24085576832294464, Macro-Precision: nan

Micro-Recall: 0.3144144117832184, Macro-Recall: nan

Micro-F1: 0.27276280522346497, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2783964276313782, Macro-Precision: 0.2597523629665375

Micro-Recall: 0.29667720198631287, Macro-Recall: nan

Micro-F1: 0.2872462570667267, Macro-F1: 0.2645552158355713

-------------

-----------------------------------------------------------------------------------------
Micro-Precision: 0.30618181824684143, Macro-Precision: 0.29783105850219727

Micro-Recall: 0.3283931314945221, Macro-Recall: nan

Micro-F1: 0.31689876317977905, Macro-F1: 0.2984883487224579

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.25145819783210754, Macro-Precision: 0.24296121299266815

Micro-Recall: 0.26963168382644653, Macro-Recall: 0.25770241022109985

Micro-F1: 0.2602280378341675, Macro-F1: 0.2487383633852005

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.30463576316833496, Macro-Precision: 0.30380287766456604

Micro-Recall: 0.32886505126953125, Macro-Recall: 0.33007755875587463

Micro-F1

-----------------------------------------------------------------------------------------
Micro-Precision: 0.30382776260375977, Macro-Precision: 0.28914472460746765

Micro-Recall: 0.3518005609512329, Macro-Recall: nan

Micro-F1: 0.3260590434074402, Macro-F1: 0.29276242852211

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.30320513248443604, Macro-Precision: 0.293804794549942

Micro-Recall: 0.3378571569919586, Macro-Recall: 0.32429152727127075

Micro-F1: 0.3195945918560028, Macro-F1: 0.30580633878707886

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2420727014541626, Macro-Precision: nan

Micro-Recall: 0.27193745970726013, Macro-Recall: 0.264064759016037

Micro-F1: 0.256137490272522, Mac

-----------------------------------------------------------------------------------------
Micro-Precision: 0.26307693123817444, Macro-Precision: nan

Micro-Recall: 0.2780487835407257, Macro-Recall: nan

Micro-F1: 0.2703557312488556, Macro-F1: nan

-----------------------------------------------------------------------------------------
[8,    55] test loss: 0.679
-----------------------------------------------------------------------------------------
| end of epoch   7 | time: 116.34s | valid loss  0.70 | valid ppl     2.01
-----------------------------------------------------------------------------------------
[9,     5] train loss: 0.684
[9,    10] train loss: 0.682
[9,    15] train loss: 0.684
[9,    20] train loss: 0.687
[9,    25] train loss: 0.702
[9,    30] train loss: 0.684
[9,    35] train loss: 0.694
[9,    40] train loss: 0.699
[9,    45] train loss: 0.689
[9,    50] train loss: 0.684
[9,    55] train loss: 0.685
[9,    60] train loss: 0.690
[9,    65] train loss: 0.688
[9

-----------------------------------------------------------------------------------------
Micro-Precision: 0.31532198190689087, Macro-Precision: nan

Micro-Recall: 0.3517753779888153, Macro-Recall: nan

Micro-F1: 0.3325527012348175, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.28432732820510864, Macro-Precision: 0.2977571189403534

Micro-Recall: 0.3125, Macro-Recall: 0.3236168324947357

Micro-F1: 0.29774871468544006, Macro-F1: 0.30677035450935364

-----------------------------------------------------------------------------------------
[9,    20] test loss: 0.670
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3045977056026459, Macro-Precision: 0.30472448468208313

Micro-Recall: 0.323664128780365, Macro-Recall: 0.31823253631591797

Micro-F1: 0.3138416111469269

-----------------------------------------------------------------------------------------
Micro-Precision: 0.302807480096817, Macro-Precision: 0.28512799739837646

Micro-Recall: 0.30816325545310974, Macro-Recall: 0.2921556830406189

Micro-F1: 0.30546191334724426, Macro-F1: 0.28621384501457214

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3361823260784149, Macro-Precision: 0.32992270588874817

Micro-Recall: 0.3713611364364624, Macro-Recall: nan

Micro-F1: 0.35289719700813293, Macro-F1: 0.32935604453086853

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.31148648262023926, Macro-Precision: 0.30879461765289307

Micro-Recall: 0.3316546678543091, Macro-Recall: 0.33800816535949707

Micro-F1: 

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3639097809791565, Macro-Precision: nan

Micro-Recall: 0.41296929121017456, Macro-Recall: nan

Micro-F1: 0.3868905007839203, Macro-F1: nan

-----------------------------------------------------------------------------------------
[10,     5] test loss: 0.654
-----------------------------------------------------------------------------------------
Micro-Precision: 0.32867133617401123, Macro-Precision: 0.3120744824409485

Micro-Recall: 0.3425101339817047, Macro-Recall: 0.32457253336906433

Micro-F1: 0.3354480564594269, Macro-F1: 0.31577807664871216

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3168458640575409, Macro-Precision: 0.29687246680259705

Micro-Recall: 0.3614063858985901, Macro-Recall: 0.34003496170043945

Micro-F1: 0.33

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3146139085292816, Macro-Precision: 0.2916567325592041

Micro-Recall: 0.340234637260437, Macro-Recall: 0.31274017691612244

Micro-F1: 0.32692307233810425, Macro-F1: 0.30040478706359863

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.296875, Macro-Precision: 0.2993662357330322

Micro-Recall: 0.3353798985481262, Macro-Recall: 0.34184110164642334

Micro-F1: 0.3149549663066864, Macro-F1: 0.31382203102111816

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3146905303001404, Macro-Precision: 0.2911665141582489

Micro-Recall: 0.3320220410823822, Macro-Recall: 0.3032536208629608

Micro-F1: 

-----------------------------------------------------------------------------------------
Micro-Precision: 0.31520992517471313, Macro-Precision: nan

Micro-Recall: 0.4024604558944702, Macro-Recall: nan

Micro-F1: 0.35353145003318787, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.26075267791748047, Macro-Precision: 0.2802054286003113

Micro-Recall: 0.3442768454551697, Macro-Recall: nan

Micro-F1: 0.2967495322227478, Macro-F1: 0.2970438003540039

-----------------------------------------------------------------------------------------
[10,    50] test loss: 0.672
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3129771053791046, Macro-Precision: 0.3180200755596161

Micro-Recall: 0.34506502747535706, Macro-Recall: 0.35616302490234375

Micro-F1: 0.32823872566223145,

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3250664174556732, Macro-Precision: 0.3153213858604431

Micro-Recall: 0.3136752247810364, Macro-Recall: 0.30906686186790466

Micro-F1: 0.31926923990249634, Macro-F1: 0.3077135980129242

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.41811320185661316, Macro-Precision: 0.41308704018592834

Micro-Recall: 0.3647136390209198, Macro-Recall: 0.37188106775283813

Micro-F1: 0.38959211111068726, Macro-F1: 0.3868838846683502

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.34056398272514343, Macro-Precision: 0.3420860171318054

Micro-Recall: 0.31483957171440125, Macro-Recall: 0.32449945807456

-----------------------------------------------------------------------------------------
Micro-Precision: 0.33622559905052185, Macro-Precision: 0.324625164270401

Micro-Recall: 0.3178400695323944, Macro-Recall: nan

Micro-F1: 0.3267744183540344, Macro-F1: 0.31539812684059143

-----------------------------------------------------------------------------------------
[11,    35] test loss: 0.654
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2925170063972473, Macro-Precision: nan

Micro-Recall: 0.29655173420906067, Macro-Recall: nan

Micro-F1: 0.2945205569267273, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.32826748490333557, Macro-Precision: nan

Micro-Recall: 0.3467094600200653, Macro-Recall: nan

Micro-F1: 0.33723652362823486, Macro-F1: nan

----------------

[12,    25] train loss: 0.660
[12,    30] train loss: 0.649
[12,    35] train loss: 0.649
[12,    40] train loss: 0.653
[12,    45] train loss: 0.652
[12,    50] train loss: 0.650
[12,    55] train loss: 0.657
[12,    60] train loss: 0.656
[12,    65] train loss: 0.653
[12,    70] train loss: 0.648
[12,    75] train loss: 0.648
[12,    80] train loss: 0.653
[12,    85] train loss: 0.658
[12,    90] train loss: 0.656
[12,    95] train loss: 0.648
[12,   100] train loss: 0.658
[12,   105] train loss: 0.648
[12,   110] train loss: 0.652
[12,   115] train loss: 0.671
[12,   120] train loss: 0.656
[12,   125] train loss: 0.658
[12,   130] train loss: 0.654
[12,   135] train loss: 0.648
[12,   140] train loss: 0.646
[12,   145] train loss: 0.656
[12,   150] train loss: 0.661
[12,   155] train loss: 0.641
[12,   160] train loss: 0.649
[12,   165] train loss: 0.651
Finished Training
-----------------------------------------------------------------------------------------
Micro-Precision: 0.385

-----------------------------------------------------------------------------------------
Micro-Precision: 0.29952457547187805, Macro-Precision: 0.2804766595363617

Micro-Recall: 0.33070865273475647, Macro-Recall: nan

Micro-F1: 0.314345121383667, Macro-F1: 0.28390830755233765

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.31587302684783936, Macro-Precision: 0.30648288130760193

Micro-Recall: 0.32384052872657776, Macro-Recall: 0.3253710865974426

Micro-F1: 0.31980714201927185, Macro-F1: 0.31278371810913086

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.30640000104904175, Macro-Precision: nan

Micro-Recall: 0.31062448024749756, Macro-Recall: nan

Micro-F1: 0.3084977865219116, Macro-F1: 

-----------------------------------------------------------------------------------------
Micro-Precision: 0.31647777557373047, Macro-Precision: 0.3008909523487091

Micro-Recall: 0.3249776065349579, Macro-Recall: 0.30185025930404663

Micro-F1: 0.3206713795661926, Macro-F1: 0.29882609844207764

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.30808401107788086, Macro-Precision: 0.3218681812286377

Micro-Recall: 0.3055555522441864, Macro-Recall: 0.32129043340682983

Micro-F1: 0.306814581155777, Macro-F1: 0.31794777512550354

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.35660091042518616, Macro-Precision: 0.3467491567134857

Micro-Recall: 0.3547169864177704, Macro-Recall: nan

Micro-F1: 0.3

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3134087324142456, Macro-Precision: nan

Micro-Recall: 0.3622782528400421, Macro-Recall: nan

Micro-F1: 0.3360762298107147, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.33003532886505127, Macro-Precision: 0.3387446403503418

Micro-Recall: 0.3376717269420624, Macro-Recall: 0.34038159251213074

Micro-F1: 0.33380985260009766, Macro-F1: 0.3354630172252655

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.34210526943206787, Macro-Precision: 0.36738064885139465

Micro-Recall: 0.33847635984420776, Macro-Recall: nan

Micro-F1: 0.3402811586856842, Macro-F1: 0.3564115166664123

-----------------------------------------------------------------------------------------
Micro-Precision: 0.34266263246536255, Macro-Precision: 0.33916330337524414

Micro-Recall: 0.34448668360710144, Macro-Recall: 0.3490457236766815

Micro-F1: 0.3435722291469574, Macro-F1: 0.34155213832855225

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2602739632129669, Macro-Precision: nan

Micro-Recall: 0.33101046085357666, Macro-Recall: nan

Micro-F1: 0.29141104221343994, Macro-F1: nan

-----------------------------------------------------------------------------------------
[13,    30] test loss: 0.655
-----------------------------------------------------------------------------------------
Micro-Precision: 0.36044973134994507, Macro-Precision: 0.3394298255443573

Micro-Recall: 0.36874154210090637, Macro-Recall: nan

Micro-F1: 0.364548504352569

-----------------------------------------------------------------------------------------
Micro-Precision: 0.30180805921554565, Macro-Precision: 0.2899596691131592

Micro-Recall: 0.3181818127632141, Macro-Recall: nan

Micro-F1: 0.30977872014045715, Macro-F1: 0.29803040623664856

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.35820895433425903, Macro-Precision: nan

Micro-Recall: 0.3563474416732788, Macro-Recall: 0.3628469705581665

Micro-F1: 0.3572757840156555, Macro-F1: 0.36289656162261963

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3055172562599182, Macro-Precision: 0.3065609633922577

Micro-Recall: 0.321480393409729, Macro-Recall: 0.3166470229625702

Micro-F1: 0.3132956027984619, 

-----------------------------------------------------------------------------------------
Micro-Precision: 0.27248501777648926, Macro-Precision: 0.2516711354255676

Micro-Recall: 0.3277243673801422, Macro-Recall: nan

Micro-F1: 0.29756274819374084, Macro-F1: 0.26268887519836426

-----------------------------------------------------------------------------------------
[14,    15] test loss: 0.643
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3123612105846405, Macro-Precision: nan

Micro-Recall: 0.3543241024017334, Macro-Recall: nan

Micro-F1: 0.3320220410823822, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.33358603715896606, Macro-Precision: 0.31540024280548096

Micro-Recall: 0.37735849618911743, Macro-Recall: nan

Micro-F1: 0.3541247546672821, Macro-F1: 0.32

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3288537561893463, Macro-Precision: 0.3323066234588623

Micro-Recall: 0.3717604875564575, Macro-Recall: nan

Micro-F1: 0.34899330139160156, Macro-F1: 0.3422170877456665

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.26573997735977173, Macro-Precision: nan

Micro-Recall: 0.3140096664428711, Macro-Recall: nan

Micro-F1: 0.28786537051200867, Macro-F1: 0.26937785744667053

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3501400649547577, Macro-Precision: 0.3531079888343811

Micro-Recall: 0.40551501512527466, Macro-Recall: nan

Micro-F1: 0.3757985830307007, Macro-F1: 0.3546156883239746


[15,   140] train loss: 0.613
[15,   145] train loss: 0.620
[15,   150] train loss: 0.617
[15,   155] train loss: 0.614
[15,   160] train loss: 0.636
[15,   165] train loss: 0.627
Finished Training
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3274894952774048, Macro-Precision: 0.3227660655975342

Micro-Recall: 0.31426647305488586, Macro-Recall: 0.31628087162971497

Micro-F1: 0.32074177265167236, Macro-F1: 0.31492340564727783

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.346712201833725, Macro-Precision: 0.3265223801136017

Micro-Recall: 0.3212025463581085, Macro-Recall: 0.2946632206439972

Micro-F1: 0.3334702253341675, Macro-F1: 0.30643483996391296

-----------------------------------------------------------------------------------------
----------------------------------

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3348484933376312, Macro-Precision: nan

Micro-Recall: 0.3389570415019989, Macro-Recall: nan

Micro-F1: 0.33689025044441223, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.289055198431015, Macro-Precision: 0.26431044936180115

Micro-Recall: 0.2816773056983948, Macro-Recall: 0.25539493560791016

Micro-F1: 0.2853185534477234, Macro-F1: 0.25720149278640747

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3177124559879303, Macro-Precision: 0.3039034903049469

Micro-Recall: 0.32599836587905884, Macro-Recall: nan

Micro-F1: 0.3218020796775818, Macro-F1: 0.2932838201522827



-----------------------------------------------------------------------------------------
Micro-Precision: 0.36873406171798706, Macro-Precision: 0.3858030438423157

Micro-Recall: 0.33539411425590515, Macro-Recall: 0.34487679600715637

Micro-F1: 0.3512747883796692, Macro-F1: 0.3572148084640503

-----------------------------------------------------------------------------------------
[15,    45] test loss: 0.654
-----------------------------------------------------------------------------------------
Micro-Precision: 0.34304746985435486, Macro-Precision: nan

Micro-Recall: 0.3371522128582001, Macro-Recall: nan

Micro-F1: 0.3400743007659912, Macro-F1: 0.3047119379043579

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.39904987812042236, Macro-Precision: 0.3965338170528412

Micro-Recall: 0.3579545319080353, Macro-Recall: 0.3500909209251404



-----------------------------------------------------------------------------------------
Micro-Precision: 0.3668639063835144, Macro-Precision: 0.3382145166397095

Micro-Recall: 0.38003501296043396, Macro-Recall: 0.35820648074150085

Micro-F1: 0.3733333349227905, Macro-F1: 0.3415278494358063

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.31987813115119934, Macro-Precision: 0.32262811064720154

Micro-Recall: 0.3849679231643677, Macro-Recall: nan

Micro-F1: 0.34941762685775757, Macro-F1: 0.3192461133003235

-----------------------------------------------------------------------------------------
[16,    10] test loss: 0.619
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2960470914840698, Macro-Precision: 0.29467248916625977

Micro-Recall: 0.3217550218105316, Macro-Recall: nan


-----------------------------------------------------------------------------------------
Micro-Precision: 0.2769097089767456, Macro-Precision: 0.2636127173900604

Micro-Recall: 0.2783595025539398, Macro-Recall: 0.2555888891220093

Micro-F1: 0.2776327133178711, Macro-F1: 0.2570701241493225

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3690265417098999, Macro-Precision: 0.3796781003475189

Micro-Recall: 0.357326477766037, Macro-Recall: 0.35291963815689087

Micro-F1: 0.36308228969573975, Macro-F1: 0.3620027005672455

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.35881906747817993, Macro-Precision: 0.35766440629959106

Micro-Recall: 0.390444815158844, Macro-Recall: nan

Micro-F1: 0.37396

-----------------------------------------------------------------------------------------
Micro-Precision: 0.34438157081604004, Macro-Precision: 0.33086463809013367

Micro-Recall: 0.3441033959388733, Macro-Recall: 0.31676122546195984

Micro-F1: 0.3442424237728119, Macro-F1: 0.31952598690986633

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3265145421028137, Macro-Precision: 0.3195367157459259

Micro-Recall: 0.35470086336135864, Macro-Recall: 0.3436727821826935

Micro-F1: 0.34002459049224854, Macro-F1: 0.327850878238678

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3333333432674408, Macro-Precision: 0.3280140161514282

Micro-Recall: 0.37964776158332825, Macro-Recall: nan

Micro-F1: 0.3

-----------------------------------------------------------------------------------------
Micro-Precision: 0.4126620888710022, Macro-Precision: 0.3931044638156891

Micro-Recall: 0.3828733265399933, Macro-Recall: 0.3678458333015442

Micro-F1: 0.3972099721431732, Macro-F1: 0.37420666217803955

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3054393231868744, Macro-Precision: 0.32166531682014465

Micro-Recall: 0.33890435099601746, Macro-Recall: nan

Micro-F1: 0.3213028311729431, Macro-F1: 0.323604941368103

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2831050157546997, Macro-Precision: 0.29830971360206604

Micro-Recall: 0.2978382706642151, Macro-Recall: 0.3106190860271454

Micro-F1: 0.2902

-----------------------------------------------------------------------------------------
Micro-Precision: 0.32486528158187866, Macro-Precision: 0.31167855858802795

Micro-Recall: 0.3403225839138031, Macro-Recall: 0.3234027028083801

Micro-F1: 0.33241432905197144, Macro-F1: 0.30793172121047974

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.30877193808555603, Macro-Precision: 0.35559946298599243

Micro-Recall: 0.30318689346313477, Macro-Recall: 0.3385430574417114

Micro-F1: 0.3059539198875427, Macro-F1: 0.33933737874031067

-----------------------------------------------------------------------------------------
[17,    40] test loss: 0.644
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3547169864177704, Macro-Precision: 0.34917327761650085

Micro-Recall: 0.3824247419834137, 

-----------------------------------------------------------------------------------------
Micro-Precision: 0.2878338396549225, Macro-Precision: 0.28327643871307373

Micro-Recall: 0.3005422055721283, Macro-Recall: 0.2930595576763153

Micro-F1: 0.29405078291893005, Macro-F1: 0.2842530310153961

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.31921109557151794, Macro-Precision: 0.3189978003501892

Micro-Recall: 0.34904152154922485, Macro-Recall: 0.3520031273365021

Micro-F1: 0.3334605097770691, Macro-F1: 0.3284159302711487

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.29890909790992737, Macro-Precision: 0.29338499903678894

Micro-Recall: 0.3726201355457306, Macro-Recall: nan

Micro-F1: 0.3

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3234672248363495, Macro-Precision: 0.31507331132888794

Micro-Recall: 0.33528122305870056, Macro-Recall: 0.3298226296901703

Micro-F1: 0.32926830649375916, Macro-F1: 0.31678879261016846

-----------------------------------------------------------------------------------------
[18,    25] test loss: 0.607
-----------------------------------------------------------------------------------------
Micro-Precision: 0.37809452414512634, Macro-Precision: 0.37264689803123474

Micro-Recall: 0.3956044018268585, Macro-Recall: 0.40328285098075867

Micro-F1: 0.38665133714675903, Macro-F1: 0.3805897831916809

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.36256447434425354, Macro-Precision: 0.34294697642326355

Micro-Recall: 0.36826348304748535

-----------------------------------------------------------------------------------------
Micro-Precision: 0.39886271953582764, Macro-Precision: 0.38513660430908203

Micro-Recall: 0.37944358587265015, Macro-Recall: 0.37059080600738525

Micro-F1: 0.3889108896255493, Macro-F1: 0.37447115778923035

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.30295565724372864, Macro-Precision: 0.2884252071380615

Micro-Recall: 0.37011033296585083, Macro-Recall: nan

Micro-F1: 0.33318284153938293, Macro-F1: 0.29672494530677795

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.38568297028541565, Macro-Precision: 0.3749730587005615

Micro-Recall: 0.3960990309715271, Macro-Recall: nan

Micro-F1: 0.390821605920

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3648068606853485, Macro-Precision: 0.3620712459087372

Micro-Recall: 0.37199124693870544, Macro-Recall: nan

Micro-F1: 0.36836403608322144, Macro-F1: 0.3585646450519562

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.29532402753829956, Macro-Precision: nan

Micro-Recall: 0.3412322402000427, Macro-Recall: nan

Micro-F1: 0.31662270426750183, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.29741379618644714, Macro-Precision: nan

Micro-Recall: 0.34271523356437683, Macro-Recall: nan

Micro-F1: 0.318461537361145, Macro-F1: nan

--------------------------------------------

-----------------------------------------------------------------------------------------
Micro-Precision: 0.40934932231903076, Macro-Precision: 0.41179072856903076

Micro-Recall: 0.4257555902004242, Macro-Recall: 0.42190903425216675

Micro-F1: 0.417391300201416, Macro-F1: 0.4137822985649109

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.35316845774650574, Macro-Precision: 0.34863603115081787

Micro-Recall: 0.3742833733558655, Macro-Recall: 0.3607860207557678

Micro-F1: 0.36341947317123413, Macro-F1: 0.3520265221595764

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.41239479184150696, Macro-Precision: 0.40856191515922546

Micro-Recall: 0.4139784872531891, Macro-Recall: 0.412648826837539

-----------------------------------------------------------------------------------------
Micro-Precision: 0.2772277295589447, Macro-Precision: nan

Micro-Recall: 0.341143399477005, Macro-Recall: nan

Micro-F1: 0.30588236451148987, Macro-F1: 0.2754981219768524

-----------------------------------------------------------------------------------------
[19,    55] test loss: 0.633
-----------------------------------------------------------------------------------------
| end of epoch  18 | time: 115.55s | valid loss  0.64 | valid ppl     1.89
-----------------------------------------------------------------------------------------
[20,     5] train loss: 0.603
[20,    10] train loss: 0.588
[20,    15] train loss: 0.589
[20,    20] train loss: 0.593
[20,    25] train loss: 0.593
[20,    30] train loss: 0.583
[20,    35] train loss: 0.594
[20,    40] train loss: 0.598
[20,    45] train loss: 0.602
[20,    50] train loss: 0.599
[20,    55] train loss: 0.599
[20,    60] train loss: 0.584
[20,

-----------------------------------------------------------------------------------------
Micro-Precision: 0.3705882430076599, Macro-Precision: nan

Micro-Recall: 0.3649529218673706, Macro-Recall: 0.34689605236053467

Micro-F1: 0.36774900555610657, Macro-F1: 0.3475441038608551

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.33598726987838745, Macro-Precision: 0.3238956928253174

Micro-Recall: 0.3425324559211731, Macro-Recall: 0.32822805643081665

Micro-F1: 0.3392283022403717, Macro-F1: 0.32192865014076233

-----------------------------------------------------------------------------------------
[20,    20] test loss: 0.636
-----------------------------------------------------------------------------------------
Micro-Precision: 0.3614063858985901, Macro-Precision: nan

Micro-Recall: 0.39855724573135376, Macro-Recall: nan

Micro-F1: 0.37

-----------------------------------------------------------------------------------------
Micro-Precision: 0.31506848335266113, Macro-Precision: 0.3247487246990204

Micro-Recall: 0.30363035202026367, Macro-Recall: 0.30446839332580566

Micro-F1: 0.30924370884895325, Macro-F1: 0.3106740415096283

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.34696969389915466, Macro-Precision: 0.3525993824005127

Micro-Recall: 0.364359587430954, Macro-Recall: nan

Micro-F1: 0.3554520905017853, Macro-F1: 0.346609890460968

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.4028925597667694, Macro-Precision: 0.3859015703201294

Micro-Recall: 0.3622291088104248, Macro-Recall: 0.3434162139892578

Micro-F1: 0.3814

In [19]:
x = torch.randn(2,3,5)
m = torch.nn.Softmax(dim=2)
print(x.shape)
m(x)
# torch.nn.Softmax(x, dim=0)

torch.Size([2, 3, 5])


tensor([[[0.0733, 0.0961, 0.2756, 0.3893, 0.1657],
         [0.3559, 0.0187, 0.1055, 0.1910, 0.3289],
         [0.1496, 0.7128, 0.0404, 0.0556, 0.0417]],

        [[0.1742, 0.0728, 0.0993, 0.1983, 0.4553],
         [0.0665, 0.0258, 0.0714, 0.8028, 0.0336],
         [0.1200, 0.0869, 0.1828, 0.1306, 0.4798]]])

In [20]:
import torch.nn as nn
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
cos(input1, input2).shape

torch.Size([100])

In [21]:
# in1 shape [batch_size, seq_len, embed_size]
# in1 shape [16, 256, 256]
# in2 shape [1, entity_vocab, embed_size]
# in2 shape [37456, 256]
import torch
from scipy import spatial
import numpy as np

In [22]:
a = torch.randn(16, 256, 256)
# a = torch.randn(3, 5)
b = torch.randn(37456, 256) # different row number, for the fun

In [23]:
%%timeit
# Given that cos_sim(u, v) = dot(u, v) / (norm(u) * norm(v))
#                          = dot(u / norm(u), v / norm(v))
# We fist normalize the rows, before computing their dot products via transposition:
a_norm = a / a.norm(dim=2)[:, :, None]
# a_norm = a / a.norm(dim=1)[:, None]
b_norm = b / b.norm(dim=1)[:, None]
res = torch.matmul(a_norm, b_norm.transpose(0,1))
# print(res.shape)
print(res)
#  0.9978 -0.9986 -0.9985
# -0.8629  0.9172  0.9172

tensor([[[ 7.0350e-02, -7.8589e-02, -5.9236e-03,  ..., -1.6681e-01,
           4.4659e-02, -6.5032e-02],
         [ 4.3987e-02,  1.0746e-02,  4.8974e-02,  ..., -4.6926e-02,
          -1.3284e-02,  1.0489e-02],
         [-5.6106e-02,  7.1678e-02, -9.5512e-02,  ..., -5.5202e-02,
           3.6239e-02, -1.0801e-02],
         ...,
         [-7.2728e-02, -7.9300e-02, -2.6555e-02,  ..., -1.0437e-01,
           1.0003e-01,  1.4226e-01],
         [ 1.3397e-01, -2.2071e-02,  1.0243e-02,  ..., -3.9050e-02,
          -5.8970e-02,  7.0246e-02],
         [-9.6256e-02, -6.0374e-03, -1.7477e-01,  ...,  7.8413e-02,
          -3.0595e-02,  4.8560e-02]],

        [[ 5.6643e-02, -2.9540e-02, -3.3895e-04,  ...,  1.0345e-02,
           7.6199e-02,  2.6074e-02],
         [ 3.5414e-02, -3.5726e-02, -5.6459e-02,  ..., -8.8922e-02,
           2.9471e-02,  1.2311e-02],
         [ 6.3935e-02,  1.7522e-02, -2.8349e-02,  ..., -5.7184e-02,
           4.7690e-02, -2.8668e-02],
         ...,
         [ 7.1103e-02,  1

tensor([[[ 7.0350e-02, -7.8589e-02, -5.9236e-03,  ..., -1.6681e-01,
           4.4659e-02, -6.5032e-02],
         [ 4.3987e-02,  1.0746e-02,  4.8974e-02,  ..., -4.6926e-02,
          -1.3284e-02,  1.0489e-02],
         [-5.6106e-02,  7.1678e-02, -9.5512e-02,  ..., -5.5202e-02,
           3.6239e-02, -1.0801e-02],
         ...,
         [-7.2728e-02, -7.9300e-02, -2.6555e-02,  ..., -1.0437e-01,
           1.0003e-01,  1.4226e-01],
         [ 1.3397e-01, -2.2071e-02,  1.0243e-02,  ..., -3.9050e-02,
          -5.8970e-02,  7.0246e-02],
         [-9.6256e-02, -6.0374e-03, -1.7477e-01,  ...,  7.8413e-02,
          -3.0595e-02,  4.8560e-02]],

        [[ 5.6643e-02, -2.9540e-02, -3.3895e-04,  ...,  1.0345e-02,
           7.6199e-02,  2.6074e-02],
         [ 3.5414e-02, -3.5726e-02, -5.6459e-02,  ..., -8.8922e-02,
           2.9471e-02,  1.2311e-02],
         [ 6.3935e-02,  1.7522e-02, -2.8349e-02,  ..., -5.7184e-02,
           4.7690e-02, -2.8668e-02],
         ...,
         [ 7.1103e-02,  1

tensor([[[ 7.0350e-02, -7.8589e-02, -5.9236e-03,  ..., -1.6681e-01,
           4.4659e-02, -6.5032e-02],
         [ 4.3987e-02,  1.0746e-02,  4.8974e-02,  ..., -4.6926e-02,
          -1.3284e-02,  1.0489e-02],
         [-5.6106e-02,  7.1678e-02, -9.5512e-02,  ..., -5.5202e-02,
           3.6239e-02, -1.0801e-02],
         ...,
         [-7.2728e-02, -7.9300e-02, -2.6555e-02,  ..., -1.0437e-01,
           1.0003e-01,  1.4226e-01],
         [ 1.3397e-01, -2.2071e-02,  1.0243e-02,  ..., -3.9050e-02,
          -5.8970e-02,  7.0246e-02],
         [-9.6256e-02, -6.0374e-03, -1.7477e-01,  ...,  7.8413e-02,
          -3.0595e-02,  4.8560e-02]],

        [[ 5.6643e-02, -2.9540e-02, -3.3895e-04,  ...,  1.0345e-02,
           7.6199e-02,  2.6074e-02],
         [ 3.5414e-02, -3.5726e-02, -5.6459e-02,  ..., -8.8922e-02,
           2.9471e-02,  1.2311e-02],
         [ 6.3935e-02,  1.7522e-02, -2.8349e-02,  ..., -5.7184e-02,
           4.7690e-02, -2.8668e-02],
         ...,
         [ 7.1103e-02,  1

In [24]:
# %%timeit
# # -------
# # Let's verify with numpy/scipy if our computations are correct:
# a_n = a.numpy()
# b_n = b.numpy()
# res_n = np.zeros((3,5,50))

# for i in range(3):
#     for j in range(5):
#         for k in range(50):
#             # cos_sim(u, v) = 1 - cos_dist(u, v)
#             res_n[i, j, k] = 1 - spatial.distance.cosine(a_n[i, j], b_n[k])
# print(res_n)
# # [[ 0.9978022  -0.99855876 -0.99854881]
# #  [-0.86285472  0.91716063  0.9172349 ]]


from sklearn.metrics.pairwise import cosine_similarity
a_n = a.numpy()
b_n = b.numpy()

res_n = np.zeros((16,256, 37456))
for i in range(3):
    res_n[i, :, :] = cosine_similarity(a_n[i,:,:], b_n)
res_n

array([[[ 7.03498796e-02, -7.85887539e-02, -5.92361670e-03, ...,
         -1.66807175e-01,  4.46585976e-02, -6.50324896e-02],
        [ 4.39871177e-02,  1.07464558e-02,  4.89737988e-02, ...,
         -4.69259731e-02, -1.32835405e-02,  1.04894387e-02],
        [-5.61063439e-02,  7.16783926e-02, -9.55120176e-02, ...,
         -5.52019738e-02,  3.62394527e-02, -1.08005898e-02],
        ...,
        [-7.27276579e-02, -7.93004110e-02, -2.65550315e-02, ...,
         -1.04372263e-01,  1.00031167e-01,  1.42261788e-01],
        [ 1.33971170e-01, -2.20705699e-02,  1.02428058e-02, ...,
         -3.90496552e-02, -5.89703135e-02,  7.02463835e-02],
        [-9.62561071e-02, -6.03739591e-03, -1.74768373e-01, ...,
          7.84125999e-02, -3.05945352e-02,  4.85599861e-02]],

       [[ 5.66428304e-02, -2.95396857e-02, -3.38946324e-04, ...,
          1.03453314e-02,  7.61987343e-02,  2.60740984e-02],
        [ 3.54144052e-02, -3.57257277e-02, -5.64586632e-02, ...,
         -8.89221430e-02,  2.94706896e

In [25]:
# a_n.shape
b_n.shape

(37456, 256)

In [26]:
matrix1 = np.array([[[1,1,1],[1,1,1],[1,2,1]],[[1,2,1],[1,1,1],[1,1,1]]])
matrix2 = np.array([[1,1,1],[1,2,1], [1,2,1], [1,2,1]])
print(matrix1.shape)
print(matrix2.shape)

(2, 3, 3)
(4, 3)


In [27]:
import scipy.spatial as sp
out = np.zeros((2,3,4))
for i in range(2):
    out[i, :, :] = cosine_similarity(matrix1[i], matrix2, 'cosine')

In [28]:
out.shape

(2, 3, 4)

In [29]:
out

array([[[1.        , 0.94280904, 0.94280904, 0.94280904],
        [1.        , 0.94280904, 0.94280904, 0.94280904],
        [0.94280904, 1.        , 1.        , 1.        ]],

       [[0.94280904, 1.        , 1.        , 1.        ],
        [1.        , 0.94280904, 0.94280904, 0.94280904],
        [1.        , 0.94280904, 0.94280904, 0.94280904]]])

In [30]:
import numpy as np
z = torch.randint(-1,3, size=(10,)) 
print(z)


tensor([ 1,  2, -1,  0,  0,  1,  0,  1,  2,  0])


In [31]:
x = torch.tensor([[1,4,6],[3,5,6]])
embedding = torch.randint(10, size=(10, 10))
embedding

tensor([[8, 3, 2, 5, 1, 9, 6, 8, 4, 9],
        [2, 6, 8, 1, 5, 1, 6, 5, 7, 8],
        [4, 3, 3, 5, 7, 0, 7, 1, 2, 3],
        [3, 7, 5, 1, 8, 3, 5, 7, 9, 8],
        [6, 8, 4, 1, 0, 3, 5, 6, 6, 3],
        [2, 4, 8, 4, 2, 3, 7, 6, 3, 6],
        [8, 3, 3, 8, 2, 9, 2, 5, 7, 7],
        [0, 8, 4, 3, 9, 5, 3, 9, 0, 6],
        [6, 8, 6, 0, 2, 0, 7, 5, 4, 8],
        [9, 6, 7, 7, 1, 8, 6, 4, 2, 9]])

In [32]:
embedding[x, :]

tensor([[[2, 6, 8, 1, 5, 1, 6, 5, 7, 8],
         [6, 8, 4, 1, 0, 3, 5, 6, 6, 3],
         [8, 3, 3, 8, 2, 9, 2, 5, 7, 7]],

        [[3, 7, 5, 1, 8, 3, 5, 7, 9, 8],
         [2, 4, 8, 4, 2, 3, 7, 6, 3, 6],
         [8, 3, 3, 8, 2, 9, 2, 5, 7, 7]]])

In [33]:
x = torch.randint(-1, 5, size  = (10,10))
idxs = (x == -1).nonzero()
print(x)
idxs

# new_arr_0 = x[x!=-1]
# print(idxs)
# new_arr_0

# ##### 
# x = torch.randint(-1, 5, size  = (5,10)).float()
# y = torch.randint(-1, 5, size  = (5,10)).float()
# import torch.nn.functional as F
# F.cosine_similarity(x, y, dim=1)

tensor([[ 3, -1,  3,  4,  4,  0,  3, -1,  0,  0],
        [ 2,  0,  0,  1,  1,  3, -1,  4,  4, -1],
        [ 2,  2,  4,  2,  2, -1, -1,  0,  2,  1],
        [ 1,  4,  1,  3,  2,  2, -1,  4,  1,  2],
        [-1,  2,  4,  2,  1,  3, -1,  1,  1,  3],
        [-1,  2,  2,  4,  1,  0,  3,  3,  1,  0],
        [ 1,  4,  4, -1, -1,  2,  2,  3, -1,  2],
        [ 2,  2,  4, -1, -1,  1,  0,  4,  4,  4],
        [-1,  0,  4,  3,  2,  2,  2,  0, -1,  4],
        [ 2,  2,  1,  3, -1,  2,  4,  1,  4,  4]])


tensor([[0, 1],
        [0, 7],
        [1, 6],
        [1, 9],
        [2, 5],
        [2, 6],
        [3, 6],
        [4, 0],
        [4, 6],
        [5, 0],
        [6, 3],
        [6, 4],
        [6, 8],
        [7, 3],
        [7, 4],
        [8, 0],
        [8, 8],
        [9, 4]])

In [34]:
y = torch.randint(10, size=(10, 10, 10))
y[idxs[:, 0], idxs[:, 1]].shape
y[idxs[:, 0], idxs[:, 1]]

tensor([[5, 6, 8, 6, 8, 5, 1, 1, 1, 4],
        [6, 1, 1, 7, 8, 9, 9, 7, 5, 1],
        [3, 0, 6, 8, 3, 9, 3, 2, 2, 4],
        [0, 9, 8, 1, 4, 3, 1, 0, 2, 0],
        [7, 5, 1, 4, 3, 6, 3, 2, 4, 0],
        [0, 1, 2, 8, 1, 2, 3, 1, 8, 5],
        [6, 6, 4, 6, 8, 1, 3, 1, 3, 8],
        [1, 7, 4, 2, 7, 0, 6, 1, 7, 8],
        [4, 7, 2, 5, 0, 2, 2, 1, 5, 2],
        [7, 5, 9, 3, 4, 5, 7, 1, 2, 5],
        [0, 9, 4, 1, 6, 0, 9, 9, 5, 6],
        [7, 4, 2, 8, 8, 1, 8, 3, 6, 0],
        [5, 0, 3, 4, 8, 1, 9, 1, 7, 5],
        [3, 9, 2, 4, 7, 1, 2, 9, 8, 6],
        [2, 7, 1, 9, 2, 3, 1, 9, 9, 3],
        [3, 2, 9, 3, 6, 2, 3, 3, 6, 6],
        [8, 7, 8, 9, 7, 5, 0, 3, 7, 4],
        [4, 2, 4, 4, 3, 4, 7, 4, 6, 7]])

In [35]:
import torch.nn as nn
# 2D loss example (used, for example, with image inputs)
N, C = 5, 4
loss = nn.NLLLoss()
# input is of size N x C x height x width
data = torch.randn(N, 16, 10, 10)
conv = nn.Conv2d(16, C, (3, 3))
m = nn.LogSoftmax(dim=1)
# each element in target has to have 0 <= value < C
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
output = loss(m(conv(data)), target)
output.backward()

In [36]:
print(target.shape)
m(conv(data)).shape

torch.Size([5, 8, 8])


torch.Size([5, 4, 8, 8])

In [37]:
# attention_mask torch.Size([27,256])
# mention_pred torch.Size([27, 256, 3])
# entity_pred torch.Size([27, 256, 256])
# bio_tags torch.Size([27, 256])
# entity_ids torch.Size([27, 256])

# unpadded_men_pred torch.Size([25, 3, 256, 256])
# unpadded_bio_tags torch.Size([25, 256, 256])    


attention_mask = torch.randint(0,2,size=(3,5))
# mention_pred = torch.randn(3,5)
losses = torch.randn(3,5)
# bio_tags = torch.randn(3,5)
print(attention_mask)
print(losses)
# print(bio_tags)
num_unpadded = torch.sum(attention_mask, dim=1)
print("num_unpadded",num_unpadded)
x = torch.where(attention_mask == 1, losses, torch.tensor(0.).float())
print(torch.sum(x, dim=1))
torch.sum(torch.sum(x, dim=1)/num_unpadded)/len(x)
# print(torch.sum(x, dim=1))
# torch.sum(torch.sum(x, dim=1) / num_unpadded)

#     unpadded_men_pred = mention_pred[attention_mask,:].permute(0,3,1,2).contiguous()
# mention_pred[attention_mask,:]

    
    
    

tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 0, 1],
        [1, 1, 1, 1, 1]])
tensor([[ 1.6735,  0.2509,  0.8096,  0.5903,  0.3839],
        [ 0.0505,  1.0096,  0.4159, -0.8981, -2.6978],
        [ 1.3066, -0.6792,  0.1163,  0.2207, -0.2862]])
num_unpadded tensor([0, 4, 5])
tensor([ 0.0000, -1.2217,  0.6781])


tensor(nan)

In [38]:
import numpy as np
# Construct a random 50x50 RGB image    
image = np.random.random((50, 50, 3))

# Construct mask according to some condition;
# in this case, select all pixels with a red value > 0.3
mask = image[..., 0] > 0.3
mask.shape

(50, 50)

In [39]:
def compute_metrics(mention_preds, entity_preds, bio_tags, entity_ids, pretrained_entity_embedding):
    '''Computes both macro and micro confusion matrices given torch tensor outputs from our model
    '''
    # mention_pred Size([27, 256, 3])
    # entity_pred torch.Size([27, 256, 256])

    nb_classes = 2
    micro_confusion_matrix = torch.zeros(nb_classes, nb_classes)
    # micro_confusion_matrix [[TP, FN], [FP, TN]]
#     for t, p in zip(classes.view(-1), preds.view(-1)):
#         confusion_matrix[t.long(), p.long()] += 1
        
    macro_confusion_matrix = []        
    
    # first we need to find the max of the BIO tagging we have
    pred_bio_tags = torch.argmax(mention_preds, dim=2)
    # pred_bio_tags torch.Size([28, 256])
    
    # TODO: Remove after testing
    # pred_bio_tags = torch.randint(0,2, size=(28,256))

    # create two lists of spans
    pred_span_list = get_bi_spans(pred_bio_tags)
    target_span_list = get_bi_spans(bio_tags)    
    
    for batch_idx, (pred_spans, target_spans) in enumerate(zip(pred_span_list, target_span_list)):
        tp_matches = np.zeros(target_spans.shape[0], dtype=bool)
        
        batch_confusion_matrix = torch.zeros(nb_classes, nb_classes)

        # do the STRONG MATCHING here
        for span_idx in range(pred_spans.shape[0]):
            if torch.all(torch.tensor(pred_spans[span_idx, :] == target_spans), dim=1).any():
                seq_idx = pred_spans[span_idx, 0]
                # we use the start span idx to get the entity embedding and compare it to the pretrained embeds
#                 print(entity_preds[batch_idx, seq_idx, :].shape)
                cos_sim = sim_matrix(entity_preds[batch_idx, seq_idx, :].unsqueeze(0).unsqueeze(0), 
                                     pretrained_entity_embedding)
                # this should return something of size [1, 1, sim_vals]
                pred_ent_embed_id = torch.argmax(torch.squeeze(cos_sim))
                target_ent_embed_id = entity_ids[batch_idx, seq_idx]
                # just to test: 
                # TODO: REMOVE
                if batch_idx == 1 and seq_idx == 0:
                    target_ent_embed_id = 2
                # TODO: ENDREMOVE
                if pred_ent_embed_id == target_ent_embed_id:
                    # TP
                    micro_confusion_matrix[0,0] += 1
                    batch_confusion_matrix[0,0] += 1 
                    tp_matches[span_idx] = True
                else:
                    # FP
                    micro_confusion_matrix[1,0] += 1
                    batch_confusion_matrix[1,0] += 1
            else:
                # FP 
                micro_confusion_matrix[1,0] += 1
                batch_confusion_matrix[1,0] += 1

        print(tp_matches)        
        micro_confusion_matrix[0,1] += np.sum(~tp_matches)
        batch_confusion_matrix[0,1] += np.sum(~tp_matches)

        macro_confusion_matrix.append(batch_confusion_matrix)

        print(micro_confusion_matrix)


In [40]:
batch_size = 2
ent_vocab_size = 8
import torch.nn.functional as F

metric_list = ""
mention_preds = torch.tensor(np.random.rand(batch_size, 5, 3))
# mention_preds = F.softmax(torch.tensor(mention_preds), dim=2)
print(mention_preds)
entity_preds = torch.tensor(np.random.rand(batch_size, 5, 5))
bio_tags = torch.tensor(np.random.randint(0,3, size=(batch_size, 5)))
print(bio_tags.shape)
print(bio_tags)
entity_ids = torch.tensor(np.random.randint(ent_vocab_size, size=(batch_size, 5)))
print(entity_ids)
pretrained_entity_embedding = torch.tensor(np.random.rand(ent_vocab_size, 5))
print(pretrained_entity_embedding.shape)
# mention_pred Size([27, 256, 3])
# entity_pred torch.Size([27, 256, 256])
# bio_tags torch.Size([27, 256])
# entity_ids torch.Size([27, 256])

tensor([[[0.0280, 0.4551, 0.4349],
         [0.7973, 0.9785, 0.9968],
         [0.7352, 0.4448, 0.7065],
         [0.1659, 0.4560, 0.9020],
         [0.4978, 0.8070, 0.2581]],

        [[0.2938, 0.4054, 0.2894],
         [0.7827, 0.5111, 0.3319],
         [0.8681, 0.3724, 0.9933],
         [0.0949, 0.5764, 0.4387],
         [0.0866, 0.0764, 0.4486]]], dtype=torch.float64)
torch.Size([2, 5])
tensor([[1, 0, 2, 1, 1],
        [2, 0, 0, 1, 0]])
tensor([[5, 1, 2, 2, 4],
        [6, 1, 5, 7, 0]])
torch.Size([8, 5])


In [41]:
# bio_tags[1,0] = 2
# bio_tags[1,4] = 2
entity_ids[0,2] = -1
entity_ids[1,1] = -1
entity_ids[1, 3] = -1
entity_ids[1,4] = -1
print(bio_tags)
entity_ids

tensor([[1, 0, 2, 1, 1],
        [2, 0, 0, 1, 0]])


tensor([[ 5,  1, -1,  2,  4],
        [ 6, -1,  5, -1, -1]])

In [42]:
compute_metrics(metric_list, mention_preds, entity_preds, bio_tags, entity_ids, pretrained_entity_embedding)

TypeError: compute_metrics() takes 5 positional arguments but 6 were given

In [None]:
x = torch.tensor([[0, 0, 1, 0, 0],[0, 2, 0, 2, 2]]) ; print(x)
y = torch.tensor([[1, 0, 2, 2, 1], [0, 2, 0, 1, 2]]) ; y

In [None]:
x_span = get_bi_spans(x)
y_span = get_bi_spans(y)
print(x_span)
print(y_span)

In [None]:
print(y_span[0][0,:].shape)
print(x_span[0].shape)
# (y_span[0][0,:] ==  x_span[0]).shape
torch.all(torch.tensor(y_span[0][0,:] == x_span[0]), dim=1).any()

In [None]:
if(np.array([]).size <= 0):
    print("X")

In [None]:
x = torch.tensor([-0.3616, -0.6805])
torch.sum(x)/2