In [2]:
from transformers import BertModel, BertTokenizer
import io

import numpy as np
from collections import namedtuple
import sys
from typing import List, Tuple, Dict, Set, Union
import torch
import torch.nn as nn
import torch.nn.utils
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from nltk import word_tokenize
import pickle
import timeit
from scipy import spatial

from evaluator import Evaluator
from vocab import Vocab, VocabEntry
from utils import read_corpus, pad_sents, batch_iter

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\willi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
words, defs, ft_dict = pickle.load( open( "../data/words_defs_dict_1M.train", "rb" ))

vocab = VocabEntry.from_corpus(defs, 1000000, 0)
for w in ft_dict:
    vocab.add(w)

number of word types: 23437, number of word types w/ frequency >= 0: 23437


In [4]:
def create_emb_layer(weights_matrix, src_pad_token_idx, non_trainable=True):
    num_embeddings, embedding_dim = weights_matrix.shape
    emb_layer = nn.Embedding(num_embeddings, embedding_dim, src_pad_token_idx)
    emb_layer.weight.data.copy_(torch.from_numpy(weights_matrix)) #figure out what is here
    if non_trainable:
        emb_layer.weight.requires_grad = False
    return emb_layer, num_embeddings, embedding_dim

class ModelEmbeddings(nn.Module): 
    """
    Class that converts input words to their embeddings.
    """

    def __init__(self, embed_size, vocab, fasttext_dict):
        """
        Init the Embedding layers.

        @param embed_size (int): Embedding size (dimensionality)
        @param vocab (VocabEntry)
        """
        super(ModelEmbeddings, self).__init__()

        self.embed_size = embed_size

        matrix_len = len(vocab)
        weights_matrix = np.zeros((matrix_len, self.embed_size))
        words_found = 0
        #print(len(vocab), weights_matrix.shape)
        for word, index in vocab.word2id.items():
            try:
                weights_matrix[index] = np.array(fasttext_dict[word])
                words_found += 1
            except KeyError:
                weights_matrix[index] = np.random.normal(scale=0.6, size=(self.embed_size,))

        # default values
        src_pad_token_idx = vocab['<pad>']
        self.source = create_emb_layer(weights_matrix, src_pad_token_idx, True)

In [5]:
def load_vectors(fname):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    n, d = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(' ')
        data[tokens[0]] = map(float, tokens[1:])
    return data

In [6]:
class ReverseDictionary(nn.Module):

    def __init__(self, embed_dim, hidden_dim, vocab, ft_dict, freeze_bert = False):
        super(ReverseDictionary, self).__init__()
        #Instantiating BERT model object 
        
        self.ft_embedding = ModelEmbeddings(embed_dim, vocab, ft_dict)
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
        
#         Freeze bert layers
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False
        
        #Classification layer
        self.lstm_fasttext = nn.LSTM(embed_dim, hidden_dim)
        self.lin_layer = nn.Linear(hidden_dim+768, embed_dim)


    def forward(self, bert_input, ft_input, attn_masks):
        '''
        Inputs:
            -seq : Tensor of shape [B, T] containing token ids of sequences
            -attn_masks : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens
        '''

        #Feeding the input to BERT model to obtain contextualized representations
        embedded = self.ft_embedding.source[0](ft_input)
        
        cont_reps, _ = self.bert_layer(bert_input, attention_mask = attn_masks)
        
        output, (cn, hn) = self.lstm_fasttext(embedded.unsqueeze(1))
        
        cls_rep = cont_reps[:, 0]
        
        #print(cn.squeeze(1).shape, cls_rep.shape)

        toLinear = torch.cat([cls_rep, cn.squeeze(1)], 1)

        #Obtaining the representation of [CLS] head
        
        #feed cls_rep to -> fasttext layer
        projected = self.lin_layer(toLinear)

        return projected


In [34]:
model = ReverseDictionary(300, 300, vocab, ft_dict)
loss_function = nn.L1Loss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr = .001)

In [35]:
int_sents = vocab.words2indices(defs)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
max_len = max(len(x) for x in int_sents)
sents_ft_id = [torch.tensor(i, dtype=torch.long, device="cpu") for i in int_sents]
sents_bert_id = []
masks = []
for d in defs:
    tokens = ['[CLS]'] + d + ['[SEP]']
    padded_tokens = tokens + ['[PAD]' for _ in range(max_len - len(tokens))]
    attn_mask = [1 if token != '[PAD]' else 0 for token in padded_tokens]
    seg_ids = [0 for _ in range(len(padded_tokens))]
    token_ids = tokenizer.convert_tokens_to_ids(padded_tokens)
    token_ids = torch.tensor(token_ids).unsqueeze(0) 
    attn_mask = torch.tensor(attn_mask).unsqueeze(0) 
    sents_bert_id.append(token_ids)
    masks.append(attn_mask)
