In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt


In [2]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, causal=False):
    super().__init__()

    # Assume d_v = d_k
    self.d_k = d_k
    self.n_heads = n_heads

    self.key = nn.Linear(d_model, d_k * n_heads)
    self.query = nn.Linear(d_model, d_k * n_heads)
    self.value = nn.Linear(d_model, d_k * n_heads)

    # final linear layer
    self.fc = nn.Linear(d_k * n_heads, d_model)

    # causal mask
    # make it so that diagonal is 0 too
    # this way we don't have to shift the inputs to make targets
    self.causal = causal
    if causal:
      cm = torch.tril(torch.ones(max_len, max_len))
      self.register_buffer(
          "causal_mask",
          cm.view(1, 1, max_len, max_len)
      )

  def forward(self, q, k, v, pad_mask=None):
    q = self.query(q) # N x T x (hd_k)
    k = self.key(k)   # N x T x (hd_k)
    v = self.value(v) # N x T x (hd_v)

    N = q.shape[0]
    T_output = q.shape[1]
    T_input = k.shape[1]

    # change the shape to:
    # (N, T, h, d_k) -> (N, h, T, d_k)
    # in order for matrix multiply to work properly
    q = q.view(N, T_output, self.n_heads, self.d_k).transpose(1, 2)
    k = k.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2)
    v = v.view(N, T_input, self.n_heads, self.d_k).transpose(1, 2)

    # compute attention weights
    # (N, h, T, d_k) x (N, h, d_k, T) --> (N, h, T, T)
    attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
    if pad_mask is not None:
      attn_scores = attn_scores.masked_fill(
          pad_mask[:, None, None, :] == 0, float('-inf'))
    if self.causal:
      attn_scores = attn_scores.masked_fill(
          self.causal_mask[:, :, :T_output, :T_input] == 0, float('-inf'))
    attn_weights = F.softmax(attn_scores, dim=-1)
    
    # compute attention-weighted values
    # (N, h, T, T) x (N, h, T, d_k) --> (N, h, T, d_k)
    A = attn_weights @ v

    # reshape it back before final linear layer
    A = A.transpose(1, 2) # (N, T, h, d_k)
    A = A.contiguous().view(N, T_output, self.d_k * self.n_heads) # (N, T, h*d_k)

    # projection
    return self.fc(A)


In [3]:
class EncoderBlock(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model * 4),
        nn.GELU(),
        nn.Linear(d_model * 4, d_model),
        nn.Dropout(dropout_prob),
    )
    self.dropout = nn.Dropout(p=dropout_prob)
  
  def forward(self, x, pad_mask=None):
    x = self.ln1(x + self.mha(x, x, x, pad_mask))
    x = self.ln2(x + self.ann(x))
    x = self.dropout(x)
    return x


In [4]:
class DecoderBlock(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.ln3 = nn.LayerNorm(d_model)
    self.mha1 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=True)
    self.mha2 = MultiHeadAttention(d_k, d_model, n_heads, max_len, causal=False)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model * 4),
        nn.GELU(),
        nn.Linear(d_model * 4, d_model),
        nn.Dropout(dropout_prob),
    )
    self.dropout = nn.Dropout(p=dropout_prob)
  
  def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None):
    # self-attention on decoder input
    x = self.ln1(
        dec_input + self.mha1(dec_input, dec_input, dec_input, dec_mask))

    # multi-head attention including encoder output
    x = self.ln2(x + self.mha2(x, enc_output, enc_output, enc_mask))

    x = self.ln3(x + self.ann(x))
    x = self.dropout(x)
    return x


In [5]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=2048, dropout_prob=0.1):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout_prob)

    position = torch.arange(max_len).unsqueeze(1)
    exp_term = torch.arange(0, d_model, 2)
    div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
    pe = torch.zeros(1, max_len, d_model)
    pe[0, :, 0::2] = torch.sin(position * div_term)
    pe[0, :, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe)

  def forward(self, x):
    # x.shape: N x T x D
    x = x + self.pe[:, :x.size(1), :]
    return self.dropout(x)


