In [1]:
import torch
import torch.nn as nn
import math
import random
import os
from pathlib import Path
import tomllib
import numpy as np

In [2]:
def set_seed(seed=313):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

DEVICE = 'cuda'
CONFIG_PATH = Path.cwd().parent / "config.toml"

with open(CONFIG_PATH, "rb") as f:
    cfg = tomllib.load(f)
    assert divmod(cfg["d_model"],cfg["n_heads"])[1]==0, "d_model should be divisble by n_heads"
    assert divmod(cfg["n_heads"],cfg["n_kv_heads"])[1]==0, "n_heads should be divisble by n_kv_heads"

    cfg["d_head"] = cfg["d_model"] // cfg["n_heads"]
    cfg["kv_d_head"] = cfg["d_model"] // cfg["n_kv_heads"]

In [None]:
class Embedding(nn.Module):
    def __init__(self,vocab_size,d_model):
        super().__init__()

        self.emb = nn.Embedding(vocab_size,d_model)
        nn.init.normal_(self.emb.weight,mean=0,std=(d_model)**-0.5)

    def forward(self, x):
        return self.emb(x)

class RMSNorm(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        den = ((x**2).sum(dim=-1,keepdim=True)/x.size(-1) +1e-6)**0.5
        return (x/den) * self.gamma

# # TODO: apply it wherever needed
# class Dropout(nn.Module):
#     def __init__(self, p):
#         super().__init__()
#         self.p = p

#     def forward(self, x):
#         if self.training and self.p > 0:
#             mask = (torch.rand_like(x) > self.p).float()
#             x = mask * x
#             x /= 1 - self.p
#         return x


class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,d_head,n_heads,n_kv_heads,kv_d_head):
        super().__init__()

        self.d_model = d_model
        self.d_head = d_head
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.kv_d_head = kv_d_head

        self.w_q = nn.Linear(self.d_model,self.n_heads * self.d_head)
        self.w_k = nn.Linear(self.d_model,self.n_kv_heads * self.d_head)
        self.w_v = nn.Linear(self.d_model,self.n_kv_heads * self.d_head)

        self.proj_out = nn.Linear(self.d_model,self.d_model)
    
    def _apply_rope(x,base=10_000):
        """
        For each pair of dimensions (2i, 2i+1):
        q_rotated[2i]   = q[2i] * cos(m*θᵢ) - q[2i+1] * sin(m*θᵢ)
        q_rotated[2i+1] = q[2i] * sin(m*θᵢ) + q[2i+1] * cos(m*θᵢ)

        Where theta_i = 10000^{-2i/d}
        """
        s,dim = x.shape[1:3]
        x_rotated = torch.zeros_like(x)

        m = torch.arange(s).unsqueeze(1) # pos
        i = torch.arange(0,dim,2) # dim idx

        theta = base**(-i / dim)
        
        x1, x2 = x[:,:,0::2], x[:,:,1::2]

        x_rotated[:,:,0::2] = x1*torch.cos(m*theta) - x2*torch.sin(m*theta)
        x_rotated[:,:,1::2] = x1*torch.sin(m*theta) + x2*torch.cos(m*theta)

        return x_rotated

    def forward(self, idx, x):
        b, s = x.shape[:2]

        q = self.w_q(x) # [b,s,d_model]
        k = self.w_k(x) # [b,s,n_kv_heads * d_head]
        v = self.w_v(x) # [b,s,n_kv_heads * d_head]

        if (idx+1)%4:
            q = self._apply_rope(q,q.size(-1)) # rotated
            k = self._apply_rope(k,k.size(-1)) # rotated
            
        q = q.contiguous().view(b, s, self.n_heads, self.d_head) # [b,s,n_heads,d_head]
        k = k.contiguous().view(b, s, self.n_kv_heads, self.d_head) # [b,s,n_kv_heads,d_head]
        v = v.contiguous().view(b, s, self.n_kv_heads, self.d_head) # [b,s,n_kv_heads,d_head]

        # [b,s,n_heads,d_head] @ [b,s,n_kv_heads,d_head] => tranpose k (-1,-2) => [b,s,n_heads,d_head] @ [b,s,d_head,n_kv_heads] => [b,s,n_heads,n_kv_heads]

        scores = q @ k.transpose(-1,-2) # [b,s,n_heads,n_kv_heads]
        
        attn_weights = torch.softmax(
            scores / (self.n_kv_heads ** 0.5), dim=-1
        )  # [b,s,n_heads,n_kv_heads]

        # [b,s,n_heads,n_kv_heads] @ [b,s,n_kv_heads,d_head] => [b,s,n_heads,d_head] => [b,s,d_model]

        output = attn_weights @ v # [b,s,n_heads,d_head]
        output = output.contiguous().view(b,s,self.d_model)

        return self.proj_out(output) # [b,s,d_model]  

