In [122]:
import numpy as np
import torch
import torch.nn as nn
import math
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

## Input Embeddings and Positional encoding - Encoder

In [123]:
d_model = 512

In [124]:
SPECIAL_TOKENS = {
    '<pad>': 0,
    '<unk>': 1,
    '<sos>': 2,
    '<eos>': 3,
}
source_sentences = ["We are friends"]
target_sentences = ["हम दोस्त हैं"]
sentences = source_sentences + target_sentences

def build_vocab(sentences):
    vocab = dict(SPECIAL_TOKENS)
    idx = len(vocab)
    
    for s in sentences:
        tokens = s.lower().split()
        for token in tokens:
            if token not in vocab:
                vocab[token] = idx
                idx +=1
    return vocab

vocab = build_vocab(sentences)
print(vocab)

{'<pad>': 0, '<unk>': 1, '<sos>': 2, '<eos>': 3, 'we': 4, 'are': 5, 'friends': 6, 'हम': 7, 'दोस्त': 8, 'हैं': 9}


In [125]:
input = "We are friends"

def tokenize(sentence, vocab):
    tokens = sentence.lower().split()
    ids = [vocab.get(t, vocab["<unk>"]) for t in tokens]
    input_ids = torch.tensor(ids).unsqueeze(0) # (3,) -> (1,3)
    return input_ids
    
input_ids = tokenize(input, vocab)
print(input_ids,"\n",input_ids.shape)

tensor([[4, 5, 6]]) 
 torch.Size([1, 3])


In [126]:
embedding = nn.Embedding(num_embeddings=len(vocab), embedding_dim=d_model)
input_embeddings = embedding(input_ids) * math.sqrt(d_model)
print(input_embeddings)
print(input_embeddings.shape)

tensor([[[  5.9759, -42.0194, -15.8839,  ...,  28.2713,  38.8521, -24.3929],
         [ 18.0416, -14.5304, -35.0097,  ...,   8.5690,   1.9877,   7.9900],
         [  4.7285,  -9.1166, -12.1504,  ..., -25.0507, -21.2639,  24.7698]]],
       grad_fn=<MulBackward0>)
torch.Size([1, 3, 512])


In [127]:
def positionalEncoding(input_embeddings, d_model=512):
    seq_length = input_embeddings.shape[1] if input_embeddings.dim() == 3 else input_embeddings.shape[0]   # 3
    positional_encoding = torch.zeros(seq_length, d_model, device=input_embeddings.device)
    
    for pos in range(seq_length):
        for i in range(0, d_model, 2):  # [sin, cos, sin, cos , ...,sin, cos] for 512 dimensions
            PE_sin = math.sin(pos / 10000**(2*i/d_model))
            PE_cos = math.cos(pos / 10000**(2*i/d_model))
            positional_encoding[pos, i] = PE_sin
            positional_encoding[pos, i+1] = PE_cos

    return positional_encoding.unsqueeze(0)

positional_encoding = positionalEncoding(input_embeddings)
print(positional_encoding)
print(positional_encoding.shape)   # (3, 512)


def positionalEncoding2(input_embeddings, d_model=512):
    """
    x: (B, seq_len, d_model)
    Returns: (1, seq_len, d_model) on same device as x
    """
    seq_len = input_embeddings.shape[1]
    device = input_embeddings.device
    position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float, device=device) * 
                         (-math.log(10000.0) / d_model))
    
    pe = torch.zeros(seq_len, d_model, device=device)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe.unsqueeze(0)  # (1, seq_len, d_model)

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  8.0196e-01,  ...,  1.0000e+00,
           1.0746e-08,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.5814e-01,  ...,  1.0000e+00,
           2.1492e-08,  1.0000e+00]]])
torch.Size([1, 3, 512])


In [128]:
def make_padding_mask(input_ids, pad_idx=0):
    # (b, l) -> (b, 1, 1, l)
    return (input_ids != pad_idx).unsqueeze(1).unsqueeze(1)

def make_causal_mask(tgt_len, device):  # (1, tgt_len, tgt_len)
    return torch.tril(torch.ones((tgt_len, tgt_len), dtype=torch.bool, device=device)).unsqueeze(0) 

