In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Causal Self-Attention

In [None]:
class CausalSelfAttention(nn.Module):

  def __init__(self, d_model, num_heads):
    super().__init__()
    assert d_model % num_heads == 0
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads

    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)

    self.attn_dropout = nn.Dropout(0.1)
    self.resid_dropout = nn.Dropout(0.1)

  def forward(self, x):
    B, T, D = x.shape

    Q = self.W_q(x)
    K = self.W_k(x)
    V = self.W_v(x)

    Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1,2)
    K = K.view(B, T, self.num_heads, self.head_dim).transpose(1,2)
    V = V.view(B, T, self.num_heads, self.head_dim).transpose(1,2)

    scores = Q @ K.transpose(-2, -1) / math.sqrt(self.head_dim)

    causal_mask = torch.triu(torch.ones(T,T), diagonal=1).to(x.device)
    scores = scores.masked_fill(causal_mask == 1, float('-inf'))

    attn = F.softmax(scores, dim=-1)
    attn = self.attn_dropout(attn)

    out = attn @ V
    out = out.transpose(1, 2).contiguous().view(B, T, D)

    return self.resid_dropout(self.W_o(out))

## Transformer Block (GPT - style)

In [None]:
class TransformerBlock(nn.Module):

  def __init__(self, d_model, num_heads, ff_hidden):
    super().__init__()
    self.attn = CausalSelfAttention(d_model, num_heads)
    self.norm1 = nn.LayerNorm(d_model)
    self.ff = nn.Sequential(
        nn.Linear(d_model, ff_hidden),
        nn.ReLU(),
        nn.Linear(ff_hidden, d_model),
        nn.Dropout(0.1)
    )
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self, x):
    x = x + self.attn(self.norm1(x))
    x = x + self.ff(self.norm2(x))
    return x


## Tiny GPT Model

In [None]:
class TinyGPT(nn.Module):

  def __init__(self, vocab_size, d_model, num_heads, ff_hidden, num_layers, block_size):
    super().__init__()
    self.token_emb = nn.Embedding(vocab_size, d_model)
    self.pos_emb = nn.Embedding(block_size, d_model)

    self.blockes = nn.ModuleList([
        TransformerBlock(d_model, num_heads, ff_hidden)
        for _ in range(num_layers)
    ])

    self.ln_f = nn.LayerNorm(d_model)
    self.head = nn.Linear(d_model, vocab_size)
    self.head.weight = self.token_emb.weight

  def forward(self, x):
    B, T = x.shape
    pos = torch.arange(T).to(x.device)

    x = self.token_emb(x) + self.pos_emb(pos)

    for block in self.blockes:
      x = block(x)

    x = self.ln_f(x)
    return self.head(x)


## Load and Prepare Shakespeare Data

In [None]:
with open("/content/tiny-shakespeare.txt","r", encoding="utf-8") as f:
  text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = {ch: i for i,ch in enumerate(chars)}
itos = {i: ch for ch,i in stoi.items()}

encode = lambda s: [stoi[c] for c in s]
decode = lambda x: ''.join([itos[i] for i in x])

data = torch.tensor(encode(text), dtype=torch.long)

## Batch Sampling

In [None]:
block_size = 64
batch_size = 32

def get_batch():
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x,y

## Training Loop

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TinyGPT(vocab_size, 128, 4, 512, 4, block_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for step in range(2000):
  x, y = get_batch()
  x, y = x.to(device), y.to(device)

  logits = model(x)
  loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  if step % 200 == 0:
    print(f'step {step} | loss {loss.item():.4f}')

step 0 | loss 81.7916
step 200 | loss 4.0322
step 400 | loss 3.4443
step 600 | loss 3.1089
step 800 | loss 2.9156
step 1000 | loss 2.8356
step 1200 | loss 2.7423
step 1400 | loss 2.6017
step 1600 | loss 2.5273
step 1800 | loss 2.4535


## Text Generation

In [None]:
def generate(model, start, max_new_tokens=200):
    model.eval()
    idx = torch.tensor(encode(start), device=device)[None, :]

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]   # ðŸ”‘ CRITICAL FIX

        logits = model(idx_cond)
        probs = F.softmax(logits[:, -1, :], dim=-1)
        next_id = torch.multinomial(probs, 1)

        idx = torch.cat([idx, next_id], dim=1)

    return decode(idx[0].tolist())


print(generate(model, "ROMEO"))


ROMEOY: wir 'sic doy, whatusend woubupread fim to wars as od coce hen hes ist pa hy wes be berhat' ED:
Thincame ad bl shithat farase, waits wiveastregn ss se thereatin
Thow ates, and you wicerdiour ther ye