class FNN
class Decoder(nn.Module):
    def __init__(self,d_model,d_head,n_heads,n_kv_heads,kv_d_head):
        super().__init__()
        self.mha = MultiHeadAttention(d_model,d_head,n_heads,n_kv_heads,kv_d_head)
        self.fnn = FFN()
    def forward(self,idx,x):
        x = self.mha(idx,x)
        x = self.ffn(x)
        return x

class Llama(nn.Module):
    def __init__(self,vocab_size,n_layers,d_model,d_head,n_heads,n_kv_heads,kv_d_head):
        super().__init__()

        self.emb = Embedding(vocab_size,d_model)
        self.rms_norm = RMSNorm(d_model)
        self.decoder_layers = nn.ModuleList([
            Decoder(d_model,d_head,n_heads,n_kv_heads,kv_d_head) for _ in range(n_layers)
        ])
    
    def forward(self,x):
        print(x.shape)
        out = self.emb(x)
        
        for i,decoder in enumerate(self.decoder_layers):
            out = decoder(i,out)

        return out

In [4]:
x = torch.randint(0,900,(2,5))
x.shape

torch.Size([2, 5])

In [5]:
model = Llama(
    vocab_size=cfg['vocab_size'],
    n_layers=cfg['n_layers'],
    d_model=cfg['d_model'],
    d_head=cfg['d_head'],
    n_heads=cfg['n_heads'],
    n_kv_heads=cfg['n_kv_heads'],
    kv_d_head=cfg['kv_d_head']
)
model

Llama(
  (emb): Embedding(
    (emb): Embedding(32000, 768)
  )
  (rms_norm): RMSNorm()
  (decoder_layers): ModuleList(
    (0-11): 12 x Decoder(
      (mha): MultiHeadAttention(
        (w_q): Linear(in_features=768, out_features=768, bias=True)
        (w_k): Linear(in_features=768, out_features=256, bias=True)
        (w_v): Linear(in_features=768, out_features=256, bias=True)
        (proj_out): Linear(in_features=768, out_features=768, bias=True)
      )
    )
  )
)

In [6]:
model(x).shape

torch.Size([2, 5])


torch.Size([2, 5, 768])

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1_block = nn.Sequential(
            nn.Linear(cfg["d_model"], 4 * cfg["d_model"]), nn.GELU(), Dropout(0.1)
        )

        self.fc2_block = nn.Sequential(
            nn.Linear(4 * cfg["d_model"], cfg["d_model"]), Dropout(0.1)
        )

    def forward(self, x):
        x = self.fc1_block(x)  # [b,p,4*d_model]
        x = self.fc2_block(x)  # [b,p,d_model]
        return x


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm1 = LayerNorm(cfg["d_model"])
        self.layer_norm2 = LayerNorm(cfg["d_model"])

        self.dropout = Dropout
        self.mha = MultiHeadAttention()
        self.mlp = MLP()

    def forward(self, x):
        # Attention block with residual
        x_residual = x
        x = self.layer_norm1(x)
        x = self.mha(x)
        x = x + x_residual

        # MLP block with residual
        x_residual = x
        x = self.layer_norm2(x)
        x = self.mlp(x)
        x = x + x_residual

        return x


