In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# import requi9red module
import sys
  
# append the path of the
# parent directory
sys.path.append("..")

In [3]:
import json
import torch 
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import CrossEntropyLoss, NLLLoss
import torch.nn.functional as F 
from torch import autograd

import numpy as np 

from transformers.models.luke.configuration_luke import LukeConfig
from transformers import AutoTokenizer, AutoModel, AutoConfig

from entitylinker.lukemodel import LukeModel 
from entitylinker.model_utils import *

from entitylinker.dataset import MedMentionsDataset, Collater
from entitylinker.model import EntityLinker 

In [4]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

if torch.cuda.is_available():
    
    print(torch.cuda.current_device())
    print(torch.cuda.device(0))
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(0))


0
<torch.cuda.device object at 0x2b5a2572cee0>
1
Quadro RTX 5000


In [5]:
input_data_file = "/home/vs428/project/MedMentions/full/data/corpus_pubtator.txt"
train_pmids_file = "/home/vs428/project/MedMentions/full/data/corpus_pubtator_pmids_trng.txt"
test_pmids_file = "/home/vs428/project/MedMentions/full/data/corpus_pubtator_pmids_test.txt"

entity_vocab_file = "/home/vs428/project/MedMentions/full/pretraining3/entity_vocab.jsonl"
entity_embedding_model = "/home/vs428/project/MedMentions/full/pretraining3//model_epoch20.bin"
entity_embedding_metadata = "/home/vs428/project/MedMentions/full/pretraining3//metadata.json"

MAX_LENGTH = 128
MODEL_NAME = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'

In [6]:
# Create Tokenizer and DataLoader
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

collater = Collater(tokenizer, max_length = MAX_LENGTH)
train_dataset = MedMentionsDataset(input_data_file, train_pmids_file, tokenizer, entity_vocab_file, 
                             max_length = MAX_LENGTH, stride=0, first_token_ent_only=True)
test_dataset = MedMentionsDataset(input_data_file, test_pmids_file, tokenizer, entity_vocab_file, 
                             max_length = MAX_LENGTH, stride=0, first_token_ent_only=True)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collater)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=collater)
# tokens, bio_tags, entity_ids = next(iter(train_dataloader))


In [7]:

# define optimizer, model, loss etc. 


# with open(entity_embedding_metadata) as f:
#     luke_metadata = json.loads(f.read())

# luke_config = luke_metadata['model_config']
# bert_config = AutoConfig.from_pretrained(MODEL_NAME)

# config = LukeConfig(
#     entity_vocab_size=luke_config['entity_vocab_size'],
#     bert_model_name=luke_config['bert_model_name'],
#     entity_emb_size=luke_config['entity_emb_size'],
#     **bert_config.to_dict(),
# )

# luke_model = LukeModel(config)
# pretrained_entity_embeddings = luke_model.load_state_dict(torch.load(entity_embedding_model, map_location=torch.device('cpu')))
# print(pretrained_entity_embeddings)


In [8]:
# Retrieve pretrained entity embeddings

model_archive = ModelArchive.load(entity_embedding_model)

model_archive.tokenizer
model_archive.entity_vocab
model_archive.bert_model_name
model_archive.config
model_archive.max_mention_length
pretrained_entity_embeddings = model_archive.state_dict['entity_embeddings.entity_embeddings.weight']
pretrained_entity_embeddings = pretrained_entity_embeddings.to(device)

