In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.utils.data import Dataset, DataLoader
from collections import Counter

# Positional Encoding with dropout
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Scaled Dot-Product Attention
def scaled_dot_product_attention(query, key, value, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attn = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn, value)
    return output, attn

# Multi-Head Attention (corrected)
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len = query.size(1)
        q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        if mask is not None:
            if mask.dim() == 5:  # Case when mask is (batch_size, 1, 1, seq_len_q, seq_len_k)
                mask = mask.squeeze(1)  # Remove the extra dimension
            elif mask.dim() == 4 and mask.size(1) != self.num_heads:
                mask = mask.expand(-1, self.num_heads, -1, -1)
        output, attn = scaled_dot_product_attention(q, k, v, mask)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.w_o(output), attn

# Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(torch.relu(self.linear1(x)))

# Encoder Layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output, _ = self.mha(x, x, x, mask)
        out1 = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff(out1)
        out2 = self.norm2(out1 + self.dropout(ff_output))
        return out2

# Decoder Layer
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_mha = MultiHeadAttention(d_model, num_heads)
        self.cross_mha = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        self_attn_output, _ = self.self_mha(x, x, x, tgt_mask)
        out1 = self.norm1(x + self.dropout(self_attn_output))
        cross_attn_output, _ = self.cross_mha(out1, enc_output, enc_output, src_mask)
        out2 = self.norm2(out1 + self.dropout(cross_attn_output))
        ff_output = self.ff(out2)
        out3 = self.norm3(out2 + self.dropout(ff_output))
        return out3

# Transformer Model
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=128, num_heads=8, num_layers=4, d_ff=256, max_len=15, dropout=0.1):
        super(Transformer, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, dropout, max_len)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.linear = nn.Linear(d_model, tgt_vocab_size)
        self.d_model = d_model

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        src_emb = self.pos_enc(src_emb)
        tgt_emb = self.pos_enc(tgt_emb)
        enc_output = src_emb
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, src_mask)
        dec_output = tgt_emb
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, enc_output, src_mask, tgt_mask)
        output = self.linear(dec_output)
        return output

# Mask generation functions
def generate_square_subsequent_mask(sz, device):
    mask = torch.tril(torch.ones(sz, sz, device=device))
    return mask

def create_padding_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)

# Toy data for simple sentence-pair data (English to French)
data = [
    ("hello", "bonjour"),
    ("thank you", "merci"),
    ("good morning", "bon matin"),
    ("good night", "bonne nuit"),
    ("yes", "oui"),
    ("no", "non"),
    ("how are you", "comment ca va"),
    ("i am fine", "je vais bien"),
    ("what is your name", "comment tu t'appelles"),
    ("my name is", "je m'appelle"),
    ("goodbye", "au revoir"),
    ("please", "s'il vous plait"),
    ("excuse me", "excusez moi"),
    ("i love you", "je t'aime"),
    ("see you later", "a bientot"),
]

# Build vocabularies
special_tokens = ['<pad>', '<sos>', '<eos>', '<unk>']
all_src_words = []
all_tgt_words = []
for src_sent, tgt_sent in data:
    all_src_words.extend(src_sent.split())
    all_tgt_words.extend(tgt_sent.split())
src_vocab = {token: idx for idx, token in enumerate(special_tokens)}
for word in set(all_src_words):
    if word not in src_vocab:
        src_vocab[word] = len(src_vocab)
tgt_vocab = {token: idx for idx, token in enumerate(special_tokens)}
for word in set(all_tgt_words):
    if word not in tgt_vocab:
        tgt_vocab[word] = len(tgt_vocab)
src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)
pad_idx = tgt_vocab['<pad>']
sos_idx = tgt_vocab['<sos>']
eos_idx = tgt_vocab['<eos>']
unk_idx = tgt_vocab['<unk>']
print(f"Source vocabulary size: {src_vocab_size}")
print(f"Target vocabulary size: {tgt_vocab_size}")

# Dataset
class TranslationDataset(Dataset):
    def __init__(self, data, src_vocab, tgt_vocab, max_len=15):
        self.data = data
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        src_tokens = ['<sos>'] + src.split() + ['<eos>']
        tgt_tokens = ['<sos>'] + tgt.split() + ['<eos>']
        src_ids = [self.src_vocab.get(token, self.src_vocab['<unk>']) for token in src_tokens]
        tgt_ids = [self.tgt_vocab.get(token, self.tgt_vocab['<unk>']) for token in tgt_tokens]
        src_ids += [self.src_vocab['<pad>']] * (self.max_len - len(src_ids))
        tgt_ids += [self.tgt_vocab['<pad>']] * (self.max_len - len(tgt_ids))
        return torch.tensor(src_ids[:self.max_len], dtype=torch.long), torch.tensor(tgt_ids[:self.max_len], dtype=torch.long)

dataset = TranslationDataset(data, src_vocab, tgt_vocab, max_len=15)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Initialize model
model = Transformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=128,
    num_heads=8,
    num_layers=4,
    d_ff=256,
    max_len=15,
    dropout=0.1
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.98), eps=1e-9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

