In [None]:
import PyPDF2
import stanfordnlp
import pickle
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import *
from tqdm import tqdm
import re
import random
import math
import matplotlib.pyplot as plt

In [None]:
_ = torch.manual_seed(0)

In [None]:
class TransformerSentence():
    def __init__(self, sentence_str, 
                 model=BertModel.from_pretrained('scibert-scivocab-uncased'), 
                 tokenizer=BertTokenizer.from_pretrained('scibert-scivocab-uncased')):
        
        self.raw_string = sentence_str
        self.model = model
        self.tokenizer = tokenizer
        self.summary = {}

        
    def write_summary(self, input_tokens=None, 
                      hidden_states=None, 
                      hidden_attentions=None,
                      print_tokens=True):
        
        if (input_tokens or hidden_states or hidden_attentions) is None:
            input_tokens, hidden_states, hidden_attentions = self.forward()
        
        # this replaces adds a "_{counter}" to the repreated tokens, so that 
        # they can be used uniquely as the keys for the embeddings dictionary
        input_tokens = TransformerSentence.make_unique(input_tokens)
        
        if print_tokens:
            print('Sentence Tokenization: ', input_tokens)
            
        # write summary into the object
        self.summary['input_tokens'] = input_tokens
        self.summary['states'] = hidden_states
        self.summary['attentions'] = hidden_attentions

        self.summary['token_embeddings'] = {input_token: hidden_states[:, i, :] 
                                            for i, input_token in enumerate(input_tokens)}
        
    def forward(self):
        encoded_inputs_dict = self.tokenizer.encode_plus(self.raw_string)
        input_ids = encoded_inputs_dict['input_ids']
        input_tensor = torch.tensor([input_ids])
        input_tokens = [self.tokenizer.decode(input_ids[j]).replace(' ', '') 
                        for j in range(len(input_ids))]
        
        final_attention, final_state, hidden_states_tup, hidden_attentions_tup = self.model(input_tensor)
        
        # stacking states and attentions along the first dimention (which corresponds to the batch when necessary)
        hidden_attentions = torch.cat(hidden_attentions_tup, dim=0) # 'layers', 'heads', 'queries', 'keys'
        hidden_states = torch.cat(hidden_states_tup, dim=0) # 'layers', 'tokens', 'embeddings'
        
        return input_tokens, hidden_states.detach(), hidden_attentions.detach()
    
    
    def attention_from_tokens(self, token1, token2, display=True):
        input_tokens = self.summary['input_tokens']
        
        if (token1 and token2) not in input_tokens:
            raise ValueError('One or both of the tokens introduced are not in the sentence!')
            
        idx1, idx2 = input_tokens.index(token1), input_tokens.index(token2)
        attention = self.summary['attentions'][:, :, idx1, idx2].numpy()
        if display:
            TransformerSentence.display_attention(attention, title=(token1, token2))
        return attention
    
    
    def attention_from_idx(self, i, j, display=True):
        attention = self.summary['attentions'][:, :, i, j].numpy()
        if display:
            TransformerSentence.display_attention(attention, title=f'Token idx: {(i, j)}')
        return attention
    
    
    def save(self, name, path='.'):
        with open(os.path.join(path, name), 'wb') as file:
            pickle.dump(self, file)
    
    
    @staticmethod
    def visualize_embedding(embedding, title=None, vmax=None, vmin=None):
        if (vmax or vmin) is None:
            vmax = max(embedding)
            vmin = min(embedding)
            
        N = embedding.size()[0]
        h = math.ceil(math.sqrt(N))
        # N = a*b where abs(a-b) is minimum
        while (N % h != 0):
            h -= 1
        w = int(N / h)
        visualization = embedding.reshape((h, w)).numpy()
        fig, ax = plt.subplots()
        im = ax.imshow(visualization, vmax=vmax, vmin=vmin)
        if title is not None:
            ax.set_title(title)
        plt.show()
    
    @staticmethod
    def display_attention(attention, title=None):
        fig, ax = plt.subplots()
        im = ax.imshow(attention, vmin=0., vmax=1.)
        if title is not None:
            ax.set_title(title)
        ax.set_xlabel('HEADS')
        ax.set_ylabel('LAYERS')
        plt.show()
    
    @staticmethod
    def load(name, path='.'):
        with open(os.path.join(path, name), 'rb') as file:
            SentenceObject = pickle.load(file)
        return SentenceObject
    
    @staticmethod
    def make_unique(L):
        unique_L = []
        for i, v in enumerate(L):
            totalcount = L.count(v)
            count = L[:i].count(v)
            unique_L.append(v + '_' + str(count+1) if totalcount > 1 else v)
        return unique_L

In [None]:
# Preloading models (this is the most costly)
BertBaseModel = BertModel.from_pretrained('bert-base-uncased')
BertBaseTokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
BertLargeModel = BertModel.from_pretrained('bert-large-uncased')
BertLargeTokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
SciBertModel = BertModel.from_pretrained('scibert-scivocab-uncased')
SciBertTokenizer = BertTokenizer.from_pretrained('scibert-scivocab-uncased')
SciBertBaseVocabModel = BertModel.from_pretrained('scibert-basevocab-uncased')
SciBertBaseVocabTokenizer = BertTokenizer.from_pretrained('scibert-basevocab-uncased')

