# Use the ngram embeddings

In [1]:
# numpy, useful for efficient vector operations
import numpy as np

# pytorch 
import torch

# transformoer models from huggingface
from transformers import BertTokenizerFast, BertModel

# gensim library
from gensim.models import KeyedVectors

# natural langauge toolkit
import nltk
from nltk.tokenize import word_tokenize

import re

In [30]:
NGRAMS = 1
LAYER = 11

## Load Contextual Model
The contextual model is not actually necessary to find similar words. But if you want to find a similar word in context, or from a word not in the word-model dictionary we need the model. 

In [9]:
# = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [10]:
# define model name
model = 'bert-base-uncased' #for norwegian you can use: 'NbAiLab/nb-bert-base' or 'ltgoslo/norbert'

# the tokenizer plits the input text into tokens, which in this case is called wordpieces 
tokenizer = BertTokenizerFast.from_pretrained(model)

# download the model online
bert_model = BertModel.from_pretrained(model)



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [19]:
# load in ngram embeddings
word_model = KeyedVectors.load("academic_ngrams_"+str(NGRAMS)+".kv")

In [11]:
# change to cuda cores if a gpu is available
bert_model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [12]:
# bert create a vector representation (actually many, one between each layer in the model)
# this function fetches the tokens and embeddings 
def sent_to_embedding_tokens(sent): 
    with torch.no_grad(): 
        tokenized_sent = tokenizer.tokenize(sent)
        # check if tokens are too long
        if (len(tokenized_sent) > 500): 
            return False, False
        inputs = tokenizer(sent, return_tensors = "pt")
        outputs = bert_model(**inputs.to(device), output_hidden_states=True)
        hidden_states = outputs[2]

        token_embeddings = torch.stack(hidden_states, dim=0) #stack all hidden states into same tensor
        token_embeddings = token_embeddings.squeeze(dim=1) # remove empty dimension
        token_embeddings = token_embeddings.permute(1,0,2)[1:-1] # change dimensions and remove remove cls and 
    
    return tokenized_sent, token_embeddings

In [20]:
def get_word_index2token_index(tokenized_sent, word_array): 
    current_word = word_array[0]
    word_index = 0
    word_index2token_index = {}
    current_token = ""
    for i, token in enumerate(tokenized_sent): 
        current_token += (re.sub("^##", "", token)) # it is standard for BERT tokenizer to add ## for subpieces not first in a word
        if (current_word == current_token): 
            if (word_index in word_index2token_index): 
                word_index2token_index[word_index].append(i)
            else: 
                word_index2token_index.update({word_index: [i]})
            word_index += 1
            # check if it was the last word
            if (len(word_array) <= word_index): 
                break
            current_word = word_array[word_index]
            current_token = ""
        else: 
            if (word_index in word_index2token_index): 
                word_index2token_index[word_index].append(i)
            else: 
                word_index2token_index.update({word_index: [i]})
    return word_index2token_index

def get_ngrams_and_indecies(word_index2token_index, word_array): 
    ngrams = [ngram for ngram in nltk.ngrams(word_array, NGRAMS)]
    ngram_word_indecies = [[i+j for j in range(len(ngrams[0]))] for i in range(len(ngrams))]
    ngram_token_indecies = []
    for word_indecies in ngram_word_indecies:  
        ngram_token_indecies.append([token_index for word_index in word_indecies for token_index in word_index2token_index[word_index]])
    return ngrams, ngram_token_indecies


def get_filtered_ngrams_in_sent(tokenized_sent, sent): 
    word_array = word_tokenize(sent.lower())
    word_index2token_index = get_word_index2token_index(tokenized_sent, word_array)
    ngrams, ngram_token_indecies = get_ngrams_and_indecies(word_index2token_index, word_array)
    # make the ngram of words into one string 
    ngrams_in_sent = [" ".join(ngram) for ngram in ngrams]
    
    filtered_ngrams_in_sent = []
    for i, ngram_in_sent in enumerate(ngrams_in_sent): 
        if(ngram_in_sent in filtered_ngrams_list): 
            filtered_ngrams_in_sent.append((ngram_in_sent, ngram_token_indecies[i]))
            
    return filtered_ngrams_in_sent

def get_substitute_indecies(tokenized_sent, sent, substitute_phrase):
    word_array = word_tokenize(sent.lower())
    substitute_array = word_tokenize(substitute_phrase)
    # find word indecies for overlap
    for i in range(len(word_array)): 
        sub_array = word_array[i:i+len(substitute_array)]
        if (sub_array == substitute_array): 
            word_indecies = [i+j for j in range(len(substitute_array))]
            break
    if (not word_indecies): 
        return false
    
    word_index2token_index = get_word_index2token_index(tokenized_sent, word_array)
    
    return [token_index for word_index in word_indecies for token_index in word_index2token_index[word_index]]


In [26]:
def get_substitute_embedding(token_embeddings, substitute_indecies): 
    return np.mean(token_embeddings[substitute_indecies], axis=0)


def find_similar_ngrams(sent, substitute_phrase, top = 10): 
    # find embeddings of the sentence words
    tokenized_sent, token_embeddings = sent_to_embedding_tokens(sent)
    token_embeddings = token_embeddings.to("cpu").detach().numpy()[:, LAYER, :]

    # extract relevant phrases for ngram
    substitute_indecies = get_substitute_indecies(tokenized_sent, sent, substitute_phrase)
    substitute_embedding = get_substitute_embedding(token_embeddings, substitute_indecies)
    
    return word_model.similar_by_vector(substitute_embedding, topn=top)

## Test embeddings

In [24]:
word_model.most_similar("solve", topn = 10)

[('solving', 0.8838314414024353),
 ('solved', 0.8668434619903564),
 ('address', 0.8500644564628601),
 ('compute', 0.8045283555984497),
 ('tackle', 0.7967976927757263),
 ('overcome', 0.7944207191467285),
 ('implement', 0.7927462458610535),
 ('obtain', 0.777320921421051),
 ('solution', 0.7771925926208496),
 ('perform', 0.774975061416626)]

In [31]:
test_sent = "We would like to figure out the best parameters for the model."
substitute_phrase = "figure out"
find_similar_ngrams(test_sent, substitute_phrase, top = 20)

[('develop', 0.7141934633255005),
 ('determine', 0.6942133903503418),
 ('find', 0.684475839138031),
 ('out', 0.6837274432182312),
 ('explain', 0.6746566295623779),
 ('finding', 0.6740469336509705),
 ('show', 0.6731117963790894),
 ('discover', 0.6675494313240051),
 ('summarize', 0.6672800779342651),
 ('detailed', 0.6641571521759033),
 ('analyze', 0.6611805558204651),
 ('provide', 0.6599010229110718),
 ('evaluate', 0.6570184826850891),
 ('discuss', 0.6554686427116394),
 ('propose', 0.653902530670166),
 ('investigate', 0.6517993807792664),
 ('studied', 0.6510133147239685),
 ('demonstrate', 0.6507325172424316),
 ('developed', 0.6485055088996887),
 ('describe', 0.6476252675056458)]