In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Not connected to a GPU


In [2]:
!pip install datasets
!python -m spacy download en_core_web_lg
import torch
import numpy as np
import spacy
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datasets import load_dataset
from nltk.tokenize import RegexpTokenizer
import time
import math
from tqdm import tqdm
import os

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.5.2-py3-none-any.whl (432 kB)
[K     |████████████████████████████████| 432 kB 5.1 MB/s 
Collecting multiprocess
  Downloading multiprocess-0.70.13-py37-none-any.whl (115 kB)
[K     |████████████████████████████████| 115 kB 51.4 MB/s 
Collecting huggingface-hub<1.0.0,>=0.2.0
  Downloading huggingface_hub-0.10.0-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 49.0 MB/s 
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 52.7 MB/s 
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1
  Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)
[K     |████████████████████████████████| 127 kB 45.3 MB/s 
Installing collected packag

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [14]:
dataset = load_dataset("europarl_bilingual", lang1="en", lang2="fr")



  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
def process_europarl_corpus(dataset):
  en_fr_sents = []
  for i,sent in enumerate(dataset["train"]["translation"]):
    en_fr_sents.append((sent["en"], sent["fr"]))
    # if i%1000==0:
    #   print(f"Currently on iteration {i}")
  return en_fr_sents

In [6]:
en_fr_sents = process_europarl_corpus(dataset)

In [7]:
en_fr_df = pd.DataFrame(en_fr_sents, columns=["en", "fr"])
en_fr_df = en_fr_df.loc[:10000]

In [8]:
en_fr_df.head()

Unnamed: 0,en,fr
0,Resumption of the session,Reprise de la session
1,I declare resumed the session of the European ...,Je déclare reprise la session du Parlement eur...
2,"Although, as you will have seen, the dreaded '...","Comme vous avez pu le constater, le grand ""bog..."
3,You have requested a debate on this subject in...,Vous avez souhaité un débat à ce sujet dans le...
4,"In the meantime, I should like to observe a mi...","En attendant, je souhaiterais, comme un certai..."


In [9]:
nlp = spacy.load("en_core_web_lg")

In [10]:
def preprocess(text, nltk_tokenizer):
  #cur_sent = " ".join(["cls"] + nltk_tokenizer.tokenize(text.lower()) + ["sep"])
  cur_sent = " ".join(["cls"] + [x for x in text.split() if x !=""] + ["sep"])
  return cur_sent

In [11]:
nltk_tokenizer = RegexpTokenizer(r"[\w\d'\s]+")
en_fr_df["en"] = en_fr_df["en"].apply(lambda x: preprocess(x, nltk_tokenizer))
en_fr_df["fr"] = en_fr_df["fr"].apply(lambda x: preprocess(x, nltk_tokenizer))

In [12]:
def generate_mappings(en_fr_df, col_name, batch_size=10000):
  sents = list(set(en_fr_df[col_name].tolist()))
  sents.sort()
  latest_token_index = 0
  sent_to_tokens = {}
  token_to_index = {}
  token_vector_list = []
  seen_tokens = set()
  cls_doc = nlp("cls", disable=["parser","ner"])
  seen_tokens.add(cls_doc[0].text)
  token_to_index[cls_doc[0].text] = latest_token_index
  latest_token_index += 1
  token_vector_list.append(torch.Tensor(cls_doc[0].vector))
  sep_doc = nlp("sep", disable=["parser","ner"])
  seen_tokens.add(sep_doc[0].text)
  token_to_index[sep_doc[0].text] = latest_token_index
  latest_token_index += 1
  token_vector_list.append(torch.Tensor(sep_doc[0].vector))
  start_time = time.time()
  num_sents = len(sents)
  num_batches = math.ceil(num_sents/batch_size)
  print(f"Number of sentences: {num_sents}")
  print(f"Number of batches: {num_batches}")
  for batch in range(1,num_batches+1):
    batch_start_index = (batch-1)*batch_size
    docs = nlp.pipe(sents[batch_start_index:min(batch*batch_size, num_sents)], disable=["parser", "ner"])
    for i,doc in enumerate(docs):
      cur_sent_tokens = []
      for token in doc:
        cur_sent_tokens.append(token)
        if token.text not in seen_tokens:
          seen_tokens.add(token.text)
          token_to_index[token.text] = latest_token_index
          latest_token_index += 1
          token_vector_list.append(torch.Tensor(token.vector))
        sent_to_tokens[sents[batch_start_index+i]] = cur_sent_tokens
    end_time = time.time()
    total_time = end_time-start_time
    print(f"Total time to finish batch {batch} is: {total_time/60} minutes")
    start_time = time.time()
  token_vectors = torch.stack(token_vector_list)
  index_to_token = {i:token for token,i in token_to_index.items()}
  return sent_to_tokens,token_to_index,index_to_token,token_vectors

In [None]:
fr_sent_to_tokens,fr_token_to_index,fr_index_to_token,fr_token_vectors = generate_mappings(en_fr_df, "fr", batch_size=100000)
en_sent_to_tokens,en_token_to_index,en_index_to_token,en_token_vectors = generate_mappings(en_fr_df, "en", batch_size=100000)

Number of sentences: 9903
Number of batches: 1
Total time to finish batch 1 is: 0.5529534180959066 minutes
Number of sentences: 9902
Number of batches: 1
Total time to finish batch 1 is: 0.3870299736658732 minutes


In [None]:
# x=[i for i in fr_sent_to_tokens.keys() if "madame" in i and "lynne" in i][0]
# print(x)
# fr_sent_to_tokens[x]

In [None]:
fr_token_vectors.size(),len(fr_token_to_index),len(fr_sent_to_tokens),en_token_vectors.size(),len(en_token_to_index),len(en_sent_to_tokens)

(torch.Size([16478, 300]), 16478, 9903, torch.Size([11651, 300]), 11651, 9902)

In [None]:
class MTDataset(Dataset):
  def __init__(self,df,max_len,fr_sent_to_tokens,fr_token_to_index,en_sent_to_tokens,en_token_to_index,fr_token_vectors,en_token_vectors):
    self.df = df
    self.max_len = max_len
    self.fr_sent_to_tokens = fr_sent_to_tokens
    self.fr_token_to_index = fr_token_to_index
    self.en_sent_to_tokens = en_sent_to_tokens
    self.en_token_to_index = en_token_to_index
    self.fr_token_vectors = fr_token_vectors
    self.en_token_vectors = en_token_vectors
  def __len__(self):
    return len(self.df)
  def __getitem__(self, idx):
    cur_en_sent = self.df.loc[idx,"en"]
    cur_fr_sent = self.df.loc[idx,"fr"]
    cur_en_tokens = self.en_sent_to_tokens[cur_en_sent]
    cur_fr_tokens = self.fr_sent_to_tokens[cur_fr_sent]
    if (len(cur_en_tokens))>self.max_len or (len(cur_fr_tokens))>self.max_len:
      raise Exception("The input or target sentence is more than max len tokens")
    inputs = torch.stack([torch.Tensor(self.en_token_vectors[self.en_token_to_index[token.text]]) for token in cur_en_tokens])
    targets = torch.stack([torch.Tensor(self.fr_token_vectors[self.fr_token_to_index[token.text]]) for token in cur_fr_tokens[:-1]])
    labels = torch.LongTensor([self.fr_token_to_index[token.text] for token in cur_fr_tokens[1:]])
    return {
        "inputs":inputs,
        "targets": targets,
        "input_seq_len": torch.LongTensor([len(inputs)]),
        "target_seq_len": torch.LongTensor([len(targets)]),
        "labels": labels
    }

In [None]:
def collate_fn(batch):
  ignore_index = -1
  inputs = nn.utils.rnn.pad_sequence([batch[i]["inputs"] for i in range(len(batch))], batch_first=True)
  input_seq_lens = torch.stack([batch[i]["input_seq_len"] for i in range(len(batch))]).squeeze()
  targets = nn.utils.rnn.pad_sequence([batch[i]["targets"] for i in range(len(batch))], batch_first=True)
  target_seq_lens = torch.stack([batch[i]["target_seq_len"] for i in range(len(batch))]).squeeze()
  labels = nn.utils.rnn.pad_sequence([batch[i]["labels"] for i in range(len(batch))], batch_first=True, padding_value=ignore_index)
  return inputs,targets,input_seq_lens,target_seq_lens,labels

In [None]:
def auto_reg_collate_fn(batch):
  max_len=512
  ignore_index = -1
  inputs = nn.utils.rnn.pad_sequence([batch[i]["inputs"] for i in range(len(batch))], batch_first=True)
  input_seq_lens = torch.stack([batch[i]["input_seq_len"] for i in range(len(batch))]).squeeze()
  targets = torch.stack([batch[i]["targets"][0] for i in range(len(batch))])
  labels = nn.utils.rnn.pad_sequence([batch[i]["labels"] for i in range(len(batch))], batch_first=True, padding_value=ignore_index)
  ignore_index_vals = torch.full((1,max_len-1-labels.size()[-1]), ignore_index).expand(len(batch), -1)
  labels = torch.cat([labels, ignore_index_vals], dim=-1)
  return inputs,targets,input_seq_lens,labels

In [None]:
max_len = 512
batch_size = 64
mtds = MTDataset(en_fr_df,max_len,fr_sent_to_tokens,fr_token_to_index,en_sent_to_tokens,en_token_to_index,fr_token_vectors,en_token_vectors)
train_samples = int(0.8*len(mtds))
valid_samples = int(0.1*len(mtds))
test_samples = len(mtds) - (train_samples+valid_samples)
train_ds,valid_ds,test_ds = random_split(mtds, [train_samples, valid_samples, test_samples])
train_dataloader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_ds, batch_size=batch_size, collate_fn=auto_reg_collate_fn)
test_dataloader = DataLoader(test_ds, batch_size=batch_size, collate_fn=auto_reg_collate_fn)

In [None]:
cur_batch = next(iter(valid_dataloader))
print(cur_batch[0].size(),cur_batch[1].size(),cur_batch[2].size(),cur_batch[3].size())

torch.Size([64, 83, 300]) torch.Size([64, 300]) torch.Size([64]) torch.Size([64, 511])


In [None]:
train_ds[0]

{'inputs': tensor([[  0.1240,  -0.0542,   0.1174,  ...,   0.0663,  -0.1964,   0.1107],
         [ -9.3107, -10.8970,   1.0395,  ...,   1.5233,   6.7447, -15.8290],
         [ -0.7692,  -1.6331,   1.6490,  ...,   0.0795,  -5.9882,   5.2371],
         ...,
         [ -1.5490,  -0.8127,  -3.4127,  ...,  -4.9873,   1.1062,   3.6366],
         [ -0.0765,  -4.6896,  -4.0431,  ...,   1.3040,  -0.5270,  -1.3622],
         [ -0.1954,  -1.7745,  -2.6836,  ...,  -2.3685,   0.7443,  -2.5710]]),
 'targets': tensor([[ 0.1240, -0.0542,  0.1174,  ...,  0.0663, -0.1964,  0.1107],
         [-0.2036, -2.4590,  0.2916,  ...,  2.6948, -0.2805,  0.3046],
         [-1.0661, -1.1376, -0.2675,  ..., -0.2280, -1.4617,  2.5464],
         ...,
         [-1.4330, -2.0019,  0.9980,  ..., -0.7067, -2.3276,  2.4603],
         [-0.9714, -0.5238,  0.5012,  ...,  1.1822,  0.1077, -0.1339],
         [-0.0765, -4.6896, -4.0431,  ...,  1.3040, -0.5270, -1.3622]]),
 'input_seq_len': tensor([16]),
 'target_seq_len': tensor([

In [None]:
class Encoder(nn.Module):
  def __init__(self, emb_dim, enc_hidden_dim, en_token_vectors, train_emb=False):
    super().__init__()
    # self.embedding_layer = nn.Embedding.from_pretrained(en_token_vectors)
    # if not train_emb:
    #   self.embedding_layer.weight.requires_grad = False
    self.GRU = nn.GRU(emb_dim, enc_hidden_dim)
    self.layernorm_layer = nn.LayerNorm(enc_hidden_dim)
  def forward(self, inputs, input_seq_lens):
    # x = self.embedding_layer(inputs)
    inputs = nn.utils.rnn.pack_padded_sequence(inputs, input_seq_lens.cpu(), batch_first=True, enforce_sorted=False)
    outputs, hidden = self.GRU(inputs)
    hidden = self.layernorm_layer(hidden)
    return outputs, hidden

In [None]:
class Decoder(nn.Module):
  def __init__(self, emb_dim, enc_hidden_dim, dec_hidden_dim, fr_token_vectors, train_emb=False):
    super().__init__()
    # self.embedding_layer = nn.Embedding.from_pretrained(fr_token_vectors)
    # if not train_emb:
    #   self.embedding_layer.weight.requires_grad = False
    self.GRU = nn.GRU(emb_dim+enc_hidden_dim, dec_hidden_dim)
    self.layernorm_layer = nn.LayerNorm(dec_hidden_dim)
    self.dense_layer = nn.Linear(dec_hidden_dim, len(fr_token_vectors))
    self.softmax_layer = nn.Softmax(dim=-1)
  def forward(self, context_vector, init_hidden_state, targets, target_seq_lens):
    # x = self.embedding_layer(targets)
    targets_with_context = torch.cat([targets, context_vector.unsqueeze(dim=1).expand(-1, targets.size()[1], -1)], dim=-1)
    #print(targets_with_context.size())
    targets_with_context = nn.utils.rnn.pack_padded_sequence(targets_with_context, target_seq_lens.cpu(), batch_first=True, enforce_sorted=False)
    # print(targets_with_context[0].size(), targets_with_context[1].size(), targets_with_context[0], targets_with_context[1])
    # print("after pack padded sequence")
    outputs, hidden = self.GRU(targets_with_context, init_hidden_state)
    # print("after gru")
    outputs, seq_lens = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
    # print(outputs.size(),hidden.size(),seq_lens)
    outputs = self.layernorm_layer(outputs)
    x = self.dense_layer(outputs)
    x = self.softmax_layer(x)
    return hidden,x

In [None]:
emb_dim = 300
enc_hidden_dim = 128
dec_hidden_dim = 128
enc = Encoder(emb_dim, enc_hidden_dim, en_token_vectors)
dec = Decoder(emb_dim, enc_hidden_dim, dec_hidden_dim, fr_token_vectors)
inputs,targets,input_seq_lens,target_seq_lens,labels = next(iter(train_dataloader))
print(inputs.size(),targets.size(),input_seq_lens.size(),target_seq_lens.size(),labels.size())
_, context = enc(inputs, input_seq_lens)
print(context.squeeze().size(), targets.size())
hidden_state,decoder_output = dec(context.squeeze(), context, targets, target_seq_lens)
decoder_output = decoder_output.permute(0,-1,1)
print(hidden_state.size(), decoder_output.size())

torch.Size([64, 72, 300]) torch.Size([64, 73, 300]) torch.Size([64]) torch.Size([64]) torch.Size([64, 73])
torch.Size([64, 128]) torch.Size([64, 73, 300])
torch.Size([1, 64, 128]) torch.Size([64, 16478, 73])


In [None]:
emb_dim = 300
enc_hidden_dim = 128
dec_hidden_dim = 128
enc = Encoder(emb_dim, enc_hidden_dim, en_token_vectors)
dec = Decoder(emb_dim, enc_hidden_dim, dec_hidden_dim, fr_token_vectors)
enc.to(device)
dec.to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
encoder_optimizer = optim.Adam(enc.parameters(), lr=1e-3)
decoder_optimizer = optim.Adam(dec.parameters(), lr=1e-3)
print(enc.parameters)
print(dec.parameters)

<bound method Module.parameters of Encoder(
  (GRU): GRU(300, 128)
  (layernorm_layer): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)>
<bound method Module.parameters of Decoder(
  (GRU): GRU(428, 128)
  (layernorm_layer): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (dense_layer): Linear(in_features=128, out_features=16478, bias=True)
  (softmax_layer): Softmax(dim=-1)
)>


In [None]:
if os.path.exists("/content/sample_data/log.txt"):
    os.remove("/content/sample_data/log.txt")
teacher_forcing = True
epochs = 10
print(f"number of batches: {len(train_ds)/batch_size}")
train_losses = []
valid_losses = []
min_valid_epoch_loss = float("inf")
for epoch_num in range(epochs):
  print(f"EPOCH {epoch_num}")
  running_loss = 0
  loss = 0
  batch_num = 0
  epoch_start_time = time.time()
  batch_start_time = time.time()
  cur_batch_start_time = time.time()
  enc.train()
  dec.train()
  for batch_num,batch in enumerate(tqdm(train_dataloader)):
    inputs,targets,input_seq_lens,target_seq_lens,labels = next(iter(train_dataloader))
    inputs = inputs.to(device)
    targets = targets.to(device)
    input_seq_lens = input_seq_lens.to(device)
    target_seq_lens = target_seq_lens.to(device)
    labels = labels.to(device)
    _, context = enc(inputs, input_seq_lens)
    hidden_state,decoder_output = dec(context.squeeze(dim=0), context, targets, target_seq_lens)
    decoder_output = decoder_output.permute(0,-1,1)
    loss = loss_fn(decoder_output, labels)
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    if batch_num%500==0:
        batch_end_time = time.time()
        batch_total_time = batch_end_time-batch_start_time
        # print(f"BATCH_END,epoch:{epoch_num},batch:{batch_num},loss:{loss},time:{batch_total_time/60}")
        # print(decoder_output.size(),labels.size())
        with open("/content/sample_data/log.txt", "a+") as f:
          f.write(f"BATCH_END,epoch:{epoch_num},batch:{batch_num},loss:{loss},time:{batch_total_time/60}\n")
        batch_start_time = time.time()
    #print(f"Finished batch {batch_num} in time {(time.time()-cur_batch_start_time)/60} minutes")
    cur_batch_start_time = time.time()
    running_loss += loss
  train_epoch_loss = running_loss/len(train_dataloader)
  train_losses.append(train_epoch_loss)
  epoch_end_time = time.time()
  epoch_total_time = epoch_end_time-epoch_start_time
  print(f"EPOCH_END,epoch:{epoch_num},loss:{train_epoch_loss},time:{epoch_total_time/60}")
  with open("/content/sample_data/log.txt", "a+") as f:
    f.write(f"EPOCH_END,epoch:{epoch_num},loss:{train_epoch_loss},time:{epoch_total_time/60}\n")
  with torch.no_grad():
    enc.eval()
    dec.eval()
    print("Running Validation...")
    running_loss = 0
    epoch_start_time = time.time()
    for batch_num,batch in enumerate(tqdm(valid_dataloader)):
      inputs,targets,input_seq_lens,labels = batch
      inputs = inputs.to(device)
      targets = targets.to(device)
      input_seq_lens = input_seq_lens.to(device)
      labels = labels.to(device)
      _, context = enc(inputs, input_seq_lens)
      decoder_final_outputs = []
      for decoder_start_index in range(len(targets)):
        num_decoder_output_tokens = 1
        cur_predicted_token = None
        cur_context_vector = context[:,decoder_start_index,:].unsqueeze(dim=1)
        decoder_start_vector = targets[decoder_start_index].unsqueeze(0).unsqueeze(0)
        cur_input_vector = decoder_start_vector
        prev_hidden_state = cur_context_vector
        decoder_sequence_outputs = []
        while cur_predicted_token!="sep" and num_decoder_output_tokens<max_len:
          hidden_state,decoder_output = dec(cur_context_vector.squeeze(dim=0), prev_hidden_state, cur_input_vector, torch.LongTensor([1]))
          cur_predicted_token_index = int(torch.argmax(decoder_output[0,0,:]))
          cur_predicted_token = fr_index_to_token[cur_predicted_token_index]
          prev_hidden_state = hidden_state
          cur_input_vector = fr_token_vectors[cur_predicted_token_index].unsqueeze(0).unsqueeze(0)
          cur_input_vector.to(device)
          num_decoder_output_tokens += 1
          decoder_sequence_outputs.append(decoder_output)
        decoder_sequence_output = torch.cat(decoder_sequence_outputs, dim=1)
        pad_matrix = torch.full((1,max_len-num_decoder_output_tokens,decoder_sequence_output.size()[-1]), 0)
        decoder_sequence_output = torch.cat([decoder_sequence_output,pad_matrix], dim=1)
        decoder_final_outputs.append(decoder_sequence_output)
      decoder_final_output = torch.cat(decoder_final_outputs, dim=0)
      decoder_final_output = decoder_final_output.permute(0,-1,1)
      running_loss += loss_fn(decoder_final_output, labels)
    valid_epoch_loss = running_loss/len(valid_dataloader)
    valid_losses.append(valid_epoch_loss)
    if valid_epoch_loss < min_valid_epoch_loss:
      min_valid_epoch_loss = valid_epoch_loss
      print(f"Saving model after training epoch {epoch_num} with loss {min_valid_epoch_loss}...")
      with open("/content/sample_data/log.txt", "a+") as f:
        f.write(f"Saving model after training epoch {epoch_num} with loss {min_valid_epoch_loss}...\n")
      torch.save(enc.state_dict(), f"/content/sample_data/best_enc_{epoch_num}_{min_valid_epoch_loss}.bin")
      torch.save(dec.state_dict(), f"/content/sample_data/best_dec_{epoch_num}_{min_valid_epoch_loss}.bin")
    epoch_end_time = time.time()
    epoch_total_time = epoch_end_time-epoch_start_time
    print(f"VALID_EPOCH_END,loss:{valid_epoch_loss},time:{epoch_total_time/60}")
    with open("/content/sample_data/log.txt", "a+") as f:
      f.write(f"VALID_EPOCH_END,loss:{valid_epoch_loss},time:{epoch_total_time/60}\n")

number of batches: 125.0
EPOCH 0


 90%|████████▉ | 112/125 [12:58<01:29,  6.87s/it]

In [None]:
print(len(valid_ds), len(valid_dataloader))
with torch.no_grad():
    enc.eval()
    dec.eval()
    for batch_num,batch in enumerate(tqdm(valid_dataloader)):
      inputs,targets,input_seq_lens,labels = batch
      inputs = inputs.to(device)
      targets = targets.to(device)
      input_seq_lens = input_seq_lens.to(device)
      labels = labels.to(device)
      _, context = enc(inputs, input_seq_lens)
      decoder_final_outputs = []
      for decoder_start_index in range(len(targets)):
        num_decoder_output_tokens = 1
        cur_predicted_token = None
        cur_context_vector = context[:,decoder_start_index,:].unsqueeze(dim=1)
        decoder_start_vector = targets[decoder_start_index].unsqueeze(0).unsqueeze(0)
        cur_input_vector = decoder_start_vector
        prev_hidden_state = cur_context_vector
        decoder_sequence_outputs = []
        while cur_predicted_token!="sep" and num_decoder_output_tokens<max_len:
          hidden_state,decoder_output = dec(cur_context_vector.squeeze(dim=0), prev_hidden_state, cur_input_vector, torch.LongTensor([1]))
          cur_predicted_token_index = int(torch.argmax(decoder_output[0,0,:]))
          cur_predicted_token = fr_index_to_token[cur_predicted_token_index]
          prev_hidden_state = hidden_state
          cur_input_vector = fr_token_vectors[cur_predicted_token_index].unsqueeze(0).unsqueeze(0)
          cur_input_vector.to(device)
          num_decoder_output_tokens += 1
          decoder_sequence_outputs.append(decoder_output)
        decoder_sequence_output = torch.cat(decoder_sequence_outputs, dim=1)
        pad_matrix = torch.full((1,max_len-num_decoder_output_tokens,decoder_sequence_output.size()[-1]), 0)
        decoder_sequence_output = torch.cat([decoder_sequence_output,pad_matrix], dim=1)
        decoder_final_outputs.append(decoder_sequence_output)
      decoder_final_output = torch.cat(decoder_final_outputs, dim=0)
      print(decoder_final_output.size(), labels.size())

100 2


 50%|█████     | 1/2 [00:21<00:21, 21.91s/it]

torch.Size([64, 511, 4648]) torch.Size([64, 511])


100%|██████████| 2/2 [00:34<00:00, 17.00s/it]

torch.Size([36, 511, 4648]) torch.Size([36, 511])





In [None]:
teacher_forcing = True
epochs = 10
batch_size = 128
print(f"number of batches: {len(train_ds)/batch_size}")
for epoch in range(epochs):
  encoder_optimizer.zero_grad()
  decoder_optimizer.zero_grad()
  running_loss = 0
  loss = 0
  batch_num = 0
  epoch_start_time = time.time()
  batch_start_time = time.time()
  cur_batch_start_time = time.time()
  for i in range(len(train_ds)):
    inputs = train_ds[i]['inputs']
    targets = train_ds[i]['targets']
    labels = train_ds[i]['labels']
    inputs = inputs.to(device)
    targets = targets.to(device)
    labels = labels.to(device)
    #print(inputs, target_ids, labels)
    #print("===========")
    _, context = enc(inputs)
    hidden_state,decoder_output = dec(context, context, targets)
    #print("=======================")
    #print(decoder_outputs.size(), labels.size())
    #print(target_ids, torch.argmax(decoder_outputs, dim=-1),  labels)
    #print("=======================")
    #print(decoder_outputs.size(), labels.size())
    # if i % 1000==0:
    #   print(torch.gather(decoder_outputs, -1, labels.unsqueeze(dim=-1)))
    loss += loss_fn(decoder_output, labels)
    #loss.backward()
    # print("========BEFORE STEP")
    # print(enc.GRU.weight_ih_l0[280,50], enc.GRU.weight_ih_l0.grad[280,50])
    # print(dec.GRU.weight_ih_l0[280,50], dec.GRU.weight_ih_l0.grad[280,50])
    # if i%1000==0:
    #   print(hidden_state)
    # max_grad = torch.max(dec.dense_layer.weight.grad)
    # max_grad_pos = (dec.dense_layer.weight.grad==max_grad).nonzero()
    # max_grad_x = int(max_grad_pos[0][0])
    # max_grad_y = int(max_grad_pos[0][1])
    # print(dec.dense_layer.weight[max_grad_x,max_grad_y],dec.dense_layer.weight.grad[max_grad_x,max_grad_y], max_grad, max_grad_pos[0])
    # encoder_optimizer.step()
    # decoder_optimizer.step()
    # print("========AFTER STEP")
    # print(enc.GRU.weight_ih_l0[280,50], enc.GRU.weight_ih_l0.grad[280,50])
    # print(dec.GRU.weight_ih_l0[280,50], dec.GRU.weight_ih_l0.grad[280,50])
    # print(dec.dense_layer.weight[max_grad_x,max_grad_y],dec.dense_layer.weight.grad[max_grad_x,max_grad_y], max_grad, max_grad_pos[0])
    if (i+1) % batch_size == 0:
      running_loss += loss
      loss.backward()
      # print("========BEFORE STEP")
      # print(enc.GRU.weight_ih_l0[280,50], enc.GRU.weight_ih_l0.grad[280,50])
      # print(dec.GRU.weight_ih_l0[280,50], dec.GRU.weight_ih_l0.grad[280,50])
      encoder_optimizer.step()
      decoder_optimizer.step()
      # print("========AFTER STEP")
      # print(enc.GRU.weight_ih_l0[280,50], enc.GRU.weight_ih_l0.grad[280,50])
      # print(dec.GRU.weight_ih_l0[280,50], dec.GRU.weight_ih_l0.grad[280,50])
      encoder_optimizer.zero_grad()
      decoder_optimizer.zero_grad()
      if batch_num%100==0:
        batch_end_time = time.time()
        batch_total_time = batch_end_time-batch_start_time
        print(f"End of batch {batch_num}, current loss is {loss/batch_size}, total time taken: {batch_total_time/60}")
      # print(f"Finished batch {batch_num} in time {(time.time()-cur_batch_start_time)/60} minutes")
      cur_batch_start_time = time.time()
      loss = 0
      batch_num +=1
  running_loss += loss
  loss.backward()
  encoder_optimizer.step()
  decoder_optimizer.step()
  epoch_end_time = time.time()
  epoch_total_time = epoch_end_time-epoch_start_time
  print(f"End of epoch {epoch}, current loss is {running_loss/len(train_ds)}, total time taken: {epoch_total_time/60}")

number of batches: 1875.0
End of batch 0, current loss is 11.211583137512207, total time taken: 0.00957711140314738
End of batch 100, current loss is 11.079082489013672, total time taken: 0.7673563559850057
End of batch 200, current loss is 11.04337215423584, total time taken: 1.5181862950325011


KeyboardInterrupt: ignored

In [None]:
teacher_forcing = True
epochs = 10
batch_size = 256
print(f"number of batches: {len(train_ds)/batch_size}")
for epoch in range(epochs):
  encoder_optimizer.zero_grad()
  decoder_optimizer.zero_grad()
  running_loss = 0
  loss = 0
  batch_num = 0
  epoch_start_time = time.time()
  batch_start_time = time.time()
  cur_batch_start_time = time.time()
  for i in range(len(train_ds)):
    inputs = train_ds[i]['input_ids']
    target_ids = train_ds[i]['target_ids']
    labels = train_ds[i]['labels']
    inputs = inputs.to(device)
    target_ids = target_ids.to(device)
    labels = labels.to(device)
    #print(inputs, target_ids, labels)
    #print("===========")
    _, context = enc(inputs)
    prev_hidden_state = context
    prev_label = None
    decoder_outputs = []
    for j in range(len(target_ids)):
      cur_label = torch.Tensor(target_ids[j])
      hidden_state,decoder_output = dec(context, prev_hidden_state, cur_label)
      prev_hidden_state = hidden_state
      prev_label = torch.argmax(decoder_output)
      decoder_outputs.append(decoder_output)
    decoder_outputs = torch.cat(decoder_outputs, dim=0)
    #print("=======================")
    #print(decoder_outputs.size(), labels.size())
    #print(target_ids, torch.argmax(decoder_outputs, dim=-1),  labels)
    #print("=======================")
    #print(decoder_outputs.size(), labels.size())
    # if i % 1000==0:
    #   print(torch.gather(decoder_outputs, -1, labels.unsqueeze(dim=-1)))
    loss += loss_fn(decoder_outputs, labels)
    #loss.backward()
    # print("========BEFORE STEP")
    # print(enc.GRU.weight_ih_l0[280,50], enc.GRU.weight_ih_l0.grad[280,50])
    # print(dec.GRU.weight_ih_l0[280,50], dec.GRU.weight_ih_l0.grad[280,50])
    # if i%1000==0:
    #   print(hidden_state)
    # max_grad = torch.max(dec.dense_layer.weight.grad)
    # max_grad_pos = (dec.dense_layer.weight.grad==max_grad).nonzero()
    # max_grad_x = int(max_grad_pos[0][0])
    # max_grad_y = int(max_grad_pos[0][1])
    # print(dec.dense_layer.weight[max_grad_x,max_grad_y],dec.dense_layer.weight.grad[max_grad_x,max_grad_y], max_grad, max_grad_pos[0])
    # encoder_optimizer.step()
    # decoder_optimizer.step()
    # print("========AFTER STEP")
    # print(enc.GRU.weight_ih_l0[280,50], enc.GRU.weight_ih_l0.grad[280,50])
    # print(dec.GRU.weight_ih_l0[280,50], dec.GRU.weight_ih_l0.grad[280,50])
    # print(dec.dense_layer.weight[max_grad_x,max_grad_y],dec.dense_layer.weight.grad[max_grad_x,max_grad_y], max_grad, max_grad_pos[0])
    if (i+1) % batch_size == 0:
      running_loss += loss
      loss.backward()
      # print("========BEFORE STEP")
      # print(enc.GRU.weight_ih_l0[280,50], enc.GRU.weight_ih_l0.grad[280,50])
      # print(dec.GRU.weight_ih_l0[280,50], dec.GRU.weight_ih_l0.grad[280,50])
      encoder_optimizer.step()
      decoder_optimizer.step()
      # print("========AFTER STEP")
      # print(enc.GRU.weight_ih_l0[280,50], enc.GRU.weight_ih_l0.grad[280,50])
      # print(dec.GRU.weight_ih_l0[280,50], dec.GRU.weight_ih_l0.grad[280,50])
      encoder_optimizer.zero_grad()
      decoder_optimizer.zero_grad()
      if batch_num%100==0:
        batch_end_time = time.time()
        batch_total_time = batch_end_time-batch_start_time
        print(f"End of batch {batch_num}, current loss is {loss/batch_size}, total time taken: {batch_total_time/60}")
      print(f"Finished batch {batch_num} in time {(time.time()-cur_batch_start_time)/60} minutes")
      cur_batch_start_time = time.time()
      loss = 0
      batch_num +=1
  running_loss += loss
  loss.backward()
  encoder_optimizer.step()
  decoder_optimizer.step()
  epoch_end_time = time.time()
  epoch_total_time = epoch_end_time-epoch_start_time
  print(f"End of epoch {epoch}, current loss is {running_loss/len(train_ds)}, total time taken: {epoch_total_time/60}")

number of batches: 937.5
End of batch 0, current loss is 11.209644317626953, total time taken: 0.16027901967366537
Finished batch 0 in time 0.1602869470914205 minutes
Finished batch 1 in time 0.1357125719388326 minutes
Finished batch 2 in time 0.1436716914176941 minutes
Finished batch 3 in time 0.14378207127253215 minutes
Finished batch 4 in time 0.15283559958140056 minutes
Finished batch 5 in time 0.14501917362213135 minutes


KeyboardInterrupt: ignored

In [None]:
class LuongAttnDecoder(nn.Module):
  def __init__(self, emb_dim, enc_hidden_dim, dec_hidden_dim, fr_token_vectors, train_emb=False):
    super().__init__()
    self.embedding_layer = nn.Embedding.from_pretrained(fr_token_vectors)
    if not train_emb:
      self.embedding_layer.weight.requires_grad = False
    self.GRU = nn.GRU(emb_dim, dec_hidden_dim)
    self.dense_layer = nn.Linear(dec_hidden_dim+enc_hidden_dim, len(fr_token_vectors))
    self.attn = nn.Linear(enc_hidden_dim, enc_hidden_dim)
    self.softmax_layer = nn.Softmax(dim=-1)
  def forward(self, context_vectors, init_hidden_state, targets):
    x = self.embedding_layer(targets)
    _, hidden = self.GRU(x.unsqueeze(dim=0), init_hidden_state)
    context_vector_lin_out = self.attn(context_vectors)
    alignment_scores = context_vector_lin_out @ hidden.T
    context_vector = alignment_scores.T @ context_vectors
    x = torch.tanh(self.dense_layer(torch.cat([context_vector, hidden], dim=1)))
    x = self.softmax_layer(x)
    return hidden,x

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self,emb_dim,n_heads):
    super().__init__()
    self.emb_dim = emb_dim
    self.n_heads = n_heads
    self.heads = []
    self.head_size = self.emb_dim/self.n_heads
    for i in torch.range(self.n_heads):
      cur_heads = [nn.Linear(self.emb_dim,self.head_size)]*3
      self.heads.append(cur_heads)
    self.softmax_layer = nn.Softmax(dim=-1)
    self.output_linear_layer = nn.Linear(emb_dim,emb_dim)
  def forward(self, Q, K, V, attn_mask):
    head_outputs = []
    for head in self.heads:
      Q_attn = self.head[0](Q)
      K_attn = self.head[1](K)
      V_attn = self.head[2](V)
      attn_scores = self.softmax_layer(((Q_attn @ K_attn.T) + attn_mask)/torch.sqrt(Q.size()[-1]))
      head_output = attn_scores @ V_attn
      head_outputs.append(head_output)
    head_result = torch.cat(head_outputs, dim=-1)
    x = self.output_linear_layer(head_result)
    return x

In [None]:
class PosEncoding(nn.Module):
  def __init__(self, n, d):
    super().__init__()
    self.n = n
    self.d = d
  def pos_encoding_denom(self,i):
    return self.n**(2*i/self.d)
  def forward(self, inputs):
    x_k = torch.arange(inputs.size()[0]).unsqueeze(dim=-1)
    x_k = x_k.expand(-1,self.emb_dim)
    x_i = torch.arange(inputs.size()[1]).unsqueeze(dim=0)
    x_i = torch.Tensor(list(map(self.pos_encoding_denom, torch.arange(inputs.size()[1])))).unsqueeze(dim=0)
    x_i = x_i.expand(inputs.size()[0],-1)
    x_pos = x_k/x_i
    even_indices = torch.arange(0,inputs.size()[1], step=2)
    odd_indices = torch.arange(1,inputs.size()[1], step=2)
    x_pos[:,even_indices] = torch.sin(x_pos[:,even_indices])
    x_pos[:,odd_indices] = torch.cos(x_pos[:,odd_indices])
    return inputs+x_pos

In [None]:
class AttnEncoder(nn.Module):
  def __init__(self, emb_dim):
    super().__init__()
    self.emb_dim = emb_dim
    self.pos_encoding = PosEncoding(10000,emb_dim)
    self.multi_head_attn = nn.MultiheadAttention(emb_dim, 3)
    self.layer_norm1 = nn.LayerNorm(emb_dim)
    self.linear_layer = nn.Linear(emb_dim, emb_dim)
    self.relu = nn.ReLU()
    self.layer_norm2 = nn.LayerNorm(emb_dim)
  def forward(self, inputs):
    inputs = self.pos_encoding(inputs)
    x = self.multi_head_attn(inputs, inputs, inputs)[0]
    x = self.layer_norm1(inputs+x)
    linear_output = self.linear_layer(x)
    feed_forward_output = self.relu(linear_output)
    x = self.layer_norm2(feed_forward_output+x)
    return x

In [None]:
class AttnDecoder(nn.Module):
  def __init__(self, emb_dim, encoder_output):
    self.emb_dim = emb_dim
    self.encoder_output = encoder_output
    self.pos_encoding = PosEncoding(10000,emb_dim)
    self.masked_multi_head_attn = nn.MultiheadAttention(emb_dim, 3)
    self.layer_norm1 = nn.LayerNorm(emb_dim)
    self.encoder_decoder_attn = nn.MultiheadAttention(emb_dim, 3)
    self.layer_norm2 = nn.LayerNorm(emb_dim)
    self.linear_layer1 = nn.Linear(emb_dim, emb_dim)
    self.relu = nn.ReLU()
    self.layer_norm3 = nn.LayerNorm(emb_dim)
    self.linear_layer2 = nn.Linear(emb_dim)
  def forward(self, inputs, attn_mask, timestep):
    inputs = self.pos_encoding(inputs)
    
    

SyntaxError: ignored

In [None]:
attn_enc = AttnEncoder(300)
x = attn_enc(torch.randn(10,300))
print(x.size())

torch.Size([10, 300])


In [None]:
emb_dim = 300
enc_hidden_dim = 128
dec_hidden_dim = 128
enc = Encoder(emb_dim, enc_hidden_dim, en_token_vectors)
dec = LuongAttnDecoder(emb_dim, enc_hidden_dim, dec_hidden_dim, fr_token_vectors)
for i in range(1):
    inputs = train_ds[i]['input_ids']
    target_ids = train_ds[i]['target_ids']
    labels = train_ds[i]['labels']
    inputs = inputs.to(device)
    target_ids = target_ids.to(device)
    labels = labels.to(device)
    outputs, context = enc(inputs)
    prev_hidden_state = context
    prev_label = None
    decoder_outputs = []
    for j in range(1):
      cur_label = torch.Tensor(target_ids[j])
      hidden_state,decoder_output = dec(outputs, prev_hidden_state, cur_label)
      print(decoder_output.size())

torch.Size([1, 896])