In [6]:
class Encoder(nn.Module):
  def __init__(self,
               vocab_size,
               max_len,
               d_k,
               d_model,
               n_heads,
               n_layers,
              #  n_classes,
               dropout_prob):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [
        EncoderBlock(
            d_k,
            d_model,
            n_heads,
            max_len,
            dropout_prob) for _ in range(n_layers)]
    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)
    # self.fc = nn.Linear(d_model, n_classes)
  
  def forward(self, x, pad_mask=None):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(x, pad_mask)

    # many-to-one (x has the shape N x T x D)
    # x = x[:, 0, :]

    x = self.ln(x)
    # x = self.fc(x)
    return x


In [7]:
class Decoder(nn.Module):
  def __init__(self,
               vocab_size,
               max_len,
               d_k,
               d_model,
               n_heads,
               n_layers,
               dropout_prob):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [
        DecoderBlock(
            d_k,
            d_model,
            n_heads,
            max_len,
            dropout_prob) for _ in range(n_layers)]
    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)
    self.fc = nn.Linear(d_model, vocab_size)
  
  def forward(self, enc_output, dec_input, enc_mask=None, dec_mask=None):
    x = self.embedding(dec_input)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(enc_output, x, enc_mask, dec_mask)
    x = self.ln(x)
    x = self.fc(x) # many-to-many
    return x


In [8]:
class Transformer(nn.Module):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
  
  def forward(self, enc_input, dec_input, enc_mask, dec_mask):
    enc_output = self.encoder(enc_input, enc_mask)
    dec_output = self.decoder(enc_output, dec_input, enc_mask, dec_mask)
    return dec_output


In [9]:
encoder = Encoder(vocab_size=20_000,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
decoder = Decoder(vocab_size=10_000,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
transformer = Transformer(encoder, decoder)


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
encoder.to(device)
decoder.to(device)


cuda:0


Decoder(
  (embedding): Embedding(10000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha1): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha2): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
 

In [11]:
xe = np.random.randint(0, 20_000, size=(8, 512))
xe_t = torch.tensor(xe).to(device)

xd = np.random.randint(0, 10_000, size=(8, 256))
xd_t = torch.tensor(xd).to(device)

maske = np.ones((8, 512))
maske[:, 256:] = 0
maske_t = torch.tensor(maske).to(device)

maskd = np.ones((8, 256))
maskd[:, 128:] = 0
maskd_t = torch.tensor(maskd).to(device)

out = transformer(xe_t, xd_t, maske_t, maskd_t)
out.shape


torch.Size([8, 256, 10000])

In [12]:
out


tensor([[[-0.0838,  0.4689, -0.8988,  ..., -0.7771, -1.2173, -0.3510],
         [-1.3486,  0.1689,  0.5149,  ..., -0.5438, -1.0344, -0.8729],
         [-0.5519, -1.0766, -0.0190,  ...,  0.1202,  0.0863, -0.5699],
         ...,
         [-0.7001, -0.6699, -0.3183,  ...,  0.1169, -0.0726, -0.1699],
         [ 0.1086, -0.8313, -0.0122,  ...,  0.2362, -0.0886,  0.7051],
         [ 0.0644,  0.6151, -1.3603,  ...,  0.7313, -0.3213,  1.4813]],

        [[-0.0974,  1.6530, -0.6403,  ..., -0.4491, -0.8697,  0.5653],
         [-0.2173,  0.4646, -0.6453,  ..., -0.1123, -0.2016,  0.5505],
         [ 0.3307,  0.5438, -0.5280,  ..., -0.3066, -0.3113,  0.0534],
         ...,
         [ 1.0794, -0.0807, -1.2905,  ...,  0.8979,  0.3465,  0.5867],
         [-0.2698,  0.1874, -0.6488,  ...,  1.1859,  1.0347,  0.4326],
         [ 0.3200,  0.2562, -0.4384,  ...,  0.2891, -0.5363,  0.5783]],

        [[-0.7660,  0.3287, -0.3151,  ...,  0.3258, -0.8193,  0.1250],
         [ 0.6688,  1.7092,  0.4422,  ...,  0

In [14]:
!wget -nc https://lazyprogrammer.me/course_files/nlp3/spa.txt


--2025-08-04 17:16:08--  https://lazyprogrammer.me/course_files/nlp3/spa.txt
Resolving lazyprogrammer.me (lazyprogrammer.me)... 172.64.80.1
Connecting to lazyprogrammer.me (lazyprogrammer.me)|172.64.80.1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/plain]
Saving to: 'spa.txt'

     0K .......... .......... .......... .......... .......... 7.57M
    50K .......... .......... .......... .......... ..........  445K
   100K .......... .......... .......... .......... ..........  383K
   150K .......... .......... .......... .......... .......... 20.1M
   200K .......... .......... .......... .......... .......... 13.9M
   250K .......... .......... .......... .......... ..........  430K
   300K .......... .......... .......... .......... .......... 10.8M
   350K .......... .......... .......... .......... ..........  422K
   400K .......... .......... .......... .......... ..........  234M
   450K .......... .......... .......... .......... ..

In [15]:
import pandas as pd
df = pd.read_csv('spa.txt', sep="\t", header=None)
df.head()


Unnamed: 0,0,1
0,Go.,Ve.
1,Go.,Vete.
2,Go.,Vaya.
3,Hi.,Hola.
4,Run!,¡Corre!


In [19]:
df.shape


(30000, 2)

In [17]:
df = df.iloc[:30_000] # takes too long


In [18]:
df.columns = ['en', 'es']
df.to_csv('spa.csv', index=None)


In [20]:
from datasets import load_dataset
raw_dataset = load_dataset('csv', data_files='spa.csv')


  from .autonotebook import tqdm as notebook_tqdm
  return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs)
Generating train split: 30000 examples [00:00, 426442.54 examples/s]


In [21]:
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 30000
    })
})

