<a href="https://colab.research.google.com/github/tanyag/tiny_llm_colab/blob/main/tiny_llm_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [4]:
text = "hello world. this is a test of a tiny language model. " * 100
chars = sorted(list(set(text)))
vocab_size = len(chars)

char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

def encode(s): return [char_to_idx[c] for c in s]
def decode(indices): return ''.join([idx_to_char[i] for i in indices])

data = torch.tensor(encode(text), dtype=torch.long)


In [5]:
block_size = 8
batch_size = 4

def get_batch():
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x, y


In [6]:
class TinyLLM(nn.Module):
    def __init__(self, vocab_size, n_embed=32):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, n_embed)
        self.fc = nn.Linear(n_embed, vocab_size)

    def forward(self, idx):
        x = self.embed(idx)
        logits = self.fc(x)
        return logits


In [7]:
model = TinyLLM(vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for step in range(500):
    x, y = get_batch()
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss.item():.4f}")


Step 0, Loss: 3.0763
Step 100, Loss: 2.1897
Step 200, Loss: 1.7439
Step 300, Loss: 1.3745
Step 400, Loss: 1.2169


In [8]:
context = torch.tensor([[char_to_idx['h']]], dtype=torch.long)
generated = context
model.eval()
for _ in range(100):
    logits = model(generated)
    probs = F.softmax(logits[:, -1, :], dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    generated = torch.cat((generated, next_token), dim=1)

print("\nGenerated Text:")
print(decode(generated[0].tolist()))



Generated Text:
he elld. tis od. ll. morlodes tel tdelldory a te wo f teldis mof inorldua a tuagf ti wod. nisodeloes 
