
Exploration of a "number command" transformer.

Depending on which command token is after a sequence of digit's, the network is asked to perform different tasks.

example sequences:
```
<sos> 33358 <sum> 22 <eos>
<sos> 00487 <sum> 19 <eos>
<sos> 840 <sum> 12 <eos>
<sos> 70996 <sum> 31 <eos>
<sos> 2778 <sum> 24 <eos>
```

The transformer model is based on Adrej Karpathy's NanoGPT
https://github.com/karpathy/nanoGPT/tree/master



Notebook by Justin Thomas



In [9]:
import random

def generate_random_numbers_and_sum():
    """
    Generate 3–5 random digits and output:
    <sos> digits <sum> total <eos>
    """
    # Pick a random length between 3 and 5
    length = random.randint(3, 5)
    
    # Generate that many random digits
    numbers = [random.randint(0, 9) for _ in range(length)]
    
    # Calculate the sum of the numbers
    total_sum = sum(numbers)
    
    # Join the numbers into a single string without commas
    numbers_str = ''.join(map(str, numbers))
    
    # Format the output
    output = f"<sos> {numbers_str} <sum> {total_sum} <eos>"
    return output

# Example usage
for _ in range(5):
    print(generate_random_numbers_and_sum())


<sos> 33358 <sum> 22 <eos>
<sos> 00487 <sum> 19 <eos>
<sos> 840 <sum> 12 <eos>
<sos> 70996 <sum> 31 <eos>
<sos> 2778 <sum> 24 <eos>


In [8]:
import random

def generate_sequence():
    """
    Generate a sequence like:
    <sos> 1237 <even> 137137137 <eos>
    - First part is a random digit string (length 3–5).
    - Second part is a repeated subsequence of the first part.
    """
    # Generate 3–5 random digits for the first sequence
    numbers = [random.randint(0, 9) for _ in range(random.randint(3, 5))]
    numbers_str = ''.join(map(str, numbers))
    
    # Pick a random subsequence (at least 2 digits long)
    start = random.randint(0, len(numbers) - 2)
    end = random.randint(start + 1, len(numbers))
    subseq = numbers_str[start:end]
    
    # Repeat the subsequence 3 times
    repeated = subseq * 3
    
    # Format the output
    output = f"<sos> {numbers_str} <even> {repeated} <eos>"
    return output

# Example usage
for _ in range(5):
    print(generate_sequence())


<sos> 204 <even> 222 <eos>
<sos> 528 <even> 282828 <eos>
<sos> 0972 <even> 000 <eos>
<sos> 0503 <even> 000 <eos>
<sos> 4617 <even> 444 <eos>


In [10]:
import random

def generate_random_numbers_and_reverse():
    """
    Generate 3–5 random digits and output:
    <sos> digits <reverse> reversed_digits <eos>
    """
    # Pick a random length between 3 and 5
    length = random.randint(3, 5)
    
    # Generate that many random digits
    numbers = [random.randint(0, 9) for _ in range(length)]
    
    # Join the numbers into a string
    numbers_str = ''.join(map(str, numbers))
    
    # Reverse the string
    reversed_str = numbers_str[::-1]
    
    # Format the output
    output = f"<sos> {numbers_str} <reverse> {reversed_str} <eos>"
    return output

# Example usage
for _ in range(5):
    print(generate_random_numbers_and_reverse())


<sos> 4355 <reverse> 5534 <eos>
<sos> 16505 <reverse> 50561 <eos>
<sos> 78667 <reverse> 76687 <eos>
<sos> 7467 <reverse> 7647 <eos>
<sos> 80364 <reverse> 46308 <eos>


In [11]:
import random

def generate_random_numbers_and_sort():
    """
    Generate 3–5 random digits and output either:
    <sos> digits <ascending> sorted_digits <eos>
    or
    <sos> digits <descending> sorted_digits <eos>
    """
    # Pick a random length between 3 and 5
    length = random.randint(3, 5)
    
    # Generate that many random digits
    numbers = [random.randint(0, 9) for _ in range(length)]
    numbers_str = ''.join(map(str, numbers))
    
    # Randomly choose ascending or descending
    if random.choice([True, False]):
        sorted_digits = ''.join(map(str, sorted(numbers)))
        command = "ascending"
    else:
        sorted_digits = ''.join(map(str, sorted(numbers, reverse=True)))
        command = "descending"
    
    # Format the output
    output = f"<sos> {numbers_str} <{command}> {sorted_digits} <eos>"
    return output

