In [None]:
!pip install transformers
!pip install datasets
!pip install pytorch-lightning

Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 9.8 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 17.7 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 8.4 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 69.1 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 58.4 MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for

In [None]:
# Imports go here -- a bit shabby for now...
import numpy as np
import pandas as pd
import os
import urllib.request
import time
import datetime
import json
import random
import re
import zipfile
import pickle
import math
from collections import OrderedDict
from google.colab import drive
import matplotlib.pyplot as plt

import nltk
from nltk.corpus import wordnet as wn

import torch
from torch import nn
from torch.nn.functional import cross_entropy
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import TensorDataset, DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.optim import AdamW
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import transformers
from transformers import AutoTokenizer, BertForMaskedLM, BertModel, get_linear_schedule_with_warmup
from datasets import load_dataset
from sklearn.model_selection import train_test_split

In [None]:
def free_memory():
  with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
# Mount my Drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# Some flags for the rest of the notebook
force_download = False
process_dictionaries=True
filter_using_glove = True
dont_use_edmt = True
dont_use_webster = True
dont_use_unix = False
use_lstm_model = True
use_multi_layers = False
force_restart_training = False
NUM_TARGET_EPOCHS = 50          # Used for linear schedule
if use_multi_layers:
  if dont_use_unix:
    CHECKPT_DIR = "/content/gdrive/MyDrive/rd-checkpt-bl-2"
  else:
    CHECKPT_DIR = "/content/gdrive/MyDrive/rd-checkpt-bl-4"
else:
  if dont_use_unix:
    CHECKPT_DIR = "/content/gdrive/MyDrive/rd-checkpt-bl-1"
  else:
    CHECKPT_DIR = "/content/gdrive/MyDrive/rd-checkpt-bl-3"
sample_bad_words = ['timewrn', 'svahng', 'bulletinyyy', 'seabream', 'srivalo', 'nortelnet', 'piyanart', 'prohertrib', 'canyonres']

In [None]:
if force_download or not os.path.isfile('/content/gdrive/MyDrive/webster_dict.json'):
  urllib.request.urlretrieve ("https://raw.githubusercontent.com/matthewreagan/WebstersEnglishDictionary/master/dictionary.json", "/content/gdrive/MyDrive/webster_dict.json")
if force_download or not os.path.isfile('/content/gdrive/MyDrive/edmt_dict.json'):
  urllib.request.urlretrieve ("https://raw.githubusercontent.com/eddydn/DictionaryDatabase/master/EDMTDictionary.json", "/content/gdrive/MyDrive/edmt_dict.json")
nltk.download('wordnet')
wordnet = [(synset.lemma_names()[0], synset.definition()) for synset in wn.all_synsets()]
wordnet = [(word, defn) for (word, defn) in wordnet if '_' not in word]

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


In [None]:
if force_download or not os.path.isfile('/content/gdrive/MyDrive/glove_6B_100d.pkl'):
  urllib.request.urlretrieve ("http://nlp.stanford.edu/data/glove.6B.zip", "glove.6B.zip")
  with zipfile.ZipFile("glove.6B.zip", 'r') as zip_ref:
    zip_ref.extractall(".")
  print('Indexing word vectors.')
  embeddings_index = {}
  f = open('glove.6B.100d.txt', encoding='utf-8')
  for line in f:
      values = line.split()
      word = values[0]
      coefs = np.asarray(values[1:], dtype='float32')
      embeddings_index[word] = coefs
  f.close()

  print('Found %s word vectors.' % len(embeddings_index))
  pickle.dump({'embeddings_index' : embeddings_index } , open('/content/gdrive/MyDrive/glove_6B_100d.pkl', 'wb'))

In [None]:
glove_vectors = pickle.load(open('/content/gdrive/MyDrive/glove_6B_100d.pkl', 'rb'))['embeddings_index']
print("Example glove vector for 'sprint': ", glove_vectors['sprint'])