In [22]:
split = raw_dataset['train'].train_test_split(test_size=0.3, seed=42)
split


DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 21000
    })
    test: Dataset({
        features: ['en', 'es'],
        num_rows: 9000
    })
})

In [23]:
from transformers import AutoTokenizer

model_checkpoint = "Helsinki-NLP/opus-mt-en-es"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)




In [24]:
en_sentence = split["train"][0]["en"]
es_sentence = split["train"][0]["es"]

inputs = tokenizer(en_sentence)
targets = tokenizer(text_target=es_sentence)

tokenizer.convert_ids_to_tokens(targets['input_ids'])


['▁Yo', '▁puedo', '▁arreglarlo', '.', '</s>']

In [25]:
es_sentence

'Yo puedo arreglarlo.'

In [26]:
max_input_length = 128
max_target_length = 128

def preprocess_function(batch):
    model_inputs = tokenizer(
        batch['en'], max_length=max_input_length, truncation=True)

    # Set up the tokenizer for targets
    labels = tokenizer(
        text_target=batch['es'], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [27]:
tokenized_datasets = split.map(
    preprocess_function,
    batched=True,
    remove_columns=split["train"].column_names,
)


Map: 100%|██████████| 21000/21000 [00:02<00:00, 7118.80 examples/s]
Map: 100%|██████████| 9000/9000 [00:01<00:00, 6751.29 examples/s]


In [28]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 21000
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9000
    })
})

In [29]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer)


In [30]:
batch = data_collator([tokenized_datasets["train"][i] for i in range(0, 5)])
batch.keys()


KeysView({'input_ids': tensor([[   33,    88,  9222,    48,     3,     0, 65000, 65000],
        [  552, 11490,     9,   310,   255,     3,     0, 65000],
        [  143,    31,   125,  1208,     3,     0, 65000, 65000],
        [ 1093,   220,  1890,    23,    48,     3,     0, 65000],
        [  124,    20,   100, 18422,    48,   141,     3,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[  711,  1039, 44159,     3,     0,  -100,  -100,  -100],
        [ 2722, 18663,   239,   212,     3,     0,  -100,  -100],
        [  539,    43,   155,   960,     3,     0,  -100,  -100],
        [15165,  1250,   380,  3564,    36,  1016,     3,     0],
        [  350,     8, 19153,    29, 31326,     3,     0,  -100]])})

In [31]:
batch['input_ids']


tensor([[   33,    88,  9222,    48,     3,     0, 65000, 65000],
        [  552, 11490,     9,   310,   255,     3,     0, 65000],
        [  143,    31,   125,  1208,     3,     0, 65000, 65000],
        [ 1093,   220,  1890,    23,    48,     3,     0, 65000],
        [  124,    20,   100, 18422,    48,   141,     3,     0]])

In [32]:
batch['attention_mask']


tensor([[1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]])

In [33]:
batch['labels']


tensor([[  711,  1039, 44159,     3,     0,  -100,  -100,  -100],
        [ 2722, 18663,   239,   212,     3,     0,  -100,  -100],
        [  539,    43,   155,   960,     3,     0,  -100,  -100],
        [15165,  1250,   380,  3564,    36,  1016,     3,     0],
        [  350,     8, 19153,    29, 31326,     3,     0,  -100]])

In [34]:
tokenizer.all_special_ids


[0, 1, 65000]

In [35]:
tokenizer.all_special_tokens


['</s>', '<unk>', '<pad>']

In [36]:
tokenizer('<pad>')


{'input_ids': [65000, 0], 'attention_mask': [1, 1]}

In [37]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator
)
valid_loader = DataLoader(
    tokenized_datasets["test"],
    batch_size=32,
    collate_fn=data_collator
)