def combine_padding_and_causal(tgt_pad_mask, causal_mask):    # tgt_pad_mask: (B, 1, 1, L) -> convert to (B, L) valid positions
    batch_size = tgt_pad_mask.shape[0]
    L = tgt_pad_mask.shape[-1]
    valid = tgt_pad_mask.squeeze(1).squeeze(1)   # (batch_size, L)
    causal_b = causal_mask.expand(batch_size, -1, -1)   # (batch_size, L, L)
    valid_src = valid.unsqueeze(1).expand(-1, L, -1)   # (batch_size, L, L)
    combined = causal_b & valid_src
    return combined.unsqueeze(1)  # (batch, 1, L, L)

In [129]:
x = input_embeddings + positional_encoding
print(x)
print(x.shape)


tensor([[[  5.9759, -41.0194, -15.8839,  ...,  29.2713,  38.8521, -23.3929],
         [ 18.8831, -13.9901, -34.2077,  ...,   9.5690,   1.9877,   8.9900],
         [  5.6378,  -9.5328, -11.1922,  ..., -24.0507, -21.2639,  25.7698]]],
       grad_fn=<AddBackward0>)
torch.Size([1, 3, 512])


## Encoder Sub-layer 1 = Multi-Head Attention

In [130]:
# class ScaledDotProductAttention(nn.Module):
#     def __init__(self, d_k):
#         super().__init__()
# W_Q = torch.randn(x.shape[2], x.shape[2])   # (512, 512)
# W_K = torch.randn(x.shape[2], x.shape[2])
# W_V = torch.randn(x.shape[2], x.shape[2])

