In [1]:
from pathlib import Path
from typing import Literal
import pandas as pd
import torch
from torch import nn, optim
from en_indic_transformer import Transformer, Tokenizer, Trainer, TranslationDataLoader, TranslationDataset

Create a various values to use for the rest of the notebook

Get the data from the data directory and create a dataframe

In [2]:
tokenizer = Tokenizer('gpt2', extend_base_encoder={'<|english|>','<|hindi|>', '<|kannada|>' }) # adding kannada for later
src_prepend_value = '<|english|>'
target_prepend_value = '<|hindi|>'

batch_size = 16
random_seed = 42 # for reproducibility
device: Literal['cpu', 'cuda'] = 'cuda' if torch.cuda.is_available() else 'cpu' # device for training.

# apply random_seed
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)

# transformer details
context_length = 4096
vocab_size = tokenizer.n_vocab # since using gpt2 tokenizer
emb_dim = 768
enc_layers = 5
dec_layers = 5
num_heads = 12
dropout = 0.1
bias = False

# training details
epochs = 5
lr = 0.01

In [3]:
path = Path().absolute().parent

In [4]:
data_dir = path / 'data/eng_hindi.csv'

In [5]:
df = pd.read_csv(data_dir)

In [6]:
df

Unnamed: 0,english_sentence,hindi_sentence
0,"However, Paes, who was partnering Australia's ...",आस्ट्रेलिया के पाल हेनली के साथ जोड़ी बनाने वाल...
1,"Whosoever desires the reward of the world, wit...",और जो शख्स (अपने आमाल का) बदला दुनिया ही में च...
2,The value of insects in the biosphere is enorm...,"जैव-मंडल में कीड़ों का मूल्य बहुत है, क्योंकि ..."
3,Mithali To Anchor Indian Team Against Australi...,आस्ट्रेलिया के खिलाफ वनडे टीम की कमान मिताली को
4,After the assent of the Honble President on 8t...,"8 सितम्‍बर, 2016 को माननीय राष्‍ट्रपति की स्‍व..."
...,...,...
127700,Examples of art deco construction can be found...,आर्ट डेको शैली के निर्माण मैरीन ड्राइव और ओवल ...
127701,and put it in our cheeks.,और अपने गालों में डाल लेते हैं।
127702,"As for the other derivatives of sulphur , the ...","जहां तक गंधक के अन्य उत्पादों का प्रश्न है , द..."
127703,its complicated functioning is defined thus in...,Zरचना-प्रकिया को उसने एक पहेली में यों बांधा है .


There are 127705 rows in the dataset. Use 300 rows for training and 50 for validation. I am running on cpu. Will use gpu later.

In [7]:
train_len = 300
train_df = df.iloc[:train_len,:]
test_df = df.iloc[train_len: train_len + 50, :]

In [8]:
train_df

Unnamed: 0,english_sentence,hindi_sentence
0,"However, Paes, who was partnering Australia's ...",आस्ट्रेलिया के पाल हेनली के साथ जोड़ी बनाने वाल...
1,"Whosoever desires the reward of the world, wit...",और जो शख्स (अपने आमाल का) बदला दुनिया ही में च...
2,The value of insects in the biosphere is enorm...,"जैव-मंडल में कीड़ों का मूल्य बहुत है, क्योंकि ..."
3,Mithali To Anchor Indian Team Against Australi...,आस्ट्रेलिया के खिलाफ वनडे टीम की कमान मिताली को
4,After the assent of the Honble President on 8t...,"8 सितम्‍बर, 2016 को माननीय राष्‍ट्रपति की स्‍व..."
...,...,...
295,Regular physical activity can help control you...,नियमित शारीरिक कसरतों से वजन नियंत्रण में रहता...
296,Other users can't use login in YouTube account...,एप्पल टी वि संस्करण से अलग प्रयोगकर्ता अपने यू...
297,"And, being a little intrigued, I went to go me...","और क्योंकि मुझे ये अजीब लगा था, मैं उनसे मिलने..."
298,Somebody's storytelling. Interactive art. You ...,यह कहानी चल रही है. आदान-प्रदान की कला.आप मुझ ...


In [9]:
test_df