In [38]:
# check how it works
for batch in train_loader:
  for k, v in batch.items():
    print("k:", k, "v.shape:", v.shape)
  break


k: input_ids v.shape: torch.Size([32, 9])
k: attention_mask v.shape: torch.Size([32, 9])
k: labels v.shape: torch.Size([32, 10])


In [39]:
tokenizer.vocab_size


65001

In [40]:
tokenizer.decode([60000])


'ѕэр'

In [45]:
for i in range(1000):
    print(tokenizer.decode([60000 + i]))


ѕэр
֖
אהוד
אמת
ישראל
שאול
؋
أ،
أحد
اسیدپاشی
الأحرار
الارض
الدول
الشعب
العرب
اله
الوطني
بازخوری
بالموئل
بصمت
دونم
رئيسية
ريم
سورة
سوريا
سکائی
سید
شریف
صالح
ضد
طلال
عبيد
فضای
قاموس
قبل
لبيك
للحجب
لله
متعلقة
مثلي
محسن
مدني
مسجد
مسعود
مصطلحات
معنية
مواضيع
ناصر
نزار
نواز
هل
هيثم
وادي
پرتیبھا
پکیج
چین
کم
বিজয়
ਭਾਰਤ
คิง
ེ
ဪဪ
ᄆ
ᠱ
ᠶ
ᢲ
ᴘʀᴏ
ṇḍ
ṓ
Ả
ἀήρ
ἐκεῖνος
‍✈️
’à
††††††††
′û
₴
←←←
←→
∎
−∞
∘
√§
√√
∠
∮
⊱
⎢
⎼
├Ą
▀
▣
▦
▬
▲一
►►►
◂
♦♦♦
♪♪♪♪
⚜
⚪️
⚽
✆
✩✩
✯
❄
❶
➍
⦶
もっと知る
もっと見る
よく
ク
ベド
マップ
三亚
不倫
与合作伙伴们共同开发的
丝
个
久保田智広
么
书
人
什么
作業療法
値
像
军队应该允许妇女在在战斗岗位中服役吗
到
包子
即时中国大陆映像
古城
另外
回到顶部
国家民族事务委员会经济发展司
国家统计局人口和社会科技统计司
多
大
天天看高清影视
奇摩輸入法
如
如意淘
学生宿舍
它应该是非法烧我们的国旗
当店へのアクセス
您支持单一支付者医保系统吗
抗原
折角
拼音输入法
支持机构
文敏
明故宫
更多范文
木
本平台是由
松田聖子
林
枚
毛泽东
汉语
注意
洪門
清
游戏
游泳
烤鸭
猴
玄師
电子邮箱
米津
納豆
繁體
结果
羊肉串
美国是否应该留在联合国
茜
虎
跟着我们
路
这样
重庆
金士顿技术支持
鈴木雅之
钅
钟
长城
陳典桓
隐私政策
青岛
面
韩文强
頨
饺子
鮥
麻辣烫
동현
되자
세계
안상홍님
어머니와
영생
예루살렘
예를
제
중국
집
찾으라
천국
코맨
하나님
한국어Č













##########
#^
#Євромайдан
#كلنا
#†
#社畜死亡かるた
$###
$±
+£
+»
==<
=Í
=’
>>¡
\\$
\\\\

In [41]:
tokenizer.add_special_tokens({"cls_token": "<s>"})


1

In [42]:
tokenizer("<s>")


{'input_ids': [65001, 0], 'attention_mask': [1, 1]}