Example glove vector for 'sprint':  [ 1.2558    -0.32234    0.04832    0.36313   -0.012474  -0.67533
 -0.15519   -0.10026   -1.0433    -0.0051245  0.84508    0.69359
 -0.41752   -0.59553   -0.7022    -0.44532   -0.07182   -0.014373
  0.085832  -0.38478    0.17784    0.42696    0.70415    1.0822
  0.13701    0.048887   0.13159    0.36777    0.43866   -0.8762
 -0.60843    0.74679   -0.081127  -0.95482    1.4353    -0.1464
 -0.40491    1.2206    -0.016826   1.285      1.024     -0.0481
 -0.32355   -0.65945   -0.84005    0.60295    0.8954    -0.50376
  0.58893   -0.38534   -0.30326   -0.19669    0.91021    0.43647
  0.50445   -1.371     -0.88019    1.4        1.5329     0.32102
  0.09122   -0.05632    1.0116     0.20832    0.56912   -0.14315
 -0.17157    0.42336    0.55502    0.11152   -0.30011   -0.55684
 -0.87482    0.070793   0.20729    0.24309    0.36296   -0.58297
 -0.038588  -1.1104     0.42161   -0.88943    0.12108    0.95354
 -0.3903    -0.25821   -0.18965    0.018633   0.64512    

In [None]:
def clean_text(text):
  """
  Cleans the text. For now, a no-op
  """
  return text

def remove_duplicates(dictionary):
  """
  The EDMT dictionary has duplicates, e.g. two definitions of 'A'.
  Given a list of (word, meaning pairs), throws away all but the first definition for each word.
  https://stackoverflow.com/questions/29563953/most-pythonic-way-to-remove-tuples-from-a-list-if-first-element-is-a-duplicate
  """
  return list(OrderedDict(dictionary[::-1]).items())[::-1]

def split_definitions_webster(combined, min_word_count=3):
  """
  Split a string giving multiple definitions into its constitutents, and then filter by the minimum word count. Webster uses the convention -
  1. First definition 2. Second definition 3. ...
  """
  max_count = 0
  last_index = -1
  splits = [0]
  while True:
    search_for = "{}.".format(max_count+1)
    found_index = combined.find(search_for)
    if found_index <= last_index:
      break
    splits.append(found_index)
    last_index = found_index
    max_count += 1
  if max_count <= 1:
    defs = [combined.strip()]
  else:
    defs = [combined[i+2:j].strip() for i,j in zip(splits, splits[1:]+[None])]
  return [defn for defn in defs if len(defn.split()) >= min_word_count]

def split_definitions_edmt(combined, min_word_count=5):
  """
  Split a string giving multiple definitions into its constitutents, and then filter by the minimum word count. EDMT uses the convention -
  First definition ; Second definition ; ...
  """
  if ';' in combined:
    defs = [defn.strip() for defn in combined.split(';')]
  else:
    defs = [combined.strip()]
  return [defn for defn in defs if len(defn.split()) >= min_word_count]

def should_use_definition(word, definition, min_prefix_overlap=6, retain_probability = 0):
  """
  Some definitions are just poor for training. Consider:
  'The quality of being brutal' - brutalistic
  There are a lot of examples like this among our training data -- we thus weed out those definitions where the word
  shares a prefix of length >= min_prefix_overlap with a word in the definition.
  We overlook a few cases with probability retain_probability
  """
  min_prefix_overlap = min(min_prefix_overlap, len(word))
  ok = True
  for def_word in definition.split():
    if len(def_word) < min_prefix_overlap:
      continue
    if def_word[:min_prefix_overlap].lower() == word[:min_prefix_overlap].lower():
      ok = False
      break
  if ok:
    return True
  elif retain_probability > 0 and random.random() < retain_probability:
    return True
  else:
    return False

def process_dictionary(dictionary, name):
  """
  Split definitions and filter them.
  """
  processed = []
  for word, defn in dictionary:
    if name == 'webster':
      defs = split_definitions_webster(defn)
    elif name == 'edmt':
      defs = split_definitions_edmt(defn)
    else:
      defs = [defn]
    for split_defn in defs:
      if should_use_definition(word, split_defn):
        processed.append((word, split_defn))
  return processed

def glove_filter(dictionary):
  """
  Throw out those entries where the word is not in glove
  """
  dictionary = [(word, defn) for (word, defn) in dictionary if word in glove_vectors]
  return dictionary

In [None]:
def read_webster_dict(path="/content/gdrive/MyDrive/webster_dict.json"):
  with open(path) as f:
    webster = json.load(f)
  return remove_duplicates([(key.lower(), clean_text(value)) for key,value in webster.items()])

def read_edmt_dict(path="/content/gdrive/MyDrive/edmt_dict.json"):
  with open(path) as f:
    edmt = json.load(f)
  return remove_duplicates([(entry['word'].lower(), clean_text(entry['description'])) for entry in edmt])

In [None]:
webster = read_webster_dict()
edmt = read_edmt_dict()
unix = pickle.load(open('/content/gdrive/MyDrive/unix-dictionary.pkl', 'rb'))['dictionary']
print("Webster has {} word-definition pairs.".format(len(webster)))
print(random.choice(webster))
print("EDMT has {} word-definition pairs.".format(len(edmt)))
print(random.choice(edmt))
print("WordNet has {} word-definition pairs.".format(len(wordnet)))
print(random.choice(wordnet))
print("Unix has {} word-definition pairs.".format(len(unix)))
print(random.choice(unix))
if process_dictionaries:
  webster = process_dictionary(webster, 'webster')
  edmt = process_dictionary(edmt, 'edmt')
  print("Webster has {} word-definition pairs after processing.".format(len(webster)))
  print(random.choice(webster))
  print("EDMT has {} word-definition pairs after processing.".format(len(edmt)))
  print(random.choice(edmt))
if filter_using_glove:
  webster = glove_filter(webster)
  edmt = glove_filter(edmt)
  wordnet = glove_filter(wordnet)
  unix = glove_filter(unix)
  print("Webster has {} word-definition pairs after Glove filtering.".format(len(webster)))
  print(random.choice(webster))
  print("EDMT has {} word-definition pairs after Glove filtering.".format(len(edmt)))
  print(random.choice(edmt))
  print("WordNet has {} word-definition pairs after Glove filtering.".format(len(wordnet)))
  print(random.choice(wordnet))
  print("Unix has {} word-definition pairs after Glove filtering.".format(len(unix)))
  print(random.choice(unix))
if dont_use_edmt:
  edmt = []
if dont_use_webster:
  webster = []
if dont_use_unix:
  unix = []

Webster has 102217 word-definition pairs.
('chap', "1. To cause to open in slits or chinks; to split; to cause the skin of to crack or become rough. Then would unbalanced heat licentious reign, Crack the dry hill, and chap the russet plain. Blackmore. Nor winter's blast chap her fair face. Lyly. 2. To strike; to beat. [Scot.]\n\n1. To crack or open in slits; as, the earth chaps; the hands chap. 2. To strike; to knock; to rap. [Scot.]\n\n1. A cleft, crack, or chink, as in the surface of the earth, or in the skin. 2. A division; a breach, as in a party. [Obs.] Many clefts and chaps in our council board. T. Fuller. 3. A blow; a rap. [Scot.]\n\n1. One of the jaws or the fleshy covering of a jaw; -- commonly in the plural, and used of animals, and colloquially of human beings. His chaps were all besmeared with crimson blood. Cowley. He unseamed him [Macdonald] from the nave to the chaps. Shak. 2. One of the jaws or cheeks of a vise, etc.\n\n1. A buyer; a chapman. [Obs.] If you want to sell,

In [None]:
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [None]:
def encode_data(data, tokenizer, using_lstm_model=False, max_def_length=128):
  """
  Removes newlines, then encodes the definition of each word in the input. Prunes away those inputs whose encoded length exceeds
  max_length. Also returns the encoded gold-truth outputs.
  Observation - the encoded word is anywhere from 2 to 9 tokens long. Since a word may correspond to more than one token,
  it is hard to enforce a 1-token rule. Hence, we just enforce that the output is at most 10 tokens long.
  """
  num_total = len(data)
  encoded_def = []
  encoded_def_attn_masks = []
  encoded_targets = []
  for i in range(num_total):
    word = data[i][0]
    definition = data[i][1].replace('\n','')
    iids = tokenizer.encode(definition, add_special_tokens=True, padding="max_length", max_length=128, return_tensors="pt")[0]
    if iids.shape[-1] != 128:
      continue
    attn_mask = (iids != tokenizer.pad_token_id).int()
    if word == '':
      target = torch.zeros(100)
    else:
      target = torch.tensor(glove_vectors[word])
    encoded_def.append(iids)
    encoded_def_attn_masks.append(attn_mask)
    encoded_targets.append(target)
  encoded_def = torch.stack(encoded_def)
  encoded_def_attn_masks = torch.stack(encoded_def_attn_masks)
  encoded_targets = torch.stack(encoded_targets)
  return (encoded_def, encoded_def_attn_masks, encoded_targets)

In [None]:
combined_dataset = webster + edmt + wordnet + unix
encoded_dataset = encode_data(combined_dataset, bert_tokenizer)

In [None]:
print("Overall, using {} examples.".format(encoded_dataset[0].shape[0]))
print("Example encoded sentence:\n{}".format(encoded_dataset[0][14]))
print("Example encoded target:\n{}".format(encoded_dataset[2][14]))

Overall, using 150957 examples.
Example encoded sentence:
tensor([ 101, 3143,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0])
Example encoded target:
tensor([ 2.0031e-03, -3.6897e-01,  1.1625e-01, -2.8270e-01,  

In [None]:
def train_val_test_split(encoded_dataset):
  """
  Splits the dataset into train, validation and test datasets. Currently, 92%, 4.8% and 3.2% of the samples go to the training, validation
  and test sets, respectively.
  """
  train_enc_def, val_test_enc_def, train_targets, val_test_targets = train_test_split(encoded_dataset[0], encoded_dataset[2], random_state=199, test_size=0.08)
  train_attn_masks, val_test_attn_masks, _, _ = train_test_split(encoded_dataset[1], encoded_dataset[2], random_state=199, test_size=0.08)
  val_enc_def, test_enc_def, val_targets, test_targets = train_test_split(val_test_enc_def, val_test_targets, random_state=1700, test_size=0.4)
  val_attn_masks, test_attn_masks, _, _ = train_test_split(val_test_attn_masks, val_test_targets, random_state=1700, test_size=0.4)

  return {
      'train' : (train_enc_def, train_attn_masks, train_targets),
      'validation' : (val_enc_def, val_attn_masks, val_targets),
      'test' : (test_enc_def, test_attn_masks, test_targets)
  }

In [None]:
split_dataset = train_val_test_split(encoded_dataset)
print("Number of train examples : {}".format(split_dataset['train'][0].shape[0]))
print("Number of validation examples : {}".format(split_dataset['validation'][0].shape[0]))
print("Number of test examples : {}".format(split_dataset['test'][0].shape[0]))

Number of train examples : 138880
Number of validation examples : 7246
Number of test examples : 4831


In [None]:
class BertLSTM(nn.Module):
  """
  BERT -> LSTM -> Linear
  """
  def __init__(self, out_dim=100, seq_len=128):
    super().__init__()
    self.out_dim = out_dim
    self.seq_len = seq_len
    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.hidden_size = self.bert.config.hidden_size
    self.LSTM = nn.LSTM(self.hidden_size, self.hidden_size, bidirectional=True)
    self.Linear = nn.Linear(self.hidden_size*2, self.out_dim)
    self.train_mode = True

  def train(self):
    self.train_mode = True

  def eval(self):
    self.train_mode = False

  def forward(self, input_ids, attention_mask):
    outputs = self.bert(input_ids,attention_mask)
    encoded_layers, pooled_output = outputs.last_hidden_state, outputs.pooler_output
    seq_lens = encoded_layers.shape[0] * [self.seq_len]
    encoded_layers = encoded_layers.permute(1, 0, 2)
    enc_hiddens, (last_hidden, last_cell) = self.LSTM(nn.utils.rnn.pack_padded_sequence(encoded_layers, seq_lens))
    output_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)
    output_hidden = nn.functional.dropout(output_hidden,0.2)
    if self.train_mode:
      output_hidden = nn.functional.dropout(output_hidden,0.2)
    return self.Linear(output_hidden)

class BertMultiLSTM(nn.Module):
  """
  BERT -> 4 x (LSTM + Dropout) -> Linear
  """
  def __init__(self, out_dim=100, seq_len=128):
    super().__init__()
    self.out_dim = out_dim
    self.seq_len = seq_len
    self.bert = BertModel.from_pretrained('bert-base-uncased')
    self.hidden_size = self.bert.config.hidden_size
    self.LSTM = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, num_layers=4, dropout=0.1, bidirectional=True)
    self.Linear = nn.Linear(self.hidden_size*2, self.out_dim)
    self.train_mode = True

  def train(self):
    self.train_mode = True

  def eval(self):
    self.train_mode = False
  
  def forward(self, input_ids, attention_mask):
    outputs = self.bert(input_ids,attention_mask)
    encoded_layers, pooled_output = outputs.last_hidden_state, outputs.pooler_output
    seq_lens = encoded_layers.shape[0] * [self.seq_len]
    encoded_layers = encoded_layers.permute(1, 0, 2)
    enc_hiddens, (last_hidden, last_cell) = self.LSTM(nn.utils.rnn.pack_padded_sequence(encoded_layers, seq_lens))
    output_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)
    if self.train_mode:
      output_hidden = nn.functional.dropout(output_hidden,0.2)
    return self.Linear(output_hidden)

In [None]:
if use_multi_layers:
  bert_lstm_model = BertMultiLSTM()
else:
  bert_lstm_model = BertLSTM()
print(bert_lstm_model)

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertLSTM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
   

In [None]:
BATCH_SIZE = 32

train_dataset = TensorDataset(*split_dataset['train'])
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)

