In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


## Transformer Architecture

In [6]:
class MultiHeadAttention(nn.Module):

  def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
    super().__init__()
    assert d_model % n_heads == 0

    self.d_model = d_model
    self.n_heads = n_heads
    self.d_k = d_model // n_heads

    self.q_linear = nn.Linear(d_model, self.d_k * n_heads)
    self.k_linear = nn.Linear(d_model, self.d_k * n_heads)
    self.v_linear = nn.Linear(d_model, self.d_k * n_heads) # Because d_v = d_k
    self.out = nn.Linear(d_model, self.d_k * n_heads)

    self.dropout = nn.Dropout(dropout)

  def forward(self, query, key, value, mask = None):

    batch_size = query.size(0)

    Q = self.q_linear(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k
    K = self.k_linear(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k
    V = self.v_linear(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_head, seq_len, d_k

    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # batch_size, n_head, seq_len, seq_len

    if mask is not None:
      scores = scores.masked_fill(mask == 0, -1e9)

    attention_weights = F.softmax(scores, dim = -1) # batch_size, n_head, seq_len, seq_len
    attention_weights = self.dropout(attention_weights)

    # Concatenate heads and put through final linear layer
    context = torch.matmul(attention_weights, V) # batch_size, n_head, seq_len, d_k

    context = context.transpose(1, 2).contiguous().view(
        batch_size, -1, self.d_model
    )

    output = self.out(context) # batch_size, seq_len, d_model

    return output, attention_weights



In [7]:
class PositionalEncoding(nn.Module):

  def __init__(self, d_model: int, max_len: int):
    super().__init__()

    pe = torch.zeros(max_len, d_model)

    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

    div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                         (-math.log(10000.0)/d_model))

    pe[:,0::2] = torch.sin(position * div_term)
    pe[:,1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0).transpose(0, 1)

    self.register_buffer('pe', pe)

  def forward(self, x):
    return x + self.pe[:x.size(0), :]

In [8]:
class FeedForward(nn.Module):

  def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    super().__init__()

    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [9]:
class TransformerBlock(nn.Module):

  def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
    super().__init__()

    self.attention = MultiHeadAttention(d_model, n_heads, dropout)
    self.feed_forward = FeedForward(d_model, d_ff, dropout)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask = None):

    attn_output, attention_weights = self.attention(x, x, x, mask)
    x = self.norm1(x + self.dropout(attn_output))

    ff_output = self.feed_forward(x)

    x = self.norm2(x + self.dropout(ff_output))

    return x, attention_weights

In [10]:
class TinyTransformer(nn.Module):

  def __init__(self, vocab_size: int, d_model: int, n_heads: int,
               n_layers: int, d_ff: int, max_len: int, dropout: float = 0.1):

    super().__init__()

    self.d_model = d_model
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len)

    self.transformer_blocks = nn.ModuleList([
        TransformerBlock(d_model, n_heads, d_ff, dropout)
        for _ in range(n_layers)
    ])

    self.ln_f = nn.LayerNorm(d_model)
    self.head = nn.Linear(d_model, vocab_size)
    self.dropout = nn.Dropout(dropout)

    self.apply(self._init_weights)

  def _init_weights(self, module):

    if isinstance(module, nn.Linear):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias)

    elif isinstance(module, nn.Embedding):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    elif isinstance(module, nn.LayerNorm):
      torch.nn.init.zeros_(module.bias)
      torch.nn.init.ones_(module.weight)

  def create_causal_mask(self, size):
    mask = torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)
    return mask

  def forward(self, x, targets=None):

    batch_size, seq_len = x.size()

    mask = self.create_causal_mask(seq_len).to(x.device)

    # Token embedding + positional encoding
    x = self.embedding(x) * math.sqrt(self.d_model)
    x = self.pos_encoding(x)
    x = self.dropout(x)

    # Pass through transformer blocks
    attention_weights = []
    for block in self.transformer_blocks:
      x, atten = block(x, mask)
      attention_weights.append(atten)

    x = self.ln_f(x)
    logits = self.head(x)

    loss = None
    if targets is not None:
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

    return logits, loss, attention_weights

In [11]:
class SimpleTextDataset(Dataset):

  def __init__(self, text: str, seq_len: int):

    self.chars = sorted(list(set(text)))
    self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
    self.idx_to_char = {i:ch for i, ch in enumerate(self.chars)}

    self.data = [self.char_to_idx[ch] for ch in text]
    self.seq_len = seq_len

  def __len__(self):
    return len(self.data) - self.seq_len

  def __getitem__(self, idx):
    x = torch.tensor(self.data[idx:idx + self.seq_len], dtype=torch.long)
    y = torch.tensor(self.data[idx+1:idx+self.seq_len+1], dtype = torch.long)
    return x, y

  @property
  def vocab_size(self):
    return len(self.chars)