# Example usage
for _ in range(5):
    print(generate_random_numbers_and_sort())


<sos> 19425 <descending> 95421 <eos>
<sos> 5554 <ascending> 4555 <eos>
<sos> 465 <descending> 654 <eos>
<sos> 418 <ascending> 148 <eos>
<sos> 67816 <ascending> 16678 <eos>


In [2]:
import random
import pickle
import numpy as np
from typing import List, Tuple
import os

class CustomTokenizer:
    def __init__(self):
        # Define vocabulary 
        self.tokens = [
            "<sos>", "<eos>", "<sum>", "<reverse>", "<ascending>", "<descending>", "<even>"
        ] + [str(i) for i in range(10)]
        
        # Token ↔ ID mappings
        self.token2id = {tok: idx for idx, tok in enumerate(self.tokens)}
        self.id2token = {idx: tok for tok, idx in self.token2id.items()}
        self.vocab_size = len(self.tokens)
    
    def encode(self, text: str) -> List[int]:
        """
        Convert string to list of token IDs.
        - Matches tags and digits
        - Ignores spaces
        """
        tokens = []
        i = 0
        while i < len(text):
            if text[i] == " ":
                i += 1
                continue
            
            if text[i] == "<":  # possible special token
                j = text.find(">", i)
                if j != -1:
                    tok = text[i:j+1]
                    if tok in self.token2id:
                        tokens.append(self.token2id[tok])
                        i = j + 1
                        continue
            
            if text[i].isdigit():
                tokens.append(self.token2id[text[i]])
                i += 1
            else:
                raise ValueError(f"Unexpected character: {text[i]}")
        
        return tokens
    
    def decode(self, token_ids: List[int]) -> str:
        """Convert list of IDs back to string, preserving original format."""
        tokens = [self.id2token[i] for i in token_ids]
        result = ""
        current_number = ""
        
        for i, token in enumerate(tokens):
            if token.startswith("<") and token.endswith(">"):
                # Finish any current number
                if current_number:
                    if result and not result.endswith(" "):
                        result += " "
                    result += current_number
                    current_number = ""
                
                # Add special token with spaces
                if result and not result.endswith(" "):
                    result += " "
                result += token
                
            else:
                # Accumulate digits into current number
                current_number += token
        
        # Add final number if exists
        if current_number:
            if result and not result.endswith(" "):
                result += " "
            result += current_number
        
        return result

def generate_sum_example():
    """Generate sum command example"""
    length = random.randint(3, 5)
    numbers = [random.randint(0, 9) for _ in range(length)]
    total_sum = sum(numbers)
    numbers_str = ''.join(map(str, numbers))
    return f"<sos> {numbers_str} <sum> {total_sum} <eos>"

def generate_reverse_example():
    """Generate reverse command example"""
    length = random.randint(3, 5)
    numbers = [random.randint(0, 9) for _ in range(length)]
    numbers_str = ''.join(map(str, numbers))
    reversed_str = numbers_str[::-1]
    return f"<sos> {numbers_str} <reverse> {reversed_str} <eos>"

def generate_sort_example():
    """Generate ascending or descending sort example"""
    length = random.randint(3, 5)
    numbers = [random.randint(0, 9) for _ in range(length)]
    numbers_str = ''.join(map(str, numbers))
    
    if random.choice([True, False]):
        sorted_digits = ''.join(map(str, sorted(numbers)))
        command = "ascending"
    else:
        sorted_digits = ''.join(map(str, sorted(numbers, reverse=True)))
        command = "descending"
    
    return f"<sos> {numbers_str} <{command}> {sorted_digits} <eos>"