In [9]:
def sim_matrix(a, b, eps=1e-8):
    """
    works on a 3D and 2D array
    """
    a_n, b_n = torch.linalg.norm(a, dim=-1)[:, :, None], torch.linalg.norm(b, dim=-1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.matmul(a_norm, b_norm.transpose(0, 1))        
    return sim_mt        


In [10]:
# define loss function
def mention_entity_loss(mention_pred, entity_pred, bio_tags, entity_ids, 
                        attention_mask, pretrained_entity_embedding, device, lm=0.1):
    # attention_mask torch.Size([27,256])
    # mention_pred torch.Size([27, 256, 3])
    # entity_pred torch.Size([27, 256, 256])
    # bio_tags torch.Size([27, 256])
    # entity_ids torch.Size([27, 256])
    
    # unpadded_men_pred torch.Size([25, 3, 256, 256])
    # unpadded_bio_tags torch.Size([25, 256, 256])    
    
    ### MENTION LOSS

    # first compute all loss
    mention_loss = F.nll_loss(mention_pred.permute(0,2,1).contiguous(), bio_tags, reduction="none")

    # get only the unpadded losses: batch size
    num_unpadded = torch.sum(attention_mask, dim=1)
    masked_mention_loss = torch.where(attention_mask == 1, mention_loss, torch.tensor(0.).float().to(device))
    
    # compute average of the masked loss
    avg_mention_loss = torch.sum(masked_mention_loss, dim=1) / num_unpadded
    avg_mention_loss = torch.sum(avg_mention_loss)/len(avg_mention_loss)

#     print("MENTION LOSS: ", avg_mention_loss)

    
    ### ENTITY LOSS
#     cos_sim = sim_matrix(entity_pred, pretrained_entity_embedding)
#     print("cos_sim", cos_sim.shape)
    # cos_sim torch.Size([27, 256, 34727])
    
    # get all the IDs in entity_ids that aren't -1
    nonzero_ent_ids = entity_ids[entity_ids!=-1]
    # get the true entity embeddings
    # shape: [nonzero_ent_ids, embed_dim]
    true_ent_embeddings = pretrained_entity_embedding[nonzero_ent_ids, :]
    
    # get non-neg1 idxs 
    ent_idxs = (entity_ids != -1).nonzero()
    real_entity_pred = entity_pred[ent_idxs[:, 0], ent_idxs[:, 1], :]    
#     print("true_ent_embeddings", true_ent_embeddings.shape)
#     print(real_entity_pred.shape)
    entity_loss = torch.sum(1 - F.cosine_similarity(true_ent_embeddings, real_entity_pred, dim=1))/true_ent_embeddings.shape[1]
#     print("ENTITY LOSS", entity_loss)
#     max_ents = torch.max(cos_sim, dim=2)
    
    # target_ent_embeddings = pretrained_entity_embedding[entity_ids, :]
    # target_ent_embeddings torch.Size([27, 256, 256])
    # max_ent_idxs = torch.argmax(cos_sim, dim=2)
#     max_ents = self.entity_embedding[max_ent_idxs, :]
    
    
#     entity_pred[ent_idxs]
#     entity_loss = torch.mean((output - target)**2)
    
    

    loss = (lm * avg_mention_loss) + ((1-lm) * entity_loss)
    return loss

In [11]:
# define model, etc. 
model = EntityLinker(model_name='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract',
                     pretrained_entity_embeddings=pretrained_entity_embeddings )

model.to(device)


optimizer = optim.Adam(model.parameters(), lr=0.00002)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
# for name, param in model.named_parameters():
#     if param.requires_grad:
#     print(name, param.requires_grad)

In [13]:
def get_bi_spans(batch_tags):
    """Returns a list of N x 2 numpy arrays"""
    batch_tags = batch_tags.cpu().detach().numpy()
    bi_spans = []
    for batch_idx in range(batch_tags.shape[0]):
        b_tags = (batch_tags[batch_idx, :] == 0).astype(int)
        i_tags = (batch_tags[batch_idx, :] == 1).astype(int)
        
        # if there aren't any beginning tags, just return an empty array
        if not b_tags.any():
            bi_spans.append(np.array([]))
            continue
        
        # get all B tag locations
        b_idx, = b_tags.nonzero() 

        # get idxs for I tags
        d = np.diff(i_tags)
        i_idx, = d.nonzero()

        # We need to start things after the change in "condition". Therefore, 
        # we'll shift the index by 1 to the right.
        i_idx += 1

        # add end idxs for all B tags (just interweave +1 of b_idxs)
        b_idx = np.vstack((b_idx,b_idx+1)).reshape((-1,),order='F')    

        if i_tags[0]:
            # If the start of condition is True prepend a 0
            i_idx = np.r_[0, i_idx]

        if i_tags[-1]:
            # If the end of condition is True, append the length of the array
            i_idx = np.r_[i_idx, i_tags.size] # Edit

        # reshape to idxs
        b_idx.shape = (-1,2)
        i_idx.shape = (-1,2)    

        bi_idx = []
        # combine the b and i tags
        for start, stop in b_idx:
            # get idx where I tags start for each B tag, if exists
            i, = np.where(i_idx[:, 0] == stop) 
            if i.size > 0:
                bi_idx.append([start, int(i_idx[i, 1])])
            else:
                bi_idx.append([start, stop])

        bi_spans.append(np.array(bi_idx))

    assert batch_tags.shape[0] == len(bi_spans), f"{batch_tags.shape[0]},{len(bi_spans)},{b_idx}"
    return bi_spans


In [14]:
def compute_conf_matrices(mention_preds, entity_preds, bio_tags, entity_ids, pretrained_entity_embedding):
    '''Computes both macro and micro confusion matrices given torch tensor outputs from our model
    '''
    # mention_pred Size([27, 256, 3])
    # entity_pred torch.Size([27, 256, 256])

    nb_classes = 2
    micro_confusion_matrix = torch.zeros(nb_classes, nb_classes)
    # micro_confusion_matrix [[TP, FN], [FP, TN]]
        
    macro_confusion_matrix = []        
    
    # first we need to find the max of the BIO tagging we have
    pred_bio_tags = torch.argmax(mention_preds, dim=2)
    # pred_bio_tags torch.Size([28, 256])
    
    # create two lists of spans
    pred_span_list = get_bi_spans(pred_bio_tags)
    target_span_list = get_bi_spans(bio_tags)    

    for batch_idx, (pred_spans, target_spans) in enumerate(zip(pred_span_list, target_span_list)):
        tp_matches = 0
        
        batch_confusion_matrix = torch.zeros(nb_classes, nb_classes)

        # do the STRONG MATCHING here

        # first check if there are any predicted spans, if not, we just add a bunch of FNs
        if pred_spans.size <= 0 or pred_spans.shape[0] == 0:
            micro_confusion_matrix[0,1] += target_spans.shape[0]
            batch_confusion_matrix[0,1] += target_spans.shape[0]
            
        # conversely, if the target spans are empty, we add a bunch of FPs instead
        elif target_spans.size <= 0:
            micro_confusion_matrix[1,0] += pred_spans.shape[0]
            batch_confusion_matrix[1,0] += pred_spans.shape[0]            
            
        else:
            for span_idx in range(pred_spans.shape[0]):
                
                if torch.all(torch.tensor(pred_spans[span_idx, :] == target_spans), dim=1).any():
                    seq_idx = pred_spans[span_idx, 0]
                    # we use the start span idx to get the entity embedding and compare it to the pretrained embeds

                    cos_sim = sim_matrix(entity_preds[batch_idx, seq_idx, :].unsqueeze(0).unsqueeze(0), 
                                        pretrained_entity_embedding)
                    # this should return something of size [1, 1, sim_vals]
                    pred_ent_embed_id = torch.argmax(torch.squeeze(cos_sim))
                    target_ent_embed_id = entity_ids[batch_idx, seq_idx]
                    # just to test: 
                    # TODO: REMOVE
    #                 if batch_idx == 1 and seq_idx == 0:
    #                     target_ent_embed_id = 2
                    # TODO: ENDREMOVE
                    if pred_ent_embed_id == target_ent_embed_id:
                        # TP
                        micro_confusion_matrix[0,0] += 1
                        batch_confusion_matrix[0,0] += 1 
                        tp_matches += 1
                    else:
                        # FP
                        micro_confusion_matrix[1,0] += 1
                        batch_confusion_matrix[1,0] += 1
                else:
                    # FP 
                    micro_confusion_matrix[1,0] += 1
                    batch_confusion_matrix[1,0] += 1
        
        # all FNs 
        micro_confusion_matrix[0,1] += target_spans.shape[0] - tp_matches
        batch_confusion_matrix[0,1] += target_spans.shape[0] - tp_matches
        
        # we don't do TNs since we don't need them for P,R,F1
        
        # add on the batch confusion matrix to the batch
        macro_confusion_matrix.append(batch_confusion_matrix)

    return micro_confusion_matrix, macro_confusion_matrix


In [15]:
def compute_metrics(metrics, micro_conf, macro_confs):
    '''Given both the micro and macro confusion matrices, we compute a given set of metrics
    
    Matrix orientation is: 
    |  TP  |  FN |
    |-------------
    |  FP  |  TN |
    '''
    micro_metrics = {}
    macro_metrics = {}
    
    if "precision" in metrics:
        micro_metrics['precision'] = micro_conf[0,0] / (micro_conf[0,0] + micro_conf[1,0])
        macro_metrics['precision'] = torch.mean(torch.tensor([batch_conf[0,0] / (batch_conf[0,0] + batch_conf[1,0]) for batch_conf in macro_confs]))
    if "recall" in metrics:
        micro_metrics['recall'] = micro_conf[0,0] / (micro_conf[0,0] + micro_conf[0,1])
        macro_metrics['recall'] = torch.mean(torch.tensor([batch_conf[0,0] / (batch_conf[0,0] + batch_conf[0,1]) for batch_conf in macro_confs]))
    if "f1" in metrics:
        micro_metrics['f1'] = micro_conf[0,0] / (micro_conf[0,0] + ((1/2) * (micro_conf[0,1] + micro_conf[1,0])))
        macro_metrics['f1'] = torch.mean(torch.tensor([batch_conf[0,0] / (batch_conf[0,0] + ((1/2) * (batch_conf[0,1] + batch_conf[1,0]))) for batch_conf in macro_confs]))
        
    return micro_metrics, macro_metrics


In [16]:
# Train!!
def train(model, train_loader, optimizer, criterion):

    model.train()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        tokenized_text, bio_tags, entity_ids = data
        bio_tags = bio_tags.to(device)
        entity_ids = entity_ids.to(device)
#         print("bio_tags", bio_tags.shape)
#         print("entity_ids", entity_ids.shape)
        # zero the parameter gradients
        optimizer.zero_grad()
#         with autograd.detect_anomaly():

        # forward + backward + optimize
        input_ids = tokenized_text['input_ids'].to(device)
        token_type_ids = tokenized_text['token_type_ids'].to(device)
        attention_mask = tokenized_text['attention_mask'].to(device)
                          
        mention_pred, entity_pred = model(input_ids,token_type_ids,attention_mask)
#         print("ent pred", entity_pred.shape)
        loss = criterion(mention_pred, entity_pred, bio_tags, entity_ids, attention_mask, 
                         pretrained_entity_embeddings, device)
#         print(loss)
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if i % 5 == 4:    # print every 2000 mini-batches
            print('[%d, %5d] train loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 5))
            running_loss = 0.0
#         if i == 2:
#             break
    print('Finished Training')

In [22]:
def evaluate(model, test_loader, optimizer, criterion):
    
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        running_loss = 0.0
        
        for i, data in enumerate(test_loader, 0):
            
            tokenized_text, bio_tags, entity_ids = data
            bio_tags = bio_tags.to(device)
            entity_ids = entity_ids.to(device)
            
    #         with autograd.detect_anomaly():

            # forward + backward + optimize
            input_ids = tokenized_text['input_ids'].to(device)
            token_type_ids = tokenized_text['token_type_ids'].to(device)
            attention_mask = tokenized_text['attention_mask'].to(device)

            mention_pred, entity_pred = model(input_ids,token_type_ids,attention_mask)
    #         print("ent pred", entity_pred.shape)
            loss = criterion(mention_pred, entity_pred, bio_tags, entity_ids, 
                            attention_mask, pretrained_entity_embeddings, device)
    #         print(loss)


            micro_conf_mat, macro_conf_mat = compute_conf_matrices(mention_pred, entity_pred, bio_tags, entity_ids, pretrained_entity_embeddings)
            micro_metrics, macro_metrics = compute_metrics(["precision", "recall","f1"], micro_conf_mat, macro_conf_mat)
            print('-' * 89)
            print(f"Micro-Precision: {micro_metrics['precision']}, Macro-Precision: {macro_metrics['precision']}\n")
            print(f"Micro-Recall: {micro_metrics['recall']}, Macro-Recall: {macro_metrics['recall']}\n")
            print(f"Micro-F1: {micro_metrics['f1']}, Macro-F1: {macro_metrics['f1']}\n")    
            print('-' * 89)


            running_loss += loss.item()
            total_loss += loss.item()
            if i % 5 == 4:    # print every 2000 mini-batches
                print('[%d, %5d] test loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 5))
                running_loss = 0.0
        # whatever i is at is the len of test_loader
        return total_loss / i

In [24]:
num_epochs = 5
import time
import math 

# try:
for epoch in range(num_epochs):  # loop over the dataset multiple times
    epoch_start_time = time.time()
    train(model, train_dataloader, optimizer, mention_entity_loss)
    val_loss = evaluate(model, test_dataloader, optimizer, mention_entity_loss)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(epoch, 
                                      (time.time() - epoch_start_time),
                                       val_loss, math.exp(val_loss)))
    print('-' * 89)        
# except Exception as e:
#     print("An exception occurred", e) 
    

[1,     5] loss: 6.795
[1,    10] loss: 6.572
[1,    15] loss: 6.732
[1,    20] loss: 6.706
[1,    25] loss: 6.600
[1,    30] loss: 6.564
[1,    35] loss: 6.701
[1,    40] loss: 6.623
[1,    45] loss: 6.565
[1,    50] loss: 6.784
[1,    55] loss: 6.433
[1,    60] loss: 6.522
[1,    65] loss: 6.760
[1,    70] loss: 6.644
[1,    75] loss: 6.451
[1,    80] loss: 6.339
Finished Training
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2070000022649765, Macro-Precision: 0.19737441837787628

Micro-Recall: 0.2449704110622406, Macro-Recall: 0.22900573909282684

Micro-F1: 0.22439023852348328, Macro-F1: 0.20945365726947784

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.1898948848247528, Macro-Precision: 0.19537004828453064

Micro-Recall: 0.22125642001628876, Macro-Recall: nan

Micro-F1:

-----------------------------------------------------------------------------------------
Micro-Precision: 0.21802857518196106, Macro-Precision: 0.22363512217998505

Micro-Recall: 0.23962947726249695, Macro-Recall: nan

Micro-F1: 0.22831925749778748, Macro-F1: 0.2247065305709839

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2053571492433548, Macro-Precision: 0.2081628292798996

Micro-Recall: 0.2524698078632355, Macro-Recall: nan

Micro-F1: 0.22648940980434418, Macro-F1: 0.21811993420124054

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.19710859656333923, Macro-Precision: 0.19016139209270477

Micro-Recall: 0.23418517410755157, Macro-Recall: 0.22046706080436707

Micro-F1: 0.214053228497

-----------------------------------------------------------------------------------------
Micro-Precision: 0.23054379224777222, Macro-Precision: 0.2360098659992218

Micro-Recall: 0.29075974225997925, Macro-Recall: nan

Micro-F1: 0.25717398524284363, Macro-F1: 0.24991273880004883

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.24352671205997467, Macro-Precision: 0.23651044070720673

Micro-Recall: 0.2991143465042114, Macro-Recall: nan

Micro-F1: 0.268473356962204, Macro-F1: 0.25186842679977417

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.257432222366333, Macro-Precision: 0.24168016016483307

Micro-Recall: 0.30926215648651123, Macro-Recall: nan

Micro-F1: 0.2809770107269287, Macro-F1: 0.

-----------------------------------------------------------------------------------------
Micro-Precision: 0.27583643794059753, Macro-Precision: nan

Micro-Recall: 0.3206568658351898, Macro-Recall: nan

Micro-F1: 0.2965627610683441, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2989655137062073, Macro-Precision: nan

Micro-Recall: 0.338143527507782, Macro-Recall: 0.3173072934150696

Micro-F1: 0.31734994053840637, Macro-F1: 0.29473164677619934

-----------------------------------------------------------------------------------------
[3,    10] train loss: 6.137
-----------------------------------------------------------------------------------------
Micro-Precision: 0.23402909934520721, Macro-Precision: nan

Micro-Recall: 0.28582465648651123, Macro-Recall: nan

Micro-F1: 0.2573465406894684, Macro-F1: nan

----------------

-----------------------------------------------------------------------------------------
Micro-Precision: 0.2523461878299713, Macro-Precision: 0.25624117255210876

Micro-Recall: 0.3058129847049713, Macro-Recall: nan

Micro-F1: 0.2765187621116638, Macro-F1: 0.26488667726516724

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2629427909851074, Macro-Precision: 0.2652309536933899

Micro-Recall: 0.30720254778862, Macro-Recall: nan

Micro-F1: 0.2833547294139862, Macro-F1: 0.27405864000320435

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2319551557302475, Macro-Precision: 0.24153731763362885

Micro-Recall: 0.27687159180641174, Macro-Recall: nan

Micro-F1: 0.25243088603019714, Macro-F1: 0.246

-----------------------------------------------------------------------------------------
Micro-Precision: 0.22533711791038513, Macro-Precision: 0.2319248765707016

Micro-Recall: 0.2626137435436249, Macro-Recall: nan

Micro-F1: 0.24255156517028809, Macro-F1: 0.2389853298664093

-----------------------------------------------------------------------------------------
[4,    25] train loss: 6.264
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2652428448200226, Macro-Precision: 0.26159578561782837

Micro-Recall: 0.3087409734725952, Macro-Recall: nan

Micro-F1: 0.2853437066078186, Macro-F1: 0.2680239677429199

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2682165205478668, Macro-Precision: 0.2577258348464966

Micro-Recall: 0.3165438175201416, Macro-Recall: nan

Micro-F1: 0.29038

-----------------------------------------------------------------------------------------
Micro-Precision: 0.24384485185146332, Macro-Precision: 0.23359854519367218

Micro-Recall: 0.3017529249191284, Macro-Recall: nan

Micro-F1: 0.2697257995605469, Macro-F1: 0.2430112659931183

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.29777488112449646, Macro-Precision: nan

Micro-Recall: 0.3276917636394501, Macro-Recall: nan

Micro-F1: 0.3120178282260895, Macro-F1: nan

-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
Micro-Precision: 0.2517586052417755, Macro-Precision: 0.23856154084205627

Micro-Recall: 0.30276045203208923, Macro-Recall: nan

Micro-F1: 0.27491408586502075, Macro-F1: 0.25013214349746704

------------

In [None]:
x = torch.randn(2,3,5)
m = torch.nn.Softmax(dim=2)
print(x.shape)
m(x)
# torch.nn.Softmax(x, dim=0)

In [None]:
import torch.nn as nn
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
cos(input1, input2).shape

In [None]:
# in1 shape [batch_size, seq_len, embed_size]
# in1 shape [16, 256, 256]
# in2 shape [1, entity_vocab, embed_size]
# in2 shape [37456, 256]
import torch
from scipy import spatial
import numpy as np

In [None]:
a = torch.randn(16, 256, 256)
# a = torch.randn(3, 5)
b = torch.randn(37456, 256) # different row number, for the fun

In [None]:
%%timeit
# Given that cos_sim(u, v) = dot(u, v) / (norm(u) * norm(v))
#                          = dot(u / norm(u), v / norm(v))
# We fist normalize the rows, before computing their dot products via transposition:
a_norm = a / a.norm(dim=2)[:, :, None]
# a_norm = a / a.norm(dim=1)[:, None]
b_norm = b / b.norm(dim=1)[:, None]
res = torch.matmul(a_norm, b_norm.transpose(0,1))
# print(res.shape)
print(res)
#  0.9978 -0.9986 -0.9985
# -0.8629  0.9172  0.9172

In [None]:
# %%timeit
# # -------
# # Let's verify with numpy/scipy if our computations are correct:
# a_n = a.numpy()
# b_n = b.numpy()
# res_n = np.zeros((3,5,50))

# for i in range(3):
#     for j in range(5):
#         for k in range(50):
#             # cos_sim(u, v) = 1 - cos_dist(u, v)
#             res_n[i, j, k] = 1 - spatial.distance.cosine(a_n[i, j], b_n[k])
# print(res_n)
# # [[ 0.9978022  -0.99855876 -0.99854881]
# #  [-0.86285472  0.91716063  0.9172349 ]]


from sklearn.metrics.pairwise import cosine_similarity
a_n = a.numpy()
b_n = b.numpy()

res_n = np.zeros((16,256, 37456))
for i in range(3):
    res_n[i, :, :] = cosine_similarity(a_n[i,:,:], b_n)
res_n

In [None]:
# a_n.shape
b_n.shape

In [None]:
matrix1 = np.array([[[1,1,1],[1,1,1],[1,2,1]],[[1,2,1],[1,1,1],[1,1,1]]])
matrix2 = np.array([[1,1,1],[1,2,1], [1,2,1], [1,2,1]])
print(matrix1.shape)
print(matrix2.shape)

In [None]:
import scipy.spatial as sp
out = np.zeros((2,3,4))
for i in range(2):
    out[i, :, :] = cosine_similarity(matrix1[i], matrix2, 'cosine')

In [None]:
out.shape

In [None]:
out

In [None]:
import numpy as np
z = torch.randint(-1,3, size=(10,)) 
print(z)


In [None]:
x = torch.tensor([[1,4,6],[3,5,6]])
embedding = torch.randint(10, size=(10, 10))
embedding

In [None]:
embedding[x, :]

In [None]:
x = torch.randint(-1, 5, size  = (10,10))
idxs = (x == -1).nonzero()
print(x)
idxs

# new_arr_0 = x[x!=-1]
# print(idxs)
# new_arr_0

# ##### 
# x = torch.randint(-1, 5, size  = (5,10)).float()
# y = torch.randint(-1, 5, size  = (5,10)).float()
# import torch.nn.functional as F
# F.cosine_similarity(x, y, dim=1)

In [None]:
y = torch.randint(10, size=(10, 10, 10))
y[idxs[:, 0], idxs[:, 1]].shape
y[idxs[:, 0], idxs[:, 1]]

In [None]:
import torch.nn as nn
# 2D loss example (used, for example, with image inputs)
N, C = 5, 4
loss = nn.NLLLoss()
# input is of size N x C x height x width
data = torch.randn(N, 16, 10, 10)
conv = nn.Conv2d(16, C, (3, 3))
m = nn.LogSoftmax(dim=1)
# each element in target has to have 0 <= value < C
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
output = loss(m(conv(data)), target)
output.backward()

In [None]:
print(target.shape)
m(conv(data)).shape

In [None]:
# attention_mask torch.Size([27,256])
# mention_pred torch.Size([27, 256, 3])
# entity_pred torch.Size([27, 256, 256])
# bio_tags torch.Size([27, 256])
# entity_ids torch.Size([27, 256])

# unpadded_men_pred torch.Size([25, 3, 256, 256])
# unpadded_bio_tags torch.Size([25, 256, 256])    


attention_mask = torch.randint(0,2,size=(3,5))
# mention_pred = torch.randn(3,5)
losses = torch.randn(3,5)
# bio_tags = torch.randn(3,5)
print(attention_mask)
print(losses)
# print(bio_tags)
num_unpadded = torch.sum(attention_mask, dim=1)
print("num_unpadded",num_unpadded)
x = torch.where(attention_mask == 1, losses, torch.tensor(0.).float())
print(torch.sum(x, dim=1))
torch.sum(torch.sum(x, dim=1)/num_unpadded)/len(x)
# print(torch.sum(x, dim=1))
# torch.sum(torch.sum(x, dim=1) / num_unpadded)

#     unpadded_men_pred = mention_pred[attention_mask,:].permute(0,3,1,2).contiguous()
# mention_pred[attention_mask,:]

    
    
    

In [None]:
import numpy as np
# Construct a random 50x50 RGB image    
image = np.random.random((50, 50, 3))

# Construct mask according to some condition;
# in this case, select all pixels with a red value > 0.3
mask = image[..., 0] > 0.3
mask.shape

In [None]:
def compute_metrics(mention_preds, entity_preds, bio_tags, entity_ids, pretrained_entity_embedding):
    '''Computes both macro and micro confusion matrices given torch tensor outputs from our model
    '''
    # mention_pred Size([27, 256, 3])
    # entity_pred torch.Size([27, 256, 256])

    nb_classes = 2
    micro_confusion_matrix = torch.zeros(nb_classes, nb_classes)
    # micro_confusion_matrix [[TP, FN], [FP, TN]]
#     for t, p in zip(classes.view(-1), preds.view(-1)):
#         confusion_matrix[t.long(), p.long()] += 1
        
    macro_confusion_matrix = []        
    
    # first we need to find the max of the BIO tagging we have
    pred_bio_tags = torch.argmax(mention_preds, dim=2)
    # pred_bio_tags torch.Size([28, 256])
    
    # TODO: Remove after testing
    # pred_bio_tags = torch.randint(0,2, size=(28,256))

    # create two lists of spans
    pred_span_list = get_bi_spans(pred_bio_tags)
    target_span_list = get_bi_spans(bio_tags)    
    
    for batch_idx, (pred_spans, target_spans) in enumerate(zip(pred_span_list, target_span_list)):
        tp_matches = np.zeros(target_spans.shape[0], dtype=bool)
        
        batch_confusion_matrix = torch.zeros(nb_classes, nb_classes)

        # do the STRONG MATCHING here
        for span_idx in range(pred_spans.shape[0]):
            if torch.all(torch.tensor(pred_spans[span_idx, :] == target_spans), dim=1).any():
                seq_idx = pred_spans[span_idx, 0]
                # we use the start span idx to get the entity embedding and compare it to the pretrained embeds
#                 print(entity_preds[batch_idx, seq_idx, :].shape)
                cos_sim = sim_matrix(entity_preds[batch_idx, seq_idx, :].unsqueeze(0).unsqueeze(0), 
                                     pretrained_entity_embedding)
                # this should return something of size [1, 1, sim_vals]
                pred_ent_embed_id = torch.argmax(torch.squeeze(cos_sim))
                target_ent_embed_id = entity_ids[batch_idx, seq_idx]
                # just to test: 
                # TODO: REMOVE
                if batch_idx == 1 and seq_idx == 0:
                    target_ent_embed_id = 2
                # TODO: ENDREMOVE
                if pred_ent_embed_id == target_ent_embed_id:
                    # TP
                    micro_confusion_matrix[0,0] += 1
                    batch_confusion_matrix[0,0] += 1 
                    tp_matches[span_idx] = True
                else:
                    # FP
                    micro_confusion_matrix[1,0] += 1
                    batch_confusion_matrix[1,0] += 1
            else:
                # FP 
                micro_confusion_matrix[1,0] += 1
                batch_confusion_matrix[1,0] += 1

        print(tp_matches)        
        micro_confusion_matrix[0,1] += np.sum(~tp_matches)
        batch_confusion_matrix[0,1] += np.sum(~tp_matches)

        macro_confusion_matrix.append(batch_confusion_matrix)

        print(micro_confusion_matrix)


In [None]:
batch_size = 2
ent_vocab_size = 8
import torch.nn.functional as F

metric_list = ""
mention_preds = torch.tensor(np.random.rand(batch_size, 5, 3))
# mention_preds = F.softmax(torch.tensor(mention_preds), dim=2)
print(mention_preds)
entity_preds = torch.tensor(np.random.rand(batch_size, 5, 5))
bio_tags = torch.tensor(np.random.randint(0,3, size=(batch_size, 5)))
print(bio_tags.shape)
print(bio_tags)
entity_ids = torch.tensor(np.random.randint(ent_vocab_size, size=(batch_size, 5)))
print(entity_ids)
pretrained_entity_embedding = torch.tensor(np.random.rand(ent_vocab_size, 5))
print(pretrained_entity_embedding.shape)
# mention_pred Size([27, 256, 3])
# entity_pred torch.Size([27, 256, 256])
# bio_tags torch.Size([27, 256])
# entity_ids torch.Size([27, 256])

In [None]:
# bio_tags[1,0] = 2
# bio_tags[1,4] = 2
entity_ids[0,2] = -1
entity_ids[1,1] = -1
entity_ids[1, 3] = -1
entity_ids[1,4] = -1
print(bio_tags)
entity_ids

In [None]:
compute_metrics(metric_list, mention_preds, entity_preds, bio_tags, entity_ids, pretrained_entity_embedding)

In [None]:
x = torch.tensor([[0, 0, 1, 0, 0],[0, 2, 0, 2, 2]]) ; print(x)
y = torch.tensor([[1, 0, 2, 2, 1], [0, 2, 0, 1, 2]]) ; y

In [None]:
x_span = get_bi_spans(x)
y_span = get_bi_spans(y)
print(x_span)
print(y_span)

In [None]:
print(y_span[0][0,:].shape)
print(x_span[0].shape)
# (y_span[0][0,:] ==  x_span[0]).shape
torch.all(torch.tensor(y_span[0][0,:] == x_span[0]), dim=1).any()

In [None]:
if(np.array([]).size <= 0):
    print("X")

In [27]:
x = torch.tensor([-0.3616, -0.6805])
torch.sum(x)/2

tensor(-0.5210)