In [1]:
import numpy as np
import torch 
import torch.nn as nn 
from transformers import Transformer
from transformer_train import Batch, NoamOptimizer

In [2]:
transformer = Transformer(input_vocab=10, output_vocab=10)

Try using the transformer with None mask

In [3]:
input = torch.randint(10, (60,5))
output = torch.randint(10, (60,10))
transformer(input, output, None, None)

tensor([[[-7.7241, -1.6738, -1.7825,  ..., -2.9035, -2.7298, -2.9162],
         [-6.3393, -3.1287, -1.2187,  ..., -2.8651, -3.8741, -1.5978],
         [-4.8164, -3.8533, -2.7677,  ..., -0.8371, -4.2480, -2.8280],
         ...,
         [-6.1315, -3.0389, -1.4388,  ..., -3.3017, -3.7573, -1.2052],
         [-7.3743, -1.7026, -1.8243,  ..., -3.2567, -2.4396, -2.4312],
         [-4.6603, -4.1607, -2.3987,  ..., -1.8191, -4.1767, -2.0908]],

        [[-5.2516, -2.9991, -1.8385,  ..., -1.4333, -3.4890, -1.8554],
         [-6.9604, -1.7433, -1.2655,  ..., -2.5555, -2.1312, -3.8134],
         [-5.7765, -3.1944, -0.8544,  ..., -3.0560, -3.6236, -1.5050],
         ...,
         [-6.7821, -1.5662, -1.4108,  ..., -3.0513, -1.9618, -3.5587],
         [-5.5438, -4.5060, -1.9625,  ..., -2.8324, -2.9298, -1.4453],
         [-4.1153, -3.5120, -2.3348,  ..., -1.4389, -3.8671, -2.1483]],

        [[-5.3493, -4.4682, -1.0783,  ..., -2.1520, -3.8770, -2.9099],
         [-5.3334, -4.5377, -1.0115,  ..., -2

Try using the transformer with real mask

In [4]:
input = torch.randint(10, (60,5))
output = torch.randint(10, (60,10))
input_mask = torch.randint(2, (60, 1, 5))
output_mask = torch.randint(2, (60, 10, 10))
transformer(input, output, input_mask, output_mask)

tensor([[[-2.1784, -2.4380, -2.8516,  ..., -4.0840, -2.8811, -0.7360],
         [-1.9744, -3.6604, -3.3502,  ..., -4.2985, -3.4730, -0.7521],
         [-3.0202, -2.5104, -4.3819,  ..., -3.2995, -2.4610, -0.8855],
         ...,
         [-2.3645, -3.5712, -3.7282,  ..., -4.3372, -3.1454, -0.6048],
         [-1.4906, -3.3852, -3.6559,  ..., -4.3083, -3.0591, -1.2246],
         [-1.9915, -2.1084, -4.4619,  ..., -4.5640, -3.3588, -0.6902]],

        [[-4.9137, -3.9449, -3.0051,  ..., -2.7996, -3.9877, -2.0667],
         [-5.9068, -5.1644, -3.5201,  ..., -2.7560, -2.9781, -2.9157],
         [-4.4664, -4.0871, -2.9414,  ..., -2.9199, -4.0902, -1.9247],
         ...,
         [-5.3064, -3.7297, -2.2830,  ..., -2.4822, -2.1513, -2.8990],
         [-4.5457, -7.2209, -3.0816,  ..., -3.5740, -5.1189, -2.0954],
         [-5.9209, -5.2169, -3.8749,  ..., -1.6439, -4.0117, -2.6541]],

        [[-6.5225, -3.5297, -3.5759,  ..., -1.4199, -3.2896, -2.6221],
         [-6.0816, -4.6127, -3.5266,  ..., -2

## Creating and training transformer for copy task

In [5]:
num_vocab = 11
seq_len = 10
model_dim = 512

In [6]:
def copy_data_generator(num_batches=30, batch_size=20):
    for _ in range(num_batches):
        data = np.random.randint(num_vocab, size=(batch_size, seq_len))
        src = torch.from_numpy(data)
        tgt = torch.from_numpy(data.copy())
        yield Batch(src, tgt, 0)

In [7]:
transformer = Transformer(input_vocab=num_vocab, output_vocab=num_vocab, model_dim=model_dim, num_coder=2)

In [8]:
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
noam_optimizer = NoamOptimizer(adam_optimizer, model_dim, 1, 400)

In [9]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)

for epoch in range(10):
    transformer.train()
    data_iter = copy_data_generator()

    losses = 0.
    cnt = 0

    for i, batch in enumerate(data_iter):
        out = transformer(batch.src, batch.trg, batch.src_mask, batch.trg_mask)

        out = out.reshape(-1, out.shape[-1])
        labels = batch.trg_y.reshape(-1)

        loss = loss_fn(out, labels)
        
        loss.backward()
        noam_optimizer.step()
        noam_optimizer.optimizer.zero_grad()

        losses += loss.item()
        cnt += 1

    print(epoch, losses / cnt)

0 2.2781896273295086
1 1.811338988939921
2 1.6485260645548503
3 1.2655906836191813
4 1.0189296901226044
5 0.8298549135526021
6 0.7520816405614217
7 0.7098671356836955
8 0.8333705087502797
9 0.8598594109217326


## Testing transformer for copy task

In [10]:
input = np.random.randint(num_vocab, size=(10, seq_len))

expected_output = torch.from_numpy(input.copy()) 
output_init = torch.from_numpy(input[:,0].reshape(-1,1))
input = torch.from_numpy(input)
input_mask = torch.ones(10, 1, seq_len)

In [11]:
output = transformer.greedy_decode(input, output_init, seq_len, input_mask)

In [12]:
expected_output

tensor([[ 0,  4,  2,  9,  4,  3,  6,  3,  1,  1],
        [ 4,  2,  2,  2,  6, 10,  6,  7,  5,  9],
        [10,  7,  0,  3,  5,  8,  5,  7,  9,  3],
        [ 7,  9,  6,  8,  7,  8, 10,  5,  8,  6],
        [ 5, 10,  3, 10,  1,  9,  7,  4,  9,  4],
        [ 1,  0,  3,  7,  6,  9,  3,  6,  2,  3],
        [ 3, 10,  9,  6,  9,  2, 10,  6,  5,  2],
        [ 9,  3,  1,  2,  5,  4,  4,  2,  9,  3],
        [ 0,  3,  1,  0,  7,  5,  2,  6,  2,  6],
        [ 4,  0,  5,  9,  6,  8, 10,  2,  7,  0]])

In [13]:
output

tensor([[ 0,  4,  2,  9,  4,  3,  6,  1,  3,  1],
        [ 4,  2,  6,  2,  6, 10,  6,  7,  5,  9],
        [10,  7,  6,  3,  5,  8,  5,  7,  9,  3],
        [ 7,  9,  6,  7,  8, 10,  8,  5,  8,  6],
        [ 5, 10,  3, 10,  1,  9,  7,  4,  9,  4],
        [ 1,  3,  7,  6,  6,  9,  3,  6,  3,  2],
        [ 3, 10,  9,  6,  9, 10,  2,  6,  5,  2],
        [ 9,  3,  1,  5,  4,  4,  4,  2,  9,  3],
        [ 0,  3,  1,  6,  7,  5,  2,  6,  2,  6],
        [ 4,  7,  5,  9,  6,  8, 10,  2,  7, 10]])

In [14]:
# get accuracy 
acc = (output == expected_output).sum() / (10 * seq_len)
acc.item()

0.800000011920929

## Creating and training transformer for sequence reversal

In [15]:
num_vocab = 11
seq_len = 10
model_dim = 512

In [16]:
def reverse_data_generator(num_batches=30, batch_size=20):
    for _ in range(num_batches):
        data = np.random.randint(num_vocab, size=(batch_size, seq_len))
        src = torch.from_numpy(data)
        tgt = torch.from_numpy(data[:, ::-1].copy())
        yield Batch(src, tgt, 0)

In [17]:
transformer = Transformer(input_vocab=num_vocab, output_vocab=num_vocab, model_dim=model_dim, num_coder=2)

In [18]:
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
noam_optimizer = NoamOptimizer(adam_optimizer, model_dim, 1, 400)

In [19]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)

for epoch in range(10):
    transformer.train()
    data_iter = reverse_data_generator()

    losses = 0.
    cnt = 0

    for i, batch in enumerate(data_iter):
        out = transformer(batch.src, batch.trg, batch.src_mask, batch.trg_mask)

        out = out.reshape(-1, out.shape[-1])
        labels = batch.trg_y.reshape(-1)

        loss = loss_fn(out, labels)
        
        loss.backward()
        noam_optimizer.step()
        noam_optimizer.optimizer.zero_grad()

        losses += loss.item()
        cnt += 1

    print(epoch, losses / cnt)

0 2.3041552821795146
1 1.8227254152297974
2 1.6300251364707947
3 1.307507570584615
4 0.973494819800059
5 0.8183308084805806
6 0.7434642116228739
7 0.7893629471460978
8 0.7124856233596801
9 0.7884847939014434


## Testing transformer for sequence reversal

In [20]:
input = np.random.randint(num_vocab, size=(10, seq_len))

expected_output = torch.from_numpy(input[:, ::-1].copy()) 
output_init = torch.from_numpy(input[:,-1].reshape(-1,1))
input = torch.from_numpy(input)
input_mask = torch.ones(10, 1, seq_len)

In [21]:
output = transformer.greedy_decode(input, output_init, seq_len, input_mask)

In [22]:
expected_output

tensor([[ 4,  6, 10,  1,  5,  6,  5,  2,  5, 10],
        [ 0, 10,  0,  2,  8, 10,  3,  5,  4,  3],
        [10,  5,  8,  8, 10, 10,  0, 10,  5,  8],
        [10,  1,  4,  8,  1,  5,  8,  5,  4,  5],
        [10,  2,  4,  9,  7,  0,  3,  9,  9,  8],
        [10, 10,  3,  2,  9,  2,  1, 10,  5,  2],
        [ 7,  4,  4,  1,  4,  1,  0,  1,  6,  9],
        [ 4,  8,  6,  1,  8,  0,  3,  2,  0,  1],
        [ 7,  3,  1,  4,  8,  4,  9, 10,  5,  8],
        [ 7,  5,  8,  3,  4, 10,  9,  7,  4,  4]])

In [23]:
output

tensor([[ 4,  6, 10,  1,  5,  6,  5,  2,  5, 10],
        [ 0, 10,  5, 10,  2,  8,  5,  5,  4,  3],
        [10,  5,  8,  8, 10,  5, 10,  5, 10,  8],
        [10,  4,  1,  8,  1,  5,  8,  5,  4,  5],
        [10,  2,  4,  9,  7,  5,  3,  9,  9,  9],
        [10, 10,  3,  2,  9,  2,  1, 10,  5,  2],
        [ 7,  4,  7,  1,  4,  1,  1,  1,  6,  6],
        [ 4,  8,  6,  1,  8,  5,  3,  2,  1,  1],
        [ 7,  3,  1,  4,  8,  4, 10,  5,  7,  8],
        [ 7,  5,  7,  3,  4, 10,  9,  7,  4,  4]])

In [24]:
# get accuracy 
acc = (output == expected_output).sum() / (10 * seq_len)
acc.item()

0.7799999713897705