validation_dataset = TensorDataset(*split_dataset['validation'])
validation_sampler = RandomSampler(validation_dataset)
validation_dataloader = DataLoader(validation_dataset, sampler=validation_sampler, batch_size=BATCH_SIZE)

In [None]:
if torch.cuda.is_available():
  print("Using GPU: {}".format(torch.cuda.get_device_name(0)))
  device = torch.device("cuda")
  bert_lstm_model.cuda()
else:
  print("No GPUs available, using CPU")
  device = torch.device("cpu")

Using GPU: Tesla P100-PCIE-16GB


In [None]:
def format_time(elapsed):
  elapsed_rounded = int(round(elapsed))
  return str(datetime.timedelta(seconds=elapsed_rounded))

def flat_accuracy(preds, labels):
  preds_flat = np.argmax(preds, axis=1).flatten()
  labels_flat = labels.flatten()
  return np.sum(preds_flat == labels_flat) / labels_flat.shape[0]

In [None]:
def get_max_checkpt(checkpt_dir):
  max_checkpt = 0
  for filename in os.listdir(checkpt_dir):
    if re.match(r"checkpt-([0-9]+).pt", filename):
      checkpt_num = int(filename.split('.')[-2].split('-')[-1])
      if checkpt_num > max_checkpt:
        max_checkpt = checkpt_num
  return max_checkpt

