In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# --------------------------
# 1. A Tiny Manual RNN Encoder
# --------------------------
class TinyEncoder(nn.Module):
    def __init__(self, input_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(input_vocab_size, embed_size)

        # RNN parameters
        self.hidden_size = hidden_size
        self.W_h = nn.Parameter(torch.randn(hidden_size, hidden_size)*0.1)
        self.W_x = nn.Parameter(torch.randn(hidden_size, embed_size)*0.1)
        self.b   = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, src_tokens):
        """
        src_tokens: shape (src_len,)
        Returns final hidden state (hidden_size,).
        """
        h = torch.zeros(self.hidden_size)

        for t in range(src_tokens.shape[0]):
            token_id = src_tokens[t]
            x_t = self.embedding(token_id)

            h = torch.tanh(
                torch.mv(self.W_h, h) +
                torch.mv(self.W_x, x_t) +
                self.b
            )

        return h


# -------------------------
# 2. A Tiny Manual RNN Decoder
# -------------------------
class TinyDecoder(nn.Module):
    def __init__(self, output_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_vocab_size, embed_size)

        self.hidden_size = hidden_size
        self.W_h = nn.Parameter(torch.randn(hidden_size, hidden_size)*0.1)
        self.W_x = nn.Parameter(torch.randn(hidden_size, embed_size)*0.1)
        self.b   = nn.Parameter(torch.zeros(hidden_size))

        # Output projection
        self.W_out = nn.Parameter(torch.randn(output_vocab_size, hidden_size)*0.1)
        self.b_out = nn.Parameter(torch.zeros(output_vocab_size))

    def forward(self, dec_tokens, init_hidden):
        h = init_hidden
        logits_list = []

        for t in range(dec_tokens.shape[0]):
            token_id = dec_tokens[t]
            x_t = self.embedding(token_id)

            h = torch.tanh(
                torch.mv(self.W_h, h) +
                torch.mv(self.W_x, x_t) +
                self.b
            )
            logits_t = torch.mv(self.W_out, h) + self.b_out
            logits_list.append(logits_t.unsqueeze(0))

        return torch.cat(logits_list, dim=0)


# -------------------------------------
# 3. Example Data: "I go <EOS>" -> "मैं जाता हूँ <EOS>"
# -------------------------------------
ENG_VOCAB_SIZE = 3  # I=0, go=1, <EOS>=2
HIN_VOCAB_SIZE = 5  # <GO>=0, मैं=1, जाता=2, हूँ=3, <EOS>=4

# Map IDs to words for printing
HIN_ID2WORD = {
    0: "<GO>",
    1: "मैं",
    2: "जाता",
    3: "हूँ",
    4: "<EOS>"
}

EMBED_SIZE = 1
HIDDEN_SIZE = 2

encoder = TinyEncoder(ENG_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
decoder = TinyDecoder(HIN_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)

# Source: "I go <EOS>" => [0,1,2]
encoder_input = torch.tensor([0,1,2])

# Decoder target: "मैं जाता हूँ <EOS>" => [1,2,3,4]
# We'll do teacher forcing in training:
decoder_input  = torch.tensor([0,1,2,3])  # <GO>, मैं, जाता, हूँ
decoder_target = torch.tensor([1,2,3,4])  #     मैं, जाता, हूँ, <EOS>

# ----------------------------------
# 4. Training Loop (Cross Entropy)
# ----------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=0.5)

num_epochs = 5000
for epoch in range(num_epochs):
    optimizer.zero_grad()

    # 1) Encode
    enc_hidden = encoder(encoder_input)  # shape (2,)

    # 2) Decode
    logits = decoder(decoder_input, enc_hidden)  # (4,5)

    # 3) Compute cross-entropy
    loss = criterion(logits, decoder_target)

    # 4) Backprop + update
    loss.backward()
    optimizer.step()

    # Print stats
    if (epoch+1) % 5 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss = {loss.item():.4f}")

    # ------------------------------
    # Print generated words every 20 epochs
    # ------------------------------
    if (epoch+1) % 20 == 0:
        print(f"\n--- Decoding after epoch {epoch+1} ---")
        with torch.no_grad():
            # Re-encode
            enc_hidden = encoder(encoder_input)

            # Start <GO>=0
            current_token = torch.tensor(0)
            h = enc_hidden.clone()

            generated_tokens = []
            for _ in range(6):
                x_t = decoder.embedding(current_token)
                h = torch.tanh(
                    torch.mv(decoder.W_h, h) +
                    torch.mv(decoder.W_x, x_t) +
                    decoder.b
                )

                logits_t = torch.mv(decoder.W_out, h) + decoder.b_out
                next_token = torch.argmax(logits_t).item()
                generated_tokens.append(next_token)

                if next_token == 4:  # <EOS>
                    break
                current_token = torch.tensor(next_token)

            # Convert IDs to words
            generated_words = [HIN_ID2WORD[t] for t in generated_tokens]
            print("Generated tokens:", generated_words)
        print("-----------------------------------\n")




Epoch 5/5000, Loss = 1.5082
Epoch 10/5000, Loss = 1.3279
Epoch 15/5000, Loss = 1.0741
Epoch 20/5000, Loss = 0.9155

--- Decoding after epoch 20 ---
Generated tokens: ['मैं', 'हूँ', 'मैं', 'हूँ', 'मैं', 'हूँ']
-----------------------------------

Epoch 25/5000, Loss = 0.8037
Epoch 30/5000, Loss = 0.6522
Epoch 35/5000, Loss = 0.4441
Epoch 40/5000, Loss = 0.2977

--- Decoding after epoch 40 ---
Generated tokens: ['मैं', 'जाता', 'हूँ', '<EOS>']
-----------------------------------

Epoch 45/5000, Loss = 0.2163
Epoch 50/5000, Loss = 0.1686
Epoch 55/5000, Loss = 0.1381
Epoch 60/5000, Loss = 0.1170

--- Decoding after epoch 60 ---
Generated tokens: ['मैं', 'जाता', 'हूँ', '<EOS>']
-----------------------------------

Epoch 65/5000, Loss = 0.1016
Epoch 70/5000, Loss = 0.0898
Epoch 75/5000, Loss = 0.0805
Epoch 80/5000, Loss = 0.0730

--- Decoding after epoch 80 ---
Generated tokens: ['मैं', 'जाता', 'हूँ', '<EOS>']
-----------------------------------

Epoch 85/5000, Loss = 0.0667
Epoch 90/5000, Lo