# Q = x @ W_Q  # (1, 3, 512)
# K = x @ W_K
# V = x @ W_V

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.d_k = d_k
    
    def forward(self, Q, K, V, mask=None):
        attention = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_k)
        if (mask is not None):
            attention = attention.masked_fill(~mask, float('-inf'))
        
        attention_weights = torch.softmax(attention, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output, attention_weights
    

class MultiHeadAttention(nn.Module):
    def __init__(self,d_model=512, heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.heads = heads
        self.d_k = d_model // heads    # 512//8 = 64 because they are concatenated later
        
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.attention = ScaledDotProductAttention(self.d_k)
        self.z = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, q, k, v, mask=None):
        # Linear
        Q = self.W_Q(q)   # (1, 3, 512)
        K = self.W_K(k)
        V = self.W_V(v)
        
        batch_size, tgt_len, _ = Q.shape
        _, src_len, _ = K.shape
        # heads
        Q = Q.view(batch_size, tgt_len, self.heads, self.d_k)   # (1, 3, 8, 64)
        K = K.view(batch_size, src_len, self.heads, self.d_k)
        V = V.view(batch_size, src_len, self.heads, self.d_k)
        
        Q = Q.transpose(1,2)    # (1, 8, 3, 64)
        K = K.transpose(1,2)
        V = V.transpose(1,2)
        
        # attention
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(0).unsqueeze(0)    # (1, 1, tgt_len, src_len) = (1, 1, 3, 3)
            elif mask.dim() == 3:
                mask = mask.unsqueeze(1)     # (1, 1, 3, 3)
            elif mask.dim() == 4:
                pass
            mask = mask.to(torch.bool)
                
                
        output, attention_weights = self.attention(Q, K, V, mask=mask)    # (1, 8, 3, 64) ; (1, 8, 3, 3)
        # concat
        concat_output = output.transpose(1,2)     # (1, 3, 8, 64)
        concat_output = concat_output.contiguous().view(batch_size, tgt_len, self.d_model)   # (1, 3, 512)
        # linear
        output = self.z(concat_output)
        
        output = self.dropout(output)
        
        return output
        
heads = 8
mha = MultiHeadAttention(d_model, heads)
mha_output = mha(x, x, x)   # (1, 3, 512)
print(mha_output)
print(mha_output.shape)

tensor([[[ -5.7154,  -3.6326,  -2.1424,  ...,   0.6965,   4.7687,  11.3503],
         [  6.5223,  -0.0000, -11.3013,  ...,  -5.8819,  -1.1839,  -2.2014],
         [ -3.8892,  -3.0986,  -5.4991,  ...,   7.6042,   7.8150,  11.6261]]],
       grad_fn=<MulBackward0>)
torch.Size([1, 3, 512])


## Feed Forward

In [131]:
class FeedForward(nn.Module):
    def __init__(self, d_model, inner_dim=4*d_model, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.inner_dim = inner_dim
        self.l1 = nn.Linear(d_model, inner_dim)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(inner_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.dropout(self.l2(self.relu(self.l1(x))))

## Encoder

In [132]:
class Encoder(nn.Module):
    def __init__(self, d_model, heads, inner_dim, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.heads = heads
        self.inner_dim = inner_dim

        self.attention = MultiHeadAttention(d_model=self.d_model, heads=self.heads)
        self.layer_norm1 = nn.LayerNorm(self.d_model)
        self.layer_norm2 = nn.LayerNorm(self.d_model)
        self.ffn = FeedForward(self.d_model, inner_dim=4*self.d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, src_mask=None):
        y1 = self.attention(x, x, x, mask = src_mask)
        x = x + y1
        x = self.layer_norm1(x)
        
        y2 = self.ffn(x)
        x = x + y2
        encoder_output = self.layer_norm2(x)
        
        return encoder_output

input = "We are friends"
input_ids = tokenize(input, vocab)
input_embeddings = embedding(input_ids) * math.sqrt(d_model)
positional_encoding = positionalEncoding(input_embeddings, d_model)
x = input_embeddings + positional_encoding
dropout = nn.Dropout(p=0.1)
x = dropout(x)

encoder = Encoder(d_model=512, heads=8, inner_dim=2048)
encoder_output = encoder(x)   # (1, 3, 512)
print(encoder_output)
print(encoder_output.shape)

tensor([[[ 0.1140, -1.6504, -1.0796,  ...,  0.8825, -0.0107, -0.7957],
         [ 0.5929, -1.0266, -2.0543,  ...,  0.0375, -0.2036, -0.0611],
         [ 0.0303, -0.9900, -0.9896,  ..., -0.7503,  0.0853,  0.7605]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 3, 512])


## Decoder block

In [133]:
target_sentence = "हम दोस्त हैं"
target_ids = tokenize(target_sentence, vocab)

def shift_right(target_ids):
    batch_size = target_ids.shape[0]
    sos = vocab['<sos>']
    sos_tensor = torch.full((batch_size, 1), sos, device=target_ids.device, dtype=torch.long)
    return torch.cat([sos_tensor, target_ids[:, :-1]], dim=1)

decoder_input = shift_right(target_ids)
decoder_labels = target_ids

print(decoder_input)   # <sos> हम दोस्त
print(decoder_labels)  # हम दोस्त हैं

tensor([[2, 7, 8]])
tensor([[7, 8, 9]])


In [134]:
class Decoder(nn.Module):
    def __init__(self, d_model, heads, inner_dim=4*d_model, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.heads = heads
        self.inner_dim = inner_dim
        
        self.masked_attention = MultiHeadAttention(d_model=self.d_model, heads=self.heads)
        self.layer_norm1 = nn.LayerNorm(self.d_model)
        self.cross_attention = MultiHeadAttention(self.d_model, self.heads)
        self.layer_norm2 = nn.LayerNorm(self.d_model)
        self.ffn = FeedForward(d_model=self.d_model)
        self.layer_norm3 = nn.LayerNorm(self.d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, tgt_mask=None, src_mask=None):
        y1 = self.masked_attention(x, x, x, mask=tgt_mask)
        x = x + y1
        x = self.layer_norm1(x)
        
        y2 = self.cross_attention(x, encoder_output, encoder_output, mask=src_mask)
        x = x + y2
        x = self.layer_norm2(x)
        
        y3 = self.ffn(x)
        x = x + y3
        decoder_output = self.layer_norm3(x)
        
        return decoder_output

decoder_input_embeddings = embedding(decoder_input) * math.sqrt(d_model)
decoder_positional_encoding = positionalEncoding(decoder_input_embeddings, d_model)
x = decoder_input_embeddings + decoder_positional_encoding
x = dropout(x)

decoder = Decoder(d_model, heads)
decoder_output = decoder(x, encoder_output)
print(decoder_output)
print(decoder_output.shape)

tensor([[[ 0.2399, -0.0883, -1.1978,  ...,  0.7475, -0.8898,  0.0947],
         [ 0.1491, -1.6535, -0.3827,  ...,  1.7874,  0.3774,  0.7398],
         [-1.0974, -2.1587,  2.0471,  ...,  1.8877,  0.7828, -1.1030]]],
       grad_fn=<NativeLayerNormBackward0>)
torch.Size([1, 3, 512])


## Transformer

In [135]:
class Transformer(nn.Module):
    def __init__(self, d_model, heads, dropout, vocab_size, num_layers, embedding=None):
        super().__init__()
        self.d_model = d_model
        self.heads = heads
        self.dropout = dropout
        self.vocab_size = vocab_size
        self.inner_dim = 4*self.d_model
        self.num_layers = num_layers
        # self.embedding = self.embedding
        
        self.encoder_layers = nn.ModuleList([
            Encoder(self.d_model, self.heads, self.inner_dim, dropout=self.dropout)
            for _ in range(self.num_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            Decoder(self.d_model, self.heads, self.inner_dim, dropout=self.dropout) 
            for _ in range(self.num_layers)
        ])      
        self.linear = nn.Linear(self.d_model, self.vocab_size, bias=True)
        
        if embedding is not None:  # tying weighs of linear with embedding
            assert tuple(embedding.weight.shape) == (self.vocab_size, self.d_model), \
                f"Embedding weight shape {tuple(embedding.weight.shape)} doesn't match (vocab_size, d_model)"
            
            self.linear.weight = embedding.weight
            if self.linear.bias is not None:
                nn.init.zeros_(self.linear.bias)
            
    def forward(self, x_encoder, x_decoder, tgt_mask=None, src_mask=None):
        
        # Encoder stack
        x = x_encoder
        for layer in self.encoder_layers:
            x = layer(x, src_mask=src_mask)
        encoder_output = x
        
        # decoder stack
        y = x_decoder
        for layer in self.decoder_layers:
            y = layer(y, encoder_output, tgt_mask=tgt_mask, src_mask=src_mask)
        decoder_output = y
        
        logits = self.linear(decoder_output)
        
        return logits

d_model = 512
heads = 8
dropout = 0.1
vocab_size = len(vocab)
num_layers = 6 
embedding=embedding
model = Transformer(d_model, heads, dropout, vocab_size, num_layers, embedding)
print(model)


Transformer(
  (encoder_layers): ModuleList(
    (0-5): 6 x Encoder(
      (attention): MultiHeadAttention(
        (W_Q): Linear(in_features=512, out_features=512, bias=True)
        (W_K): Linear(in_features=512, out_features=512, bias=True)
        (W_V): Linear(in_features=512, out_features=512, bias=True)
        (attention): ScaledDotProductAttention()
        (z): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (l1): Linear(in_features=512, out_features=2048, bias=True)
        (relu): ReLU()
        (l2): Linear(in_features=2048, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder_layers): ModuleList(
    (0-5): 6 x Decoder(
      (mask

In [136]:
%pip install -q torchinfo

In [137]:
from torchinfo import summary
# encoder and decoder inputs: (batch, seq_len, d_model)
summary(model, input_size=[(1, 3, 512), (1, 3, 512)], col_names=("input_size","output_size","num_params"))

Layer (type:depth-idx)                                  Input Shape               Output Shape              Param #
Transformer                                             [1, 3, 512]               [1, 3, 10]                --
├─ModuleList: 1-1                                       --                        --                        --
│    └─Encoder: 2-1                                     [1, 3, 512]               [1, 3, 512]               --
│    │    └─MultiHeadAttention: 3-1                     [1, 3, 512]               [1, 3, 512]               1,050,624
│    │    └─LayerNorm: 3-2                              [1, 3, 512]               [1, 3, 512]               1,024
│    │    └─FeedForward: 3-3                            [1, 3, 512]               [1, 3, 512]               2,099,712
│    │    └─LayerNorm: 3-4                              [1, 3, 512]               [1, 3, 512]               1,024
│    └─Encoder: 2-2                                     [1, 3, 512]               [1, 3

## Training

In [138]:
def train_one_epoch(model, iterator, loss_fn, optimizer, scheduler, d_model, dropout, device):
    model.train()
    epoch_loss = 0.0
    num_batches = 0
    
    for batch in iterator:
        src_input_ids, tgt_input_ids = batch
    
        src_input_ids = src_input_ids.to(device)
        tgt_input_ids = tgt_input_ids.to(device)
        # src_input_ids = torch.cat([tokenize(s, vocab) for s in src_sentences], dim=0).to(device)  # (batch, seq_len)
        # tgt_input_ids = torch.cat([tokenize(t, vocab) for t in tgt_sentences], dim=0).to(device)
        # src_input_ids = tokenize(src_sentences, vocab).to(device)
        # tgt_input_ids = tokenize(tgt_sentences, vocab).to(device)
        
        tgt_labels = tgt_input_ids.clone()
        
        src_input_embeddings = embedding(src_input_ids) * math.sqrt(d_model)
        src_positional_encoding = positionalEncoding2(src_input_embeddings, d_model)
        x_encoder = src_input_embeddings + src_positional_encoding
        x_encoder = F.dropout(x_encoder, p=dropout, training=model.training)
        
        tgt_input_ids = shift_right(tgt_input_ids)
        tgt_input_embeddings = embedding(tgt_input_ids) * math.sqrt(d_model)
        tgt_positional_encoding = positionalEncoding2(tgt_input_embeddings, d_model)
        x_decoder = tgt_input_embeddings + tgt_positional_encoding
        x_decoder = F.dropout(x_decoder, p=dropout, training=model.training)
        
        src_mask = make_padding_mask(src_input_ids, pad_idx=vocab["<pad>"])
        tgt_pad_mask = make_padding_mask(tgt_input_ids, pad_idx=vocab["<pad>"])
        causal = make_causal_mask(tgt_input_ids.shape[1], device=device)
        tgt_mask = combine_padding_and_causal(tgt_pad_mask, causal)
        
        optimizer.zero_grad()
        logits = model(x_encoder, x_decoder, tgt_mask=tgt_mask, src_mask=src_mask)
        
        loss = loss_fn(logits.view(-1, logits.size(-1)), tgt_labels.view(-1))
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        epoch_loss += loss.item()
        num_batches += 1
        
    return epoch_loss / num_batches


In [139]:
# source_sentences = ["We are friends","I love you","Hello how are you","What is your name","The cat is on the mat","I am a student","She is my sister","The book is on the table","Good morning","Thank you very much","I am from India","How old are you","Where do you live","This is my house","The weather is nice today","I want to eat food","She is reading a book","They are playing cricket","We go to school","He is a doctor","My father is tall","The dog is barking","Please help me","I am happy","See you later"]
# target_sentences = ["हम दोस्त हैं <eos>","मैं तुमसे प्यार करता हूं <eos>","नमस्ते आप कैसे हैं <eos>","आपका नाम क्या है <eos>","बिल्ली चटाई पर है <eos>","मैं एक छात्र हूं <eos>","वह मेरी बहन है <eos>","किताब मेज पर है <eos>","शुभ प्रभात <eos>","धन्यवाद बहुत बहुत <eos>","मैं भारत से हूं <eos>","आप कितने साल के हैं <eos>","आप कहां रहते हैं <eos>","यह मेरा घर है <eos>","आज मौसम अच्छा है <eos>","मैं खाना खाना चाहता हूं <eos>","वह किताब पढ़ रही है <eos>","वे क्रिकेट खेल रहे हैं <eos>","हम स्कूल जाते हैं <eos>","वह डॉक्टर है <eos>","मेरे पिता लंबे हैं <eos>","कुत्ता भौंक रहा है <eos>","कृपया मेरी मदद करें <eos>","मैं खुश हूं <eos>","बाद में मिलते हैं <eos>"]

In [140]:
from datasets import load_dataset
dataset = load_dataset("cfilt/iitb-english-hindi")          # or: load_dataset("cfilt/iitb-english-hindi", "default")
# Take first 2000 clean pairs
small_data = dataset["train"].select(range(2000))
source_sentences = [ex["translation"]["en"] for ex in small_data]
target_sentences = [ex["translation"]["hi"] + " <eos>" for ex in small_data]

In [141]:
sentences = source_sentences + target_sentences
vocab = build_vocab(sentences)
vocab_size = len(vocab)
id2word = {idx: word for word, idx in vocab.items()}

embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
print(f"Vocab size: {vocab_size} | Embedding: {embedding.weight.shape}")


Vocab size: 1542 | Embedding: torch.Size([1542, 512])


In [142]:
class SentencePairDataset(Dataset):
    def __init__(self, src_sentences, tgt_sentences):
        assert len(src_sentences) == len(tgt_sentences)
        self.pairs = list(zip(src_sentences, tgt_sentences))
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        return self.pairs[idx]

def sentence_collate(batch):
    # return tuple(zip(*batch))    # (list(src), list(tgt))
    src_list, tgt_list = zip(*batch)
    
    # Pad src
    src_lens = [len(tokenize(s, vocab)[0]) for s in src_list]  # lengths
    max_src = max(src_lens)
    src_padded = torch.full((len(batch), max_src), vocab['<pad>'], dtype=torch.long)
    for i, s in enumerate(src_list):
        ids = tokenize(s, vocab)[0]  # (seq_len,)
        src_padded[i, :len(ids)] = ids
    
    # Pad tgt (with <eos> already in strings)
    tgt_lens = [len(tokenize(t, vocab)[0]) for t in tgt_list]
    max_tgt = max(tgt_lens)
    tgt_padded = torch.full((len(batch), max_tgt), vocab['<pad>'], dtype=torch.long)
    for i, t in enumerate(tgt_list):
        ids = tokenize(t, vocab)[0]
        tgt_padded[i, :len(ids)] = ids
    
    return src_padded, tgt_padded  # now returns tensors

In [143]:
class NoamScheduler:
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup = warmup_steps
        self._step = 0

    def step(self):
        self._step += 1
        lr = (self.d_model ** -0.5) * min(self._step ** -0.5,
                                         self._step * (self.warmup ** -1.5))
        for p in self.optimizer.param_groups:
            p['lr'] = lr
        return lr

In [144]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

dataset = SentencePairDataset(source_sentences, target_sentences)
iterator = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=sentence_collate
)

loss_fn = nn.CrossEntropyLoss(ignore_index=vocab['<pad>'], label_smoothing=0.1)
dropout = 0.1
d_model = 512
heads = 8
num_layers = 6
model = Transformer(d_model, heads, dropout, vocab_size=len(vocab), num_layers=num_layers, embedding=embedding)
model = model.to(device)

embedding = embedding.to(device)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
if embedding.weight.dim() > 1:
    nn.init.xavier_uniform_(embedding.weight)
    
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(embedding.parameters()),
    lr=0.0, betas=(0.9, 0.98), eps=1e-9
)
scheduler = NoamScheduler(optimizer, d_model=d_model, warmup_steps=2000)

N_EPOCHS = 100

for epoch in range(N_EPOCHS):
    train_loss = train_one_epoch(
        model, iterator, 
        loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler,
        d_model=d_model, dropout=dropout, 
        device=device)
    print(f'Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')

  return disable_fn(*args, **kwargs)


Epoch: 01 | Train Loss: 6.419 | Train PPL: 613.659
Epoch: 02 | Train Loss: 5.216 | Train PPL: 184.108
Epoch: 03 | Train Loss: 4.092 | Train PPL:  59.868
Epoch: 04 | Train Loss: 3.363 | Train PPL:  28.887
Epoch: 05 | Train Loss: 2.963 | Train PPL:  19.363
Epoch: 06 | Train Loss: 2.726 | Train PPL:  15.265
Epoch: 07 | Train Loss: 2.588 | Train PPL:  13.303
Epoch: 08 | Train Loss: 2.452 | Train PPL:  11.609
Epoch: 09 | Train Loss: 2.403 | Train PPL:  11.056
Epoch: 10 | Train Loss: 2.376 | Train PPL:  10.763
Epoch: 11 | Train Loss: 2.324 | Train PPL:  10.219
Epoch: 12 | Train Loss: 2.264 | Train PPL:   9.626
Epoch: 13 | Train Loss: 2.356 | Train PPL:  10.553
Epoch: 14 | Train Loss: 2.287 | Train PPL:   9.850
Epoch: 15 | Train Loss: 2.268 | Train PPL:   9.660
Epoch: 16 | Train Loss: 2.235 | Train PPL:   9.350
Epoch: 17 | Train Loss: 2.221 | Train PPL:   9.216
Epoch: 18 | Train Loss: 2.296 | Train PPL:   9.939
Epoch: 19 | Train Loss: 2.354 | Train PPL:  10.524
Epoch: 20 | Train Loss: 2.581 |

In [145]:
torch.save(model.state_dict(), "pt2.pt")

## Inference

In [146]:


model.eval()

def translate(src_sentence, model, embedding, vocab, positionalEncoding, 
              make_padding_mask, make_causal_mask, combine_padding_and_causal, 
              device, max_len=50):
    
    with torch.no_grad():
        # 1. Encode source (once)
        src_ids = tokenize(src_sentence, vocab).to(device)
        src_emb = embedding(src_ids) * math.sqrt(d_model)
        src_pe = positionalEncoding(src_emb, d_model).to(device)
        x_encoder = src_emb + src_pe
        src_mask = make_padding_mask(src_ids, pad_idx=vocab["<pad>"])

        # 2. Start decoding with <sos>
        tgt_ids = torch.tensor([[vocab['<sos>']]], device=device)

        for step in range(max_len):
            # Decoder input
            tgt_emb = embedding(tgt_ids) * math.sqrt(d_model)
            tgt_pe = positionalEncoding(tgt_emb, d_model).to(device)
            x_decoder = tgt_emb + tgt_pe

            # Masks (using your exact functions)
            tgt_pad_mask = make_padding_mask(tgt_ids, pad_idx=vocab["<pad>"])
            causal = make_causal_mask(tgt_ids.shape[1], device=device)
            tgt_mask = combine_padding_and_causal(tgt_pad_mask, causal)

            # Forward
            logits = model(x_encoder, x_decoder, tgt_mask=tgt_mask, src_mask=src_mask)

            # Softmax + argmax (exactly as in the paper diagram)
            probs = torch.softmax(logits[0, -1, :], dim=-1)      # (vocab_size,)
            next_id = torch.argmax(probs).item()
            next_token = torch.tensor([[next_id]], device=device)

            # Optional nice debug print
            # print(f"Step {step+1:2d} → {id2word[next_id]:<10} | Conf: {probs[next_id]:.4f}")

            # Append
            tgt_ids = torch.cat([tgt_ids, next_token], dim=1)

            if next_id == vocab['<eos>']:
                break

        # Convert to sentence
        translated_tokens = [id2word[tok.item()] for tok in tgt_ids[0][1:]]  # skip <sos>
        translated = ' '.join(translated_tokens).replace('<eos>', '').strip()

        return translated

test_sentence = "We are friends"
print("Source      :", test_sentence)
print("Translation :", translate(test_sentence, model, embedding, vocab, 
                                 positionalEncoding2, make_padding_mask, 
                                 make_causal_mask, combine_padding_and_causal, 
                                 device, max_len=30))

Source      : We are friends
Translation : पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली पिछली


In [None]:
checkpoint = torch.load("pt2.pt", map_location=device)
model.load_state_dict(checkpoint)
model = model.to(device)