def load_latest_checkpt(checkpt_dir=CHECKPT_DIR):
  if force_restart_training:
    return
  mx_checkpt = get_max_checkpt(checkpt_dir)
  if mx_checkpt > 0:
    checkpt_file = os.path.join(checkpt_dir, "checkpt-{}.pt".format(mx_checkpt))
    bert_lstm_model.load_state_dict(torch.load(checkpt_file))
  return mx_checkpt

In [None]:
NUM_EPOCHS = 50
NUM_STEPS = len(train_dataloader) * NUM_TARGET_EPOCHS
optimizer = AdamW(bert_lstm_model.parameters(), lr=2e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=NUM_STEPS)

In [None]:
def train_bert_lstm():
  loss_values = []
  start_epoch = load_latest_checkpt() # 0-indexed
  scheduler.last_epoch = start_epoch - 1
  save = True
  bert_lstm_model.train()
  for epoch in range(start_epoch, NUM_EPOCHS):
    print("Using BERT-LSTM model")
    print("======== Epoch {} / {} ========".format(epoch+1, NUM_EPOCHS))
    print("Training phase")
    epoch_start = time.time()
    epoch_loss = 0
    bert_lstm_model.train()
    for step, batch in enumerate(train_dataloader):
      if step % 40 == 0 and step != 0:
        elapsed = format_time(time.time() - epoch_start)
        print("Batch {} of {}. Elapsed {}".format(step, len(train_dataloader), elapsed))
      batch_enc_def = batch[0].to(device)
      batch_attn_mask = batch[1].to(device)
      batch_targets = batch[2].to(device) # These are the glove vectors
      bert_lstm_model.zero_grad()
      outputs = bert_lstm_model(input_ids=batch_enc_def, attention_mask=batch_attn_mask)
      # This function takes logits and labels
      MSE = nn.MSELoss(reduction='none')
      loss = MSE(outputs, batch_targets)
      loss = torch.mean(torch.sum(loss, axis=1))
      epoch_loss += loss
      loss.backward()
      clip_grad_norm_(bert_lstm_model.parameters(), 1.0)
      optimizer.step()
      scheduler.step()
    avg_train_loss = epoch_loss / len(train_dataloader)
    loss_values.append(avg_train_loss)
    print("Average training loss for epoch {} : {}".format(epoch+1, avg_train_loss))
    print("Epoch took {}".format(format_time(time.time()-epoch_start)))

    print("\nValidation phase")
    val_start = time.time()
    bert_lstm_model.eval()
    val_loss, val_accuracy = 0, 0
    batch_eval_steps, batch_eval_examples = 0, 0
    for batch in validation_dataloader:
      batch = tuple(tup.to(device) for tup in batch)
      batch_enc_def, batch_attn_mask, batch_targets = batch
      with torch.no_grad():
        outputs = bert_lstm_model(input_ids=batch_enc_def, attention_mask=batch_attn_mask)
      MSE = nn.MSELoss(reduction='none')
      loss = MSE(outputs, batch_targets)
      loss = torch.mean(torch.sum(loss, axis=1))
      val_loss += loss
    avg_val_loss = val_loss / len(validation_dataloader)
    print("Validation loss: {}".format(avg_val_loss))
    print("Validation took {}".format(format_time(time.time()-val_start)))
    if save:
      checkpt_path = os.path.join(CHECKPT_DIR, "checkpt-{}.pt".format(epoch+1))
      torch.save(bert_lstm_model.state_dict(), checkpt_path)

