In [1]:
import torch
import sys

sys.path.append('../src/')

from data_loader import DatasetLanguage, collate_fn
from torch.utils.data import DataLoader
from transformers import Transformer
import json

Paths

In [2]:
training_path = '../data/train.json'
testing_path = '../data/test.json'
validation_path = '../data/validation.json'
x_vocab = '../data/english_map.json'
y_vocab = '../data/hindi_map.json'

Vocab Sizes

x: 1802939

y: 2180936

In [3]:
with open(x_vocab, 'r', encoding='utf-8') as f:
    data = json.load(f)
    total_vocab = len(data)
    print(f'x: {total_vocab}')
    
with open(y_vocab, 'r', encoding='utf-8') as f:
    data = json.load(f)
    total_vocab = len(data)
    print(f'y: {total_vocab}')

x: 31018
y: 27473


Data Loader

In [4]:
def get_loders(data_path, batch_size):
    dataset = DatasetLanguage(data_path=data_path,
                                x_vocab=x_vocab,
                                y_vocab=y_vocab)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [5]:
train_dataloader = get_loders(training_path, batch_size=2)

**Training**

Model

In [6]:
vocab_size_input = 31018
vocab_size_output = 27473

transformer_model = Transformer(
    num_blocks=1,
    d_model=512,
    num_heads=8,
    vocab_size_input=vocab_size_input,
    vocab_size_output=vocab_size_output
)

In [7]:
transformer_model.parameters()

<generator object Module.parameters at 0x0000012507F7BCA0>

In [None]:
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
size = len(train_dataloader.dataset)
batch_size = 2
epochs = 2

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    transformer_model.train()
    for batch, (X, y) in enumerate(train_dataloader):
        pred = transformer_model(X, y)
        
        y_one_hot = torch.nn.functional.one_hot(y, num_classes=vocab_size_output).to(torch.float)
        loss = loss_fn(pred, y_one_hot)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

Epoch 1
-------------------------------
loss: 0.001615  [    2/ 7000]
loss: 0.001248  [  202/ 7000]
loss: 0.001113  [  402/ 7000]
loss: 0.002844  [  602/ 7000]
loss: 0.001372  [  802/ 7000]
loss: 0.000975  [ 1002/ 7000]
loss: 0.010503  [ 1202/ 7000]
loss: 0.001768  [ 1402/ 7000]
loss: 0.001586  [ 1602/ 7000]


In [10]:
torch.nn.functional.one_hot(y, num_classes=vocab_size_output)

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0]]])

In [11]:
pred

tensor([[[1.8621e-05, 2.7510e-05, 1.0490e-05,  ..., 2.6279e-05,
          1.9552e-05, 6.6343e-05],
         [1.6636e-05, 2.4985e-05, 1.7356e-05,  ..., 2.7062e-05,
          1.9494e-05, 2.0020e-05],
         [1.5191e-05, 4.7449e-05, 1.3954e-05,  ..., 2.3909e-05,
          5.9029e-05, 3.1964e-05],
         ...,
         [2.4853e-05, 5.6504e-05, 1.0780e-05,  ..., 4.0894e-05,
          3.7860e-05, 1.0435e-05],
         [4.9898e-05, 5.6150e-05, 2.6220e-05,  ..., 7.2906e-05,
          2.0828e-05, 3.4885e-05],
         [3.1868e-05, 6.7873e-05, 1.5580e-05,  ..., 2.0059e-05,
          5.1960e-05, 2.0337e-05]],

        [[1.7727e-05, 2.8221e-05, 1.2633e-05,  ..., 2.9488e-05,
          1.8888e-05, 6.6712e-05],
         [3.6069e-05, 4.0887e-05, 3.1571e-05,  ..., 6.8090e-05,
          3.2708e-05, 2.8404e-05],
         [5.3294e-05, 2.8720e-05, 1.6916e-05,  ..., 3.4922e-05,
          4.0616e-05, 3.6511e-05],
         ...,
         [2.2892e-05, 8.3273e-05, 1.1395e-05,  ..., 7.3501e-05,
          2.614