In [None]:
raw_sentence = "Computer Vision: What is the difference between local descriptors and global descriptors"

scibert_sentence = TransformerSentence(raw_sentence,
                                       model=SciBertModel,
                                       tokenizer=SciBertTokenizer)
bert_sentence = TransformerSentence(raw_sentence,
                                    model=BertBaseModel,
                                    tokenizer=BertBaseTokenizer)
bert_large_sentence = TransformerSentence(raw_sentence,
                                          model=BertLargeModel,
                                          tokenizer=BertLargeTokenizer)

In [None]:
scibert_sentence.write_summary(print_tokens=False)
bert_sentence.write_summary(print_tokens=False)
bert_large_sentence.write_summary(print_tokens=False)

In [None]:
e1i =  scibert_sentence.summary['token_embeddings']['computer'][0,:]
e1f =  scibert_sentence.summary['token_embeddings']['computer'][-1,:]
e2i =  scibert_sentence.summary['token_embeddings']['vision'][0,:]
e2f =  scibert_sentence.summary['token_embeddings']['vision'][-1,:]


In [None]:
_ = scibert_sentence.attention_from_tokens('what', 'difference', display=True)
_ = bert_sentence.attention_from_tokens('what', 'difference', display=True)

In [None]:
# Read input sequences from .txt file and put them in a list
with open("../datasets/quora_questions.txt") as f:
    text = f.read()
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
try:
    sentences.remove('') # remove possible empty strings
except:
    None

In [None]:
scibert_sentences = []
for raw_sentence in tqdm(sentences):
    SentenceObj = TransformerSentence(raw_sentence,
                                      model=SciBertModel,
                                      tokenizer=SciBertTokenizer)
    SentenceObj.write_summary(print_tokens=False)
    scibert_sentences.append(SentenceObj)

In [None]:
for i, sentence in enumerate(scibert_sentences):
    print(i, sentence.raw_string)

In [None]:
scibert_sentences[0].summary['states'][0, :, :].size()

# machine learning
ml1 = scibert_sentences[12]
ml2 = scibert_sentences[16]
ml3 = scibert_sentences[17]
ml4 = scibert_sentences[19]
list_ml = [ml1, ml2 ,ml3, ml4]
# computer vision     
cv1 = scibert_sentences[13]
cv2 = scibert_sentences[14]
cv3 = scibert_sentences[17]
cv4 = scibert_sentences[91]
list_cv = [cv1, cv2, cv3, cv4]
# deep learning    
dl1 = scibert_sentences[5]
dl2 = scibert_sentences[22]
dl3 = scibert_sentences[26]
dl4 = scibert_sentences[75]
list_dl = [dl1, dl2, dl3, dl4]
# neural networks
nns1 = scibert_sentences[0]
nns2 = scibert_sentences[6]
nns3 = scibert_sentences[8]
nns4 = scibert_sentences[85]
nn5 = scibert_sentences[66]
nnf6 = scibert_sentences[41]
nnf7 = scibert_sentences[53]
list_nns = [nns1, nns2, nns3, nns4]
# facial recognition
fr1 = scibert_sentences[7]
fr2 = scibert_sentences[47]
fr3 = scibert_sentences[182]
fr4 = scibert_sentences[260]
list_fr = [fr1, fr2, fr3, fr4]
# 

size = (15, 3)

fig, axs = plt.subplots(1, 4, figsize=size)
fig.suptitle('MACHINE LEARNING')
a1 = ml1.attention_from_tokens('machine', 'learning', display=False)
a2 = ml2.attention_from_tokens('machine', 'learning', display=False)
a3 = ml3.attention_from_tokens('machine', 'learning', display=False)
a4 = ml4.attention_from_tokens('machine', 'learning', display=False)

axs[0].imshow(a1, vmax=1., vmin=0.)
axs[1].imshow(a2, vmax=1., vmin=0.)
axs[2].imshow(a3, vmax=1., vmin=0.)
axs[3].imshow(a4, vmax=1., vmin=0.)

fig.show()


In [None]:
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-08) # similarity func.
sentence = scibert_sentences[23]

print(sentence.raw_string)
print(sentence.summary['input_tokens'])

In [None]:
distance_evolution = {}
for token1 in sentence.summary['input_tokens']:
    for token2 in sentence.summary['input_tokens']:
        embs1 = sentence.summary['token_embeddings'][token1]#.clamp(-2, 2)
        embs2 = sentence.summary['token_embeddings'][token2]#.clamp(-2, 2)
        distance_evolution[(token1, token2)] = list(cos(embs1, embs2))

In [None]:
for key, d in distance_evolution.items():
    plt.plot(d, alpha=0.05)

In [None]:
#### SEE EMBEDDING ACROSS LAYERS ####
token = 'linear_1'
for i in range(13):
    embedding = sentence.summary['token_embeddings'][token][i, :]#.clamp(-2, 2)
    print(max(embedding), min(embedding))
    print('argmax, argmin', torch.argmax(embedding), torch.argmin(embedding))
    plt.plot(embedding, alpha=1)
    plt.show()
    sentence.visualize_embedding(embedding)