def generate_even_example():
    """Generate even (repetition) command example"""
    numbers = [random.randint(0, 9) for _ in range(random.randint(3, 5))]
    numbers_str = ''.join(map(str, numbers))
    
    # Pick a random subsequence (at least 1 digit long)
    start = random.randint(0, len(numbers) - 1)
    end = random.randint(start + 1, len(numbers))
    subseq = numbers_str[start:end]
    
    # Repeat the subsequence 3 times
    repeated = subseq * 3
    
    return f"<sos> {numbers_str} <even> {repeated} <eos>"

def generate_dataset(num_examples: int = 50000) -> List[str]:
    """Generate a balanced dataset with all command types"""
    examples = []
    generators = [generate_sum_example, generate_reverse_example, 
                 generate_sort_example, generate_even_example]
    
    examples_per_type = num_examples // len(generators)
    
    for generator in generators:
        for _ in range(examples_per_type):
            examples.append(generator())
    
    # Add remaining examples to reach exact count
    remaining = num_examples - len(examples)
    for _ in range(remaining):
        generator = random.choice(generators)
        examples.append(generator())
    
    # Shuffle the dataset
    random.shuffle(examples)
    return examples

def create_training_data(examples: List[str], tokenizer: CustomTokenizer) -> Tuple[np.ndarray, dict]:
    """Convert examples to tokenized training data"""
    tokenized_examples = []
    max_length = 0
    
    # Tokenize all examples and find max length
    for example in examples:
        tokens = tokenizer.encode(example)
        tokenized_examples.append(tokens)
        max_length = max(max_length, len(tokens))
    
    print(f"Maximum sequence length: {max_length}")
    print(f"Vocabulary size: {tokenizer.vocab_size}")
    
    # Convert to numpy array (pad sequences to max length)
    data = np.full((len(examples), max_length), tokenizer.token2id["<eos>"], dtype=np.int64)
    
    for i, tokens in enumerate(tokenized_examples):
        data[i, :len(tokens)] = tokens
    
    # Create metadata
    meta = {
        'vocab_size': tokenizer.vocab_size,
        'max_length': max_length,
        'num_examples': len(examples),
        'token2id': tokenizer.token2id,
        'id2token': tokenizer.id2token
    }
    
    return data, meta

def test_tokenizer():
    """Test the tokenizer with sample examples"""
    print("Testing Tokenizer...")
    print("=" * 50)
    
    tokenizer = CustomTokenizer()
    
    # Test with each type of example
    test_examples = [
        generate_sum_example(),
        generate_reverse_example(), 
        generate_sort_example(),
        generate_even_example()
    ]
    
    for i, example in enumerate(test_examples):
        print(f"\nTest {i+1}:")
        print(f"Original: {example}")
        
        # Encode
        encoded = tokenizer.encode(example)
        print(f"Encoded:  {encoded}")
        
        # Decode
        decoded = tokenizer.decode(encoded)
        print(f"Decoded:  {decoded}")
        
        # Check if round-trip works (normalize whitespace for comparison)
        original_normalized = " ".join(example.split())
        decoded_normalized = " ".join(decoded.split())
        success = original_normalized == decoded_normalized
        print(f"Round-trip successful: {success}")
        
        if not success:
            print(f"ERROR: Round-trip failed!")
            print(f"Expected: '{original_normalized}'")
            print(f"Got:      '{decoded_normalized}'")
    
    print(f"\nVocabulary size: {tokenizer.vocab_size}")
    print(f"Tokens: {tokenizer.tokens}")

def save_dataset(train_data: np.ndarray, val_data: np.ndarray, meta: dict, data_dir: str = "data/number_commands"):
    """Save dataset to files"""
    os.makedirs(data_dir, exist_ok=True)
    
    # Save binary data
    train_data.astype(np.uint16).tofile(os.path.join(data_dir, 'train.bin'))
    val_data.astype(np.uint16).tofile(os.path.join(data_dir, 'val.bin'))
    
    # Save metadata
    with open(os.path.join(data_dir, 'meta.pkl'), 'wb') as f:
        pickle.dump(meta, f)
    
    print(f"Dataset saved to {data_dir}")
    print(f"Train examples: {len(train_data)}")
    print(f"Val examples: {len(val_data)}")



