In [1]:
import pickle
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import *
from tqdm import tqdm
import re
import random
import math
import itertools
import matplotlib.pyplot as plt
import numpy as np
import umap
import spacy

In [2]:
from custom.transformer_sentence import TransformerSentence

In [3]:
class TransformerSentence():
    def __init__(self, sentence_str, 
                 model=BertModel.from_pretrained('scibert-scivocab-uncased'), 
                 tokenizer=BertTokenizer.from_pretrained('scibert-scivocab-uncased')):
        
        self.raw_string = sentence_str
        self.model = model
        self.tokenizer = tokenizer
        self.summary = {}

        
    def write_summary(self, input_tokens=None, 
                      hidden_states=None, 
                      hidden_attentions=None,
                      print_tokens=True):
        
        if (input_tokens or hidden_states or hidden_attentions) is None:
            input_tokens, hidden_states, hidden_attentions = self.forward()
        
        # this replaces adds a "_{counter}" to the repreated tokens, so that 
        # they can be used uniquely as the keys for the embeddings dictionary
        input_tokens = TransformerSentence.make_unique(input_tokens)
        
        if print_tokens:
            print('Sentence Tokenization: ', input_tokens)
            
        # write summary into the object
        self.summary['input_tokens'] = input_tokens
        self.summary['states'] = hidden_states
        self.summary['attentions'] = hidden_attentions

        self.summary['token_embeddings'] = {input_token: hidden_states[:, i, :] 
                                            for i, input_token in enumerate(input_tokens)}
        
    def forward(self):
        encoded_inputs_dict = self.tokenizer.encode_plus(self.raw_string)
        input_ids = encoded_inputs_dict['input_ids']
        input_tensor = torch.tensor([input_ids])
        input_tokens = [self.tokenizer.decode(input_ids[j]).replace(' ', '') 
                        for j in range(len(input_ids))]
        
        final_attention, final_state, hidden_states_tup, hidden_attentions_tup = self.model(input_tensor)
        
        # stacking states and attentions along the first dimention (which corresponds to the batch when necessary)
        hidden_attentions = torch.cat(hidden_attentions_tup, dim=0) # 'layers', 'heads', 'queries', 'keys'
        hidden_states = torch.cat(hidden_states_tup, dim=0) # 'layers', 'tokens', 'embeddings'
        
        return input_tokens, hidden_states.detach(), hidden_attentions.detach()
    
    
    def attention_from_tokens(self, token1, token2, display=True):
        input_tokens = self.summary['input_tokens']
        
        if (token1 and token2) not in input_tokens:
            raise ValueError('One or both of the tokens introduced are not in the sentence!')
            
        idx1, idx2 = input_tokens.index(token1), input_tokens.index(token2)
        attention = self.summary['attentions'][:, :, idx1, idx2].numpy()
        if display:
            TransformerSentence.display_attention(attention, title=(token1, token2))
        return attention
    
    
    def attention_from_idx(self, i, j, display=True):
        attention = self.summary['attentions'][:, :, i, j].numpy()
        if display:
            TransformerSentence.display_attention(attention, title=f'Token idx: {(i, j)}')
        return attention
    
    def visualize_token_path(self, fit, 
                             tokens_to_follow=None, 
                             print_tokens=False, 
                             fig_axs=(None, None), 
                             figsize=(10, 10)):
        
        if tokens_to_follow is None:
            all_tokens = self.summary['input_tokens']
            regex = re.compile(r'^[a-zA-Z]')
            tokens_to_follow = [i for i in all_tokens if regex.search(i)]
            
        if print_tokens: print(tokens_to_follow)  
            
        colors = list(range(len(tokens_to_follow)))
        projections = []
        layer_depth = self.summary['states'].size()[0]
        
        for i in range(layer_depth):
            layer_embeddings = self.summary['states'][i, :, :]
            projection = fit.transform(layer_embeddings)
            projections.append(projection)

        data = np.stack(projections, axis=0)
        if None in fig_axs:
            fig, axs = plt.subplots(figsize=figsize)
        for token in tokens_to_follow:
            i = self.summary['input_tokens'].index(token)
            plt.plot(data[:,i,0], data[:,i,1], '-o', alpha=0.3)
            plt.annotate(s=token, xy=(data[0, i, 0], data[0, i, 1]))

        plt.show()
        
    def visualize_sentence_shape(self, fit, tokens_to_follow=None, 
                                 print_tokens=False, 
                                 fig_axs=(None, None), 
                                 figsize=(10, 10)):

        if tokens_to_follow is None:
            all_tokens = self.summary['input_tokens']
            regex = re.compile(r'^[a-zA-Z]')
            tokens_to_follow = [i for i in all_tokens if regex.search(i)]

        if print_tokens: print(tokens_to_follow)  

        colors = list(range(len(tokens_to_follow)))
        projections = []
        layer_depth = self.summary['states'].size()[0]
        
        # get list of indeces of the tokens to follow
        idxs = [self.summary['input_tokens'].index(token) for token in tokens_to_follow] 
        token_embeddings = self.summary['states'][-1, idxs, :]
        data = fit.transform(token_embeddings)
        
        if None in fig_axs:
            fig, axs = plt.subplots(figsize=figsize)
            
        plt.plot(data[:,0], data[:,1], '-o')
        for i, token in enumerate(tokens_to_follow):
            plt.annotate(s=token, xy=(data[i, 0], data[i, 1]))
        #plt.show()
    
    
    def save(self, name, path='.'):
        with open(os.path.join(path, name), 'wb') as file:
            pickle.dump(self, file)
    
    
    @staticmethod
    def visualize_embedding(embedding, title=None, vmax=None, vmin=None):
        if (vmax or vmin) is None:
            vmax = max(embedding)
            vmin = min(embedding)
            
        N = embedding.size()[0]
        h = math.ceil(math.sqrt(N))
        # N = a*b where abs(a-b) is minimum
        while (N % h != 0):
            h -= 1
        w = int(N / h)
        visualization = embedding.reshape((h, w)).numpy()
        fig, ax = plt.subplots()
        im = ax.imshow(visualization, vmax=vmax, vmin=vmin, cmap='viridis')
        fig.colorbar(im)
        if title is not None:
            ax.set_title(title)
        plt.show()
    
    @staticmethod
    def display_attention(attention, title=None):
        fig, ax = plt.subplots()
        im = ax.imshow(attention, vmin=0., vmax=1., cmap='viridis')
        fig.colorbar(im)
        if title is not None:
            ax.set_title(title)
        ax.set_xlabel('HEADS')
        ax.set_ylabel('LAYERS')
        plt.show()
    
    @staticmethod
    def load(name, path='.'):
        with open(os.path.join(path, name), 'rb') as file:
            SentenceObject = pickle.load(file)
        return SentenceObject
    
    @staticmethod
    def make_unique(L):
        unique_L = []
        for i, v in enumerate(L):
            totalcount = L.count(v)
            count = L[:i].count(v)
            unique_L.append(v + '_' + str(count+1) if totalcount > 1 else v)
        return unique_L

