# Transformer
We will build an English to German translator model.

In [90]:
import torch
import torch.nn as nn
import math

In [91]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### Multi-head Attention 

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, 'd_model must be divisible by num_heads'
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        output = torch.matmul(attn_weights, V)
        return output

    def split_heads(self, x):
        batch_size, seq_len, _ = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_len, _ = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)

        Q = self.split_heads(Q)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output


### Position-wise Feed-Forward Networks

In [93]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=None, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # GELU tends to perform better but is more computationally expensive than ReLU
        self.gelu = nn.GELU()
        # self.relu = nn.ReLU()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.W1 = nn.Linear(d_model, d_ff, bias=True) 
        self.W2 = nn.Linear(d_ff, d_model, bias=True) 

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x = self.W1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.W2(x)
        x = self.dropout(x)
        
        return x + residual

## Attention blocks
The full attention block consists of LayerNorm, followed by Multi-Headed Attention layer, followed by residual connections.

### Self-Attention block

In [None]:
class SelfAttentionBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        residual = x
        x = self.norm(x)
        x = self.attn(x, x, x, mask)
        x = self.dropout(x) 
        return x + residual

### Cross-Attention block

In [95]:
class CrossAttentionBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, mask=None):
        residual = x
        x = self.norm(x)
        x = self.attn(x, encoder_output, encoder_output, mask)
        return residual + self.dropout(x)


## Decoder Layer and Encoder Layer

In [96]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = SelfAttentionBlock(d_model, num_heads, dropout)
        self.cross_attn = CrossAttentionBlock(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        x = self.self_attn(x, tgt_mask)
        x = self.cross_attn(x, encoder_output, src_mask)
        x = self.ffn(x)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = SelfAttentionBlock(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        
    def forward(self, x, mask=None):
        x = self.self_attn(x, mask)
        x = self.ffn(x)
        return x

### Positional Encoding

In [97]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super().__init__()
        position = torch.arange(max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

## Full Model

In [None]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, 
                 num_layers=6, d_ff=2048, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.encoder_embed = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embed = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )

        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
        )
            
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_len = tgt.size(1)
        causal_mask = (1 - torch.triu(torch.ones(1, seq_len, seq_len), diagonal=1)).bool()
        tgt_mask = tgt_mask & causal_mask.to(device)
        return src_mask, tgt_mask
        
    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        
        src_emb = self.encoder_embed(src) * math.sqrt(d_model) 
        src_emb = self.pos_encoding(src_emb)
        src_emb = self.dropout(src_emb)
        
        tgt_emb = self.decoder_embed(tgt) * math.sqrt(d_model) 
        tgt_emb =self.pos_encoding(tgt_emb)
        tgt_emb = self.dropout(tgt_emb)
        
        enc_output = src_emb
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, src_mask)
            
        dec_output = tgt_emb
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, enc_output, src_mask, tgt_mask)
            
        return self.fc_out(dec_output)
            


## Data

In [99]:
from datasets import load_dataset
import pandas as pd

train_dataset = load_dataset('iwslt2017', 'iwslt2017-en-de', trust_remote_code=True, split='train')
val_dataset = load_dataset('iwslt2017', 'iwslt2017-en-de', trust_remote_code=True, split='validation')
test_dataset = load_dataset('iwslt2017', 'iwslt2017-en-de', trust_remote_code=True, split='test')

In [100]:
print(f"Train dataset: {train_dataset}")
print(f"Validation dataset: {val_dataset}")
print(f"Test dataset: {test_dataset}")

Train dataset: Dataset({
    features: ['translation'],
    num_rows: 206112
})
Validation dataset: Dataset({
    features: ['translation'],
    num_rows: 888
})
Test dataset: Dataset({
    features: ['translation'],
    num_rows: 8079
})


In [101]:
def prepare_data():
    dataset = load_dataset('iwslt2017', 'iwslt2017-en-de')
    train_df = pd.DataFrame(dataset['train']['translation'])
    valid_df = pd.DataFrame(dataset['validation']['translation'])
    test_df = pd.DataFrame(dataset['test']['translation'])
    return train_df, valid_df, test_df