In [None]:
train_bert_lstm()
bert_lstm_model.eval()

In [None]:
def has_blocked_chars(word):
  """
  Prune away words with spurious characters such as @
  """
  return (word in sample_bad_words) or any(not char.isalpha() for char in word)

def get_k_closest_words(vec, k=5, skip_implausible=True):
  """
  Returns top k closest words when comparing - for now only k=1 is supported.
  """
  vec = vec.detach().cpu().numpy().flatten()
  closest = [None] * k
  distances = [math.inf] * k
  for word, wvec in glove_vectors.items():
    if skip_implausible and has_blocked_chars(word):
      continue
    distance = np.linalg.norm(wvec-vec)
    ind = 0
    while ind < k and distances[ind] < distance:
      ind += 1
    if ind < k:
      closest = closest[:ind] + [word] + closest[ind:-1]
      distances = distances[:ind] + [distance] + distances[ind:-1]
  return closest

def get_closest_word(vec, skip_implausible=True):
  """
  Gets the closest word among the glove words to the given vector
  """
  vec = vec.detach().cpu().numpy().flatten()
  closest = None
  dmin = math.inf
  for word, wvec in glove_vectors.items():
    if skip_implausible and has_blocked_chars(word):
      continue
    distance = np.linalg.norm(wvec-vec)
    if distance < dmin:
      closest = word
      dmin = distance
  return closest