class TransformerLM(nn.Module):
    def __init__(self, patch_size=4, img_size=32, in_channels=3, num_classes=100):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.embedding = Embedding(patch_size, img_size, in_channels)
        self.encoder_layers = nn.ModuleList(
            [Encoder() for _ in range(cfg["num_encoder_layer"])]
        )
        self.norm = LayerNorm(cfg["d_model"])
        self.head = nn.Linear(cfg["d_model"], num_classes)

    def _extract_patches(self, x):
        # x: [b, c, h, w]
        b, c = x.shape[:2]
        p = self.patch_size

        # Unfold to patches
        x = x.unfold(2, p, p).unfold(3, p, p)  # [b, c, n_h, n_w, p, p]
        x = x.contiguous().view(b, c, -1, p, p)  # [b, c, num_patches, p, p]
        x = x.permute(0, 2, 1, 3, 4)  # [b, num_patches, c, p, p]
        x = x.contiguous().view(b, -1, c * p * p)  # [b, num_patches, c*p*p]

        return x

    def forward(self, x):
        # Extract patches
        x = self._extract_patches(x)  # [b,c,h,w] -> [b,num_patches,c*p*p]

        # Embedding
        x = self.embedding(x)  # [b,num_patches,d_model]

        # Transformer encoder
        for layer in self.encoder_layers:
            x = layer(x)

        # Classification head (use [CLS] token or global average pooling)
        x = self.norm(x)
        x = x.mean(
            dim=1
        )  # Global average pooling [b,num_patches,d_model] -> [b,d_model]
        out = self.head(x)  # [b,num_classes]

        return out


if __name__ == "__main__":
    # Test the model
    model = TransformerLM(patch_size=4, img_size=32, in_channels=3, num_classes=100)

    # Test forward pass
    x = torch.randn(1, 3, 32, 32)  # [batch, channels, height, width]
    output = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")  # Should be [1, 100]


TypeError: Embedding.__init__() takes 3 positional arguments but 4 were given

In [None]:
# n = int(0.9 * len(data))
# train_data = data[:n]
# val_data = data[n:]