In [4]:
# Preloading models (this is the most costly)
# Bert base and large, uncased
BertBaseModel = BertModel.from_pretrained('bert-base-uncased')
BertBaseTokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
BertLargeModel = BertModel.from_pretrained('bert-large-uncased')
BertLargeTokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
# Scibert uncased
SciBertModel = BertModel.from_pretrained('scibert-scivocab-uncased')
SciBertTokenizer = BertTokenizer.from_pretrained('scibert-scivocab-uncased')
SciBertBaseVocabModel = BertModel.from_pretrained('scibert-basevocab-uncased')
SciBertBaseVocabTokenizer = BertTokenizer.from_pretrained('scibert-basevocab-uncased')
# Scibert Cased
#SciBertModelCased = BertModel.from_pretrained('scibert-scivocab-cased')
#SciBertTokenizerCased = BertTokenizer.from_pretrained('scibert-scivocab-cased')
#SciBertBaseVocabModelCased = BertModel.from_pretrained('scibert-basevocab-cased')
#SciBertBaseVocabTokenizerCased = BertTokenizer.from_pretrained('scibert-basevocab-cased')

In [5]:
def load_dataset(txt_path="../datasets/quora_questions.txt", 
                 MODEL=SciBertModel,
                 TOKENIZER=SciBertTokenizer):
    
    # Read input sequences from .txt file and put them in a list
    with open(txt_path) as f:
        text = f.read()
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    try:
        sentences.remove('') # remove possible empty strings
    except:
        None
    
    list_SentenceObj, ALL_INITIAL_EMBEDDINGS, ALL_CONTEXT_EMBEDDINGS = [], [], []
    
    for raw_sentence in tqdm(sentences):
        SentenceObj = TransformerSentence(raw_sentence,
                                          model=MODEL,
                                          tokenizer=TOKENIZER)
        SentenceObj.write_summary(print_tokens=False)
        list_SentenceObj.append(SentenceObj)
        ALL_INITIAL_EMBEDDINGS.append(SentenceObj.summary['states'][0, :, :])
        ALL_CONTEXT_EMBEDDINGS.append(SentenceObj.summary['states'][-1, :, :])

    ALL_INITIAL_EMBEDDINGS = torch.cat(ALL_INITIAL_EMBEDDINGS, dim=0)
    ALL_CONTEXT_EMBEDDINGS = torch.cat(ALL_CONTEXT_EMBEDDINGS, dim=0)
    
    return list_SentenceObj, ALL_INITIAL_EMBEDDINGS, ALL_CONTEXT_EMBEDDINGS