# Training loop
num_epochs = 100
print("Starting training...")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (src, tgt) in enumerate(loader):
        src = src.to(device)
        tgt = tgt.to(device)
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        tgt_seq_len = tgt_input.size(1)
        tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
        src_padding_mask = create_padding_mask(src, pad_idx).to(device)
        tgt_padding_mask = create_padding_mask(tgt_input, pad_idx).to(device)
        tgt_mask = tgt_mask.unsqueeze(0) * tgt_padding_mask
        output = model(src, tgt_input, src_padding_mask, tgt_mask)
        loss = criterion(output.reshape(-1, tgt_vocab_size), tgt_output.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    avg_loss = total_loss / len(loader)
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}')

# Sample Scaled Dot-Product Attention Demonstration
print("\n" + "="*50)
print("ATTENTION MECHANISM DEMONSTRATIONS")
print("="*50)
dummy_q = torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])
dummy_k = dummy_q
dummy_v = dummy_q
output, attn = scaled_dot_product_attention(dummy_q, dummy_k, dummy_v)
print('Sample Attention weights:')
print(attn)

# Sample Multi-Head Attention Demonstration
mha = MultiHeadAttention(d_model=128, num_heads=8)
dummy_x = torch.rand(2, 5, 128)
output, attn = mha(dummy_x, dummy_x, dummy_x)
print(f'\nSample Multi-Head Attention output shape: {output.shape}')
print(f'Sample Multi-Head Attention weights shape: {attn.shape}')

# Evaluation Examples
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)
inv_src_vocab = {v: k for k, v in src_vocab.items()}
inv_tgt_vocab = {v: k for k, v in tgt_vocab.items()}
model.eval()
with torch.no_grad():
    for test_idx in range(min(5, len(dataset))):
        src_tensor, tgt_tensor = dataset[test_idx]
        src = src_tensor.unsqueeze(0).to(device)
        tgt = tgt_tensor.unsqueeze(0).to(device)
        tgt_input = tgt[:, :-1]
        tgt_seq_len = tgt_input.size(1)
        tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
        src_padding_mask = create_padding_mask(src, pad_idx).to(device)
        tgt_padding_mask = create_padding_mask(tgt_input, pad_idx).to(device)
        tgt_mask = tgt_mask.unsqueeze(0) * tgt_padding_mask
        output = model(src, tgt_input, src_padding_mask, tgt_mask)
        pred = output.argmax(dim=-1)[0]
        src_words = [inv_src_vocab.get(id.item(), '<unk>') for id in src[0] if id.item() not in [pad_idx]]
        tgt_words = [inv_tgt_vocab.get(id.item(), '<unk>') for id in tgt[0] if id.item() not in [pad_idx]]
        pred_words = [inv_tgt_vocab.get(id.item(), '<unk>') for id in pred if id.item() not in [pad_idx]]
        print(f"\nExample {test_idx + 1}:")
        print(f"Source: {' '.join(src_words)}")
        print(f"Target: {' '.join(tgt_words)}")
        print(f"Predicted: {' '.join(pred_words)}")

# Inference function
def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=15):
    model.eval()
    inv_tgt_vocab = {v: k for k, v in tgt_vocab.items()}
    src_tokens = ['<sos>'] + src_sentence.split() + ['<eos>']
    src_ids = [src_vocab.get(token, src_vocab['<unk>']) for token in src_tokens]
    src_ids += [src_vocab['<pad>']] * (max_len - len(src_ids))
    src_tensor = torch.tensor(src_ids[:max_len], dtype=torch.long).unsqueeze(0).to(device)
    tgt_ids = [tgt_vocab['<sos>']]
    with torch.no_grad():
        for _ in range(max_len - 1):
            tgt_tensor = torch.tensor(tgt_ids, dtype=torch.long).unsqueeze(0).to(device)
            tgt_mask = generate_square_subsequent_mask(len(tgt_ids), device)
            src_padding_mask = create_padding_mask(src_tensor, pad_idx).to(device)
            output = model(src_tensor, tgt_tensor, src_padding_mask, tgt_mask)
            next_token = output[0, -1].argmax().item()
            if next_token == tgt_vocab['<eos>']:
                break
            tgt_ids.append(next_token)
    pred_words = [inv_tgt_vocab.get(id, '<unk>') for id in tgt_ids[1:]]
    return ' '.join(pred_words)

print("\n" + "="*50)
print("INTERACTIVE TRANSLATION EXAMPLES")
print("="*50)
test_sentences = ["hello", "thank you", "good morning", "how are you"]
for sent in test_sentences:
    if all(word in src_vocab for word in sent.split()):
        translation = translate(model, sent, src_vocab, tgt_vocab, device)
        print(f"'{sent}' -> '{translation}'")
    else:
        print(f"'{sent}' -> (contains unknown words)")

print("\nTraining completed!")