def is_in_top_1_10_100(word, vec):
  """
  Returns three booleans depicting whether the word is among the top 1, 10, and 100
  closest ones respectively in terms of word vector distance to vec.
  """
  words_100 = get_k_closest_words(vec=vec, k=100)
  if word not in words_100:
    return (0,0,0)
  else:
    idx = words_100.index(word)
    if idx == 0:
      return (1,1,1)
    elif idx < 10:
      return (0,1,1)
    else:
      return (0,0,1)

In [None]:
test_dataset = TensorDataset(*split_dataset['test'])
test_sampler = SequentialSampler(test_dataset)                                  # Use a sequential sampler for testing since we may have to resume it after pausing
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=BATCH_SIZE)

In [None]:
def get_baseline(inputs):
  gvec = np.zeros(100)
  n = 0
  for word in inputs.split():
    if word in glove_vectors:
      gvec += glove_vectors[word]
      n += 1
  if n > 0:
    gvec /= n
  return gvec

def eval_baseline():
  total = 0
  n1 = 0
  n10 = 0
  n100 = 0
  for sample in test_dataloader:
    test_defs = bert_tokenizer.batch_decode(sequences=sample[0], skip_special_tokens=True)
    targets = sample[2]
    for i in range(targets.shape[0]):
      word = torch.from_numpy(get_baseline(test_defs[i])).to(device)
      top1, top10, top100 = is_in_top_1_10_100(targets[i], word)
      n1 += top1
      n10 += top10
      n100 += top100
      total += 1
      print(total, "done")
  print(total, n1, n10, n100)