train_df, val_df, test_df = prepare_data()

In [102]:
train_df.head()

Unnamed: 0,de,en
0,"Vielen Dank, Chris.","Thank you so much, Chris."
1,"Es ist mir wirklich eine Ehre, zweimal auf die...",And it's truly a great honor to have the oppor...
2,Ich bin wirklich begeistert von dieser Konfere...,"I have been blown away by this conference, and..."
3,"Das meine ich ernst, teilweise deshalb -- weil...","And I say that sincerely, partly because I ne..."
4,Versetzen Sie sich mal in meine Lage!,Put yourselves in my position.


In [103]:
class TranslationDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        src_text = self.data.iloc[idx]['en']
        tgt_text = self.data.iloc[idx]['de']
        
        src_enc = self.tokenizer(
            src_text, 
            max_length=self.max_len, 
            padding='max_length', 
            truncation=True, 
            return_tensors='pt'
        )
        
        tgt_enc = self.tokenizer(
            tgt_text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True, 
            return_tensors='pt'
        )
        
        return {
            'src_ids': src_enc['input_ids'].squeeze(),
            'tgt_ids': tgt_enc['input_ids'].squeeze(),
            'src_mask': src_enc['attention_mask'].squeeze(),
            'tgt_mask': tgt_enc['attention_mask'].squeeze()
        }

In [104]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [None]:
def load_data(batch_size=32):
    dataset = load_dataset('iwslt2017', 'iwslt2017-en-de')
    tokenizer = AutoTokenizer.from_pretrained('t5-small')
    
    train_df = pd.DataFrame(dataset['train']['translation'])
    val_df = pd.DataFrame(dataset['validation']['translation'])
    test_df = pd.DataFrame(dataset['test']['translation'])
    
    # train_df['de'] = '<s> ' + train_df['de'] + ' </s>'
    # val_df['de'] = '<s> ' + val_df['de'] + ' </s>'
    # test_df['de'] = '<s> ' + test_df['de'] + ' </s>'
    
    train_ds = TranslationDataset(train_df, tokenizer, 128)
    val_ds = TranslationDataset(val_df, tokenizer, 128)
    test_ds = TranslationDataset(test_df, tokenizer, 128)
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size)
    test_loader = DataLoader(test_ds, batch_size=batch_size)
    
    return train_loader, val_loader, test_loader, tokenizer

## Training and testing

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=10, lr=1e-5):
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    i=1
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch in train_loader:
            src_ids = batch['src_ids'].to(device)
            tgt_ids = batch['tgt_ids'].to(device)
            
            outputs = model(src_ids, tgt_ids[:, :-1])
            loss = criterion(outputs.contiguous().view(-1, outputs.size(-1)), 
                            tgt_ids[:, 1:].contiguous().view(-1))
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()
            
            if i % 100 == 0:
                print(f'     Batch {i}, Train Loss: {loss.item()}')
            i += 1

        print(f'Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}')
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                src_ids = batch['src_ids'].to(device)
                tgt_ids = batch['tgt_ids'].to(device)
                
                outputs = model(src_ids, tgt_ids[:, :-1])
                loss = criterion(outputs.contiguous().view(-1, outputs.size(-1)), 
                                tgt_ids[:, 1:].contiguous().view(-1))
                val_loss += loss.item()
                
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}\n')

     Batch 100, Train Loss: nan


In [None]:
# main training loop
if __name__ == '__main__':
    train_loader, val_loader, _, tokenizer = load_data()
    
    model = Transformer(
        src_vocab_size=tokenizer.vocab_size,
        tgt_vocab_size=tokenizer.vocab_size,
        d_model=512,
        num_heads=8,
        num_layers=6,
        d_ff=2048,
        max_seq_len=128
    ).to(device)
    
    train_model(model, train_loader, val_loader)
        
    torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    }, 
    './translator.pth')