assert(len(sents_bert_id) == len(masks))
assert(len(masks) == len(sents_ft_id))

In [36]:
start = timeit.default_timer()
losses = []
for epoch in range(5000):
    loss_cum = []
    for i in [100,120]:
        
        print(words[i])
        model.zero_grad()
        tag_scores = model.forward(sents_bert_id[i], sents_ft_id[i], masks[i])
        y_pred = tag_scores[0].double().unsqueeze(1)
        y_array = model.ft_embedding.source[0](torch.tensor(vocab[words[i]])).double().unsqueeze(1)
        #print(y_pred.shape, y_array.shape)
        loss = loss_function(y_pred, y_array)
        loss_cum.append(loss)
        loss.backward()
        optimizer.step() 
#         print(loss)
    
    lossavg = sum(loss_cum)/len(loss_cum)
    losses.append(loss)
    print(epoch, lossavg, timeit.default_timer() - start)

bristle
coltsfoot
0 tensor(0.2158, dtype=torch.float64, grad_fn=<DivBackward0>) 4.246253700000125
bristle
coltsfoot
1 tensor(0.1698, dtype=torch.float64, grad_fn=<DivBackward0>) 11.489367400000447
bristle
coltsfoot
2 tensor(0.1484, dtype=torch.float64, grad_fn=<DivBackward0>) 18.649803500000417
bristle
coltsfoot
3 tensor(0.0970, dtype=torch.float64, grad_fn=<DivBackward0>) 25.116106800000125
bristle
coltsfoot
4 tensor(0.0891, dtype=torch.float64, grad_fn=<DivBackward0>) 30.650640300000305
bristle
coltsfoot
5 tensor(0.0750, dtype=torch.float64, grad_fn=<DivBackward0>) 34.593354000000545
bristle
coltsfoot
6 tensor(0.0655, dtype=torch.float64, grad_fn=<DivBackward0>) 38.527614200000244
bristle
coltsfoot
7 tensor(0.0576, dtype=torch.float64, grad_fn=<DivBackward0>) 42.40188080000007
bristle
coltsfoot
8 tensor(0.0524, dtype=torch.float64, grad_fn=<DivBackward0>) 46.31541350000043
bristle
coltsfoot
9 tensor(0.0467, dtype=torch.float64, grad_fn=<DivBackward0>) 53.481824000000415
bristle
colts

KeyboardInterrupt: 

In [37]:
eval = Evaluator()
model.zero_grad()

for i in [100,120]:
    model.zero_grad()
    tag_scores = model.forward(sents_bert_id[i], sents_ft_id[i], masks[i])
    y_pred = tag_scores[0].double()#.unsqueeze(1)
    #print(y_pred)
    y_array = model.ft_embedding.source[0](torch.tensor(vocab[words[i]])).double().unsqueeze(1)
    #print(y_array)
    #print(y_pred.shape, y_array.shape)
    loss = loss_function(y_pred, y_array)
    eval.top_ten_hundred(ft_dict, words[i], y_pred.detach().numpy())
    print(np.linalg.norm(ft_dict[words[i]]-y_pred.detach().numpy()))
#     print(np.linalg.norm(ft_dict[words[i]]-y_pred.detach().numpy()))
#     print(sorted(ft_dict.keys(), key=lambda word: spatial.distance.cosine(ft_dict[word], y_pred.detach().numpy())))
#     print(ft_dict['fault'].shape, y_pred.detach().numpy().shape)
#     print(loss)

['bristle', 'bristly', 'brush', 'rankle', 'prickly', 'scoff', 'snicker', 'lash', 'flinch', 'grumble']
bristle ['bristle', 'bristly', 'brush', 'rankle', 'prickly', 'scoff', 'snicker', 'lash', 'flinch', 'grumble']
0.40043542281613803
['coltsfoot', 'spicebush', 'bloodroot', 'wolfsbane', 'ageratum', 'thimbleberry', 'fireweed', 'comfrey', 'bayberry', 'snakeroot']
coltsfoot ['coltsfoot', 'spicebush', 'bloodroot', 'wolfsbane', 'ageratum', 'thimbleberry', 'fireweed', 'comfrey', 'bayberry', 'snakeroot']
0.40982832313355433