In [6]:
import torch
_ = torch.manual_seed(0)
s1 = "The conference is in New York and I like it."
s2 = "London is a beautiful city."
sentence1 = TransformerSentence(s1, model=BertBaseModel, tokenizer=BertBaseTokenizer)
sentence2 = TransformerSentence(s2, model=BertBaseModel, tokenizer=BertBaseTokenizer)

sentence1.write_summary(print_tokens=True)
sentence2.write_summary(print_tokens=True)

Sentence Tokenization:  ['[CLS]', 'the', 'conference', 'is', 'in', 'new', 'york', 'and', 'i', 'like', 'it', '.', '[SEP]']
Sentence Tokenization:  ['[CLS]', 'london', 'is', 'a', 'beautiful', 'city', '.', '[SEP]']


# Naive Distance Based Compacting Method

In [7]:
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-5) # similarity func.

def remove_subsets(L):
    filtered = filter(lambda f: not any(set(f) < set(g) for g in L), L)
    return list(filtered)

def indices_to_compact_by_similarity_threshold(sequence_embeddings,
                                               sim_function=cos,
                                               threshold=0.1,
                                               exclude_special_tokens=True,
                                               combinatorics=None):
    # combinatorics= 'sequential', 'all'
    seq_length, embedding_size = sequence_embeddings.size() #make sure the input is proper size!!
    indices = list(range(seq_length))    
    
    # Combinations of indices that are group candidates
    if combinatorics == 'sequential':
        if exclude_special_tokens:
            idx_combinations = [indices[s:e] for s, e in itertools.combinations(range(1, len(indices)), 2)]
        else:
            idx_combinations = [indices[s:e] for s, e in itertools.combinations(range(len(indices)+1), 2)]
            
    elif combinatorics == 'all':
        idx_combinations = []
        for L in range(2, seq_length+1):
            combinations = list(itertools.combinations(indices, r=L))
            idx_combinations.extend(combinations)
    else:
        raise ValueError('You must specify the combinatorics as "sequencial" or "all"!!')
    
    
    all_indices_to_compact = []
    for indices in idx_combinations:
        group_candidate = sequence_embeddings[indices, :]
        group_size = len(indices)
        center = torch.mean(group_candidate, dim=0)
        center = center.repeat(group_size, 1)
        # calculate all embeddings similarities w.r.t. the center of the group
        similarities = sim_function(center, group_candidate)
        worst_sim, _ = torch.min(similarities, dim=0)
        if worst_sim > threshold: all_indices_to_compact.append(indices)
            
    indices_to_compact = remove_subsets(all_indices_to_compact)
    
    return indices_to_compact


def compact_embeddings(original_embeddings, indices_to_compact):
    new_embeddings_list = []
    for indices in indices_to_compact:
        group = original_embeddings[indices, :]
        center = torch.mean(group, dim=0)
        new_embeddings_list.append(center)
        
    new_embeddings = torch.stack(new_embeddings_list, dim=0)
    
    return new_embeddings


In [8]:
#s = "Their car broke down two miles out of town."
s = "Be sure to put on a life jacket before getting into the boat."
#s = "It’s time to get on the plane."
#s = "what is a bayesian network."
sentence = TransformerSentence(s, model=BertBaseModel, tokenizer=BertBaseTokenizer)
sentence.write_summary()
sequence_embeddings = sentence.summary['states'][-1, :, :]
indices_to_compact = indices_to_compact_by_similarity_threshold(sequence_embeddings, 
                                                                sim_function=cos, 
                                                                threshold=0.90,
                                                                exclude_special_tokens=True,
                                                                combinatorics='sequential')
print(indices_to_compact)
new_embeddings = compact_embeddings(sequence_embeddings, indices_to_compact)
print('Original Length: ', sequence_embeddings.size()[0], 'Compact size: ', new_embeddings.size()[0])

Sentence Tokenization:  ['[CLS]', 'be', 'sure', 'to', 'put', 'on', 'a', 'life', 'jacket', 'before', 'getting', 'into', 'the', 'boat', '.', '[SEP]']
[[1], [2, 3], [3, 4], [5, 6], [7], [8], [9], [10], [11, 12], [13], [14]]
Original Length:  16 Compact size:  11


In [9]:
#!python -m spacy download en

nlp = spacy.load("en") # en_core_web_sm

doc = nlp("Be sure to put on a life jacket before getting into the boat.")
print(list(doc.noun_chunks))
for np in doc.noun_chunks: # use np instead of np.text
    print(np)

print()

for sentence in sentences:
    doc = nlp(sentence)
    noun_chunks.append(list(doc.noun_chunks))

[a life jacket, the boat]
a life jacket
the boat



NameError: name 'sentences' is not defined

[0, 2, 4, 6, 8]