In [12]:
def generate_text(model, dataset, start_text: str, max_new_tokens: int = 100,
                  temperature: float = 1.0):
  model.eval()
  with torch.no_grad():
    context = [dataset.char_to_idx[ch] for ch in start_text]
    context = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device)

    generate = []
    for _ in range(max_new_tokens):

      logits, _, _ = model(context)

      # Focus on the last time step
      logits = logits[:, -1, :] / temperature

      # Sample from the distribution
      probs = F.softmax(logits, dim=-1)
      next_token = torch.multinomial(probs, num_samples=1)

      # Append to context and generated sequence
      context = torch.cat([context, next_token], dim=1)
      generate.append(next_token.item())

      # Keep context length manageable
      if context.size(1) > 100:
        context = context[:, -100:]

    # Decode the generated tokens
    generated_text = ''.join([dataset.idx_to_char[token] for token in generate])

    return start_text + generated_text


In [15]:
def train_model():

  sample_text = """
    In the beginning was the Word, and the Word was with God, and the Word was God.
    All things were made through him, and without him was not any thing made that was made.
    In him was life, and the life was the light of men.
    The light shines in the darkness, and the darkness has not overcome it.
    There was a man sent from God, whose name was John.
    He came as a witness, to bear witness about the light, that all might believe through him.
    He was not the light, but came to bear witness about the light.
    The true light, which gives light to everyone, was coming into the world.
    He was in the world, and the world was made through him, yet the world did not know him.
    He came to his own, and his own people did not receive him.
    But to all who did receive him, who believed in his name, he gave the right to become children of God.
  """

  # Hyperparameters
  seq_len = 64
  d_model = 128
  n_heads = 8
  n_layers = 4
  d_ff = 512
  batch_size = 32
  learning_rate = 0.001
  num_epochs = 50

  dataset = SimpleTextDataset(sample_text, seq_len)
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

  model = TinyTransformer(
      vocab_size=dataset.vocab_size,
      d_model = d_model,
      n_heads = n_heads,
      n_layers = n_layers,
      d_ff = d_ff,
      max_len = seq_len * 2
  ).to(device)

  # Count parameters
  total_params = sum(p.numel() for p in model.parameters())
  print(f"Total parameters: {total_params}")

  # Optimizer
  optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)

  # Training loop
  losses = []
  model.train()

  print("\nStarting training...")
  for epoch in range(num_epochs):
    epoch_loss = 0
    num_batches = 0

    for batch_idx, (x, y) in enumerate(dataloader):
      x, y = x.to(device), y.to(device)

      optimizer.zero_grad()

      logits, loss, _ = model(x, y)
      loss.backward()

      # Gradient clipping
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

      optimizer.step()

      epoch_loss += loss.item()
      num_batches += 1

      if batch_idx % 10 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}")

    avg_loss = epoch_loss / num_batches
    losses.append(avg_loss)

    # Generate sample text every 10 epochs
    if (epoch + 1) % 10 == 0:
      print("\n" + "="*50)
      print("Generated text sample: ")
      generated = generate_text(model, dataset, "At the end of the day", max_new_tokens=100)
      print(generated)
      print("="*50 + "\n")
      model.train()

  torch.save({
    'model_state_dict': model.state_dict(),
    'vocab_size': dataset.vocab_size,
    'char_to_idx': dataset.char_to_idx,
    'idx_to_char': dataset.idx_to_char,
    'hyperparameters': {
        'd_model': d_model,
        'n_heads': n_heads,
        'n_layers': n_layers,
        'd_ff': d_ff,
        'seq_len': seq_len
    }
  }, 'tiny_transformer.pth')

  return model, dataset




In [16]:
model, dataset = train_model()

Total parameters: 802082

Starting training...
Epoch 1/50, Batch 0, Loss: 3.5463
Epoch 1/50, Batch 10, Loss: 2.7826
Epoch 1/50, Batch 20, Loss: 2.3042
Epoch 2/50, Batch 0, Loss: 2.1719
Epoch 2/50, Batch 10, Loss: 2.0383
Epoch 2/50, Batch 20, Loss: 1.9541
Epoch 3/50, Batch 0, Loss: 1.9228
Epoch 3/50, Batch 10, Loss: 1.8023
Epoch 3/50, Batch 20, Loss: 1.7883
Epoch 4/50, Batch 0, Loss: 1.7654
Epoch 4/50, Batch 10, Loss: 1.6823
Epoch 4/50, Batch 20, Loss: 1.5630
Epoch 5/50, Batch 0, Loss: 1.5140
Epoch 5/50, Batch 10, Loss: 1.5284
Epoch 5/50, Batch 20, Loss: 1.4052
Epoch 6/50, Batch 0, Loss: 1.3606
Epoch 6/50, Batch 10, Loss: 1.2273
Epoch 6/50, Batch 20, Loss: 1.0820
Epoch 7/50, Batch 0, Loss: 1.1037
Epoch 7/50, Batch 10, Loss: 0.9935
Epoch 7/50, Batch 20, Loss: 0.8738
Epoch 8/50, Batch 0, Loss: 0.8610
Epoch 8/50, Batch 10, Loss: 0.7918
Epoch 8/50, Batch 20, Loss: 0.7245
Epoch 9/50, Batch 0, Loss: 0.7086
Epoch 9/50, Batch 10, Loss: 0.6231
Epoch 9/50, Batch 20, Loss: 0.5964
Epoch 10/50, Batc