In [43]:
tokenizer.vocab_size


65001

In [46]:
encoder = Encoder(vocab_size=tokenizer.vocab_size + 1,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
decoder = Decoder(vocab_size=tokenizer.vocab_size + 1,
                  max_len=512,
                  d_k=16,
                  d_model=64,
                  n_heads=4,
                  n_layers=2,
                  dropout_prob=0.1)
transformer = Transformer(encoder, decoder)


In [47]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
encoder.to(device)
decoder.to(device)


cuda:0


Decoder(
  (embedding): Embedding(65002, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): DecoderBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha1): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha2): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
 

In [48]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.Adam(transformer.parameters())


In [49]:
from datetime import datetime
# A function to encapsulate the training loop
def train(model, criterion, optimizer, train_loader, valid_loader, epochs):
  train_losses = np.zeros(epochs)
  test_losses = np.zeros(epochs)

  for it in range(epochs):
    model.train()
    t0 = datetime.now()
    train_loss = []
    for batch in train_loader:
      # move data to GPU (enc_input, enc_mask, translation)
      batch = {k: v.to(device) for k, v in batch.items()}

      # zero the parameter gradients
      optimizer.zero_grad()

      enc_input = batch['input_ids']
      enc_mask = batch['attention_mask']
      targets = batch['labels']

      # shift targets forwards to get decoder_input
      dec_input = targets.clone().detach()
      dec_input = torch.roll(dec_input, shifts=1, dims=1)
      dec_input[:, 0] = 65_001

      # also convert all -100 to pad token id
      dec_input = dec_input.masked_fill(
          dec_input == -100, tokenizer.pad_token_id)

      # make decoder input mask
      dec_mask = torch.ones_like(dec_input)
      dec_mask = dec_mask.masked_fill(dec_input == tokenizer.pad_token_id, 0)

      # Forward pass
      outputs = model(enc_input, dec_input, enc_mask, dec_mask)
      loss = criterion(outputs.transpose(2, 1), targets)
        
      # Backward and optimize
      loss.backward()
      optimizer.step()
      train_loss.append(loss.item())

    # Get train loss and test loss
    train_loss = np.mean(train_loss)

    model.eval()
    test_loss = []
    for batch in valid_loader:
      batch = {k: v.to(device) for k, v in batch.items()}

      enc_input = batch['input_ids']
      enc_mask = batch['attention_mask']
      targets = batch['labels']

      # shift targets forwards to get decoder_input
      dec_input = targets.clone().detach()
      dec_input = torch.roll(dec_input, shifts=1, dims=1)
      dec_input[:, 0] = 65_001

      # change -100s to regular padding
      dec_input = dec_input.masked_fill(
          dec_input == -100, tokenizer.pad_token_id)

      # make decoder input mask
      dec_mask = torch.ones_like(dec_input)
      dec_mask = dec_mask.masked_fill(dec_input == tokenizer.pad_token_id, 0)

      outputs = model(enc_input, dec_input, enc_mask, dec_mask)
      loss = criterion(outputs.transpose(2, 1), targets)
      test_loss.append(loss.item())
    test_loss = np.mean(test_loss)

    # Save losses
    train_losses[it] = train_loss
    test_losses[it] = test_loss
    
    dt = datetime.now() - t0
    print(f'Epoch {it+1}/{epochs}, Train Loss: {train_loss:.4f}, \
      Test Loss: {test_loss:.4f}, Duration: {dt}')
  
  return train_losses, test_losses


In [50]:
train_losses, test_losses = train(
    transformer, criterion, optimizer, train_loader, valid_loader, epochs=15)


