In [None]:
import os
import pathlib

# Deep Learning
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision


# General
from tqdm import tqdm
import pandas as pd
import numpy as np

# Data visualization
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

# Device Setting
if torch.cuda.is_available():
  device = torch.device("cuda:0")
  print("GPU")
else:
  device = torch.device("cpu")
  print("CPU")

GPU


In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
!pip install datasets tokenizers

# Huggingface
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from transformers import BertTokenizer
from datasets import load_dataset


Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/510.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━[0m [32m307.2/510.5 kB[0m [31m8.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.7

# Bert Tokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

# Configuration

In [None]:
def get_config(num_epochs, vocab_size):
  return{
      "train_batch_size": 8,
      "test_batch_size": 1,
      "num_epochs": num_epochs,
      "lr": 3e-4,
      "lang_src": 'en',
      "lang_tgt": 'fr',
      "seq_len": 256,
      "d_model": 256,
      "h": 2,
      "depth" : 1,
      "dropout": 0.2,
      "num_classes": vocab_size
      }

config = get_config(num_epochs=10, vocab_size=50265)


# Dataset

In [None]:
raw_dataset = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')

train_size = int(0.2 * len(raw_dataset))
val_size = len(raw_dataset) - train_size
train, val = random_split(raw_dataset, [train_size, val_size])



In [None]:

class BilingualDataset(Dataset):
  def __init__(self, data, src_tokenizer, tgt_tokenizer, config):
    super().__init__()

    self.data = data
    self.src_tokenizer = src_tokenizer
    self.tgt_tokenizer = tgt_tokenizer

    self.src_lang = config['lang_src']
    self.tgt_lang = config['lang_tgt']

    self.seq_len = config['seq_len']
    self.h = config['h']

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    pairs = self.data[idx]

    source = pairs['translation'][self.src_lang]
    target = pairs['translation'][self.tgt_lang]

    h = self.h
    seq_len = self.seq_len

    # tokenized_source
    tokenized_source = self.src_tokenizer(source,
                                     max_length = self.seq_len,
                                     padding='max_length',
                                     truncation = True
                                     )

    # source ids
    source_ids = torch.tensor(tokenized_source['input_ids'], dtype= torch.long)

    # source masks
    source_masks = torch.tensor(tokenized_source['attention_mask'], dtype= torch.long).unsqueeze(0)
    source_masks = source_masks.repeat(1, h, 1)
    source_masks = source_masks.expand(seq_len, h, seq_len)
    source_masks = source_masks.transpose(0,1).contiguous()

    # tokenized_target
    tokenized_target = self.tgt_tokenizer(target,
                                     max_length = self.seq_len,
                                     padding='max_length',
                                     truncation = True
                                     )

    # target ids
    target_ids = torch.tensor(tokenized_target['input_ids'], dtype= torch.long)

    # target_masks
    target_masks = torch.tensor(tokenized_target['attention_mask'], dtype= torch.long).unsqueeze(0)
    target_masks = target_masks.repeat(1, h, 1)
    target_masks = target_masks.expand(seq_len, h, seq_len)
    target_masks = target_masks.transpose(0,1).contiguous().type(torch.int)
    # target_masks must be causally masked
    target_masks = target_masks & torch.tril(torch.ones(h, seq_len, seq_len)).type(torch.int)

    item = {
          'source_txt': source,
          'target_txt': target,
          'source_ids': source_ids,
          'source_masks': source_masks,
          'target_ids': target_ids,
          'target_masks': target_masks
    }

    return item

In [None]:
config = get_config(1, vocab_size)

src_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tgt_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

#  def __init__(self, data, src_tokenizer, tgt_tokenizer, config):

train_loader = DataLoader(
    BilingualDataset(train, src_tokenizer, tgt_tokenizer, config),
    batch_size = config['train_batch_size'],
    shuffle=True,
    pin_memory = True
)


val_loader = DataLoader(
    BilingualDataset(val, src_tokenizer, tgt_tokenizer, config),
    batch_size = config['test_batch_size'],
    shuffle=True,
    pin_memory = True
)

dataiter = iter(train_loader)
batch = next(dataiter)

for idx in range(len(batch['source_ids'])):
  print(f"source text: {batch['source_txt'][idx]}")
  print(f"target text: {batch['target_txt'][idx]}")
  print(f"source ids: {batch['source_ids'][idx]}")
  print(f"source masks: {batch['source_masks'].size()}")
  print(f"target ids: {batch['target_ids'][idx]}")
  print(f"target masks: {batch['target_masks'].size()}")


dataiter = iter(val_loader)
batch = next(dataiter)

for idx in range(len(batch['source_ids'])):
  print(f"source text: {batch['source_txt'][idx]}")
  print(f"target text: {batch['target_txt'][idx]}")
  print(f"source ids: {batch['source_ids'][idx]}")
  print(f"source masks: {batch['source_masks'].size()}")
  print(f"target ids: {batch['target_ids'][idx]}")
  print(f"target masks: {batch['target_masks'].size()}")


source text: 'In reality, that is my husband,' she said to herself; 'if I return in sincerity to the standards of prudence, it is obviously he that I ought to marry.'
target text: Au vrai, c’est là mon mari, se dit-elle ; si je reviens de bonne foi aux idées de sagesse, c’est évidemment lui que je dois épouser.
source ids: tensor([  101,  1005,  1999,  4507,  1010,  2008,  2003,  2026,  3129,  1010,
         1005,  2016,  2056,  2000,  2841,  1025,  1005,  2065,  1045,  2709,
         1999, 23997,  2000,  1996,  4781,  1997, 10975, 29424,  1010,  2009,
         2003,  5525,  2002,  2008,  1045, 11276,  2000,  5914,  1012,  1005,
          102,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,  

In [None]:
import math

class InputEmbedding(nn.Module):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.vocab_size = vocab_size
    self.d_model = d_model
    # nn.Embedding is a dictionary kind of a layer that just maps number to the vector every time and this vector is learned by the model
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    # According to "Attention is all you need", in the embedding layers, we need to multiply those weights by sqrt(emb_dim)
    return self.embedding(x) * (self.d_model ** 0.5)


class PositionalEmbedding(nn.Module):
  def __init__(self, d_model, seq_len):
    super().__init__()

    # pe = torch.zeros(seq_len, d_model)
    pe = torch.zeros(seq_len, d_model)

    pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model))

    pe[:,0::2] = torch.sin(pos*div)
    pe[:,1::2] = torch.cos(pos*div)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)

  def forward(self, x):
    # x = x + self.pe.requires_grad_(False)
    x = x + self.pe[:, :x.shape[1],:].requires_grad_(False)
    # return x.unsqueeze(0)
    return x


In [None]:

input_embedding = InputEmbedding(tokenizer.vocab_size, config['d_model'])
positional_encoding = PositionalEmbedding(config['d_model'], config['seq_len'])

train_ds = BilingualDataset(train, src_tokenizer, tgt_tokenizer, config)
print(train_ds.__getitem__(1)['source_txt'])
print(train_ds.__getitem__(1)['target_txt'])
print(train_ds.__getitem__(1)['source_ids'].size())

src_enc = input_embedding(train_ds.__getitem__(1)['source_ids'])
print(f"\nSource after embedded size: {src_enc.size()}")
src_pe = positional_encoding(src_enc)
print(f"Source after positional embedded size: {src_pe.size()}\n")


tgt_enc = input_embedding(train_ds.__getitem__(1)['source_ids'])
print(f"Target after embedded size: {tgt_enc.size()}")
tgt_pe = positional_encoding(tgt_enc)
print(f"Target after positional embedded size: {tgt_pe.size()}")


"Is he living with you?"
-- Est-ce qu'il demeure avec vous?
torch.Size([256])

Source after embedded size: torch.Size([256, 256])
Source after positional embedded size: torch.Size([1, 256, 256])

Target after embedded size: torch.Size([256, 256])
Target after positional embedded size: torch.Size([1, 256, 256])


In [None]:
class FeedForward(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.d_model = config['d_model']
    self.d_ff = self.d_model * 4

    self.ffn = nn.Sequential(
        nn.Linear(self.d_model, self.d_ff),
        nn.ReLU(),
        nn.Linear(self.d_ff, self.d_model),
        nn.Dropout(config['dropout'])
    )

  def forward(self, x):
    return self.ffn(x)


In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, config):
    super(MultiHeadAttention, self).__init__()

    self.d_model = config['d_model']
    self.h = config['h']

    assert self.d_model % self.h == 0, "d_model must be divisible by h"


    self.d_k = self.d_model // self.h
    # Query, key, value
    self.W_q = nn.Linear(self.d_model, self.d_model, bias=False)
    self.W_k = nn.Linear(self.d_model, self.d_model, bias=False)
    self.W_v = nn.Linear(self.d_model, self.d_model, bias=False)

    # Last Layer
    self.W_o = nn.Linear(self.d_model, self.d_model, bias=False)

    # dropout
    self.dropout = nn.Dropout(config['dropout'])

  def forward(self, q, k, v, mask):

    h = self.h
    d_k = self.d_k
    d_model = self.d_model

    query = self.W_q(q)
    key = self.W_k(k)
    value = self.W_v(v)

    # It's necessary since decode process will be (target, src_output, src_output, mask)
    # which means, q_seq_len and k_seq_len will be different.
    q_B, q_seq_len, _ = query.size()
    k_B, k_seq_len, _ = key.size()
    v_B, v_seq_len, _ = value.size()
    # size: (Batch, Seq_len, d_model) -> (Batch, Seq_len, h, d_model // h)
    query = query.view(q_B, q_seq_len, h, d_k)
    key = key.view(k_B, k_seq_len, h, d_k)
    value = value.view(v_B, v_seq_len, h, d_k)

    # (Batch, Seq_len, h, d_k) -> (Batch, h, Seq_len, d_k)
    # (Batch, h, Seq_len, d_k) -> (Batch * h, Seq_len, d_k)
    query = query.transpose(1,2).contiguous().view(q_B * h, q_seq_len, d_k)
    key = key.transpose(1,2).contiguous().view(k_B * h, k_seq_len, d_k)
    value = value.transpose(1,2).contiguous().view(v_B * h, v_seq_len, d_k)

    # Attention: W
    # paying attention to each sequences, therefore size should be (Batch *h, Seq_len, Seq_len)
    W = query @ key.transpose(1,2)
    W = W / (d_model ** 0.5)

    # If there is a mask, make masked spots -INF
    # seq_len must be equal to query's sequence length.
    if mask is not None:
      mask = mask.view(k_B * h, k_seq_len, k_seq_len) # (B, h, Seq_len, Seq_len) => (B * h, Seq_len, Seq_len)
      if q_seq_len != k_seq_len:
        mask = mask[:,:q_seq_len,:]
      W = W.masked_fill_(mask == 0, float('-inf'))

    W = W.softmax(dim = -1)
    # drop out
    W = self.dropout(W)

    out = W @ value # (B * h, seq_len, d_k)
    B, Seq_len, d_k = out.size()
    B = B // h
    out = out.view(B, h, Seq_len, d_k)
    out = out.transpose(1,2).contiguous().view(B, Seq_len, h * d_k)
    return self.W_o(out)


In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, config):
    super(EncoderBlock, self).__init__()

    self.d_model = config['d_model']

    self.MultiHeadAttention = MultiHeadAttention(config)

    self.ln_1 = nn.LayerNorm(self.d_model)
    self.ln_2 = nn.LayerNorm(self.d_model)
    self.FeedForward = FeedForward(config)

  def forward(self, x, src_mask):
    x = x + self.MultiHeadAttention(x, x, x, src_mask)
    x = self.ln_1(x)
    x = x + self.FeedForward(x)
    x = self.ln_2(x)
    return x


In [None]:
class Encoder(nn.Module):
  def __init__(self, config):
    super(Encoder, self).__init__()

    self.depth = config['depth']
    # Encoder: blocks of encoder blocks
    self.blocks = nn.ModuleList([
        EncoderBlock(config) for _ in range(self.depth)
    ])
    self.blocks = nn.Sequential(*self.blocks)

  def forward(self, x, src_mask):
    for block in self.blocks:
      x = block(x, src_mask)
    return x




In [None]:
class DecoderBlock(nn.Module):
  def __init__(self, config):
    super(DecoderBlock, self).__init__()

    self.d_model = config['d_model']

    self.SelfHeadAttention = MultiHeadAttention(config)
    self.CrossHeadAttention = MultiHeadAttention(config)

    self.ln_1 = nn.LayerNorm(self.d_model)
    self.ln_2 = nn.LayerNorm(self.d_model)
    self.ln_3 = nn.LayerNorm(self.d_model)

    self.FeedForward = FeedForward(config)

  def forward(self, x, encoder_out, src_mask, tgt_mask):
    # x: target, in our case positively-toned comment
    x = x + self.SelfHeadAttention(x, x, x, tgt_mask)
    x = self.ln_1(x)
    x = x + self.CrossHeadAttention(x, encoder_out, encoder_out, src_mask)
    x = self.ln_2(x)
    x = x + self.FeedForward(x)
    x = self.ln_3(x)
    return x


In [None]:
class Decoder(nn.Module):
  def __init__(self, config):
    super(Decoder, self).__init__()

    self.depth = config['depth']

    self.blocks = nn.ModuleList([
        DecoderBlock(config) for _ in range(self.depth)
    ])
    self.blocks = nn.Sequential(*self.blocks)

  def forward(self, x, encoder_out, src_mask, tgt_mask):
    for block in self.blocks:
      x = block(x, encoder_out, src_mask, tgt_mask)
    return x

In [None]:
class Transformer(nn.Module):
  def __init__(self, config, vocab_size):
    super(Transformer, self).__init__()

    self.encoder = Encoder(config)
    self.decoder = Decoder(config)

    self.d_model = config['d_model']
    self.seq_len = config['seq_len']

    # Input Embedding for source and target
    self.src_embedding = InputEmbedding(vocab_size, self.d_model)
    self.tgt_embedding = InputEmbedding(vocab_size, self.d_model)

    # Positional Embedding for source and target
    self.src_pos_embedding = PositionalEmbedding(self.d_model, self.seq_len)
    self.tgt_pos_embedding = PositionalEmbedding(self.d_model, self.seq_len)

    self.norm = nn.LayerNorm(self.d_model)
    self.projection = nn.Linear(self.d_model, vocab_size)

  def encode(self, source, src_mask):
    source = self.src_embedding(source)
    source = self.src_pos_embedding(source)
    return self.encoder(source, src_mask)

  def decode(self, target, encoder_out, src_mask, tgt_mask):
    target = self.tgt_embedding(target)
    target = self.tgt_pos_embedding(target)
    return self.decoder(target, encoder_out, src_mask, tgt_mask)

  def forward(self, decoder_out):
    out = self.norm(decoder_out)
    out = self.projection(out)
    return torch.log_softmax(out, dim=-1)


In [None]:


# path for saving model
path = "./Translator"
pathlib.Path(f"./{path}/").mkdir(parents=True, exist_ok=True)


# hyperparameters
lr = config['lr']
num_epochs = 1

# configuration
vocab_size = tokenizer.vocab_size
config = get_config(1, vocab_size)

# model, optim, loss (cross entropy loss)
model = Transformer(config, vocab_size).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
ce_loss = nn.CrossEntropyLoss(ignore_index = 0, label_smoothing=0.1).to(device)

# training
for epoch in range(num_epochs):
  torch.cuda.empty_cache()
  model.train()
  batch_iterator = tqdm(train_loader, desc = f'Processing epoch {epoch:02d}')
  for batch in batch_iterator:
    source = batch['source_ids'].to(device)
    target = batch['target_ids'].to(device)
    src_mask = batch['source_masks'].to(device)
    tgt_mask = batch['target_masks'].to(device)
    B, seq_len = source.size()

    # forward pass
    # def encode(self, source, src_mask):
    encoder_out = model.encode(source, None)
    # def decode(self, target, encoder_out, src_mask, tgt_mask):
    decoder_out = model.decode(target, encoder_out, None, None)
    out = model.forward(decoder_out) # size:(Batch, Seq_len, tgt_vocab_size)

    out = out.view(B*seq_len, vocab_size)
    target = target.view(B*seq_len)

    loss = ce_loss(out, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    batch_iterator.set_postfix(loss=f"{loss.item():6.3f}")
  torch.save(model.state_dict(), f'{path}/{epoch}.pth')

torch.save(model.state_dict(), f'{path}/final_model.pth')



Transformer(
  (encoder): Encoder(
    (blocks): Sequential(
      (0): EncoderBlock(
        (MultiHeadAttention): MultiHeadAttention(
          (W_q): Linear(in_features=256, out_features=256, bias=False)
          (W_k): Linear(in_features=256, out_features=256, bias=False)
          (W_v): Linear(in_features=256, out_features=256, bias=False)
          (W_o): Linear(in_features=256, out_features=256, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (FeedForward): FeedForward(
          (ffn): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): ReLU()
            (2): Linear(in_features=1024, out_features=256, bias=True)
            (3): Dropout(p=0.2, inplace=False)
          )
        )
      )
    )
  )
  (decoder): Decoder(
    (blocks): Sequential(
      (0): De

Processing epoch 00: 100%|██████████| 3178/3178 [04:52<00:00, 10.85it/s, loss=1.374]


In [None]:
#validation
with torch.no_grad():
  model.eval()
  batch_iterator = tqdm(val_loader)

  correct = 0
  total = 0
  for idx, batch in enumerate(batch_iterator, 0):
    source = batch['source_ids'].to(device)
    target = batch['target_ids'].to(device)
    src_mask = batch['source_masks'].to(device)
    tgt_mask = batch['target_masks'].to(device)

    B, seq_len = source.size()

    # forward pass
    # def encode(self, source, src_mask):
    encoder_out = model.encode(source, src_mask)
    # def decode(self, target, encoder_out, src_mask, tgt_mask):
    decoder_out = model.decode(target, encoder_out, src_mask, tgt_mask)
    out = model.forward(decoder_out)
    out = out.view(B*seq_len, vocab_size)

    target = target.view(B*seq_len)

    loss = ce_loss(out, target)

    pred = torch.max(out, dim=-1).indices

    total += target.shape[0]
    correct += sum(pred == target)
    batch_iterator.set_postfix(loss=f"{loss.item():6.3f}")

    if idx % 1000 ==0:
      txt_pred = tokenizer.decode(pred).replace("[PAD]", "")
      print(f"\nOriginal comments: {batch['source_txt']}\n")
      print(f"Generated comments: {txt_pred}\n")
      print(f"Target comments: {batch['target_txt']}\n")

print(f'\nAccuracy: {100 * (correct / total)}%')

  0%|          | 17/25417 [00:00<05:09, 82.07it/s, loss=0.000]


Original comments: ['"By no means.']

Generated comments: [CLS] – surtout pas! [SEP]                                                                                                                                                                                                                                                        

Target comments: ['– Surtout pas !']



  4%|▍         | 1013/25417 [00:12<04:50, 84.10it/s, loss=0.000]


Original comments: ['A pretty woman, on my word! and who must needs love me madly to have taken me in that fashion. By the way," said he, rising suddenly, with that sentiment of the true which formed the foundation of his character and his philosophy, "I don\'t know very well how it happens, but I am her husband!"']

Generated comments: [CLS] mon mauvais genie! mon bon ange! – une jolie femme, sur ma parole! – et qui doit m ’ aimer a la folie pour m ’ avoir pris de la sorte. – a propos, dit - il en se levant tout a coup avec ce sentiment du vrai qui faisait le fond de son caractere et de sa philosophie, je ne sais trop comment cela se fait, mais je suis son mari! » [SEP]                                                                                                                                              

Target comments: ['Mon mauvais génie ! mon bon ange ! – Une jolie femme, sur ma parole ! – et qui doit m’aimer à la folie pour m’avoir pris de la sorte. – À propos, dit-il en s

  8%|▊         | 2011/25417 [00:25<04:38, 84.06it/s, loss=0.000]


Original comments: ['Some Brahmins, clad in all the sumptuousness of Oriental apparel, and leading a woman who faltered at every step, followed.']

Generated comments: [CLS] derriere eux, quelques brahmanes, dans toute la somptuosite de leur costume oriental, trainaient une femme qui se soutenait a peine. [SEP]                                                                                                                                                                                                                   

Target comments: ['Derrière eux, quelques brahmanes, dans toute la somptuosité de leur costume oriental, traînaient une femme qui se soutenait à peine.']



 12%|█▏        | 3014/25417 [00:37<04:21, 85.57it/s, loss=0.000]


Original comments: ['We may take it, therefore, that the letter was composed by an educated man who wished to pose as an uneducated one, and his effort to conceal his own writing suggests that that writing might be known, or come to be known, by you.']

Generated comments: [CLS] nous pouvons donc deduire que ce message a ete compose par un individu instruit qui voulait passer pour un homme du peuple : et le fait qu ’ il a voulu deguiser sa propre ecriture suggere que cette ecriture pouvait vous etre connue, ou vous devenir connue. [SEP]                                                                                                                                                                   

Target comments: ['Nous pouvons donc déduire que ce message a été composé par un individu instruit qui voulait passer pour un homme du peuple : et le fait qu’il a voulu déguiser sa propre écriture suggère que cette écriture pouvait vous être connue, ou vous devenir connue.']



 16%|█▌        | 4012/25417 [00:50<04:13, 84.44it/s, loss=0.000]


Original comments: ['He was not better dressed than of old, for I well knew the old brown suit that he wore.']

Generated comments: [CLS] ce netait pas qu'il fut mieux habille que jadis, car je reconnus le vieux costume brun qu'il portait. [SEP]                                                                                                                                                                                                                         

Target comments: ["Ce n\x92était pas qu'il fût mieux habillé que jadis, car je reconnus le vieux costume brun qu'il portait."]



 20%|█▉        | 5008/25417 [01:02<04:12, 80.85it/s, loss=0.000]


Original comments: ["'This is a bold and healthy mind,' he said to himself, 'but _corpus debile_ (a frail body)."]

Generated comments: [CLS] voila un esprit hardi et sain, se disait - il, mais corpus debile ( le corps est faible ). [SEP]                                                                                                                                                                                                                             

Target comments: ['Voilà un esprit hardi et sain, se disait-il, mais corpus debile (le corps est faible).']



 24%|██▎       | 6016/25417 [01:15<03:52, 83.30it/s, loss=0.000]


Original comments: ["'Brother,' says she, 'you may come if you please.'"]

Generated comments: [CLS] - - mon frere, dit - elle, tu peux rentrer s'il te plait. [SEP]                                                                                                                                                                                                                                      

Target comments: ["--Mon frère, dit-elle, tu peux rentrer s'il te plaît."]



 28%|██▊       | 7010/25417 [01:28<03:52, 79.09it/s, loss=0.000]


Original comments: ['She stooped to open the small iron gate, and she made haste to inspect anxiously the lonely spot.']

Generated comments: [CLS] elle ouvrit, en se penchant, une petite grille, et se hata d ’ inspecter avec inquietude le lieu solitaire. [SEP]                                                                                                                                                                                                                            

Target comments: ['Elle ouvrit, en se penchant, une petite grille, et se hâta d’inspecter avec inquiétude le lieu solitaire.']



 32%|███▏      | 8010/25417 [01:40<03:23, 85.69it/s, loss=0.007]


Original comments: ['I roused, and interested you, because I was so unlike _them_. Had you not been really amiable, you would have hated me for it; but in spite of the pains you took to disguise yourself, your feelings were always noble and just; and in your heart, you thoroughly despised the persons who so assiduously courted you.']

Generated comments: [CLS] vous etiez fatigue de ces femmes qui ne faisaient rien que pour obtenir votre approbation. c ’ est parce que je leur ressemblais si peu que j ’ ai eveille votre interet. [SEP]                                                                                                                                                                                                      

Target comments: ['Vous étiez fatigué de ces femmes qui ne faisaient rien que pour obtenir votre approbation. C’est parce que je leur ressemblais si peu que j’ai éveillé votre intéret.']



 35%|███▌      | 9013/25417 [01:52<03:19, 82.31it/s, loss=0.000]


Original comments: ['The link of the chain, forced open by him in circumstances, alas, so different, had not been mended.']

Generated comments: [CLS] le chainon, jadis force par lui en des circonstances, helas! si differentes, n ’ avait point ete raccommode. [SEP]                                                                                                                                                                                                                          

Target comments: ['Le chaînon, jadis forcé par lui en des circonstances, hélas ! si différentes, n’avait point été raccommodé.']



 39%|███▉      | 10012/25417 [02:05<03:03, 84.17it/s, loss=0.000]


Original comments: ['Doubtless he was accustomed to such reproaches, for he listened to me calm and smiling, with his arms crossed over his breast. Then, when he thought I had said all, he advanced toward me; I sprang toward the table, I seized a knife, I placed it to my breast.']

Generated comments: [CLS] « tout ce que le coeur d'une femme peut contenir de superbe mepris et de paroles dedaigneuses, je le versai sur cet homme ; sans doute, il etait habitue a de pareils reproches ; car il m'ecouta calme, souriant, et les bras croises sur la poitrine ; puis, lorsqu'il crut que j'avais tout dit, il s'avanca vers moi ; je bondis vers la table, je saisis un couteau, je l'appuyai sur ma poitrine. [SEP]                                                                                                                

Target comments: ["«Tout ce que le coeur d'une femme peut contenir de superbe mépris et de paroles dédaigneuses, je le versai sur cet homme; sans doute, il était habitué à de pare

 43%|████▎     | 11010/25417 [02:18<02:54, 82.55it/s, loss=0.000]


Original comments: ['It was remarkable, too, I had but three subjects, and they were of three different religions—my man Friday was a Protestant, his father was a Pagan and a cannibal, and the Spaniard was a Papist.']

Generated comments: [CLS] chose surtout remarquable! je n'avais que trois sujets et ils etaient de trois religions differentes : mon homme vendredi etait protestant, son pere etait idolatre et cannibale, et l'espagnol etait papiste. [SEP]                                                                                                                                                                                            

Target comments: ["Chose surtout remarquable! je n'avais que trois sujets et ils étaient de trois religions différentes: Mon homme Vendredi était protestant, son père était idolâtre et cannibale, et l'Espagnol était papiste."]



 47%|████▋     | 12012/25417 [02:31<02:46, 80.69it/s, loss=0.000]


Original comments: ['She replied in her cheerful way, without blushing:']

Generated comments: [CLS] elle repondit de son air gai, sans rougeur : [SEP]                                                                                                                                                                                                                                                

Target comments: ['Elle répondit de son air gai, sans rougeur:']



 51%|█████     | 13014/25417 [02:43<02:29, 83.02it/s, loss=0.000]


Original comments: ["Like certain flocks of birds, whose speed they equal, these tuna swim in triangle formation, which prompted the ancients to say they'd boned up on geometry and military strategy."]

Generated comments: [CLS] ils nageaient en triangle, comme certaines troupes d'oiseaux dont ils egalaient la rapidite, ce qui faisait dire aux anciens que la geometrie et la strategie leur etaient familieres. [SEP]                                                                                                                                                                                                 

Target comments: ["Ils nageaient en triangle, comme certaines troupes d'oiseaux dont ils égalaient la rapidité, ce qui faisait dire aux anciens que la géométrie et la stratégie leur étaient familières."]



 55%|█████▌    | 14012/25417 [02:56<02:25, 78.35it/s, loss=0.000]


Original comments: ['Why?']

Generated comments: [CLS] dans quel but? [SEP]                                                                                                                                                                                                                                                         

Target comments: ['Dans quel but?']



 59%|█████▉    | 15011/25417 [03:08<02:01, 85.55it/s, loss=0.000]


Original comments: ['Mais ils étaient toujours sous la juridiction des Saints et ils ne tarderent pas a en avoir la preuve.']

Generated comments: [CLS] they soon had a proof, however, that they were still within the rifle of the saints. [SEP]                                                                                                                                                                                                                                           

Target comments: ['They soon had a proof, however, that they were still within the jurisdiction of the Saints.']



 63%|██████▎   | 16015/25417 [03:21<01:53, 82.62it/s, loss=0.000]


Original comments: ["A contract which binds me without putting you under any obligation is unfair, I must decline.'"]

Generated comments: [CLS] un engagement qui me lie sans vous obliger a rien n ’ est point egal, je le refuse. [SEP]                                                                                                                                                                                                                                     

Target comments: ['Un engagement qui me lie sans vous obliger à rien n’est point égal, je le refuse.']



 67%|██████▋   | 17009/25417 [03:34<01:41, 82.62it/s, loss=0.000]


Original comments: ['"It is the custom in war," said d’Artagnan, "why should it not be the custom in a duel?"']

Generated comments: [CLS] - - c'est l'habitude a la guerre, dit d'artagnan ; pourquoi ne serait - ce pas l'habitude dans un duel? [SEP]                                                                                                                                                                                                                        

Target comments: ["-- C'est l'habitude à la guerre, dit d'Artagnan; pourquoi ne serait-ce pas l'habitude dans un duel?"]



 71%|███████   | 18010/25417 [03:46<01:33, 78.84it/s, loss=0.000]


Original comments: ['It was impossible, therefore, to return every day to the Chimneys, and it was agreed that the little colony should camp under a hut of branches, so that the important operation could be followed night and day.']

Generated comments: [CLS] il ne fallait donc pas songer a revenir chaque jour aux cheminees, et il fut convenu que la petite colonie camperait sous une hutte de branchages, de maniere que l'importante operation fut suivie nuit et jour. [SEP]                                                                                                                                                                                             

Target comments: ["Il ne fallait donc pas songer à revenir chaque jour aux Cheminées, et il fut convenu que la petite colonie camperait sous une hutte de branchages, de manière que l'importante opération fût suivie nuit et jour."]



 75%|███████▍  | 19014/25417 [03:59<01:21, 78.12it/s, loss=0.000]


Original comments: ['Never had he been so madly in love.']

Generated comments: [CLS] jamais il n ’ avait ete aussi fou d ’ amour. [SEP]                                                                                                                                                                                                                                             

Target comments: ['Jamais il n’avait été aussi fou d’amour.']



 79%|███████▊  | 20011/25417 [04:12<01:05, 82.98it/s, loss=0.000]


Original comments: ['"Then I was jealous of Jean," thought he.']

Generated comments: [CLS] - - donc j'ai ete jaloux de jean, pensait - il. [SEP]                                                                                                                                                                                                                                         

Target comments: ["--Donc j'ai été jaloux de Jean, pensait-il."]



 83%|████████▎ | 21015/25417 [04:24<00:54, 81.15it/s, loss=0.000]


Original comments: ['"Oh, when I said I was alone," said Milady, hoping to make the novice talk by talking of herself, "it is not for want of friends in high places; but these friends themselves tremble before the cardinal. The queen herself does not dare to oppose the terrible minister.']

Generated comments: [CLS] - - oh! quand j'ai dit que j'etais seule, dit milady, esperant faire parler la novice en parlant d'elle - meme, ce n'est pas faute d'avoir aussi quelques connaissances haut placees ; mais ces connaissances tremblent elles - memes devant le cardinal : la reine elle - meme n'ose pas soutenir contre le terrible ministre ; j'ai la preuve que sa majeste, malgre son excellent coeur, a plus d'une fois ete obligee d'abandonner a la colere de son eminence les personnes qui l'avaient servie. [SEP]                                                                                       

Target comments: ["-- Oh! quand j'ai dit que j'étais seule, dit Milady, espérant faire parler la nov

 87%|████████▋ | 22014/25417 [04:37<00:39, 85.31it/s, loss=0.000]


Original comments: ['"I am as innocent as you are; and I can prove it."']

Generated comments: [CLS] je suis aussi innocent que vous et je puis le prouver. [SEP]                                                                                                                                                                                                                                            

Target comments: ['Je suis aussi innocent que vous et je puis le prouver.']



 91%|█████████ | 23017/25417 [04:49<00:29, 82.48it/s, loss=0.000]


Original comments: ['"Need I go further where every word is an agony?"']

Generated comments: [CLS] ai - je besoin d'en dire davantage, quand chaque mot est un supplice pour moi? [SEP]                                                                                                                                                                                                                                  

Target comments: ["Ai-je besoin d'en dire davantage, quand chaque mot est un supplice pour moi?"]



 94%|█████████▍| 24015/25417 [05:02<00:17, 82.02it/s, loss=0.000]


Original comments: ["The following morning at nine o'clock, when Julien came down from his prison to enter the great hall of the Law Courts, it was with the utmost difficulty that the gendarmes succeeded in clearing a passage through the immense crowd that packed the courtyard."]

Generated comments: [CLS] le lendemain a neuf heures, quand julien descendit de sa prison pour aller dans la grande salle du palais de justice, ce fut avec beaucoup de peine que les gendarmes parvinrent a ecarter la foule immense entassee dans la cour. [SEP]                                                                                                                                                                                             

Target comments: ['Le lendemain à neuf heures, quand Julien descendit de sa prison pour aller dans la grande salle du Palais de Justice, ce fut avec beaucoup de peine que les gendarmes parvinrent à écarter la foule immense entassée dans la cour.']



 98%|█████████▊| 25014/25417 [05:14<00:04, 84.70it/s, loss=0.000]


Original comments: ['Thus, we never see the true state of our condition till it is illustrated to us by its contraries, nor know how to value what we enjoy, but by the want of it.']

Generated comments: [CLS] ainsi nous ne voyons jamais le veritable etat de notre position avant qu'il n'ait ete rendu evident par des fortunes contraires, et nous n'apprecions nos jouissances qu'apres que nous les avons perdues. [SEP]                                                                                                                                                                                            

Target comments: ["Ainsi nous ne voyons jamais le véritable état de notre position avant qu'il n'ait été rendu évident par des fortunes contraires, et nous n'apprécions nos jouissances qu'après que nous les avons perdues."]



100%|██████████| 25417/25417 [05:20<00:00, 79.39it/s, loss=0.000]



Accuracy: 99.97456359863281%


In [None]:
dataiter = iter(val_loader)
batch = next(dataiter)
source = batch['source_ids'].to(device)
src_mask = batch['source_masks'].to(device)
encoder_out = model.encode(source, None)
src_txt = batch['source_txt']
print(batch['source_txt'])
print(batch['target_txt'])


['"What?"']
['-- Laquelle?']


In [None]:
batch = next(dataiter)

target = batch['target_ids'].to(device)
tgt_mask = batch['target_masks'].to(device)
print(batch['source_txt'])
print(batch['target_txt'])

B, seq_len = source.size()

# forward pass
# def encode(self, source, src_mask):

# def decode(self, target, encoder_out, src_mask, tgt_mask):
decoder_out = model.decode(target, encoder_out, None, None)
out = model.forward(decoder_out)
out = out.view(B*seq_len, vocab_size)

target = target.view(B*seq_len)

loss = ce_loss(out, target)

pred = torch.max(out, dim=-1).indices

total += target.shape[0]
correct += sum(pred == target)
batch_iterator.set_postfix(loss=f"{loss.item():6.3f}")

txt_pred = tokenizer.decode(pred).replace("[PAD]", "")
print(f"\nOriginal comments: {src_txt}\n")
print(f"Generated comments: {txt_pred}\n")
print(f"Target comments: {batch['target_txt']}\n")

['For a moment he stood panting, wiping his forehead, calming the bounds of his heart.']
["Il resta un instant encore, haletant, a s'essuyer le front, a calmer les bonds de son coeur."]

Original comments: ['"What?"']

Generated comments: [CLS] il resta un instant encore, haletant, a s'essuyer le front, a calmer les bonds de son coeur. [SEP] herb occupationpie herb herb herb herb occupationpie herb herb occupation occupation herb occupation herb herb herb herbpie herb occupation occupationpie occupation occupation herb herb herb occupation herb herb occupation herb herb herb herb herbpie herb herb herb herb herb occupation herb herb herbpie herb herb occupation herbast herb herb occupation herb herb herb herb herb herb herbpiepie herb occupation herbpieast herb herbpie herb herb herb occupationpie herb herb herb herb herb occupation herb herb herb occupation herb occupation herb herb herb herb herb occupation herb herbpie herb herb occupation herb herb occupation occupation occupation 

## INPUT EMBEDDING + POS EMBEDDING DIMENSION CHECK

In [None]:
import math

d_model = 256
seq_len = 256
d_k = d_model // h

# When generating target...
x = torch.empty(1,1).fill_(101).type_as(source).to(device)
print(f"source size: {x.size()}\n")

embedding = nn.Embedding(vocab_size, d_model).to(device)
x = embedding(x) * (d_model ** 0.5)
print(f"after embedding: {x.size()}\n")

pe = torch.zeros(seq_len, d_model).to(device)
pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model))
pe[:,0::2] = torch.sin(pos*div)
pe[:,1::2] = torch.cos(pos*div)
pe = pe.unsqueeze(0)

x = x + pe[:, :x.shape[1],:]
print(f"after positional embedding: {x.size()}\n")

# when training...

print("TRAINING PROCESS")
dataiter = iter(train_loader)
batch = next(dataiter)

x = batch['source_ids'].to(device)
print(f"source size: {x.size()}\n")

embedding = nn.Embedding(vocab_size, d_model).to(device)
x = embedding(x) * (d_model ** 0.5)
print(f"after embedding: {x.size()}\n")

pe = torch.zeros(seq_len, d_model).to(device)
pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model))
pe[:,0::2] = torch.sin(pos*div)
pe[:,1::2] = torch.cos(pos*div)
pe = pe.unsqueeze(0)

x = x + pe[:, :x.shape[1],:]
print(f"after positional embedding: {x.size()}\n")


# when validating...

print("VALIDATING PROCESS")
dataiter = iter(val_loader)
batch = next(dataiter)

x = batch['source_ids'].to(device)
print(f"source size: {x.size()}\n")

embedding = nn.Embedding(vocab_size, d_model).to(device)
x = embedding(x) * (d_model ** 0.5)
print(f"after embedding: {x.size()}\n")

pe = torch.zeros(seq_len, d_model).to(device)
pos = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model))
pe[:,0::2] = torch.sin(pos*div)
pe[:,1::2] = torch.cos(pos*div)
pe = pe.unsqueeze(0)

x = x + pe[:, :x.shape[1],:]
print(f"after positional embedding: {x.size()}\n")



source size: torch.Size([1, 1])

after embedding: torch.Size([1, 1, 256])

after positional embedding: torch.Size([1, 1, 256])

TRAINING PROCESS
source size: torch.Size([4, 256])

after embedding: torch.Size([4, 256, 256])

after positional embedding: torch.Size([4, 256, 256])

VALIDATING PROCESS
source size: torch.Size([1, 256])

after embedding: torch.Size([1, 256, 256])

after positional embedding: torch.Size([1, 256, 256])



In [None]:
d_model = 256
seq_len = 256
d_k = d_model // h

ln1 = nn.LayerNorm(d_model).to(device)

x = torch.empty(1,5).fill_(101).type_as(source).to(device)
print(f"source size: {x.size()}\n")


# Positional Embedding for source and target
tgt_embedding = InputEmbedding(vocab_size, d_model).to(device)
tgt_pos_embedding = PositionalEmbedding(d_model, seq_len).to(device)


x = tgt_embedding(x)
print(f"after input embedding: {x.size()}\n")
x = tgt_pos_embedding(x)
print(f"after positional embedding: {x.size()}\n")


SHA_W_q = nn.Linear(d_model, d_model, bias=False).to(device)
SHA_W_k = nn.Linear(d_model, d_model, bias=False).to(device)
SHA_W_v = nn.Linear(d_model, d_model, bias=False).to(device)
SHA_W_o = nn.Linear(d_model, d_model, bias=False).to(device)

query = SHA_W_q(x)
key = SHA_W_k(x)
value = SHA_W_v(x)
print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")

q_B, q_seq_len, _ = query.size()
k_B, k_seq_len, _ = key.size()
v_B, v_seq_len, _ = value.size()

query = query.view(q_B, q_seq_len, h, d_k)
key = key.view(k_B, k_seq_len, h, d_k)
value = value.view(v_B, v_seq_len, h, d_k)
print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")

# size: (Batch, Seq_len, h, d_model // h) -> (B*h, seq_len, d_model//h)
query = query.transpose(1,2).contiguous().view(q_B * h, q_seq_len, d_k)
key = key.transpose(1,2).contiguous().view(k_B * h, k_seq_len, d_k)
value = value.transpose(1,2).contiguous().view(v_B * h, v_seq_len, d_k)

print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")


W = query @ key.transpose(1, 2) / math.sqrt(d_k)
# print(f"W before masking:\n {W}\n")
print(f"W before masking size: {W.size()}\n")

decoder_mask = torch.tril(torch.ones(h, x.size(1), x.size(1))).unsqueeze(0).type(torch.int).to(device)
print(f"Decoder mask size: {decoder_mask.size()}\n")
decoder_mask = decoder_mask.view(q_B * h, q_seq_len, q_seq_len)
print(f"Mask:\n {decoder_mask}\n")
print(f"Decoder mask size: {decoder_mask.size()}\n")
W = W.masked_fill_(decoder_mask == 0, -1e9)
print(f"W after masking:\n {W}\n")


W = W.softmax(dim = -1)
out = W @ value
print(f"out size: {out.size()}\n")
B, seq_len, d_k = out.size()
B = B //h
out = out.view(B, h, seq_len, d_k)
out = out.transpose(1, 2).contiguous().view(B, seq_len, h * d_k)
print(f"out size: {out.size()}\n")

sha_out = SHA_W_o(out)
print(f"Self Head Attention out size: {sha_out.size()}\n")

x = x + sha_out
x = ln1(x)
print(f"x size: {x.size()}\n")

source size: torch.Size([1, 5])

after input embedding: torch.Size([1, 5, 256])

after positional embedding: torch.Size([1, 5, 256])

query size: torch.Size([1, 5, 256])

key size: torch.Size([1, 5, 256])

value size: torch.Size([1, 5, 256])

query size: torch.Size([1, 5, 2, 128])

key size: torch.Size([1, 5, 2, 128])

value size: torch.Size([1, 5, 2, 128])

query size: torch.Size([2, 5, 128])

key size: torch.Size([2, 5, 128])

value size: torch.Size([2, 5, 128])

W before masking size: torch.Size([2, 5, 5])

Decoder mask size: torch.Size([1, 2, 5, 5])

Mask:
 tensor([[[1, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 1, 0],
         [1, 1, 1, 1, 1]],

        [[1, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 1, 0],
         [1, 1, 1, 1, 1]]], device='cuda:0', dtype=torch.int32)

Decoder mask size: torch.Size([2, 5, 5])

W after masking:
 tensor([[[ 5.4661e+01, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+

## CHECK CROSS HEAD ATTENTION

In [None]:

# Source id size: torch.Size([1, 256])
# Source mask size: torch.Size([1, 4, 256, 256])

# Decoder input size: torch.Size([1, 1])
# Encoder output size: torch.Size([1, 256, 256])

ln2 = nn.LayerNorm(d_model).to(device)

# Cross-Head Attention
# self.CrossHeadAttention(x, encoder_out, encoder_out, src_mask)
CHA_W_q = nn.Linear(d_model, d_model, bias=False).to(device)
CHA_W_k = nn.Linear(d_model, d_model, bias=False).to(device)
CHA_W_v = nn.Linear(d_model, d_model, bias=False).to(device)
CHA_W_o = nn.Linear(d_model, d_model, bias=False).to(device)

query = CHA_W_q(x)
key = CHA_W_k(encoder_out)
value = CHA_W_v(encoder_out)

print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")

q_B, q_seq_len, _ = query.size()
k_B, k_seq_len, _ = key.size()
v_B, v_seq_len, _ = value.size()
# size: (Batch, Seq_len, d_model) -> (Batch, Seq_len, h, d_model // h)
query = query.view(q_B, q_seq_len, h, d_k)
key = key.view(k_B, k_seq_len, h, d_k)
value = value.view(v_B, v_seq_len, h, d_k)
print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")

# size: (Batch, Seq_len, h, d_model // h) -> (B*h, seq_len, d_model//h)
query = query.transpose(1,2).contiguous().view(q_B * h, q_seq_len, d_k)
key = key.transpose(1,2).contiguous().view(k_B * h, k_seq_len, d_k)
value = value.transpose(1,2).contiguous().view(v_B * h, v_seq_len, d_k)

print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")

W = query @ key.transpose(1, 2) / math.sqrt(d_k)
print(f"W before masking size: {W.size()}\n")

src_mask = source_mask.view(k_B * h, k_seq_len, k_seq_len)
src_mask = src_mask[:,:q_seq_len,:]
print(f"Mask:\n {src_mask}\n")
print(f"Source mask size: {src_mask[:,:q_seq_len,:].size()}\n")
W = W.masked_fill_(src_mask == 0, -1e9)
print(f"W after masking size:\n {W.size()}\n")

W = W.softmax(dim = -1)
out = W @ value
print(f"out size: {out.size()}\n")
B, seq_len, d_k = out.size()
B = B //h
out = out.view(B, h, seq_len, d_k)
out = out.transpose(1, 2).contiguous().view(B, seq_len, h * d_k)
print(f"out size: {out.size()}\n")

cha_out = CHA_W_o(out)
print(f"Cross Head Attention out size: {cha_out.size()}\n")

x = x + cha_out
x = ln2(x)
print(f"x size: {x.size()}\n")

# # x = self.ln_2(x)
# # x = x + self.FeedForward(x)
# # x = self.ln_3(x)

query size: torch.Size([1, 2, 256])

key size: torch.Size([1, 256, 256])

value size: torch.Size([1, 256, 256])

query size: torch.Size([1, 2, 2, 128])

key size: torch.Size([1, 256, 2, 128])

value size: torch.Size([1, 256, 2, 128])

query size: torch.Size([2, 2, 128])

key size: torch.Size([2, 256, 128])

value size: torch.Size([2, 256, 128])

W before masking size: torch.Size([2, 2, 256])

Mask:
 tensor([[[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]],

        [[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]], device='cuda:0')

Source mask size: torch.Size([2, 2, 256])

W after masking size:
 torch.Size([2, 2, 256])

out size: torch.Size([2, 2, 128])

out size: torch.Size([1, 2, 256])

Cross Head Attention out size: torch.Size([1, 2, 256])

x size: torch.Size([1, 2, 256])



# FEEDFORWARD

In [None]:
ff = FeedForward(config).to(device)
ln3 = nn.LayerNorm(d_model).to(device)

x = x + ff(x)
x = ln3(x)
print(f"x size: {x.size()}\n")


x size: torch.Size([1, 2, 256])



#CHECK FOR TRAINING

In [None]:

print("TRAINING PROCESS")
dataiter = iter(train_loader)
batch = next(dataiter)

d_model = 256
seq_len = 256
d_k = d_model // h

ln1 = nn.LayerNorm(d_model).to(device)

source = batch['source_ids'].to(device)
target = batch['target_ids'].to(device)
src_mask = batch['source_masks'].to(device)
tgt_mask = batch['target_masks'].to(device)

# Positional Embedding for source and target
tgt_embedding = InputEmbedding(vocab_size, d_model).to(device)
tgt_pos_embedding = PositionalEmbedding(d_model, seq_len).to(device)


x = tgt_embedding(target)
print(f"after input embedding: {x.size()}\n")
x = tgt_pos_embedding(x)
print(f"after positional embedding: {x.size()}\n")


SHA_W_q = nn.Linear(d_model, d_model, bias=False).to(device)
SHA_W_k = nn.Linear(d_model, d_model, bias=False).to(device)
SHA_W_v = nn.Linear(d_model, d_model, bias=False).to(device)
SHA_W_o = nn.Linear(d_model, d_model, bias=False).to(device)

query = SHA_W_q(x)
key = SHA_W_k(x)
value = SHA_W_v(x)
print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")

q_B, q_seq_len, _ = query.size()
k_B, k_seq_len, _ = key.size()
v_B, v_seq_len, _ = value.size()

query = query.view(q_B, q_seq_len, h, d_k)
key = key.view(k_B, k_seq_len, h, d_k)
value = value.view(v_B, v_seq_len, h, d_k)
print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")

# size: (Batch, Seq_len, h, d_model // h) -> (B*h, seq_len, d_model//h)
query = query.transpose(1,2).contiguous().view(q_B * h, q_seq_len, d_k)
key = key.transpose(1,2).contiguous().view(k_B * h, k_seq_len, d_k)
value = value.transpose(1,2).contiguous().view(v_B * h, v_seq_len, d_k)

print(f"query size: {query.size()}\n")
print(f"key size: {key.size()}\n")
print(f"value size: {value.size()}\n")


W = query @ key.transpose(1, 2) / math.sqrt(d_k)
print(f"W before masking:\n {W}\n")
print(f"W size: {W.size()}\n")

print(f"Mask:\n {tgt_mask}\n")
print(f"Decoder mask size: {tgt_mask.size()}\n")
tgt_mask = tgt_mask.view(q_B* h, k_seq_len, k_seq_len) # (B, h, Seq_len, Seq_len) => (B * h, Seq_len, Seq_len)

W = W.masked_fill_(tgt_mask == 0, -1e9)
print(f"W after masking:\n {W}\n")


W = W.softmax(dim = -1)
out = W @ value
print(f"out size: {out.size()}\n")
B, seq_len, d_k = out.size()
B = B //h
out = out.view(B, h, seq_len, d_k)
out = out.transpose(1, 2).contiguous().view(B, seq_len, h * d_k)
print(f"out size: {out.size()}\n")

sha_out = SHA_W_o(out)
print(f"Self Head Attention out size: {sha_out.size()}\n")

x = x + sha_out
x = ln1(x)
print(f"x size: {x.size()}\n")

TRAINING PROCESS
after input embedding: torch.Size([4, 256, 256])

after positional embedding: torch.Size([4, 256, 256])

query size: torch.Size([4, 256, 256])

key size: torch.Size([4, 256, 256])

value size: torch.Size([4, 256, 256])

query size: torch.Size([4, 256, 2, 128])

key size: torch.Size([4, 256, 2, 128])

value size: torch.Size([4, 256, 2, 128])

query size: torch.Size([8, 256, 128])

key size: torch.Size([8, 256, 128])

value size: torch.Size([8, 256, 128])

W before masking:
 tensor([[[  78.6641,  132.3236, -138.1175,  ...,  -15.7316,  -16.5606,
           -16.8900],
         [  54.0509,   79.4072,   51.2084,  ...,  -52.3254,  -55.0169,
           -57.6886],
         [  51.9549,  118.6219,  163.1727,  ...,  -63.9407,  -65.0058,
           -65.7316],
         ...,
         [  42.6520,   24.7375,  -61.6832,  ...,  -14.1162,  -15.2030,
           -16.4563],
         [  42.9732,   23.9359,  -61.8729,  ...,  -13.5574,  -14.6552,
           -15.9211],
         [  43.3201,   23.