In [None]:
import torch

import torch.nn as nn 

import torch.optim as optim

In [None]:
# ----- Setup -----



In [None]:


# ----- Setup -----
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
chars = ['A', 'B', 'C', 'D']
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for ch, i in char_to_idx.items()}

# Sample input and target
input_indices = torch.tensor([2, 1, 0, 0], dtype=torch.long)   # C B A A
target_indices = torch.tensor([1, 0, 0, 3], dtype=torch.long)  # B A A D

# One-hot encode inputs
def one_hot_encode(index, vocab_size):
    vec = torch.zeros(vocab_size)
    vec[index] = 1.0
    return vec

vocab_size = len(chars)
input_one_hot = torch.stack([one_hot_encode(i, vocab_size) for i in input_indices])
input_one_hot = input_one_hot.unsqueeze(1).to(device)  # [seq_len, 1, vocab_size]
target_tensor = target_indices.to(device)

# ----- Model -----
class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(VanillaRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input_seq, hidden):
        out, hidden = self.rnn(input_seq, hidden)
        out = self.fc(out.squeeze(0))  # remove seq_len dimension
        out = self.softmax(out)
        return out, hidden

hidden_size = 8
model = VanillaRNN(vocab_size, hidden_size, vocab_size).to(device)

# ----- Training -----
loss_fn = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

n_epochs = 100
for epoch in range(n_epochs):
    hidden = torch.zeros(1, 1, hidden_size).to(device)
    total_loss = 0

    for i in range(len(input_one_hot)):
        optimizer.zero_grad()
        inp = input_one_hot[i].unsqueeze(0)  # [1, 1, vocab_size]
        out, hidden = model(inp, hidden.detach())
        loss = loss_fn(out, target_tensor[i].unsqueeze(0))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {total_loss:.4f}")

# ----- Inference -----
print("\n🧠 Inference after training:")
hidden = torch.zeros(1, 1, hidden_size).to(device)
predicted_indices = []

for i in range(len(input_one_hot)):
    inp = input_one_hot[i].unsqueeze(0)
    output, hidden = model(inp, hidden)
    pred_idx = torch.argmax(output, dim=1).item()
    predicted_indices.append(pred_idx)

predicted_chars = [idx_to_char[i] for i in predicted_indices]
target_chars = [idx_to_char[i.item()] for i in target_tensor]

print("Target Characters   :", target_chars)
print("Predicted Characters:", predicted_chars)


Epoch 20/100, Loss: 1.7284
Epoch 40/100, Loss: 0.3393
Epoch 60/100, Loss: 0.1370
Epoch 80/100, Loss: 0.0765
Epoch 100/100, Loss: 0.0499

🧠 Inference after training:
Target Characters   : ['B', 'A', 'A', 'D']
Predicted Characters: ['B', 'A', 'A', 'D']