def get_batch(split, batch_size, block_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = []
    y = []
    pad_id = tokenizer.token_to_id("[PAD]")
    for i in ix:
        seq_x = data[i:i+block_size]
        seq_y = data[i+1:i+block_size+1]
        if len(seq_x) < block_size:
            seq_x = torch.cat([seq_x, torch.full((block_size - len(seq_x),), pad_id, dtype=torch.long, device=device)])
            seq_y = torch.cat([seq_y, torch.full((block_size - len(seq_y),), pad_id, dtype=torch.long, device=device)])
        x.append(seq_x)
        y.append(seq_y)
    return torch.stack(x).to(device), torch.stack(y).to(device)


In [None]:
class Embeddding(nn.Module):
    """
    Embedding layer combining token embeddings and positional encodings.
    Converts token IDs to dense vectors and adds position information.
    """
    def __init__(self, vocab_size, d_model, max_seq_len, dropout=0.1):
        super().__init__()
        # Token embedding: maps vocab_size tokens to d_model dimensions
        self.token_embedding = nn.Embedding(vocab_size, d_model)  # (vocab_size, d_model)
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.dropout = nn.Dropout(dropout)

        # Positional encoding: sinusoidal position embeddings
        pos = torch.arange(max_seq_len).unsqueeze(1)  # (max_seq_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))  # (d_model/2,)
        pe = torch.zeros(max_seq_len, d_model)  # (max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)  # Even indices: sine
        pe[:, 1::2] = torch.cos(pos * div_term)  # Odd indices: cosine
        self.register_buffer('pos_embedding', pe)  # Register as buffer (not a parameter)

    def forward(self, x, offset=0):
        """
        Args:
            x: Token IDs, shape (batch_size, seq_length)
            offset: Position offset for KV caching (default=0)
        Returns:
            Embedded vectors: (batch_size, seq_length, d_model)
        """
        batch_size, seq_length = x.shape  # x: (B, L)
        x = self.token_embedding(x)  # (B, L, d_model)
        # Add positional embeddings starting from offset
        x = x + self.pos_embedding[offset:offset+seq_length, :].unsqueeze(0)  # (B, L, d_model)
        return self.dropout(x)  # (B, L, d_model)


class MaskedMultiHeadAttention(nn.Module):
    """
    Masked Multi-Head Self-Attention for autoregressive generation.
    Supports optional KV caching for efficient inference.
    """
    def __init__(self, d_model, num_head, dropout=0.1):
        super().__init__()
        assert d_model % num_head == 0, "d_model must be divisible by num_head"

        # Linear projections for Query, Key, Value
        self.W_q = nn.Linear(d_model, d_model)  # (d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)  # (d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)  # (d_model, d_model)

        self.d_model = d_model
        self.num_head = num_head
        self.d_head = d_model // num_head  # Dimension per head
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, kv_cache=None):
        """
        Args:
            x: Input tensor, shape (batch_size, seq_length, d_model)
            kv_cache: Optional tuple (K_cache, V_cache) from previous steps
                     K_cache: (B, num_head, cached_len, d_head)
                     V_cache: (B, num_head, cached_len, d_head)
        Returns:
            output: (batch_size, seq_length, d_model)
            new_kv_cache: Updated (K, V) cache for next step
        """
        batch_size, seq_length, d_model = x.shape  # x: (B, L, d_model)

        # Compute Q, K, V projections
        Q = self.W_q(x)  # (B, L, d_model)
        K = self.W_k(x)  # (B, L, d_model)
        V = self.W_v(x)  # (B, L, d_model)

        # Reshape for multi-head attention: (B, num_head, L, d_head)
        Q = Q.view(batch_size, seq_length, self.num_head, self.d_head).transpose(1, 2)  # (B, num_head, L, d_head)
        K = K.view(batch_size, seq_length, self.num_head, self.d_head).transpose(1, 2)  # (B, num_head, L, d_head)
        V = V.view(batch_size, seq_length, self.num_head, self.d_head).transpose(1, 2)  # (B, num_head, L, d_head)

        # If KV cache exists, concatenate with current K, V
        if kv_cache is not None:
            K = torch.cat([kv_cache[0], K], dim=2)  # (B, num_head, cached_len+L, d_head)
            V = torch.cat([kv_cache[1], V], dim=2)  # (B, num_head, cached_len+L, d_head)

        # Store updated cache
        new_kv_cache = (K, V)

        # Scaled dot-product attention
        attn_scores = (Q @ K.transpose(-1, -2)) / math.sqrt(self.d_head)  # (B, num_head, L, cached_len+L)

        # Causal mask: prevent attending to future tokens
        # mask shape: (L, cached_len+L) - lower triangular for current sequence
        mask = torch.tril(torch.ones(seq_length, K.shape[2], device=x.device)).unsqueeze(0).unsqueeze(0)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))  # (B, num_head, L, cached_len+L)

        # Softmax and dropout
        attn_weight = torch.softmax(attn_scores, dim=-1)  # (B, num_head, L, cached_len+L)
        attn_weight = self.dropout(attn_weight)

        # Apply attention to values
        attn = (attn_weight @ V).transpose(1, 2)  # (B, L, num_head, d_head)
        attn = attn.contiguous().view(batch_size, seq_length, d_model)  # (B, L, d_model)

        # Residual connection + LayerNorm
        return self.norm(x + attn), new_kv_cache  # (B, L, d_model), (K, V)


