In [2]:
import numpy as np
import torch 
import torch.nn as nn 
from transformers import Transformer
from transformer_train import Batch, NoamOptimizer

In [2]:
transformer = Transformer(input_vocab=10, output_vocab=10)

Try using the transformer with None mask

In [3]:
input = torch.randint(10, (60,5))
output = torch.randint(10, (60,10))
transformer(input, output, None, None)

tensor([[[-4.3559, -5.2487, -0.7497,  ..., -5.3695, -1.3660, -3.9769],
         [-5.2369, -4.1565, -0.7808,  ..., -3.6044, -2.5541, -4.1441],
         [-3.7091, -3.8110, -1.6015,  ..., -3.8261, -1.2818, -4.3744],
         ...,
         [-4.5364, -5.4882, -0.9958,  ..., -4.4025, -2.8058, -4.4474],
         [-2.2326, -3.7237, -0.6873,  ..., -3.8603, -2.7467, -4.3089],
         [-2.2533, -3.8377, -0.9976,  ..., -4.6721, -1.2523, -4.1875]],

        [[-4.1219, -3.8285, -2.2555,  ..., -1.3040, -2.0745, -4.4725],
         [-6.8820, -4.5870, -1.6347,  ..., -2.0513, -3.3394, -4.2462],
         [-4.2103, -3.5013, -2.2617,  ..., -1.4769, -1.9338, -4.4816],
         ...,
         [-5.5928, -4.4443, -2.1528,  ..., -2.3184, -2.2187, -4.1659],
         [-6.3742, -4.5327, -2.5778,  ..., -1.4635, -2.0110, -4.7276],
         [-6.3213, -4.5722, -2.4557,  ..., -1.4363, -2.0143, -4.7555]],

        [[-3.9285, -4.4695, -1.3114,  ..., -2.3223, -2.6220, -1.6668],
         [-2.4939, -3.0806, -2.1883,  ..., -2

Try using the transformer with real mask

In [4]:
input = torch.randint(10, (60,5))
output = torch.randint(10, (60,10))
input_mask = torch.randint(2, (60, 1, 5))
output_mask = torch.randint(2, (60, 10, 10))
transformer(input, output, input_mask, output_mask)

tensor([[[-4.3160, -3.0536, -1.9481,  ..., -1.7658, -3.3523, -1.4735],
         [-3.1745, -3.1472, -2.1283,  ..., -1.4101, -4.1834, -1.0412],
         [-3.3421, -3.7672, -1.9132,  ..., -1.8864, -4.0480, -0.7068],
         ...,
         [-5.2904, -3.5835, -1.4576,  ..., -2.8472, -4.2545, -1.5367],
         [-7.2072, -3.5299, -1.8865,  ..., -1.4819, -3.2552, -1.6378],
         [-3.1346, -3.7828, -0.9886,  ..., -3.2449, -3.2661, -1.0610]],

        [[-4.7001, -6.8197, -1.6232,  ..., -4.6834, -1.3834, -5.2095],
         [-2.7843, -4.9023, -1.4748,  ..., -3.3052, -1.9544, -4.9664],
         [-3.1717, -5.1768, -2.4111,  ..., -4.1504, -0.8286, -4.9793],
         ...,
         [-6.6173, -6.1197, -0.8461,  ..., -3.7754, -2.9967, -4.7263],
         [-3.9399, -5.9752, -1.5434,  ..., -3.3832, -1.5053, -4.6520],
         [-5.2750, -5.6847, -0.7923,  ..., -3.9814, -2.6309, -4.5875]],

        [[-3.8285, -5.2046, -1.5694,  ..., -2.7178, -3.5082, -2.8018],
         [-4.0688, -5.0071, -1.9951,  ..., -1

## Creating and training transformer for copy task

In [5]:
num_vocab = 11
seq_len = 10
model_dim = 512

In [6]:
def copy_data_generator(num_batches=30, batch_size=20):
    for _ in range(num_batches):
        data = np.random.randint(num_vocab, size=(batch_size, seq_len))
        src = torch.from_numpy(data)
        tgt = torch.from_numpy(data.copy())
        yield Batch(src, tgt, 0)

In [7]:
transformer = Transformer(input_vocab=num_vocab, output_vocab=num_vocab, model_dim=model_dim, num_coder=2)

In [8]:
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
noam_optimizer = NoamOptimizer(adam_optimizer, model_dim, 1, 400)

In [9]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)

for epoch in range(10):
    transformer.train()
    data_iter = copy_data_generator()

    losses = 0.
    cnt = 0

    for i, batch in enumerate(data_iter):
        out = transformer(batch.src, batch.trg, batch.src_mask, batch.trg_mask)

        out = out.reshape(-1, out.shape[-1])
        labels = batch.trg_y.reshape(-1)

        loss = loss_fn(out, labels)
        
        loss.backward()
        noam_optimizer.step()
        noam_optimizer.optimizer.zero_grad()

        losses += loss.item()
        cnt += 1

    print(epoch, losses / cnt)

0 2.3245237350463865
1 1.82636163632075
2 1.6100578824679057
3 1.2480957667032877
4 0.9334281404813131
5 0.818869560956955
6 0.8302984317143758
7 0.7845571974913279
8 0.7216464360555013
9 0.7322157025337219


## Testing transformer for copy task

In [10]:
input = np.random.randint(num_vocab, size=(10, seq_len))

expected_output = torch.from_numpy(input.copy()) 
output_init = torch.from_numpy(input[:,0].reshape(-1,1))
input = torch.from_numpy(input)
input_mask = torch.ones(10, 1, seq_len)

In [11]:
output = transformer.greedy_decode(input, output_init, seq_len, input_mask)

In [12]:
expected_output

tensor([[ 7,  3, 10,  2,  1,  7,  7,  1,  9,  9],
        [ 4,  0,  1,  0,  5,  0,  9,  0, 10,  1],
        [ 9, 10,  8,  0,  1,  1,  2,  3,  4,  3],
        [10,  8,  9,  2,  7,  0, 10, 10,  7, 10],
        [ 1,  7,  2,  3,  7,  8,  5,  1,  7,  9],
        [ 1,  6, 10,  3,  4,  6,  7,  0,  9,  1],
        [ 6,  7,  8,  4,  5,  0,  7,  2,  1,  1],
        [ 2, 10, 10,  3,  0,  1,  6,  3, 10,  5],
        [ 0,  5,  2,  6, 10,  6,  6,  1,  3,  7],
        [ 8, 10,  6,  1,  4,  2,  9,  3,  1,  7]])

In [13]:
output

tensor([[ 7,  3, 10,  2,  1,  7,  1,  7,  9,  9],
        [ 4,  3,  1,  1,  5,  6,  9,  3, 10,  1],
        [ 9, 10,  8,  8,  1,  8,  2,  3,  4,  3],
        [10,  8,  9,  2,  7,  2, 10, 10,  7, 10],
        [ 1,  7,  2,  3,  7,  8,  5,  1,  7,  9],
        [ 1,  6, 10,  3,  4,  6,  7,  7,  9,  1],
        [ 6,  7,  8,  4,  5,  5,  7,  2,  1,  6],
        [ 2, 10, 10,  3, 10,  3,  6,  3, 10,  5],
        [ 0,  5,  2,  6, 10,  6,  6,  1,  3,  7],
        [ 8, 10,  6,  1,  4,  2,  9,  3,  3,  7]])

In [14]:
# get accuracy 
acc = (output == expected_output).sum() / (10 * seq_len)
acc.item()

0.8500000238418579

## Creating and training transformer for sequence reversal

In [15]:
num_vocab = 11
seq_len = 10
model_dim = 512

In [16]:
def reverse_data_generator(num_batches=30, batch_size=20):
    for _ in range(num_batches):
        data = np.random.randint(num_vocab, size=(batch_size, seq_len))
        src = torch.from_numpy(data)
        tgt = torch.from_numpy(data[:, ::-1].copy())
        yield Batch(src, tgt, 0)

In [17]:
transformer = Transformer(input_vocab=num_vocab, output_vocab=num_vocab, model_dim=model_dim, num_coder=2)

In [18]:
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
noam_optimizer = NoamOptimizer(adam_optimizer, model_dim, 1, 400)

In [19]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)

for epoch in range(10):
    transformer.train()
    data_iter = reverse_data_generator()

    losses = 0.
    cnt = 0

    for i, batch in enumerate(data_iter):
        out = transformer(batch.src, batch.trg, batch.src_mask, batch.trg_mask)

        out = out.reshape(-1, out.shape[-1])
        labels = batch.trg_y.reshape(-1)

        loss = loss_fn(out, labels)
        
        loss.backward()
        noam_optimizer.step()
        noam_optimizer.optimizer.zero_grad()

        losses += loss.item()
        cnt += 1

    print(epoch, losses / cnt)

0 2.3368816216786703
1 1.8293853362401327
2 1.7079774181048075
3 1.3476770341396331
4 0.9670732994874318
5 0.8006935338179271
6 0.7373670717080434
7 0.7913566052913665
8 0.8067436377207439
9 0.8073766867319743


## Testing transformer for sequence reversal

In [20]:
input = np.random.randint(num_vocab, size=(10, seq_len))

expected_output = torch.from_numpy(input[:, ::-1].copy()) 
output_init = torch.from_numpy(input[:,-1].reshape(-1,1))
input = torch.from_numpy(input)
input_mask = torch.ones(10, 1, seq_len)

In [21]:
output = transformer.greedy_decode(input, output_init, seq_len, input_mask)

In [22]:
expected_output

tensor([[ 9,  0,  8,  7,  7,  7, 10,  8,  3,  8],
        [ 0,  5,  5,  8,  6,  8,  8,  8,  6,  7],
        [ 5,  9,  9,  9,  0,  3,  0,  0,  0,  5],
        [ 0,  8,  2,  0,  9,  5,  1,  1,  0,  2],
        [ 9,  5,  1,  4,  1,  7,  1,  6,  3,  1],
        [ 3,  7,  0, 10,  1, 10,  9,  9,  6,  4],
        [ 3,  0,  6,  2,  3,  4,  1,  9,  4, 10],
        [ 2,  0,  0,  6,  7,  2,  1,  4,  4,  2],
        [ 0, 10,  4,  0,  1,  8, 10,  3,  2,  6],
        [ 9,  8,  1,  8,  8,  5,  6,  7,  9,  0]])

In [23]:
output

tensor([[ 9,  8,  7,  8,  7,  7, 10,  8,  3,  8],
        [ 0,  5,  5,  8,  6,  8,  8,  8,  6,  7],
        [ 5,  9,  9,  9,  9,  3,  9,  7,  5,  5],
        [ 0,  8,  2,  9,  8,  9,  1,  1,  9,  5],
        [ 9,  5,  1,  1,  4,  7,  1,  6,  3,  1],
        [ 3,  7,  8, 10,  1, 10,  9,  9,  6,  4],
        [ 3,  6,  6,  2,  3,  4,  1,  9,  4, 10],
        [ 2,  8,  7,  6,  7,  2,  1,  4,  4,  2],
        [ 0, 10, 10,  4,  1,  8, 10,  3,  3,  2],
        [ 9,  8,  1,  8,  8,  5,  6,  7,  9,  9]])

In [24]:
# get accuracy 
acc = (output == expected_output).sum() / (10 * seq_len)
acc.item()

0.7699999809265137