## LSTM WITH ATTENTION (CODED FROM SCRATCH)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

#########################################
# 1) Toy Data / Vocabulary (Single Example)
#########################################

# We use one source-target pair.
src_sentence = "i am student"       # Changed source
tgt_sentence = "je suis etudiant"   # Changed target

SRC_WORDS = ["<pad>", "<sos>", "<eos>", "i", "am", "a", "student",
             "you", "are", "teacher", "he", "is", "happy"]
TGT_WORDS = ["<pad>", "<sos>", "<eos>", "je", "suis", "etudiant",
             "tu", "es", "professeur", "il", "est", "content"]

src_stoi = {w: i for i, w in enumerate(SRC_WORDS)}
tgt_stoi = {w: i for i, w in enumerate(TGT_WORDS)}
src_itos = {i: w for w, i in src_stoi.items()}
tgt_itos = {i: w for w, i in tgt_stoi.items()}

def encode(sentence, stoi_dict):
    return [stoi_dict[w] for w in sentence.split()]

def add_sos_eos(seq, sos_idx, eos_idx):
    return [sos_idx] + seq + [eos_idx]

# Encode the sentences and add <sos> and <eos>
src_ids = add_sos_eos(encode(src_sentence, src_stoi), src_stoi["<sos>"], src_stoi["<eos>"])
tgt_ids = add_sos_eos(encode(tgt_sentence, tgt_stoi), tgt_stoi["<sos>"], tgt_stoi["<eos>"])

# Convert to tensors and add a batch dimension of 1 (for simplicity)
src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0)  # Shape: [1, src_len]
tgt_tensor = torch.tensor(tgt_ids, dtype=torch.long).unsqueeze(0)  # Shape: [1, tgt_len]

print("Source Sentence:", src_sentence)
print("Encoded Source IDs:", src_ids)
print("Target Sentence:", tgt_sentence)
print("Encoded Target IDs:", tgt_ids)


Source Sentence: i am student
Encoded Source IDs: [1, 3, 4, 6, 2]
Target Sentence: je suis etudiant
Encoded Target IDs: [1, 3, 4, 5, 2]


In [None]:
#########################################
# 2) Unidirectional LSTM Encoder
#########################################

class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        # Unidirectional LSTM (bidirectional=False)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=False)

    def forward(self, src):
        # src: [1, src_len]
        embedded = self.embedding(src)  # [1, src_len, embed_dim]
        outputs, (h, c) = self.rnn(embedded)
        # outputs: [1, src_len, hidden_dim]
        # h, c: [1, 1, hidden_dim]
        return outputs, (h, c)

#########################################
# 3) Bahdanau Attention (for Unidirectional Encoder)
#########################################