In [3]:
# Example usage
tokenizer = CustomTokenizer()

sample = "<sos> 482 <ascending> 248 <eos>"
encoded = tokenizer.encode(sample)
decoded = tokenizer.decode(encoded)

print("Original:", sample)
print("Encoded:", encoded)
print("Decoded:", decoded)

Original: <sos> 482 <ascending> 248 <eos>
Encoded: [0, 11, 15, 9, 4, 9, 11, 15, 1]
Decoded: <sos> 482 <ascending> 248 <eos>


In [19]:
# Test tokenizer first
test_tokenizer()

print("\n" + "="*50)
print("Generating Dataset...")
print("="*50)

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

# Generate dataset
print("Generating examples...")
all_examples = generate_dataset(num_examples=60000)

# Split into train/val
split_idx = int(0.9 * len(all_examples))
train_examples = all_examples[:split_idx]
val_examples = all_examples[split_idx:]

print(f"Generated {len(all_examples)} total examples")
print(f"Train: {len(train_examples)}, Val: {len(val_examples)}")

# Create tokenizer and convert to training data
tokenizer = CustomTokenizer()

print("\nTokenizing training data...")
train_data, meta = create_training_data(train_examples, tokenizer)

print("Tokenizing validation data...")
val_data, _ = create_training_data(val_examples, tokenizer)

# Save dataset
save_dataset(train_data, val_data, meta)

# Show some statistics
print("\nDataset Statistics:")
print(f"Vocabulary size: {meta['vocab_size']}")
print(f"Max sequence length: {meta['max_length']}")
print(f"Train data shape: {train_data.shape}")
print(f"Val data shape: {val_data.shape}")

# Show sample training examples
print("\nSample training examples:")
for i in range(5):
    tokens = train_data[i]
    # Remove padding
    tokens = tokens[tokens != tokenizer.token2id["<eos>"]]
    tokens = np.append(tokens, tokenizer.token2id["<eos>"])  # Add back one EOS
    decoded = tokenizer.decode(tokens.tolist())
    print(f"{i+1}: {decoded}")


Testing Tokenizer...

Test 1:
Original: <sos> 8704 <sum> 19 <eos>
Encoded:  [0, 15, 14, 7, 11, 2, 8, 16, 1]
Decoded:  <sos> 8704 <sum> 19 <eos>
Round-trip successful: True

Test 2:
Original: <sos> 2626 <reverse> 6262 <eos>
Encoded:  [0, 9, 13, 9, 13, 3, 13, 9, 13, 9, 1]
Decoded:  <sos> 2626 <reverse> 6262 <eos>
Round-trip successful: True

Test 3:
Original: <sos> 590 <ascending> 059 <eos>
Encoded:  [0, 12, 16, 7, 4, 7, 12, 16, 1]
Decoded:  <sos> 590 <ascending> 059 <eos>
Round-trip successful: True

Test 4:
Original: <sos> 210 <even> 111 <eos>
Encoded:  [0, 9, 8, 7, 6, 8, 8, 8, 1]
Decoded:  <sos> 210 <even> 111 <eos>
Round-trip successful: True