Epoch 1/15, Train Loss: 4.8972,       Test Loss: 3.8837, Duration: 0:00:18.651943
Epoch 2/15, Train Loss: 3.5531,       Test Loss: 3.3775, Duration: 0:00:18.949467
Epoch 3/15, Train Loss: 3.0734,       Test Loss: 3.0775, Duration: 0:00:21.209332
Epoch 4/15, Train Loss: 2.7223,       Test Loss: 2.8326, Duration: 0:00:21.984545
Epoch 5/15, Train Loss: 2.4304,       Test Loss: 2.6874, Duration: 0:00:21.848515
Epoch 6/15, Train Loss: 2.1925,       Test Loss: 2.5589, Duration: 0:00:22.421354
Epoch 7/15, Train Loss: 1.9998,       Test Loss: 2.4786, Duration: 0:00:19.781621
Epoch 8/15, Train Loss: 1.8387,       Test Loss: 2.4408, Duration: 0:00:19.436924
Epoch 9/15, Train Loss: 1.7099,       Test Loss: 2.4002, Duration: 0:00:18.441274
Epoch 10/15, Train Loss: 1.5984,       Test Loss: 2.3828, Duration: 0:00:17.535151
Epoch 11/15, Train Loss: 1.5056,       Test Loss: 2.3771, Duration: 0:00:17.470578
Epoch 12/15, Train Loss: 1.4296,       Test Loss: 2.3712, Duration: 0:00:17.582448
Epoch 13/15, 

In [51]:
# try it out

input_sentence = split['test'][10]['en']
input_sentence


'Can I take a day off?'

In [52]:
enc_input = tokenizer(input_sentence, return_tensors='pt')
enc_input

{'input_ids': tensor([[1283,   33,  273,    8,  502,  843,   21,    0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [53]:
dec_input_str = '<s>'

dec_input = tokenizer(text_target=dec_input_str, return_tensors='pt')
dec_input


{'input_ids': tensor([[65001,     0]]), 'attention_mask': tensor([[1, 1]])}

In [54]:
enc_input.to(device)
dec_input.to(device)
output = transformer(
    enc_input['input_ids'],
    dec_input['input_ids'][:, :-1],
    enc_input['attention_mask'],
    dec_input['attention_mask'][:, :-1],
)
output


tensor([[[ 2.4291, -7.1860,  5.8436,  ..., -9.2412, -9.2837, -7.8867]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [55]:
output.shape # N x T x V


torch.Size([1, 1, 65002])

In [56]:
enc_output = encoder(enc_input['input_ids'], enc_input['attention_mask'])
enc_output.shape


torch.Size([1, 8, 64])

In [57]:
dec_output = decoder(
    enc_output,
    dec_input['input_ids'][:, :-1],
    enc_input['attention_mask'],
    dec_input['attention_mask'][:, :-1],
)
dec_output.shape


torch.Size([1, 1, 65002])

In [58]:
torch.allclose(output, dec_output)


True

In [59]:
dec_input_ids = dec_input['input_ids'][:, :-1]
dec_attn_mask = dec_input['attention_mask'][:, :-1]

for _ in range(32):
  dec_output = decoder(
      enc_output,
      dec_input_ids,
      enc_input['attention_mask'],
      dec_attn_mask,
  )

  # choose the best value (or sample)
  prediction_id = torch.argmax(dec_output[:, -1, :], axis=-1)

  # append to decoder input
  dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1, 1)))

  # recreate mask
  dec_attn_mask = torch.ones_like(dec_input_ids)

  # exit when reach </s>
  if prediction_id == 0:
    break


In [60]:
tokenizer.decode(dec_input_ids[0])


'<s> ¿Puedo tomar un día para galletas?</s>'

In [61]:
split['test'][10]['es']


'¿Puedo tomarme un día libre?'

In [62]:
def translate(input_sentence):
  # get encoder output first
  enc_input = tokenizer(input_sentence, return_tensors='pt').to(device)
  enc_output = encoder(enc_input['input_ids'], enc_input['attention_mask'])

  # setup initial decoder input
  dec_input_ids = torch.tensor([[65_001]], device=device)
  dec_attn_mask = torch.ones_like(dec_input_ids, device=device)

  # now do the decoder loop
  for _ in range(32):
    dec_output = decoder(
        enc_output,
        dec_input_ids,
        enc_input['attention_mask'],
        dec_attn_mask,
    )

    # choose the best value (or sample)
    prediction_id = torch.argmax(dec_output[:, -1, :], axis=-1)

    # append to decoder input
    dec_input_ids = torch.hstack((dec_input_ids, prediction_id.view(1, 1)))

    # recreate mask
    dec_attn_mask = torch.ones_like(dec_input_ids)

    # exit when reach </s>
    if prediction_id == 0:
      break
  
  translation = tokenizer.decode(dec_input_ids[0, 1:])
  print(translation)


In [63]:
translate("How are you?")


¿Cómo estás?</s>