Unnamed: 0,english_sentence,hindi_sentence
300,The 10th day of waxing moon of Ashwin month is...,आश्विन शुक्ल दशमी को विजयादशमी का त्यौहार मनाय...
301,"actually near Afghanistan,","जो अफगानिस्तान के निकट है, की अोर संकेत करते हैं"
302,These persons are appointed by the central gov...,इन व्यक्तियों की नियुक्ति केंद्र सरकार द्वारा ...
303,"63. Udaipur was the capital of Mewar, and is c...",63. उदयपुर मेवाड़ के प्राचीन राज्य की ऐतिहासिक...
304,Some videos are for people above 18 years only...,कुछ वीडियॊ उन्हीं उपयॊगकर्तॊं के लिए हैं जिनकी...
305,It is meet tosh river in upper area of Himalay...,हिमालय के ऊपरी भाग में इसमें टोंस तथा बाद में ...
306,Clinical trials still need to be carried out .,इसके लिए अभी और चिकित्सकीय परीक्षणों की जरूरत ...
307,Video seeing,वीडियो रैंकिंग
308,"And depressing ones, such as the fact that","और कुछ अवसादपूर्ण हैं, जैसे नाइजीरिया में"
309,the results can be somewhat comical.,परिणाम कुछ हास्यकारक हो सकते हैं.


Create lists of source and target sentences for training and validation sets

In [10]:
# train
source_train = train_df['english_sentence'].tolist()
target_train = train_df['hindi_sentence'].tolist()

# test
source_test = test_df['english_sentence'].tolist()
target_test = test_df['hindi_sentence'].tolist()

Create training and testing data loaders

In [11]:
# train dataset
train_dataset = TranslationDataset(src=source_train, target=target_train,tokenizer=tokenizer, src_prepend_value=src_prepend_value, target_prepend_value=target_prepend_value)

# test dataset
test_dataset = TranslationDataset(src=source_test, target=target_test,tokenizer=tokenizer, src_prepend_value=src_prepend_value, target_prepend_value=target_prepend_value)

In [12]:
# train dataloader
train_dataloader = TranslationDataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# test dataloader
test_dataloader = TranslationDataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Create the model for training

In [13]:
torch.manual_seed(random_seed) # needed to get same weights for reproducibility
model = Transformer(vocab_size=vocab_size, context_length=context_length, emb_dim=emb_dim, enc_layers=enc_layers, dec_layers=dec_layers, num_heads=num_heads,dropout=dropout, bias=bias)
model.to(device)

Transformer(
  (encoder): Encoder(
    (token_embeddings): Embedding(50260, 768)
    (pos_embeddings): Embedding(4096, 768)
    (encoder_layers): ModuleList(
      (0-4): 5 x EncoderLayer(
        (mlp): MLP(
          (mlp): Sequential(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Linear(in_features=3072, out_features=768, bias=True)
          )
        )
        (attn): MultiHeadAttention(
          (wq): Linear(in_features=768, out_features=768, bias=False)
          (wk): Linear(in_features=768, out_features=768, bias=False)
          (wv): Linear(in_features=768, out_features=768, bias=False)
          (proj): Linear(in_features=768, out_features=768, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm1): LayerNorm()
        (norm2): LayerNorm()
      )
    )
  )
  (decoder): Decoder(
    (token_embeddings): Embedding(50260, 768)
    (pos_embeddings): Embedding(4

Create a optimizer and loss function

Using Adam optimizer here.

In [14]:
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

Create the trainer instance for training the model

In [15]:
trainer = Trainer(model=model, loss_fn=loss_fn, optimizer=optimizer)

In [16]:
trainer.train(train_dataloader=train_dataloader, test_dataloader=test_dataloader, epochs=epochs, device=device)

----- Epoch 0 -----
Batch 0 complete
Batch 1 complete
Batch 2 complete
Batch 3 complete
Batch 4 complete
Batch 5 complete
Batch 6 complete
Batch 7 complete
Batch 8 complete
Batch 9 complete
Batch 10 complete
Batch 11 complete
Batch 12 complete
Batch 13 complete
Batch 14 complete
Batch 15 complete
Batch 16 complete
Batch 17 complete
Batch 18 complete
Training Loss: 121.62535524368286, Test Loss: 3.9993903636932373
-----x-----
----- Epoch 1 -----
Batch 0 complete
Batch 1 complete
Batch 2 complete
Batch 3 complete
Batch 4 complete
Batch 5 complete
Batch 6 complete
Batch 7 complete
Batch 8 complete
Batch 9 complete
Batch 10 complete
Batch 11 complete
Batch 12 complete
Batch 13 complete
Batch 14 complete
Batch 15 complete
Batch 16 complete
Batch 17 complete
Batch 18 complete
Training Loss: 66.12910318374634, Test Loss: 3.3009665608406067
-----x-----
----- Epoch 2 -----
Batch 0 complete
Batch 1 complete
Batch 2 complete
Batch 3 complete
Batch 4 complete
Batch 5 complete
Batch 6 complete
Batc

{'train_loss': [6.401334486509624,
  3.4804791149340177,
  3.1953246342508415,
  3.153182945753399,
  3.146508944661994],
 'test_loss': [3.9993903636932373,
  3.3009665608406067,
  3.169841170310974,
  3.1647984981536865,
  3.264113485813141]}