Vocabulary size: 17
Tokens: ['<sos>', '<eos>', '<sum>', '<reverse>', '<ascending>', '<descending>', '<even>', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

Generating Dataset...
Generating examples...
Generated 60000 total examples
Train: 54000, Val: 6000

Tokenizing training data...
Maximum sequence length: 23
Vocabulary size: 17
Tokenizi

In [3]:
import math
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

class LayerNorm(nn.Module):
    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class NumberCommandConfig:
    block_size: int = 32  # Small sequences for number commands
    vocab_size: int = 17  # Will be set from dataset
    n_layer: int = 4      # Smaller model for this simple task
    n_head: int = 4       # Fewer attention heads
    n_embd: int = 128     # Smaller embedding dimension
    dropout: float = 0.1  # Some dropout for regularization
    bias: bool = True     # Keep bias terms

class NumberCommandTransformer(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight tying - share embeddings with output layer
        self.transformer.wte.weight = self.lm_head.weight

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """Estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS"""
        # First estimate the number of flops we do per iteration.
        # See PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # Express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt)  # per second
        flops_promised = 312e12  # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, eos_token_id=None):
        batch_size = idx.size(0)
        finished = torch.zeros(batch_size, dtype=torch.bool, device=idx.device)
    
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
    
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
    
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
    
            # Append only for unfinished sequences
            idx = torch.cat((idx, idx_next), dim=1)
    
            if eos_token_id is not None:
                # Update finished mask
                finished = finished | (idx_next.squeeze(1) == eos_token_id)
                # If all sequences finished, we can break early
                if finished.all():
                    break
    
            # Optionally, you could replace tokens in finished sequences with eos_token_id to pad
            idx_next[finished.unsqueeze(1)] = eos_token_id
    
        return idx


In [4]:
# ALL-IN-ONE TRAINING SCRIPT FOR NUMBER COMMAND TRANSFORMER
# 

import os
import time
import math
import pickle
import platform
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name())

# Check if we have the required data
data_dir = 'data/number_commands'
if not os.path.exists(os.path.join(data_dir, 'train.bin')):
    raise FileNotFoundError(f"Training data not found at {data_dir}. Please run dataset generation first.")

print("✓ Training data found")

# Load metadata
meta_path = os.path.join(data_dir, 'meta.pkl')
with open(meta_path, 'rb') as f:
    meta = pickle.load(f)

vocab_size = meta['vocab_size']
max_length = meta['max_length']

print(f"✓ Vocab size: {vocab_size}")
print(f"✓ Max sequence length: {max_length}")

# Configuration
config = {
    'vocab_size': vocab_size,
    'block_size': min(max_length, 32),  # Use actual max length or 32, whichever is smaller
    'n_layer': 4,
    'n_head': 4, 
    'n_embd': 8, #worked iwth 128, let's try smaller for our massive 17 token vocab
    'dropout': 0.1,
    'bias': True,
    'batch_size': 64,
    'learning_rate': 3e-4,
    'max_iters': 5000,
    'eval_interval': 500,
    'log_interval': 50,
    'warmup_iters': 200,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'compile': torch.cuda.is_available() and platform.system() != 'Windows',
    'out_dir': 'out_number_commands'
}

print(f"✓ Using device: {config['device']}")
print(f"✓ Model compilation: {config['compile']}")

# Set up device and dtype
device = config['device']
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)

print(f"✓ Using dtype: {dtype}")

# Create output directory
os.makedirs(config['out_dir'], exist_ok=True)

# Set seeds
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# Data loading function
def get_batch(split, batch_size, block_size):
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    
    if device == 'cuda':
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# Test data loading
print("✓ Testing data loading...")
test_x, test_y = get_batch('train', 2, config['block_size'])
print(f"✓ Batch shape: {test_x.shape}, {test_y.shape}")

# Create and initialize model (model classes should be defined already)
try:
    model_config = NumberCommandConfig(
        vocab_size=config['vocab_size'],
        block_size=config['block_size'],
        n_layer=config['n_layer'],
        n_head=config['n_head'],
        n_embd=config['n_embd'],
        dropout=config['dropout'],
        bias=config['bias']
    )
    
    model = NumberCommandTransformer(model_config)
    model.to(device)
    
    print(f"✓ Model created with {model.get_num_params()/1e6:.2f}M parameters")
    
except NameError:
    print("❌ Model classes not found! Please run the model definition cell first.")
    raise

# Initialize training components
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
optimizer = model.configure_optimizers(
    weight_decay=1e-2,
    learning_rate=config['learning_rate'], 
    betas=(0.9, 0.95),
    device_type=device
)

# Compile model with error handling
if config['compile']:
    try:
        print("Compiling model...")
        unoptimized_model = model
        model = torch.compile(model)
        print("✓ Model compilation successful!")
    except Exception as e:
        print(f"⚠ Model compilation failed: {e}")
        print("Continuing without compilation...")

