In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import TransformerDecoderLayer, TransformerDecoder

# ---------------- Vocabulary ----------------
def build_vocab(text):
    chars = sorted(set(text))
    stoi = {c: i for i, c in enumerate(chars)}
    itos = {i: c for c, i in stoi.items()}
    return stoi, itos

def encode(text, stoi):
    return [stoi[c] for c in text]

def decode(tokens, itos):
    return "".join([itos[t] for t in tokens])

# ---------------- Dataset ----------------
class TextDataset(Dataset):
    def __init__(self, tokens, seq_len):
        self.tokens = tokens
        self.seq_len = seq_len

    def __len__(self):
        return max(0, len(self.tokens) - self.seq_len - 1)

    def __getitem__(self, idx):
        x = self.tokens[idx:idx+self.seq_len]
        y = self.tokens[idx+1:idx+self.seq_len+1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

# ---------------- Model ----------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=256):
        super().__init__()
        self.pos = nn.Embedding(max_len, d_model)

    def forward(self, x):
        T = x.size(1)
        return self.pos(torch.arange(T, device=x.device)[None, :])

class TinyLLM(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=2, num_layers=2, max_len=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.posenc = PositionalEncoding(d_model, max_len)
        layer = TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=False)
        self.decoder = TransformerDecoder(layer, num_layers=num_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, ids):
        x = self.embed(ids) + self.posenc(ids)   # [B,T,E]
        h = x.transpose(0,1)                     # [T,B,E]
        T = h.size(0)
        mask = torch.triu(torch.ones(T,T, device=h.device), diagonal=1).bool()
        out = self.decoder(h, h, tgt_mask=mask)  # [T,B,E]
        return self.lm_head(out.transpose(0,1))  # [B,T,V]

    @torch.no_grad()
    def generate(self, start_ids, max_new_tokens=30):
        self.eval()
        ids = start_ids
        for _ in range(max_new_tokens):
            logits = self.forward(ids)[:,-1,:]  # last token logits
            next_id = torch.argmax(logits, dim=-1, keepdim=True)
            ids = torch.cat([ids, next_id], dim=1)
        return ids

# ---------------- Main ----------------
def main():
    # Training text
    train_text = "I am John Doe and I am instructor at Code-You."
    # Prompts we will ask later
    prompts = ["Who I am", "What is my profession"]

    # Build vocab from both training text + prompts
    combined = train_text + "\n" + "\n".join(prompts)
    stoi, itos = build_vocab(combined)

    # Encode training text
    tokens = encode(train_text, stoi)

    # Dataset + DataLoader
    seq_len = 16
    dataset = TextDataset(tokens, seq_len)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True, drop_last=True)

    # Model
    vocab_size = len(stoi)
    model = TinyLLM(vocab_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Train
    for epoch in range(10):
        for x, y in dataloader:
            logits = model(x[:, :-1])
            loss = criterion(logits.reshape(-1, vocab_size), y[:, 1:].reshape(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1} - loss {loss.item():.4f}")

    # Generate from prompts
    for p in prompts:
        start = torch.tensor([encode(p, stoi)], dtype=torch.long)
        out = model.generate(start, max_new_tokens=30)[0].tolist()
        print(f"\nPrompt: {p}")
        print("Generated:", decode(out, itos))

if __name__ == "__main__":
    main()

Epoch 1 - loss 2.2352
Epoch 2 - loss 1.2889
Epoch 3 - loss 1.0859
Epoch 4 - loss 0.7004
Epoch 5 - loss 0.5918
Epoch 6 - loss 0.2580
Epoch 7 - loss 0.2193
Epoch 8 - loss 0.1750
Epoch 9 - loss 0.2552
Epoch 10 - loss 0.1513

Prompt: Who I am
Generated: Who I amisrco oeadIadIadIadIadIadIa nm

Prompt: What is my profession
Generated: What is my professiontutCdIadIadIa ntCdIa ntutdIadI