class FNN(nn.Module):
    """
    Feed-Forward Network with residual connection.
    Expands to 4*d_model and projects back to d_model.
    """
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.fnn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),  # Expansion: (d_model, 4*d_model)
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model),  # Projection: (4*d_model, d_model)
            nn.Dropout(dropout)
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_length, d_model)
        Returns:
            (batch_size, seq_length, d_model)
        """
        out = self.fnn(x)  # (B, L, d_model)
        return self.norm(x + out)  # Residual + LayerNorm: (B, L, d_model)


class Decoder(nn.Module):
    """
    Single Decoder layer: Masked Multi-Head Attention + Feed-Forward Network.
    """
    def __init__(self, d_model, num_head, dropout=0.1):
        super().__init__()
        self.mmha = MaskedMultiHeadAttention(d_model, num_head, dropout)
        self.fnn = FNN(d_model, dropout)

    def forward(self, x, kv_cache=None):
        """
        Args:
            x: (batch_size, seq_length, d_model)
            kv_cache: Optional KV cache from previous step
        Returns:
            output: (batch_size, seq_length, d_model)
            new_kv_cache: Updated cache
        """
        attn_out, new_kv_cache = self.mmha(x, kv_cache)  # (B, L, d_model), cache
        return self.fnn(attn_out), new_kv_cache  # (B, L, d_model), cache


class DecoderBlock(nn.Module):
    """
    Full Decoder-only Transformer (GPT-style architecture).
    Embedding -> N x Decoder Layers -> LayerNorm -> Output projection
    """
    def __init__(self, n_layers, n_vocab, d_model, num_head, max_seq_len, dropout=0.1):
        super().__init__()
        self.embedding = Embeddding(n_vocab, d_model, max_seq_len, dropout)
        self.decoder_layers = nn.ModuleList(
            [Decoder(d_model, num_head, dropout) for _ in range(n_layers)]
        )
        self.norm = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, n_vocab)  # Output projection to vocabulary

    def forward(self, x, kv_caches=None, offset=0):
        """
        Args:
            x: Token IDs, shape (batch_size, seq_length)
            kv_caches: Optional list of KV caches (one per layer)
            offset: Position offset for positional encoding (for KV caching)
        Returns:
            logits: (batch_size, seq_length, n_vocab)
            new_kv_caches: Updated list of KV caches
        """
        x = self.embedding(x, offset=offset)  # (B, L, d_model)

        new_kv_caches = []
        for i, layer in enumerate(self.decoder_layers):
            # Get cache for this layer if available
            kv_cache = kv_caches[i] if kv_caches is not None else None
            x, new_kv_cache = layer(x, kv_cache)  # (B, L, d_model), cache
            new_kv_caches.append(new_kv_cache)

        x = self.norm(x)  # (B, L, d_model)
        logits = self.fc(x)  # (B, L, n_vocab)
        return logits, new_kv_caches


def train(model, config, train_data, val_data, device, tokenizer, epochs=20, batch_size=64, patience=3):
    """Train the model with gradient clipping and early stopping."""
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id("[PAD]"))
    scaler = GradScaler()
    best_val_loss = float('inf')
    epochs_no_improve = 0
    checkpoint_dir = "model_checkpoints"
    checkpoint_path = os.path.join(checkpoint_dir, "best_model.pt")

    for epoch in range(epochs):
        total_train_loss = 0
        model.train()
        for _ in range(config.batches_per_epoch):
            x, y = get_batch('train', batch_size, config.block_size)
            optimizer.zero_grad()
            with autocast():
                logits, _ = model(x)  # Ignore KV caches during training
                loss = criterion(logits.view(-1, config.n_vocab), y.view(-1))
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / config.batches_per_epoch

        total_val_loss = 0
        model.eval()
        valid_val_loss = len(val_data) > config.block_size
        if valid_val_loss:
            with torch.no_grad():
                for _ in range(config.batches_per_epoch // 10):
                    x, y = get_batch('val', batch_size, config.block_size)
                    with autocast():
                        logits, _ = model(x)  # Ignore KV caches during validation
                        loss = criterion(logits.view(-1, config.n_vocab), y.view(-1))
                    total_val_loss += loss.item()
            avg_val_loss = total_val_loss / (config.batches_per_epoch // 10)
            perplexity = math.exp(avg_val_loss)
            scheduler.step(avg_val_loss)
        else:
            print("Warning: Validation data too short or empty, skipping validation")
            avg_val_loss = float('inf')
            perplexity = float('inf')

        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}" +
              (f", Val Loss: {avg_val_loss:.4f}, Perplexity: {perplexity:.2f}" if valid_val_loss else ""))

        if valid_val_loss and avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            try:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                if not os.access(checkpoint_dir, os.W_OK):
                    raise PermissionError(f"Directory '{checkpoint_dir}' is not writable")
                checkpoint = {
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'epoch': epoch,
                    'best_val_loss': best_val_loss
                }
                torch.save(checkpoint, checkpoint_path)
                print(f"Saved best model with Val Loss: {best_val_loss:.4f}")
            except (FileNotFoundError, PermissionError, OSError) as e:
                print(f"Error saving model checkpoint: {e}")
                raise
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break


def generate(model, prompt, max_tokens=50, temperature=1.0, use_kv_cache=True, return_ids=False):
    """
    Generate text from the model.

    Args:
        model: The DecoderBlock model
        prompt: Input text string
        max_tokens: Number of tokens to generate
        temperature: Sampling temperature (higher = more random)
        use_kv_cache: If True, use KV caching for faster generation
        return_ids: If True, return token IDs instead of decoded text

    Returns:
        Generated text (or token IDs if return_ids=True)
    """
    if not prompt.strip():
        print("Warning: Empty prompt provided, using default")
        prompt = "[CLS]"

    model.eval()
    tokens = encode(prompt)
    if not tokens:
        tokens = [tokenizer.token_to_id("[CLS]")]
    tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)  # (1, prompt_len)

    if use_kv_cache:
        # KV caching mode: reuse previous computations
        kv_caches = None

        with torch.no_grad():
            for i in range(max_tokens):
                if i == 0:
                    # First step: process entire prompt
                    input_tokens = tokens  # (1, prompt_len)
                    offset = 0
                else:
                    # Subsequent steps: process only last token
                    input_tokens = tokens[:, -1:]  # (1, 1)
                    offset = tokens.shape[1] - 1  # Current position

                with autocast():
                    logits, kv_caches = model(input_tokens, kv_caches=kv_caches, offset=offset)

                logits = logits[:, -1, :] / temperature  # (1, n_vocab)
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)  # (1, 1)
                tokens = torch.cat([tokens, next_token], dim=1)  # (1, prompt_len+i+1)
    else:
        # Standard mode: recompute everything each step
        for _ in range(max_tokens):
            input_tokens = tokens[:, -config.block_size:]  # (1, min(len, block_size))
            with torch.no_grad():
                with autocast():
                    logits, _ = model(input_tokens)

            logits = logits[:, -1, :] / temperature  # (1, n_vocab)
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # (1, 1)
            tokens = torch.cat([tokens, next_token], dim=1)  # (1, prompt_len+i+1)

    token_ids = tokens[0].tolist()
    if return_ids:
        return token_ids
    return decode(token_ids)


if __name__ == "__main__":
    try:
        model = DecoderBlock(
            config.n_layers, config.n_vocab, config.d_model, config.num_head, config.max_seq_length, dropout=0.1
        ).to(device)
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Model size: {total_params:,} parameters ({total_params / 1e6:.2f}M)")
        train(model, config, train_data, val_data, device, tokenizer, patience=3)
        print(generate(model, "Music", max_tokens=50, use_kv_cache=True))
    except Exception as e:
        print(f"Error during execution: {e}")
        raise