# Loss estimation function
@torch.no_grad()
def estimate_loss(eval_iters=100):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, config['batch_size'], config['block_size'])
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# Learning rate scheduler
def get_lr(it):
    if it < config['warmup_iters']:
        return config['learning_rate'] * (it + 1) / (config['warmup_iters'] + 1)
    if it > config['max_iters']:
        return config['learning_rate'] * 0.1
    decay_ratio = (it - config['warmup_iters']) / (config['max_iters'] - config['warmup_iters'])
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return config['learning_rate'] * 0.1 + coeff * (config['learning_rate'] - config['learning_rate'] * 0.1)

# Training loop
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

iter_num = 0
best_val_loss = float('inf')
t0 = time.time()

# Initial evaluation
losses = estimate_loss()
print(f"Initial: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

while iter_num < config['max_iters']:
    # Set learning rate
    lr = get_lr(iter_num)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    # Evaluate and save checkpoints
    if iter_num % config['eval_interval'] == 0:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {lr:.2e}")
        
        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'model_args': config,
                'iter_num': iter_num,
                'best_val_loss': best_val_loss,
            }
            torch.save(checkpoint, os.path.join(config['out_dir'], 'ckpt.pt'))
            print(f"✓ Saved checkpoint (val_loss: {best_val_loss:.4f})")
    
    # Training step
    X, Y = get_batch('train', config['batch_size'], config['block_size'])
    
    with ctx:
        logits, loss = model(X, Y)
    
    # Backward pass
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)
    
    # Logging
    if iter_num % config['log_interval'] == 0:
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.1f}ms")
    
    iter_num += 1