class BahdanauAttention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        # For unidirectional encoder, encoder outputs have dimension enc_hid_dim
        self.attn = nn.Linear(enc_hid_dim + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: [1, dec_hid_dim] (batch=1)
        encoder_outputs: [1, src_len, enc_hid_dim]
        """
        src_len = encoder_outputs.size(1)
        # Expand decoder_hidden to [1, src_len, dec_hid_dim]
        dec_hidden_expanded = decoder_hidden.unsqueeze(1).repeat(1, src_len, 1)
        # Concatenate and compute energy
        energy = torch.tanh(self.attn(torch.cat((dec_hidden_expanded, encoder_outputs), dim=2)))
        attention = self.v(energy).squeeze(2)  # [1, src_len]
        return torch.softmax(attention, dim=1)

#########################################
# 4) LSTM Decoder with Attention (No Teacher Forcing)
#########################################

class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, enc_hid_dim, dec_hid_dim, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention

        self.embedding = nn.Embedding(output_dim, embed_dim)
        # Input: [embedding + context] where context has size enc_hid_dim
        self.rnn = nn.LSTM(embed_dim + enc_hid_dim, dec_hid_dim, batch_first=True)
        # fc_out takes concatenated vector: [dec_hid_dim + enc_hid_dim + embed_dim]
        self.fc_out = nn.Linear(dec_hid_dim + enc_hid_dim + embed_dim, output_dim)

    def forward(self, input_token, hidden, cell, encoder_outputs):
        """
        input_token: [1]  (current token index for batch=1)
        hidden, cell: [1, 1, dec_hid_dim]
        encoder_outputs: [1, src_len, enc_hid_dim]
        """
        # 1) Embed input token -> [1, 1, embed_dim]
        embedded = self.embedding(input_token).unsqueeze(1)

        # 2) Compute attention weights
        dec_hidden = hidden.squeeze(0)  # [1, dec_hid_dim]
        attn_weights = self.attention(dec_hidden, encoder_outputs)  # [1, src_len]
        attn_weights = attn_weights.unsqueeze(1)  # [1, 1, src_len]

        # 3) Compute context vector as weighted sum of encoder outputs
        context = torch.bmm(attn_weights, encoder_outputs)  # [1, 1, enc_hid_dim]

        # 4) Concatenate embedded token and context, then pass through LSTM
        rnn_input = torch.cat((embedded, context), dim=2)  # [1, 1, embed_dim + enc_hid_dim]
        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))
        # output: [1, 1, dec_hid_dim]

        # 5) Prepare for final prediction: concatenate [output, context, embedded]
        output_squeezed = output.squeeze(1)      # [1, dec_hid_dim]
        context_squeezed = context.squeeze(1)    # [1, enc_hid_dim]
        embedded_squeezed = embedded.squeeze(1)  # [1, embed_dim]

        concat_input = torch.cat((output_squeezed, context_squeezed, embedded_squeezed), dim=1)
        logits = self.fc_out(concat_input)       # [1, output_dim]

        return logits, (hidden, cell)

#########################################
# 5) Seq2Seq Model (No Teacher Forcing)
#########################################

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, tgt):
        """
        src: [1, src_len]
        tgt: [1, tgt_len]   (includes <sos> at index 0)
        Returns outputs: [1, tgt_len-1, output_dim]
        (Decoding is done without teacher forcing)
        """
        outputs = []
        encoder_outputs, (h, c) = self.encoder(src)
        # For unidirectional encoder, h and c are already shape [1, 1, enc_hid_dim]

        # First input to decoder is the <sos> token from tgt
        input_token = tgt[:, 0]  # shape: [1]

        tgt_len = tgt.size(1)
        for t in range(1, tgt_len):
            logits, (h, c) = self.decoder(input_token, h, c, encoder_outputs)
            outputs.append(logits.unsqueeze(1))  # [1, 1, output_dim]
            # Greedy decoding: next input is the predicted token
            input_token = logits.argmax(dim=1)

        outputs = torch.cat(outputs, dim=1)  # [1, tgt_len-1, output_dim]
        return outputs

#########################################
# 6) Instantiate Model, Optimizer, Loss
#########################################

ENC_EMB_DIM = 16
DEC_EMB_DIM = 16
ENC_HID_DIM = 32
DEC_HID_DIM = 32

device = torch.device("cpu")

attention = BahdanauAttention(ENC_HID_DIM, DEC_HID_DIM)
encoder = Encoder(len(SRC_WORDS), ENC_EMB_DIM, ENC_HID_DIM)
decoder = Decoder(len(TGT_WORDS), DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, attention)
model = Seq2Seq(encoder, decoder, device).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=tgt_stoi["<pad>"])

#########################################
# 7) Training (No Teacher Forcing, Single Example)
#########################################

# For simplicity, we'll train for a few epochs on our single example.
for epoch in range(1000):
    model.train()
    optimizer.zero_grad()

    outputs = model(src_tensor, tgt_tensor)
    # outputs: [1, tgt_len-1, output_dim]
    # For loss, reshape outputs to [tgt_len-1, output_dim] and target to [tgt_len-1]
    outputs = outputs.squeeze(0)  # [tgt_len-1, output_dim]
    target = tgt_tensor[:, 1:].squeeze(0)  # skip the <sos>, shape: [tgt_len-1]

    loss = criterion(outputs, target)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch} | Loss: {loss.item():.3f}")

#########################################
# 8) Greedy Inference (Decoding)
#########################################

def translate(model, src_sentence, max_len=10):
    model.eval()
    with torch.no_grad():
        src_ids = add_sos_eos(encode(src_sentence, src_stoi), src_stoi["<sos>"], src_stoi["<eos>"])
        src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0)
        encoder_outputs, (h, c) = model.encoder(src_tensor)

        input_token = torch.tensor([tgt_stoi["<sos>"]], dtype=torch.long)
        translation = []

        for _ in range(max_len):
            logits, (h, c) = model.decoder(input_token, h, c, encoder_outputs)
            next_token = logits.argmax(dim=1).item()
            if next_token == tgt_stoi["<eos>"]:
                break
            translation.append(next_token)
            input_token = torch.tensor([next_token], dtype=torch.long)

        # Convert token indices to words.
        return " ".join(tgt_itos[idx] for idx in translation)

# Test the model's translation
test_sentence = "you are a student"
translated = translate(model, test_sentence)
print(f"\nEnglish: {test_sentence}")
print(f"French:  {translated}")


Epoch 0 | Loss: 2.652
Epoch 10 | Loss: 2.355
Epoch 20 | Loss: 2.202
Epoch 30 | Loss: 1.703
Epoch 40 | Loss: 1.183
Epoch 50 | Loss: 0.861
Epoch 60 | Loss: 0.639
Epoch 70 | Loss: 0.472
Epoch 80 | Loss: 0.343
Epoch 90 | Loss: 0.251
Epoch 100 | Loss: 0.188
Epoch 110 | Loss: 0.144
Epoch 120 | Loss: 0.113
Epoch 130 | Loss: 0.091
Epoch 140 | Loss: 0.075
Epoch 150 | Loss: 0.062
Epoch 160 | Loss: 0.053
Epoch 170 | Loss: 0.046
Epoch 180 | Loss: 0.040
Epoch 190 | Loss: 0.036
Epoch 200 | Loss: 0.032
Epoch 210 | Loss: 0.028
Epoch 220 | Loss: 0.026
Epoch 230 | Loss: 0.023
Epoch 240 | Loss: 0.021
Epoch 250 | Loss: 0.019
Epoch 260 | Loss: 0.018
Epoch 270 | Loss: 0.017
Epoch 280 | Loss: 0.015
Epoch 290 | Loss: 0.014
Epoch 300 | Loss: 0.014
Epoch 310 | Loss: 0.013
Epoch 320 | Loss: 0.012
Epoch 330 | Loss: 0.011
Epoch 340 | Loss: 0.011
Epoch 350 | Loss: 0.010
Epoch 360 | Loss: 0.010
Epoch 370 | Loss: 0.009
Epoch 380 | Loss: 0.009
Epoch 390 | Loss: 0.008
Epoch 400 | Loss: 0.008
Epoch 410 | Loss: 0.008
Epo