In [None]:
def print_examples():
  batch_0 = next(iter(test_dataloader))
  batch_0_enc_def = batch_0[0].to(device)
  batch_0_attn_mask = batch_0[1].to(device)
  batch_0_targets = batch_0[2].to(device)

  outputs = bert_lstm_model(input_ids=batch_0_enc_def, attention_mask=batch_0_attn_mask)
  tests_decoded = []
  targets_decoded = []
  with torch.no_grad():
    for i in range(outputs.shape[0]):
      tests_decoded.append(get_closest_word(outputs[i].cpu()))
      targets_decoded.append(get_closest_word(batch_0_targets[i].cpu()))
      test_defs = bert_tokenizer.batch_decode(sequences=batch_0_enc_def, skip_special_tokens=True)
  print("Some examples:")
  for i in range(len(targets_decoded)):
    print("Definition: ", test_defs[i])
    print("Our model's output:", tests_decoded[i])
    print("Real word:", targets_decoded[i])

In [None]:
def get_testing_accuracies(dataloader, start=0, end=-1):
  total = 0
  n1 = 0
  n10 = 0
  n100 = 0
  for batch in dataloader:
    if start >= batch[2].shape[0]:
      start -= BATCH_SIZE
      continue
    batch_enc_def = batch[0].to(device)
    batch_attn_mask = batch[1].to(device)
    batch_targets = batch[2].to(device)
    outputs = bert_lstm_model(input_ids=batch_enc_def, attention_mask=batch_attn_mask)
    for i in range(start, outputs.shape[0]):
      actual = get_closest_word(batch_targets[i].cpu())
      top1, top10, top100 = is_in_top_1_10_100(actual, outputs[i])
      total += 1
      n1 += top1
      n10 += top10
      n100 += top100
      start = 0
      if i % 100 == 99:
        print("{} done.".format(i+1))
      if total == end:
        break
    if total == end:
      break
  p1 = 100.0 * n1 / total
  p10 = 100.0 * n10 / total
  p100 = 100.0 * n100 / total
  print("Top 1 accuracy  : 100% * {} / {} = {}%".format(n1, total, p1))
  print("Top 10 accuracy : 100% * {} / {} = {}%".format(n10, total, p10))
  print("Top 100 accuracy: 100% * {} / {} = {}%".format(n100, total, p100))
  return n1, n10, n100, total