# Final evaluation
losses = estimate_loss()
print(f"\nFinal: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
print(f"Training completed! Best val loss: {best_val_loss:.4f}")
print(f"Model saved to: {config['out_dir']}/ckpt.pt")

PyTorch version: 2.8.0+cpu
CUDA available: False
✓ Training data found
✓ Vocab size: 17
✓ Max sequence length: 23
✓ Using device: cpu
✓ Model compilation: False
✓ Using dtype: float16
✓ Testing data loading...
✓ Batch shape: torch.Size([2, 23]), torch.Size([2, 23])
number of parameters: 0.00M
✓ Model created with 0.00M parameters
num decayed parameter tensors: 18, with 3,392 parameters
num non-decayed parameter tensors: 34, with 432 parameters
using fused AdamW: False

STARTING TRAINING


  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


Initial: train loss 2.7825, val loss 2.7827
step 0: train loss 2.7826, val loss 2.7827, lr 1.49e-06
✓ Saved checkpoint (val_loss: 2.7827)
iter 0: loss 2.7831, time 4895.1ms
iter 50: loss 2.7581, time 2883.7ms
iter 100: loss 2.7063, time 2912.5ms
iter 150: loss 2.6333, time 2926.0ms
iter 200: loss 2.5404, time 2893.1ms
iter 250: loss 2.3827, time 2912.6ms
iter 300: loss 2.2332, time 2904.1ms
iter 350: loss 2.0792, time 2886.0ms
iter 400: loss 2.0018, time 2915.3ms
iter 450: loss 1.8971, time 2939.1ms
step 500: train loss 1.7412, val loss 1.7450, lr 2.97e-04
✓ Saved checkpoint (val_loss: 1.7450)
iter 500: loss 1.7572, time 5273.7ms
iter 550: loss 1.7201, time 2926.8ms
iter 600: loss 1.5668, time 2836.5ms
iter 650: loss 1.4624, time 2865.6ms
iter 700: loss 1.4315, time 2859.4ms
iter 750: loss 1.3545, time 2883.5ms
iter 800: loss 1.3681, time 2850.3ms
iter 850: loss 1.2914, time 2863.8ms
iter 900: loss 1.3104, time 2882.0ms
iter 950: loss 1.3066, time 2871.1ms
step 1000: train loss 1.2771,

In [7]:
import torch
import pickle
import os
import numpy as np
from typing import List

# Load model and tokenizer
def load_trained_model(checkpoint_path: str = 'out_number_commands/ckpt.pt'):
    """Load the trained model and tokenizer"""
    
    try:
        # Load metadata first
        meta_path = 'data/number_commands/meta.pkl'
        with open(meta_path, 'rb') as f:
            meta = pickle.load(f)
    except FileNotFoundError:
        raise FileNotFoundError("Dataset metadata not found. Please run dataset generation first.")
    
    # Recreate tokenizer
    class CustomTokenizer:
        def __init__(self, token2id, id2token):
            self.token2id = token2id
            self.id2token = id2token
            self.vocab_size = len(token2id)
        
        def encode(self, text: str) -> List[int]:
            tokens = []
            i = 0
            while i < len(text):
                if text[i] == " ":
                    i += 1
                    continue
                
                if text[i] == "<":
                    j = text.find(">", i)
                    if j != -1:
                        tok = text[i:j+1]
                        if tok in self.token2id:
                            tokens.append(self.token2id[tok])
                            i = j + 1
                            continue
                
                if text[i].isdigit():
                    tokens.append(self.token2id[text[i]])
                    i += 1
                else:
                    raise ValueError(f"Unexpected character: {text[i]}")
            
            return tokens
        
        def decode(self, token_ids: List[int]) -> str:
            tokens = [self.id2token[i] for i in token_ids]
            result = ""
            current_number = ""
            
            for i, token in enumerate(tokens):
                if token.startswith("<") and token.endswith(">"):
                    # Finish any current number
                    if current_number:
                        if result and not result.endswith(" "):
                            result += " "
                        result += current_number
                        current_number = ""
                    
                    # Add special token with spaces
                    if result and not result.endswith(" "):
                        result += " "
                    result += token
                    
                else:
                    # Accumulate digits into current number
                    current_number += token
            
            # Add final number if exists
            if current_number:
                if result and not result.endswith(" "):
                    result += " "
                result += current_number
            
            return result
    
    tokenizer = CustomTokenizer(meta['token2id'], meta['id2token'])
    
    # Load model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
    except FileNotFoundError:
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please train the model first.")
    
    model_args = checkpoint['model_args']
    
    # Handle both old and new checkpoint formats
    if isinstance(model_args, dict):
        # New format - direct config dict
        try:
            config = NumberCommandConfig(**model_args)
        except TypeError:
            # Fallback for missing config class
            print("Warning: Using fallback model configuration")
            config = type('Config', (), model_args)()
    else:
        # Old format - might be a config object
        config = model_args
    
    try:
        model = NumberCommandTransformer(config)
    except NameError:
        raise NameError("Model classes not found. Please run the model definition cell first.")
    
    # Remove potential compilation prefix
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    print(f"✓ Model loaded from {checkpoint_path}")
    print(f"✓ Model has {model.get_num_params()/1e6:.2f}M parameters")
    print(f"✓ Using device: {device}")
    
    return model, tokenizer, device

def test_model_completion(model, tokenizer, device, prompt: str, max_new_tokens: int = 15, temperature: float = 0.1):
    """Test model completion given a prompt"""
    
    # Encode prompt
    prompt_tokens = tokenizer.encode(prompt)
    prompt_tensor = torch.tensor(prompt_tokens, dtype=torch.long, device=device).unsqueeze(0)
    
    # Generate completion
    with torch.no_grad():
        generated = model.generate(prompt_tensor, max_new_tokens=max_new_tokens, temperature=temperature)
    
    # Decode result
    generated_tokens = generated[0].tolist()
    full_response = tokenizer.decode(generated_tokens)
    
    return full_response, generated_tokens

def calculate_expected_result(input_str: str) -> str:
    """Calculate what the expected output should be"""
    
    # Parse the input
    parts = input_str.strip().split()
    if len(parts) < 3 or parts[0] != '<sos>':
        return "Invalid format"
    
    numbers_str = parts[1]
    command = parts[2]
    
    # Convert string to list of digits
    try:
        digits = [int(d) for d in numbers_str]
    except ValueError:
        return "Invalid numbers"
    
    if command == '<sum>':
        result = str(sum(digits))
    elif command == '<reverse>':
        result = numbers_str[::-1]
    elif command == '<ascending>':
        result = ''.join(map(str, sorted(digits)))
    elif command == '<descending>':
        result = ''.join(map(str, sorted(digits, reverse=True)))
    elif command == '<even>':
        # For <even>, we need to find a subsequence and repeat it
        # This is harder to predict exactly, so we'll just note it
        result = "SUBSEQUENCE_REPEATED"
    else:
        return "Unknown command"
    
    return f"<sos> {numbers_str} {command} {result} <eos>"

def run_comprehensive_test(model, tokenizer, device):
    """Run comprehensive tests on the model"""
    
    test_cases = [
        # Sum tests
        "<sos> 123 <sum>",
        "<sos> 1232 <sum>",
        "<sos> 92999 <sum>",
        
        # Reverse tests
        "<sos> 1765 <reverse>",
        "<sos> 9235 <reverse>",
        "<sos> 505 <reverse>",
        
        # Ascending sort tests
        "<sos> 5132 <ascending>",
        "<sos> 19283 <ascending>",
        "<sos> 5231 <ascending>",
        
        # Descending sort tests
        "<sos> 1562 <descending>",
        "<sos> 91735 <descending>",
        "<sos> 482 <descending>",
    ]
    
    print("=" * 80)
    print("COMPREHENSIVE MODEL TESTING")
    print("=" * 80)
    
    correct = 0
    total = 0
    
    for i, test_case in enumerate(test_cases):
        print(f"\nTest {i+1}: {test_case}")
        
        # Get model prediction
        prediction, tokens = test_model_completion(model, tokenizer, device, test_case)
        print(f"Model output: {prediction}")
        
        # Calculate expected result
        expected = calculate_expected_result(test_case)
        print(f"Expected:     {expected}")
        
        # Check if correct (for non-even commands)
        if not '<even>' in test_case:
            is_correct = prediction.strip() == expected.strip()
            print(f"Correct: {'✓' if is_correct else '✗'}")
            if is_correct:
                correct += 1
            total += 1
        else:
            print("Correct: (even command - manual check needed)")
        
        print("-" * 40)
    
    if total > 0:
        accuracy = correct / total * 100
        print(f"\nOverall Accuracy: {correct}/{total} = {accuracy:.1f}%")
    
    return correct, total

def interactive_test(model, tokenizer, device):
    """Interactive testing interface"""
    
    print("\n" + "="*60)
    print("INTERACTIVE TESTING")
    print("Enter prompts like: <sos> 1234 <sum>")
    print("Type 'quit' to exit")
    print("="*60)
    
    while True:
        try:
            prompt = input("\nEnter prompt: ").strip()
            if prompt.lower() == 'quit':
                break
            
            if not prompt:
                continue
            
            # Test the prompt
            prediction, tokens = test_model_completion(model, tokenizer, device, prompt)
            print(f"Model output: {prediction}")
            
            # Show expected if possible
            try:
                expected = calculate_expected_result(prompt)
                if "SUBSEQUENCE" not in expected and "Invalid" not in expected and "Unknown" not in expected:
                    print(f"Expected:     {expected}")
            except:
                pass
                
        except KeyboardInterrupt:
            print("\nExiting...")
            break
        except Exception as e:
            print(f"Error: {e}")



In [8]:
try:
    # Load the trained model
    model, tokenizer, device = load_trained_model()
    
    # Run comprehensive tests
    correct, total = run_comprehensive_test(model, tokenizer, device)
    
    # Start interactive testing
    interactive_test(model, tokenizer, device)
    
except FileNotFoundError as e:
    print(f"Error: Could not find model files. Make sure you've trained the model first.")
    print("Expected files:")
    print("  - out_number_commands/ckpt.pt")
    print("  - data/number_commands/meta.pkl")
except Exception as e:
    print(f"Error loading model: {e}")


number of parameters: 0.00M
✓ Model loaded from out_number_commands/ckpt.pt
✓ Model has 0.00M parameters
✓ Using device: cpu
COMPREHENSIVE MODEL TESTING

Test 1: <sos> 123 <sum>
Error loading model: can't assign a NoneType to a torch.LongTensor