In [None]:
def get_word_from_single_def(definition, tokenizer, use_bert_lstm=True):
  single_word_dset = [["", definition]]
  encoded = encode_data(single_word_dset, tokenizer)
  defn = encoded[0].to(device)
  mask = encoded[1].to(device)
  outputs = bert_lstm_model(input_ids=defn, attention_mask=mask)
  return [get_closest_word(outputs[0])]

def get_k_closest_words_from_single_def(definition, tokenizer, use_bert_lstm=True):
  single_word_dset = [["", definition]]
  encoded = encode_data(single_word_dset, tokenizer)
  defn = encoded[0].to(device)
  mask = encoded[1].to(device)
  outputs = bert_lstm_model(input_ids=defn, attention_mask=mask)
  return get_k_closest_words(outputs[0], k=5)

In [None]:
while True:
  with open('/content/gdrive/MyDrive/accuracy.txt') as f:
    x = f.readlines()[0].split()
    x = [int(i) for i in x]
  if (x[0] == len(test_dataset)):
    break
  n1, n10, n100, total = get_testing_accuracies(test_dataloader, x[0], 64)
  x[0] += total
  x[1] += n1
  x[2] += n10
  x[3] += n100
  with open('/content/gdrive/MyDrive/accuracy.txt', 'w+') as f:
    f.write("{} {} {} {}\n".format(*x))
  print(*x, sep=' ')
print("Test Dataset Accuracy:")
print("Top 1: {}%".format(100.0*x[1]/x[0]))
print("Top 10: {}%".format(100.0*x[2]/x[0]))
print("Top 100: {}%".format(100.0*x[3]/x[0]))

Test Dataset Accuracy:
Top 1: 48.70627199337611%
Top 10: 58.9525978058373%
Top 100: 65.61788449596357%


In [None]:
print_examples()

Some examples:
Definition:  in a mutual or shared manner
Our model's output: mutually
Real word: mutually
Definition:  morally reprehensible
Our model's output: pathetic
Real word: slimy
Definition:  a list of divisions ( chapters or articles ) and the pages on which they start
Our model's output: contents
Real word: contents
Definition:  not working properly
Our model's output: defective
Real word: bad
Definition:  not influenced or affected
Our model's output: unswayed
Real word: uninfluenced
Definition:  cause to sense ; make sensitive
Our model's output: sensitize
Real word: sensitize
Definition:  small fishes found in great schools along coasts of europe ; smaller and rounder than herring
Our model's output: pilchard
Real word: pilchard
Definition:  keep from happening or arising ; make impossible
Our model's output: preclude
Real word: prevent
Definition:  cut or tear along an irregular line so that the parts can later be matched for authentication
Our model's output: indent
Real

In [None]:
definition = "sport which uses bat and ball."
with torch.no_grad():
  print(get_word_from_single_def(definition, bert_tokenizer)[0])

bat


In [None]:
definition = ""
with torch.no_grad():
  print(get_k_closest_words_from_single_def(definition, bert_tokenizer))

['invocation', 'benediction', 'paean', 'denunciation', 'ooooooooooooooooooooooooooooooooooooooo']
