
# Exploration of a "number command" transformer.

### Depending on which command token is after a sequence of digit's, the transformer is asked to perform different tasks.



The idea is to test if initially the transformer can perform its task given near full demonstations. If given examples of these multiple tasks, can the varied task tokens properly signal the transformer to change it's behaviour?
I hope to also explore by how many parameter can the model be reduce before its performance completely degrades.

I tried to vary the sequence length of the inputs on some functions. This results in some functions having shorter sequences. This was intentional to test what will happen when we prompt it with sequences that are 7 numbers long.
To challenge the Transformer I might try to se what it does if the training data contains gaps in sequence length. For example, if one task only has examples of sequences of length 1,3,5 and another has 2, 4,7, etc


What happens when two commands are entered in succession? Will it do both?

Which commands conflict with each other?

## example sequences:

```
<sos> 55,50,14,9,10,36,56,53,47,21 <sum> 351 <eos>
<sos> 44,30,61 <even_repeat> 44,30,44,30,44,30,44,30,44,30 <eos>
<sos> 50,82,41 <odd_repeat> 41,41,41,41,41,41,41,41,41,41 <eos>
<sos> 82,77,9,57,84,5,42,86 <reverse> 86,42,5,84,57,9,77,82 <eos>
<sos> 425,616,162,221,244 <descending> 616,425,244,221,162 <eos>
<sos> 190,676,440 <ascending> 190,440,676 <eos>
```

### Functions are bounded to produce combinations that do not exceed 999 in the training data.


The transformer model used is based on Adrej Karpathy's NanoGPT
https://github.com/karpathy/nanoGPT/tree/master




Notebook by Justin Thomas August 2025

[Current Status] it trains but I'm having an issue with testing/inference. This is 2.0 version of this notebook. The 1.0 version trained properly-ish, using a number range of 1-9.


## GOALS:



Ideas for improvements:

done- increase the range of numbers to double and triple digit numbers.

done - add commas to seperate the values.

done - wider range of sequence lengths

done- ensure there are no duplicate examples, and that the test set should include examples that do not occur in the training data.

## extra task ideas

Might need to systematically create examples rather than randomly. such as iterating through each possibility. Ehh but once I add in the apropriate complexity, adding every combination should be infeasible

Could have tasks were the answer should be seperated by commas.



operate on floating point numbers, using a special token for how many decimal places to include.

## Tricky ideas for testing generalization

have input sequence of 1-5, then try a 6th input digit.
try have input sequence of 1-5, and 7. test the output on a six digit sequence.


# vocab expansion - up to 999 for outputs

## Todo

Fix the testing script.

Ensure the EOS token is handled correctly by the tokenizer

Add table of




GPT suggestion:

Got it 👍 Thanks for pasting everything — I see exactly what’s happening.

Your error:

```
Error during testing: can't assign a NoneType to a torch.cuda.LongTensor
```

comes from this line in your `generate` method:

```python
idx_next[finished.unsqueeze(1)] = eos_token_id
```

---

### Why it happens

* You only pass `eos_token_id=None` into `generate` (since you didn’t specify it when calling).
* Then in the code, it still tries to do

  ```python
  idx_next[finished.unsqueeze(1)] = eos_token_id
  ```

  even when `eos_token_id` is `None`.
* PyTorch won’t let you assign `None` into a tensor, hence the crash.

---

### Fix

We just need to make sure we **only assign EOS tokens when one is actually provided**. Modify your `generate` method like this:

```python
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, eos_token_id=None):
    batch_size = idx.size(0)
    finished = torch.zeros(batch_size, dtype=torch.bool, device=idx.device)

    for _ in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
        logits, _ = self(idx_cond)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        # Append token
        idx = torch.cat((idx, idx_next), dim=1)

        if eos_token_id is not None:
            # Update finished mask
            finished = finished | (idx_next.squeeze(1) == eos_token_id)
            # If all sequences finished, break early
            if finished.all():
                break
            # Replace tokens in finished sequences with eos_token_id
            idx_next[finished.unsqueeze(1)] = eos_token_id

    return idx
```

---

### Extra note

Since your dataset clearly has `<eos>`, you should probably **tell the model what token ID that is**.
Right now you’re not giving `eos_token_id` when calling:

```python
generated = model.generate(prompt_tensor, max_new_tokens=max_new_tokens, temperature=temperature)
```

Change it to:

```python
eos_id = tokenizer.token2id.get("<eos>", None)
generated = model.generate(prompt_tensor, max_new_tokens=max_new_tokens, temperature=temperature, eos_token_id=eos_id)
```

That way the model stops at `<eos>` properly.

---

✅ So the **two-part fix** is:

1. Patch `generate` to only assign `eos_token_id` if it’s not `None`.
2. Pass the actual `<eos>` ID from your tokenizer when you call `generate`.

---

Do you want me to paste your **entire corrected code file** with those fixes in place (so you can drop it in), or just the minimal diffs?


In [1]:
import random

#Revised Version

def generate_random_numbers_and_sum():
    """
    Generate 3–10 random digits, each between 0 and 99, and output:
    <sos> (digits separated by commas) <sum> total <eos>
    """
    # Pick a random length between 3 and 10
    length = random.randint(3, 10)

    # Generate that many random digits, each between 0 and 99
    numbers = [random.randint(0, 99) for _ in range(length)]

    # Calculate the sum of the numbers
    total_sum = sum(numbers)

    # Join the numbers into a single string separated by commas
    numbers_str = ','.join(map(str, numbers))

    # Format the output
    output = f"<sos> {numbers_str} <sum> {total_sum} <eos>"
    return output

for _ in range(25):
    print(generate_random_numbers_and_sum())


<sos> 30,65,70,79,10,29,3,67,37 <sum> 390 <eos>
<sos> 75,94,21,11,31,53,43,96,34 <sum> 458 <eos>
<sos> 22,20,28,91,56,96,52,14 <sum> 379 <eos>
<sos> 47,70,73 <sum> 190 <eos>
<sos> 59,37,65 <sum> 161 <eos>
<sos> 35,75,61,37,16,85,67,22,65 <sum> 463 <eos>
<sos> 26,13,1,10,56,49,21,0 <sum> 176 <eos>
<sos> 58,15,76,73 <sum> 222 <eos>
<sos> 64,88,59,82,37,10,52 <sum> 392 <eos>
<sos> 31,33,86,76,98,58,53,71 <sum> 506 <eos>
<sos> 7,38,4,67,33,33,46,49,39 <sum> 316 <eos>
<sos> 78,80,96,16,30 <sum> 300 <eos>
<sos> 9,21,45 <sum> 75 <eos>
<sos> 66,14,88,63,63 <sum> 294 <eos>
<sos> 27,43,43,49,43,70,51,22,83,24 <sum> 455 <eos>
<sos> 86,5,72,57,68,78,98 <sum> 464 <eos>
<sos> 55,8,60,72,68,75,7,36,32,63 <sum> 476 <eos>
<sos> 22,55,61 <sum> 138 <eos>
<sos> 35,49,11,88,79,55,68,94,80 <sum> 559 <eos>
<sos> 44,42,90,63,43,72,54,97,26 <sum> 531 <eos>
<sos> 77,93,32,55,87,27 <sum> 371 <eos>
<sos> 16,62,38,14 <sum> 130 <eos>
<sos> 0,80,7,66,40,83 <sum> 276 <eos>
<sos> 67,76,26,1,73,54 <sum> 297 <eos>
<sos>

#Funtions defined

generate_random_numbers_and_sum()

generate_even_repeat_sequence()

generate_odd_repeat_sequence()

generate_random_numbers_and_reverse()

generate_ascending_sort()

generate_descending_sort()


#Functions are bounded to produce combinations that do not exceed 999.

# examples:

```

<sos> 55,50,14,9,10,36,56,53,47,21 <sum> 351 <eos>
<sos> 44,30,61 <even_repeat> 44,30,44,30,44,30,44,30,44,30 <eos>
<sos> 50,82,41 <odd_repeat> 41,41,41,41,41,41,41,41,41,41 <eos>
<sos> 82,77,9,57,84,5,42,86 <reverse> 86,42,5,84,57,9,77,82 <eos>
<sos> 425,616,162,221,244 <descending> 616,425,244,221,162 <eos>
<sos> 190,676,440 <ascending> 190,440,676 <eos>

```

In [2]:
import random

def generate_even_repeat_sequence():
    """
    Generate a sequence like:
    <sos> 12,37,84,96 <even_repeat> 12,84,96,12,84,96,12,84,96,12 <eos>
    - First part is random digits 0-99 with commas (length 3-5)
    - Second part repeats only even numbers up to 10 total numbers
    """
    # Generate 3-5 random numbers from 0-99
    numbers = [random.randint(0, 99) for _ in range(random.randint(3, 5))]
    numbers_str = ','.join(map(str, numbers))

    # Filter for even numbers
    even_numbers = [num for num in numbers if num % 2 == 0]

    # If no even numbers, return sequence with empty repeat section
    if not even_numbers:
        output = f"<sos> {numbers_str} <even_repeat>  <eos>"
        return output

    # Repeat even numbers to fill up to 10 positions
    repeated_evens = []
    for i in range(10):
        repeated_evens.append(even_numbers[i % len(even_numbers)])

    repeated_str = ','.join(map(str, repeated_evens))

    # Format the output
    output = f"<sos> {numbers_str} <even_repeat> {repeated_str} <eos>"
    return output

def generate_odd_repeat_sequence():
    """
    Generate a sequence like:
    <sos> 12,37,84,96 <odd_repeat> 37,37,37,37,37,37,37,37,37,37 <eos>
    - First part is random digits 0-99 with commas (length 3-5)
    - Second part repeats only odd numbers up to 10 total numbers
    """
    # Generate 3-5 random numbers from 0-99
    numbers = [random.randint(0, 99) for _ in range(random.randint(3, 5))]
    numbers_str = ','.join(map(str, numbers))

    # Filter for odd numbers
    odd_numbers = [num for num in numbers if num % 2 == 1]

    # If no odd numbers, return sequence with empty repeat section
    if not odd_numbers:
        output = f"<sos> {numbers_str} <odd_repeat>  <eos>"
        return output

    # Repeat odd numbers to fill up to 10 positions
    repeated_odds = []
    for i in range(10):
        repeated_odds.append(odd_numbers[i % len(odd_numbers)])

    repeated_str = ','.join(map(str, repeated_odds))

    # Format the output
    output = f"<sos> {numbers_str} <odd_repeat> {repeated_str} <eos>"
    return output

# Example usage
print("Even repeat examples:")
for _ in range(25):
    print(generate_even_repeat_sequence())

print("\nOdd repeat examples:")
for _ in range(25):
    print(generate_odd_repeat_sequence())

Even repeat examples:
<sos> 35,55,35 <even_repeat>  <eos>
<sos> 78,76,42,93,55 <even_repeat> 78,76,42,78,76,42,78,76,42,78 <eos>
<sos> 6,35,42 <even_repeat> 6,42,6,42,6,42,6,42,6,42 <eos>
<sos> 23,77,76,75,99 <even_repeat> 76,76,76,76,76,76,76,76,76,76 <eos>
<sos> 1,3,4,1 <even_repeat> 4,4,4,4,4,4,4,4,4,4 <eos>
<sos> 27,94,47 <even_repeat> 94,94,94,94,94,94,94,94,94,94 <eos>
<sos> 62,74,20 <even_repeat> 62,74,20,62,74,20,62,74,20,62 <eos>
<sos> 15,25,40,3 <even_repeat> 40,40,40,40,40,40,40,40,40,40 <eos>
<sos> 73,93,93,24,15 <even_repeat> 24,24,24,24,24,24,24,24,24,24 <eos>
<sos> 50,40,3 <even_repeat> 50,40,50,40,50,40,50,40,50,40 <eos>
<sos> 61,26,7 <even_repeat> 26,26,26,26,26,26,26,26,26,26 <eos>
<sos> 33,66,2,26,84 <even_repeat> 66,2,26,84,66,2,26,84,66,2 <eos>
<sos> 75,89,92,69,85 <even_repeat> 92,92,92,92,92,92,92,92,92,92 <eos>
<sos> 69,85,4,81,51 <even_repeat> 4,4,4,4,4,4,4,4,4,4 <eos>
<sos> 69,78,63,49,91 <even_repeat> 78,78,78,78,78,78,78,78,78,78 <eos>
<sos> 63,73,5,85 <even

In [3]:
import random

#Revised

def generate_random_numbers_and_reverse():
    """
    Generate a sequence like:
    <sos> 12,37,84,96 <reverse> 96,84,37,12 <eos>
    - First part is random numbers 0-99 with commas (length 3-5)
    - Second part reverses the order of the numbers
    """
    # Pick a random length between 3 and 8
    length = random.randint(3, 8)

    # Generate that many random numbers from 0-99
    numbers = [random.randint(0, 99) for _ in range(length)]

    # Join the numbers into a comma-separated string
    numbers_str = ','.join(map(str, numbers))

    # Reverse the list and join
    reversed_str = ','.join(map(str, numbers[::-1]))

    # Format the output
    output = f"<sos> {numbers_str} <reverse> {reversed_str} <eos>"
    return output

# Example usage
for _ in range(25):
    print(generate_random_numbers_and_reverse())

<sos> 87,96,29,89,58 <reverse> 58,89,29,96,87 <eos>
<sos> 73,65,20,55,45,39,20 <reverse> 20,39,45,55,20,65,73 <eos>
<sos> 76,52,87,11 <reverse> 11,87,52,76 <eos>
<sos> 96,23,36,10,93 <reverse> 93,10,36,23,96 <eos>
<sos> 60,67,94,5,16,35,55,3 <reverse> 3,55,35,16,5,94,67,60 <eos>
<sos> 68,65,65,13,35,46,78,94 <reverse> 94,78,46,35,13,65,65,68 <eos>
<sos> 49,44,14 <reverse> 14,44,49 <eos>
<sos> 80,77,4 <reverse> 4,77,80 <eos>
<sos> 57,4,3,17,10,0,33 <reverse> 33,0,10,17,3,4,57 <eos>
<sos> 14,59,45 <reverse> 45,59,14 <eos>
<sos> 22,19,81 <reverse> 81,19,22 <eos>
<sos> 47,49,10,76,21,94 <reverse> 94,21,76,10,49,47 <eos>
<sos> 58,51,88,2,32,6,25,28 <reverse> 28,25,6,32,2,88,51,58 <eos>
<sos> 76,44,27 <reverse> 27,44,76 <eos>
<sos> 35,34,8,25,60 <reverse> 60,25,8,34,35 <eos>
<sos> 40,44,33,54,38,83 <reverse> 83,38,54,33,44,40 <eos>
<sos> 63,13,71,59,51,50 <reverse> 50,51,59,71,13,63 <eos>
<sos> 35,36,9,35,85,9,74 <reverse> 74,9,85,35,9,36,35 <eos>
<sos> 21,48,12,48 <reverse> 48,12,48,21 <eos

In [4]:

import random

def generate_ascending_sort():
    """
    Generate 3–5 random numbers (0-999) and output:
    <sos> numbers <ascending> sorted_numbers <eos>
    """
    # Pick a random length between 3 and 5
    length = random.randint(3, 5)

    # Generate that many random numbers from 0-999
    numbers = [random.randint(0, 999) for _ in range(length)]
    numbers_str = ','.join(map(str, numbers))

    # Sort ascending
    sorted_numbers = sorted(numbers)
    sorted_str = ','.join(map(str, sorted_numbers))

    # Format the output
    output = f"<sos> {numbers_str} <ascending> {sorted_str} <eos>"
    return output

def generate_descending_sort():
    """
    Generate 3–5 random numbers (0-999) and output:
    <sos> numbers <descending> sorted_numbers <eos>
    """
    # Pick a random length between 3 and 5
    length = random.randint(3, 5)

    # Generate that many random numbers from 0-999
    numbers = [random.randint(0, 999) for _ in range(length)]
    numbers_str = ','.join(map(str, numbers))

    # Sort descending
    sorted_numbers = sorted(numbers, reverse=True)
    sorted_str = ','.join(map(str, sorted_numbers))

    # Format the output
    output = f"<sos> {numbers_str} <descending> {sorted_str} <eos>"
    return output

# Example usage
for _ in range(5):
    print(generate_ascending_sort())

for _ in range(5):
    print(generate_descending_sort())

<sos> 134,315,544 <ascending> 134,315,544 <eos>
<sos> 530,280,513,57,846 <ascending> 57,280,513,530,846 <eos>
<sos> 4,86,828,836 <ascending> 4,86,828,836 <eos>
<sos> 117,824,416 <ascending> 117,416,824 <eos>
<sos> 330,236,666 <ascending> 236,330,666 <eos>
<sos> 560,166,450,169 <descending> 560,450,169,166 <eos>
<sos> 404,63,447,397,774 <descending> 774,447,404,397,63 <eos>
<sos> 678,756,404,624 <descending> 756,678,624,404 <eos>
<sos> 517,548,392 <descending> 548,517,392 <eos>
<sos> 642,468,211,753 <descending> 753,642,468,211 <eos>


In [5]:
import random
import pickle
import numpy as np
from typing import List, Tuple, Set
import os

class CustomTokenizer:
    def __init__(self):
        # Define vocabulary
        self.tokens = [
            "<sos>", "<eos>", "<sum>", "<reverse>", "<ascending>", "<descending>",
            "<even>", "<odd>", "<even_repeat>", "<odd_repeat>", ","
        ] + [str(i) for i in range(1000)]  # 0-999

        # Token ↔ ID mappings
        self.token2id = {tok: idx for idx, tok in enumerate(self.tokens)}
        self.id2token = {idx: tok for tok, idx in self.token2id.items()}
        self.vocab_size = len(self.tokens)

    def encode(self, text: str) -> List[int]:
        """
        Convert string to list of token IDs.
        - Matches tags, numbers (0-999), and commas
        - Ignores spaces
        """
        tokens = []
        i = 0
        while i < len(text):
            if text[i] == " ":
                i += 1
                continue

            if text[i] == "<":  # special token
                j = text.find(">", i)
                if j != -1:
                    tok = text[i:j+1]
                    if tok in self.token2id:
                        tokens.append(self.token2id[tok])
                        i = j + 1
                        continue

            if text[i] == ",":  # comma
                tokens.append(self.token2id[","])
                i += 1
                continue

            if text[i].isdigit():  # start of number
                # Extract full number (could be 1-3 digits)
                j = i
                while j < len(text) and text[j].isdigit():
                    j += 1
                number_str = text[i:j]
                number = int(number_str)

                if number <= 999:
                    tokens.append(self.token2id[str(number)])
                    i = j
                else:
                    raise ValueError(f"Number out of range: {number}")
            else:
                raise ValueError(f"Unexpected character: {text[i]}")

        return tokens

    def decode(self, token_ids: List[int]) -> str:
        """Convert list of IDs back to string with proper comma formatting."""
        tokens = [self.id2token[i] for i in token_ids]
        result = ""

        for i, token in enumerate(tokens):
            if token.startswith("<") and token.endswith(">"):
                # Add special token with spaces
                if result and not result.endswith(" "):
                    result += " "
                result += token
                if i < len(tokens) - 1:  # not last token
                    result += " "
            elif token == ",":
                result += ","
            else:
                # Number token
                result += token

        return result

# Note: These functions should be imported from your existing modules
# generate_random_numbers_and_sum()
# generate_even_repeat_sequence()
# generate_odd_repeat_sequence()
# generate_random_numbers_and_reverse()
# generate_ascending_sort()
# generate_descending_sort()

def generate_dataset(num_examples: int = 50000) -> List[str]:
    """Generate a balanced dataset with all command types, ensuring no duplicates"""
    examples = set()  # Use set to automatically handle duplicates
    generators = [
        generate_random_numbers_and_sum, generate_random_numbers_and_reverse,
        generate_ascending_sort, generate_descending_sort,
        generate_even_repeat_sequence, generate_odd_repeat_sequence
    ]

    examples_per_type = num_examples // len(generators)

    # Generate examples for each type
    for generator in generators:
        attempts = 0
        generated_for_this_type = 0

        while generated_for_this_type < examples_per_type and attempts < examples_per_type * 3:
            example = generator()
            if example not in examples:
                examples.add(example)
                generated_for_this_type += 1
            attempts += 1

    # Fill remaining examples
    while len(examples) < num_examples:
        generator = random.choice(generators)
        example = generator()
        examples.add(example)

        # Prevent infinite loop if we can't generate enough unique examples
        if len(examples) >= num_examples * 0.95:  # Accept 95% of target if struggling
            break

    # Convert back to list and shuffle
    examples_list = list(examples)
    random.shuffle(examples_list)

    print(f"Generated {len(examples_list)} unique examples (requested {num_examples})")
    return examples_list

def create_training_data(examples: List[str], tokenizer: CustomTokenizer) -> Tuple[np.ndarray, dict]:
    """Convert examples to tokenized training data"""
    tokenized_examples = []
    max_length = 0

    # Tokenize all examples and find max length
    for example in examples:
        tokens = tokenizer.encode(example)
        tokenized_examples.append(tokens)
        max_length = max(max_length, len(tokens))

    print(f"Maximum sequence length: {max_length}")
    print(f"Vocabulary size: {tokenizer.vocab_size}")

    # Convert to numpy array (pad sequences to max length)
    data = np.full((len(examples), max_length), tokenizer.token2id["<eos>"], dtype=np.int64)

    for i, tokens in enumerate(tokenized_examples):
        data[i, :len(tokens)] = tokens

    # Create metadata
    meta = {
        'vocab_size': tokenizer.vocab_size,
        'max_length': max_length,
        'num_examples': len(examples),
        'token2id': tokenizer.token2id,
        'id2token': tokenizer.id2token
    }

    return data, meta

def test_tokenizer():
    """Test the tokenizer with sample examples"""
    print("Testing Tokenizer...")
    print("=" * 50)

    tokenizer = CustomTokenizer()

    # Test with each type of example
    test_examples = [
        generate_random_numbers_and_sum(),
        generate_random_numbers_and_reverse(),
        generate_ascending_sort(),
        generate_descending_sort(),
        generate_even_repeat_sequence(),
        generate_odd_repeat_sequence()
    ]

    for i, example in enumerate(test_examples):
        print(f"\nTest {i+1}:")
        print(f"Original: {example}")

        # Encode
        encoded = tokenizer.encode(example)
        print(f"Encoded:  {encoded}")

        # Decode
        decoded = tokenizer.decode(encoded)
        print(f"Decoded:  {decoded}")

        # Check if round-trip works
        success = example.replace(" ", "") == decoded.replace(" ", "")
        print(f"Round-trip successful: {success}")

def save_dataset(train_data: np.ndarray, val_data: np.ndarray, meta: dict, data_dir: str = "data/number_commands"):
    """Save dataset to files"""
    os.makedirs(data_dir, exist_ok=True)

    # Save binary data
    train_data.astype(np.uint16).tofile(os.path.join(data_dir, 'train.bin'))
    val_data.astype(np.uint16).tofile(os.path.join(data_dir, 'val.bin'))

    # Save metadata
    with open(os.path.join(data_dir, 'meta.pkl'), 'wb') as f:
        pickle.dump(meta, f)

    print(f"Dataset saved to {data_dir}")
    print(f"Train examples: {len(train_data)}")
    print(f"Val examples: {len(val_data)}")



In [6]:
# Test the tokenizer
test_tokenizer()



Testing Tokenizer...

Test 1:
Original: <sos> 88,30,71,58,30,2,91 <sum> 370 <eos>
Encoded:  [0, 99, 10, 41, 10, 82, 10, 69, 10, 41, 10, 13, 10, 102, 2, 381, 1]
Decoded:  <sos> 88,30,71,58,30,2,91 <sum> 370 <eos>
Round-trip successful: True

Test 2:
Original: <sos> 31,46,61,15 <reverse> 15,61,46,31 <eos>
Encoded:  [0, 42, 10, 57, 10, 72, 10, 26, 3, 26, 10, 72, 10, 57, 10, 42, 1]
Decoded:  <sos> 31,46,61,15 <reverse> 15,61,46,31 <eos>
Round-trip successful: True

Test 3:
Original: <sos> 249,982,420 <ascending> 249,420,982 <eos>
Encoded:  [0, 260, 10, 993, 10, 431, 4, 260, 10, 431, 10, 993, 1]
Decoded:  <sos> 249,982,420 <ascending> 249,420,982 <eos>
Round-trip successful: True

Test 4:
Original: <sos> 806,553,285,497,136 <descending> 806,553,497,285,136 <eos>
Encoded:  [0, 817, 10, 564, 10, 296, 10, 508, 10, 147, 5, 817, 10, 564, 10, 508, 10, 296, 10, 147, 1]
Decoded:  <sos> 806,553,285,497,136 <descending> 806,553,497,285,136 <eos>
Round-trip successful: True

Test 5:
Original: <sos> 29

In [7]:
# Generate and process dataset
print("\n" + "="*50)
print("Generating Dataset...")

examples = generate_dataset(600000)
tokenizer = CustomTokenizer()

# Split into train/val
split_idx = int(0.9 * len(examples))
train_examples = examples[:split_idx]
val_examples = examples[split_idx:]

# Create training data
train_data, meta = create_training_data(train_examples, tokenizer)
val_data, _ = create_training_data(val_examples, tokenizer)

# Save dataset
save_dataset(train_data, val_data, meta)


Generating Dataset...
Generated 600000 unique examples (requested 600000)
Maximum sequence length: 33
Vocabulary size: 1011
Maximum sequence length: 33
Vocabulary size: 1011
Dataset saved to data/number_commands
Train examples: 540000
Val examples: 60000


In [8]:
# Show some statistics
print("\nDataset Statistics:")
print(f"Vocabulary size: {meta['vocab_size']}")
print(f"Max sequence length: {meta['max_length']}")
print(f"Train data shape: {train_data.shape}")
print(f"Val data shape: {val_data.shape}")

# Show sample training examples
print("\nSample training examples:")
for i in range(5):
    tokens = train_data[i]
    # Remove padding
    tokens = tokens[tokens != tokenizer.token2id["<eos>"]]
    tokens = np.append(tokens, tokenizer.token2id["<eos>"])  # Add back one EOS
    decoded = tokenizer.decode(tokens.tolist())
    print(f"{i+1}: {decoded}")



Dataset Statistics:
Vocabulary size: 1011
Max sequence length: 33
Train data shape: (540000, 33)
Val data shape: (60000, 33)

Sample training examples:
1: <sos> 89,76,32 <odd_repeat> 89,89,89,89,89,89,89,89,89,89 <eos>
2: <sos> 103,271,900,565,876 <ascending> 103,271,565,876,900 <eos>
3: <sos> 819,455,196,254 <ascending> 196,254,455,819 <eos>
4: <sos> 771,320,964,435,932 <descending> 964,932,771,435,320 <eos>
5: <sos> 21,138,911,845,783 <descending> 911,845,783,138,21 <eos>


In [9]:
import math
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

class LayerNorm(nn.Module):
    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class NumberCommandConfig:
    block_size: int = 32  
    vocab_size: int = 17  # Will be set from dataset
    n_layer: int = 4      
    n_head: int = 4      
    n_embd: int = 128    
    dropout: float = 0.1  # Some dropout for regularization
    bias: bool = True     # Keep bias terms

class NumberCommandTransformer(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Weight tying - share embeddings with output layer
        self.transformer.wte.weight = self.lm_head.weight

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """Estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS"""
        # First estimate the number of flops we do per iteration.
        # See PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # Express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt)  # per second
        flops_promised = 312e12  # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, eos_token_id=None):
        batch_size = idx.size(0)
        finished = torch.zeros(batch_size, dtype=torch.bool, device=idx.device)
    
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
    
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
    
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
    
            # Handle EOS logic BEFORE appending to sequence
            if eos_token_id is not None:
                # For sequences that were already finished, force them to generate EOS tokens
                idx_next[finished.unsqueeze(1)] = eos_token_id
    
            # Now append the (possibly modified) tokens
            idx = torch.cat((idx, idx_next), dim=1)
    
            # Update finished mask and check for early termination
            if eos_token_id is not None:
                # Update finished mask based on what we just added
                finished = finished | (idx_next.squeeze(1) == eos_token_id)
                # If all sequences finished, break early
                if finished.all():
                    break
    
        return idx

In [10]:
import torch
print(torch.__version__)


2.7.1+cu118


In [None]:
# ALL-IN-ONE TRAINING SCRIPT FOR NUMBER COMMAND TRANSFORMER
#

import os
import time
import math
import pickle
import platform
from contextlib import nullcontext
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name())

# Check if we have the required data
data_dir = 'data/number_commands'
if not os.path.exists(os.path.join(data_dir, 'train.bin')):
    raise FileNotFoundError(f"Training data not found at {data_dir}. Please run dataset generation first.")

print("✓ Training data found")

# Load metadata
meta_path = os.path.join(data_dir, 'meta.pkl')
with open(meta_path, 'rb') as f:
    meta = pickle.load(f)

vocab_size = meta['vocab_size']
max_length = meta['max_length']

print(f"✓ Vocab size: {vocab_size}")
print(f"✓ Max sequence length: {max_length}")

# Configuration
config = {
    'vocab_size': vocab_size,
    'block_size': min(max_length, 32),  # Use actual max length or 32, whichever is smaller
    'n_layer': 4,
    'n_head': 4,
    'n_embd': 128, #worked iwth 128, let's try smaller for our massive 1011 token vocab
    'dropout': 0.1,
    'bias': True,
    'batch_size': 256,
    'learning_rate': 3e-4,
    'max_iters': 50000,
    'eval_interval': 500,
    'log_interval': 50,
    'warmup_iters': 200,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'compile': torch.cuda.is_available() and platform.system() != 'Windows',
    'out_dir': 'out_number_commands'
}

print(f"✓ Using device: {config['device']}")
print(f"✓ Model compilation: {config['compile']}")

# Set up device and dtype
device = config['device']
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)

print(f"✓ Using dtype: {dtype}")

# Create output directory
os.makedirs(config['out_dir'], exist_ok=True)

# Set seeds
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# Data loading function
def get_batch(split, batch_size, block_size):
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')

    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])

    if device == 'cuda':
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# Test data loading
print("✓ Testing data loading...")
test_x, test_y = get_batch('train', 2, config['block_size'])
print(f"✓ Batch shape: {test_x.shape}, {test_y.shape}")

# Create and initialize model (model classes should be defined already)
try:
    model_config = NumberCommandConfig(
        vocab_size=config['vocab_size'],
        block_size=config['block_size'],
        n_layer=config['n_layer'],
        n_head=config['n_head'],
        n_embd=config['n_embd'],
        dropout=config['dropout'],
        bias=config['bias']
    )

    model = NumberCommandTransformer(model_config)
    model.to(device)

    print(f"✓ Model created with {model.get_num_params()/1e6:.2f}M parameters")

except NameError:
    print("❌ Model classes not found! Please run the model definition cell first.")
    raise

# Initialize training components
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
optimizer = model.configure_optimizers(
    weight_decay=1e-2,
    learning_rate=config['learning_rate'],
    betas=(0.9, 0.95),
    device_type=device
)

# Compile model with error handling
if config['compile']:
    try:
        print("Compiling model...")
        unoptimized_model = model
        model = torch.compile(model)
        print("✓ Model compilation successful!")
    except Exception as e:
        print(f"⚠ Model compilation failed: {e}")
        print("Continuing without compilation...")

# Loss estimation function
@torch.no_grad()
def estimate_loss(eval_iters=100):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, config['batch_size'], config['block_size'])
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# Learning rate scheduler
def get_lr(it):
    if it < config['warmup_iters']:
        return config['learning_rate'] * (it + 1) / (config['warmup_iters'] + 1)
    if it > config['max_iters']:
        return config['learning_rate'] * 0.1
    decay_ratio = (it - config['warmup_iters']) / (config['max_iters'] - config['warmup_iters'])
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return config['learning_rate'] * 0.1 + coeff * (config['learning_rate'] - config['learning_rate'] * 0.1)

# Training loop
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

iter_num = 0
best_val_loss = float('inf')
t0 = time.time()

# Initial evaluation
losses = estimate_loss()
print(f"Initial: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

while iter_num < config['max_iters']:
    # Set learning rate
    lr = get_lr(iter_num)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Evaluate and save checkpoints
    if iter_num % config['eval_interval'] == 0:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {lr:.2e}")

        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'model_args': config,
                'iter_num': iter_num,
                'best_val_loss': best_val_loss,
            }
            torch.save(checkpoint, os.path.join(config['out_dir'], 'ckpt.pt'))
            print(f"✓ Saved checkpoint (val_loss: {best_val_loss:.4f})")

    # Training step
    X, Y = get_batch('train', config['batch_size'], config['block_size'])

    with ctx:
        logits, loss = model(X, Y)

    # Backward pass
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

    # Logging
    if iter_num % config['log_interval'] == 0:
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.1f}ms")

    iter_num += 1

# Final evaluation
losses = estimate_loss()
print(f"\nFinal: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
print(f"Training completed! Best val loss: {best_val_loss:.4f}")
print(f"Model saved to: {config['out_dir']}/ckpt.pt")

PyTorch version: 2.7.1+cu118
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3060
✓ Training data found
✓ Vocab size: 1011
✓ Max sequence length: 33
✓ Using device: cuda
✓ Model compilation: False
✓ Using dtype: bfloat16
✓ Testing data loading...
✓ Batch shape: torch.Size([2, 32]), torch.Size([2, 32])
number of parameters: 0.92M
✓ Model created with 0.92M parameters
num decayed parameter tensors: 18, with 919,936 parameters
num non-decayed parameter tensors: 34, with 6,912 parameters


  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


using fused AdamW: True

STARTING TRAINING
Initial: train loss 6.7178, val loss 6.7198
step 0: train loss 6.7162, val loss 6.7191, lr 1.49e-06
✓ Saved checkpoint (val_loss: 6.7191)
iter 0: loss 6.7268, time 11528.5ms
iter 50: loss 5.3462, time 2125.5ms
iter 100: loss 4.4233, time 2123.1ms
iter 150: loss 3.1653, time 2109.8ms
iter 200: loss 2.3094, time 1889.8ms
iter 250: loss 2.0802, time 1953.2ms
iter 300: loss 1.9363, time 1855.9ms
iter 350: loss 1.8345, time 1912.5ms
iter 400: loss 1.9217, time 1918.9ms
iter 450: loss 1.8650, time 1910.8ms
step 500: train loss 1.8212, val loss 1.8244, lr 3.00e-04
✓ Saved checkpoint (val_loss: 1.8244)
iter 500: loss 1.7950, time 7557.7ms
iter 550: loss 1.8098, time 2017.1ms
iter 600: loss 1.7800, time 1772.0ms
iter 650: loss 1.7448, time 1819.2ms
iter 700: loss 1.7306, time 1953.4ms
iter 750: loss 1.6593, time 1953.8ms
iter 800: loss 1.6521, time 1976.7ms
iter 850: loss 1.6071, time 1863.8ms
iter 900: loss 1.6501, time 1999.6ms
iter 950: loss 1.5975,

# Now for Some Inference!

In [None]:
def load_model_and_tokenizer():
    """Load the trained model and recreate tokenizer from metadata"""
    
    # Load metadata
    meta_path = 'data/number_commands/meta.pkl'
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    
    # Recreate tokenizer
    class CustomTokenizer:
        def __init__(self, token2id, id2token):
            self.token2id = token2id
            self.id2token = id2token
            self.vocab_size = len(token2id)

        def encode(self, text: str) -> List[int]:
            tokens = []
            i = 0
            while i < len(text):
                if text[i] == " ":
                    i += 1
                    continue

                if text[i] == "<":
                    j = text.find(">", i)
                    if j != -1:
                        tok = text[i:j+1]
                        if tok in self.token2id:
                            tokens.append(self.token2id[tok])
                            i = j + 1
                            continue

                if text[i] == ",":
                    tokens.append(self.token2id[","])
                    i += 1
                    continue

                if text[i].isdigit():
                    j = i
                    while j < len(text) and text[j].isdigit():
                        j += 1
                    number_str = text[i:j]
                    number = int(number_str)

                    if str(number) in self.token2id:
                        tokens.append(self.token2id[str(number)])
                        i = j
                    else:
                        raise ValueError(f"Number out of range: {number}")
                else:
                    raise ValueError(f"Unexpected character: {text[i]}")

            return tokens

        def decode(self, token_ids: List[int]) -> str:
            tokens = [self.id2token[i] for i in token_ids]
            result = ""

            for i, token in enumerate(tokens):
                if token.startswith("<") and token.endswith(">"):
                    if result and not result.endswith(" "):
                        result += " "
                    result += token
                    if i < len(tokens) - 1:
                        result += " "
                elif token == ",":
                    result += ","
                else:
                    result += token

            return result

    tokenizer = CustomTokenizer(meta['token2id'], meta['id2token'])
    
    # Load model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    checkpoint_path = 'out_number_commands/ckpt.pt'
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Create model config
    model_args = checkpoint['model_args']
    config_params = {
        'block_size': model_args['block_size'],
        'vocab_size': model_args['vocab_size'], 
        'n_layer': model_args['n_layer'],
        'n_head': model_args['n_head'],
        'n_embd': model_args['n_embd'],
        'dropout': model_args['dropout'],
        'bias': model_args['bias']
    }
    
    config = NumberCommandConfig(**config_params)
    model = NumberCommandTransformer(config)
    
    # Load state dict (remove compilation prefix if exists)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    return model, tokenizer, device

def interactive_model_testing():
    """Interactive testing interface for the number command transformer"""
    
    try:
        # Load model and tokenizer
        model, tokenizer, device = load_model_and_tokenizer()
        eos_id = tokenizer.token2id.get("<eos>", None)
        
        print("=" * 60)
        print("INTERACTIVE NUMBER COMMAND TRANSFORMER TESTING")
        print("=" * 60)
        print("Enter prompts like:")
        print("  <sos> 55,50,14 <sum>")
        print("  <sos> 82,77,9,57 <reverse>") 
        print("  <sos> 425,616,162 <ascending>")
        print("  <sos> 44,30,61 <even_repeat>")
        print("\nType 'quit' to exit")
        print("=" * 60)
        
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Make sure the model is trained and saved properly.")
        return
    
    while True:
        try:
            # Get user input
            user_input = input("\nEnter prompt: ").strip()
            
            if user_input.lower() in ['quit', 'exit', 'q']:
                print("Exiting interactive testing...")
                break
                
            if not user_input:
                print("Please enter a valid prompt.")
                continue
            
            print(f"Input: {user_input}")
            
            # Encode the input
            try:
                input_tokens = tokenizer.encode(user_input)
                input_tensor = torch.tensor(input_tokens, dtype=torch.long, device=device).unsqueeze(0)
            except Exception as e:
                print(f"Error encoding input: {e}")
                print("Make sure your input follows the format: <sos> numbers <command>")
                continue
            
            # Generate completion
            with torch.no_grad():
                generated = model.generate(
                    input_tensor, 
                    max_new_tokens=25, 
                    temperature=0.1,  # Low temperature for more deterministic output
                    eos_token_id=eos_id
                )
            
            # Decode and display result
            generated_tokens = generated[0].tolist()
            model_output = tokenizer.decode(generated_tokens)
            
            print(f"Model output: {model_output}")
            
            # Show some additional info
            print(f"Generated {len(generated_tokens) - len(input_tokens)} new tokens")
            
            # Calculate expected result if possible
            try:
                expected = calculate_expected_completion(user_input)
                if expected != "Unknown command" and expected != "Invalid format":
                    print(f"Expected: {expected}")
                    is_correct = model_output.strip() == expected.strip()
                    print(f"Matches expected: {'✅' if is_correct else '❌'}")
            except:
                pass
                
        except KeyboardInterrupt:
            print("\nExiting...")
            break
        except Exception as e:
            print(f"Error during generation: {e}")

def calculate_expected_completion(prompt: str) -> str:
    """Calculate what the expected completion should be for a given prompt"""
    
    # Parse the prompt to extract numbers and command
    parts = prompt.strip().split()
    if len(parts) < 3 or parts[0] != '<sos>':
        return "Invalid format"
    
    numbers_str = parts[1]
    command = parts[2]
    
    # Convert comma-separated string to list of numbers
    try:
        numbers = [int(n) for n in numbers_str.split(',')]
    except ValueError:
        return "Invalid format"
    
    # Calculate result based on command
    if command == '<sum>':
        result = str(sum(numbers))
    elif command == '<reverse>':
        result = ','.join(map(str, numbers[::-1]))
    elif command == '<ascending>':
        result = ','.join(map(str, sorted(numbers)))
    elif command == '<descending>':
        result = ','.join(map(str, sorted(numbers, reverse=True)))
    elif command == '<even_repeat>':
        even_numbers = [num for num in numbers if num % 2 == 0]
        if not even_numbers:
            result = ""
        else:
            repeated_evens = []
            for i in range(10):
                repeated_evens.append(even_numbers[i % len(even_numbers)])
            result = ','.join(map(str, repeated_evens))
    elif command == '<odd_repeat>':
        odd_numbers = [num for num in numbers if num % 2 == 1]
        if not odd_numbers:
            result = ""
        else:
            repeated_odds = []
            for i in range(10):
                repeated_odds.append(odd_numbers[i % len(odd_numbers)])
            result = ','.join(map(str, repeated_odds))
    else:
        return "Unknown command"
    
    return f"{prompt} {result} <eos>"

# Run interactive testing
interactive_model_testing()

In [None]:
import torch
import pickle
import os
import numpy as np
from typing import List

# Load model and tokenizer
def load_trained_model(checkpoint_path: str = 'out_number_commands/ckpt.pt'):
    """Load the trained model and tokenizer"""

    try:
        # Load metadata first
        meta_path = 'data/number_commands/meta.pkl'
        with open(meta_path, 'rb') as f:
            meta = pickle.load(f)
    except FileNotFoundError:
        raise FileNotFoundError("Dataset metadata not found. Please run dataset generation first.")

    # Recreate tokenizer
    class CustomTokenizer:
        def __init__(self, token2id, id2token):
            self.token2id = token2id
            self.id2token = id2token
            self.vocab_size = len(token2id)

        def encode(self, text: str) -> List[int]:
            tokens = []
            i = 0
            while i < len(text):
                if text[i] == " ":
                    i += 1
                    continue

                if text[i] == "<":
                    j = text.find(">", i)
                    if j != -1:
                        tok = text[i:j+1]
                        if tok in self.token2id:
                            tokens.append(self.token2id[tok])
                            i = j + 1
                            continue

                if text[i] == ",":
                    tokens.append(self.token2id[","])
                    i += 1
                    continue

                if text[i].isdigit():
                    j = i
                    while j < len(text) and text[j].isdigit():
                        j += 1
                    number_str = text[i:j]
                    number = int(number_str)

                    if str(number) in self.token2id:
                        tokens.append(self.token2id[str(number)])
                        i = j
                    else:
                        raise ValueError(f"Number out of range: {number}")
                else:
                    raise ValueError(f"Unexpected character: {text[i]}")

            return tokens

        def decode(self, token_ids: List[int]) -> str:
            tokens = [self.id2token[i] for i in token_ids]
            result = ""

            for i, token in enumerate(tokens):
                if token.startswith("<") and token.endswith(">"):
                    if result and not result.endswith(" "):
                        result += " "
                    result += token
                    if i < len(tokens) - 1:
                        result += " "
                elif token == ",":
                    result += ","
                else:
                    result += token

            return result

    tokenizer = CustomTokenizer(meta['token2id'], meta['id2token'])

    # Load model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
    except FileNotFoundError:
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please train the model first.")

    model_args = checkpoint['model_args']

    # Handle both old and new checkpoint formats
    if isinstance(model_args, dict):
        # New format - direct config dict
        try:
            config = NumberCommandConfig(**model_args)
        except TypeError:
            # Fallback for missing config class
            print("Warning: Using fallback model configuration")
            config = type('Config', (), model_args)()
    else:
        # Old format - might be a config object
        config = model_args

    try:
        model = NumberCommandTransformer(config)
    except NameError:
        raise NameError("Model classes not found. Please run the model definition cell first.")

    # Remove potential compilation prefix
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    print(f"Model loaded from {checkpoint_path}")
    print(f"Model has {model.get_num_params()/1e6:.2f}M parameters")
    print(f"Using device: {device}")

    return model, tokenizer, device

def test_model_completion(model, tokenizer, device, prompt: str, max_new_tokens: int = 15, temperature: float = 0.1):
    # Encode prompt
    prompt_tokens = tokenizer.encode(prompt)
    prompt_tensor = torch.tensor(prompt_tokens, dtype=torch.long, device=device).unsqueeze(0)

    # Get the actual EOS token ID
    eos_id = tokenizer.token2id.get("<eos>", None)

    # Generate completion
    with torch.no_grad():
        generated = model.generate(prompt_tensor, max_new_tokens=max_new_tokens, 
                                 temperature=temperature, eos_token_id=eos_id)

    # Decode result
    generated_tokens = generated[0].tolist()
    full_response = tokenizer.decode(generated_tokens)

    return full_response, generated_tokens

def calculate_expected_result(input_str: str) -> str:
    """Calculate what the expected output should be"""

    # Parse the input
    parts = input_str.strip().split()
    if len(parts) < 3 or parts[0] != '<sos>':
        return "Invalid format"

    numbers_str = parts[1]
    command = parts[2]

    # Convert comma-separated string to list of numbers
    try:
        numbers = [int(n) for n in numbers_str.split(',')]
    except ValueError:
        return "Invalid numbers"

    if command == '<sum>':
        result = str(sum(numbers))
    elif command == '<reverse>':
        result = ','.join(map(str, numbers[::-1]))
    elif command == '<ascending>':
        result = ','.join(map(str, sorted(numbers)))
    elif command == '<descending>':
        result = ','.join(map(str, sorted(numbers, reverse=True)))
    elif command == '<even_repeat>':
        even_numbers = [num for num in numbers if num % 2 == 0]
        if not even_numbers:
            result = ""
        else:
            repeated_evens = []
            for i in range(10):
                repeated_evens.append(even_numbers[i % len(even_numbers)])
            result = ','.join(map(str, repeated_evens))
    elif command == '<odd_repeat>':
        odd_numbers = [num for num in numbers if num % 2 == 1]
        if not odd_numbers:
            result = ""
        else:
            repeated_odds = []
            for i in range(10):
                repeated_odds.append(odd_numbers[i % len(odd_numbers)])
            result = ','.join(map(str, repeated_odds))
    else:
        return "Unknown command"

    return f"<sos> {numbers_str} {command} {result} <eos>"

def run_comprehensive_test(model, tokenizer, device):
    """Run comprehensive tests on the model"""

    test_cases = [
        # Sum tests
        "<sos> 55,50,14 <sum>",
        "<sos> 26,55,99 <sum>",
        "<sos> 92,99,1 <sum>",

        # Reverse tests
        "<sos> 82,77,9,57 <reverse>",
        "<sos> 190,676,440 <reverse>",
        "<sos> 505,100 <reverse>",

        # Ascending sort tests
        "<sos> 425,616,162 <ascending>",
        "<sos> 190,676,440 <ascending>",
        "<sos> 5,231,44 <ascending>",

        # Descending sort tests
        "<sos> 425,616,162,221,244 <descending>",
        "<sos> 917,35,88 <descending>",
        "<sos> 48,2,99 <descending>",

        # Even repeat tests
        "<sos> 44,30,61 <even_repeat>",
        "<sos> 12,37,84 <even_repeat>",

        # Odd repeat tests
        "<sos> 50,82,41 <odd_repeat>",
        "<sos> 13,28,55,78 <odd_repeat>",
    ]

    print("=" * 80)
    print("COMPREHENSIVE MODEL TESTING")
    print("=" * 80)

    correct = 0
    total = 0

    for i, test_case in enumerate(test_cases):
        print(f"\nTest {i+1}: {test_case}")

        # Get model prediction
        prediction, tokens = test_model_completion(model, tokenizer, device, test_case)
        print(f"Model output: {prediction}")

        # Calculate expected result
        expected = calculate_expected_result(test_case)
        print(f"Expected:     {expected}")

        # Check if correct
        is_correct = prediction.strip() == expected.strip()
        print(f"Correct: {'YES' if is_correct else 'NO'}")
        if is_correct:
            correct += 1
        total += 1

        print("-" * 40)

    if total > 0:
        accuracy = correct / total * 100
        print(f"\nOverall Accuracy: {correct}/{total} = {accuracy:.1f}%")

    return correct, total

In [None]:
try:
    # Load the trained model
    model, tokenizer, device = load_trained_model()

    # Run comprehensive tests
    correct, total = run_comprehensive_test(model, tokenizer, device)

    # Start interactive testing
    #interactive_test(model, tokenizer, device)

except FileNotFoundError as e:
    print(f"Error: Could not find model files. Make sure you've trained the model first.")
    print("Expected files:")
    print("  - out_number_commands/ckpt.pt")
    print("  - data/number_commands/meta.pkl")
except Exception as e:
    print(f"Error loading model: {e}")


# trying again to load the model

In [None]:
#

import torch
import pickle
import os
import numpy as np
from typing import List

# Load model and tokenizer
def load_trained_model(checkpoint_path: str = 'out_number_commands/ckpt.pt'):
    """Load the trained model and tokenizer"""

    try:
        # Load metadata first
        meta_path = 'data/number_commands/meta.pkl'
        with open(meta_path, 'rb') as f:
            meta = pickle.load(f)
    except FileNotFoundError:
        raise FileNotFoundError("Dataset metadata not found. Please run dataset generation first.")

    # Recreate tokenizer
    class CustomTokenizer:
        def __init__(self, token2id, id2token):
            self.token2id = token2id
            self.id2token = id2token
            self.vocab_size = len(token2id)

        def encode(self, text: str) -> List[int]:
            tokens = []
            i = 0
            while i < len(text):
                if text[i] == " ":
                    i += 1
                    continue

                if text[i] == "<":
                    j = text.find(">", i)
                    if j != -1:
                        tok = text[i:j+1]
                        if tok in self.token2id:
                            tokens.append(self.token2id[tok])
                            i = j + 1
                            continue

                if text[i] == ",":
                    tokens.append(self.token2id[","])
                    i += 1
                    continue

                if text[i].isdigit():
                    # Extract full number (could be 1-3 digits)
                    j = i
                    while j < len(text) and text[j].isdigit():
                        j += 1
                    number_str = text[i:j]
                    number = int(number_str)

                    if str(number) in self.token2id:
                        tokens.append(self.token2id[str(number)])
                        i = j
                    else:
                        raise ValueError(f"Number out of range: {number}")
                else:
                    raise ValueError(f"Unexpected character: {text[i]}")

            return tokens

        def decode(self, token_ids: List[int]) -> str:
            tokens = [self.id2token[i] for i in token_ids]
            result = ""

            for i, token in enumerate(tokens):
                if token.startswith("<") and token.endswith(">"):
                    # Add special token with spaces
                    if result and not result.endswith(" "):
                        result += " "
                    result += token
                    if i < len(tokens) - 1:  # not last token
                        result += " "
                elif token == ",":
                    result += ","
                else:
                    # Number token
                    result += token

            return result

    tokenizer = CustomTokenizer(meta['token2id'], meta['id2token'])

    # Load model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
    except FileNotFoundError:
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please train the model first.")

    model_args = checkpoint['model_args']

    # Handle both old and new checkpoint formats
    if isinstance(model_args, dict):
        # New format - direct config dict
        try:
            config = NumberCommandConfig(**model_args)
        except TypeError:
            # Fallback for missing config class
            print("Warning: Using fallback model configuration")
            config = type('Config', (), model_args)()
    else:
        # Old format - might be a config object
        config = model_args

    try:
        model = NumberCommandTransformer(config)
    except NameError:
        raise NameError("Model classes not found. Please run the model definition cell first.")

    # Remove potential compilation prefix
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    print(f"✓ Model loaded from {checkpoint_path}")
    print(f"✓ Model has {model.get_num_params()/1e6:.2f}M parameters")
    print(f"✓ Using device: {device}")

    return model, tokenizer, device

def test_model_completion(model, tokenizer, device, prompt: str, max_new_tokens: int = 15, temperature: float = 0.1):
    """Test model completion given a prompt"""

    # Encode prompt
    prompt_tokens = tokenizer.encode(prompt)
    prompt_tensor = torch.tensor(prompt_tokens, dtype=torch.long, device=device).unsqueeze(0)

    # Generate completion
    with torch.no_grad():
        generated = model.generate(prompt_tensor, max_new_tokens=max_new_tokens, temperature=temperature)

    # Decode result
    generated_tokens = generated[0].tolist()
    full_response = tokenizer.decode(generated_tokens)

    return full_response, generated_tokens

def calculate_expected_result(input_str: str) -> str:
    """Calculate what the expected output should be"""

    # Parse the input
    parts = input_str.strip().split()
    if len(parts) < 3 or parts[0] != '<sos>':
        return "Invalid format"

    numbers_str = parts[1]
    command = parts[2]

    # Convert comma-separated string to list of numbers
    try:
        numbers = [int(n) for n in numbers_str.split(',')]
    except ValueError:
        return "Invalid numbers"

    if command == '<sum>':
        result = str(sum(numbers))
    elif command == '<reverse>':
        result = ','.join(map(str, numbers[::-1]))
    elif command == '<ascending>':
        result = ','.join(map(str, sorted(numbers)))
    elif command == '<descending>':
        result = ','.join(map(str, sorted(numbers, reverse=True)))
    elif command == '<even_repeat>':
        even_numbers = [num for num in numbers if num % 2 == 0]
        if not even_numbers:
            result = ""
        else:
            repeated_evens = []
            for i in range(10):
                repeated_evens.append(even_numbers[i % len(even_numbers)])
            result = ','.join(map(str, repeated_evens))
    elif command == '<odd_repeat>':
        odd_numbers = [num for num in numbers if num % 2 == 1]
        if not odd_numbers:
            result = ""
        else:
            repeated_odds = []
            for i in range(10):
                repeated_odds.append(odd_numbers[i % len(odd_numbers)])
            result = ','.join(map(str, repeated_odds))
    else:
        return "Unknown command"

    return f"<sos> {numbers_str} {command} {result} <eos>"

def run_comprehensive_test(model, tokenizer, device):
    """Run comprehensive tests on the model"""

    test_cases = [
        # Sum tests
        "<sos> 55,50,14 <sum>",
        "<sos> 123,456,789 <sum>",
        "<sos> 92,99,1 <sum>",

        # Reverse tests
        "<sos> 82,77,9,57 <reverse>",
        "<sos> 190,676,440 <reverse>",
        "<sos> 505,100 <reverse>",

        # Ascending sort tests
        "<sos> 425,616,162 <ascending>",
        "<sos> 190,676,440 <ascending>",
        "<sos> 5,231,44 <ascending>",

        # Descending sort tests
        "<sos> 425,616,162,221,244 <descending>",
        "<sos> 917,35,88 <descending>",
        "<sos> 48,2,99 <descending>",

        # Even repeat tests
        "<sos> 44,30,61 <even_repeat>",
        "<sos> 12,37,84 <even_repeat>",

        # Odd repeat tests
        "<sos> 50,82,41 <odd_repeat>",
        "<sos> 13,28,55 <odd_repeat>",
    ]

    print("=" * 80)
    print("COMPREHENSIVE MODEL TESTING")
    print("=" * 80)

    correct = 0
    total = 0

    for i, test_case in enumerate(test_cases):
        print(f"\nTest {i+1}: {test_case}")

        # Get model prediction
        prediction, tokens = test_model_completion(model, tokenizer, device, test_case)
        print(f"Model output: {prediction}")

        # Calculate expected result
        expected = calculate_expected_result(test_case)
        print(f"Expected:     {expected}")

        # Check if correct
        is_correct = prediction.strip() == expected.strip()
        print(f"Correct: {'✓' if is_correct else '✗'}")
        if is_correct:
            correct += 1
        total += 1

        print("-" * 40)

    accuracy = correct / total * 100
    print(f"\nOverall Accuracy: {correct}/{total} = {accuracy:.1f}%")

    return correct, total

def test_with_generated_examples(model, tokenizer, device, num_tests=20):
    """Test model using your existing generator functions"""

    # Import/call your existing functions
    from your_module import (
        generate_random_numbers_and_sum,
        generate_even_repeat_sequence,
        generate_odd_repeat_sequence,
        generate_random_numbers_and_reverse,
        generate_ascending_sort,
        generate_descending_sort
    )

    generators = [
        generate_random_numbers_and_sum,
        generate_even_repeat_sequence,
        generate_odd_repeat_sequence,
        generate_random_numbers_and_reverse,
        generate_ascending_sort,
        generate_descending_sort
    ]

    print("=" * 80)
    print("TESTING WITH GENERATED EXAMPLES")
    print("=" * 80)

    correct = 0
    total = 0

    for i in range(num_tests):
        # Pick random generator
        generator = generators[i % len(generators)]

        # Generate full example
        full_example = generator()

        # Split into prompt and expected
        parts = full_example.split(' <eos>')
        prompt_part = parts[0]  # Everything before <eos>

        # Find where the command ends to create prompt
        command_tokens = ['<sum>', '<reverse>', '<ascending>', '<descending>', '<even_repeat>', '<odd_repeat>']
        prompt = None
        expected = full_example

        for cmd in command_tokens:
            if cmd in prompt_part:
                cmd_index = prompt_part.find(cmd)
                prompt = prompt_part[:cmd_index + len(cmd)]
                break

        if prompt is None:
            continue

        print(f"\nTest {i+1}:")
        print(f"Prompt: {prompt}")

        # Get model prediction
        prediction, tokens = test_model_completion(model, tokenizer, device, prompt, max_new_tokens=20)
        print(f"Model output: {prediction}")
        print(f"Expected:     {expected}")

        # Check if correct
        is_correct = prediction.strip() == expected.strip()
        print(f"Correct: {'✓' if is_correct else '✗'}")
        if is_correct:
            correct += 1
        total += 1

    accuracy = correct / total * 100
    print(f"\nGenerated Examples Accuracy: {correct}/{total} = {accuracy:.1f}%")

    return correct, total

def interactive_test(model, tokenizer, device):
    """Interactive testing interface"""

    print("\n" + "="*60)
    print("INTERACTIVE TESTING")
    print("Enter prompts like: <sos> 55,50,14 <sum>")
    print("Type 'quit' to exit")
    print("="*60)

    while True:
        try:
            prompt = input("\nEnter prompt: ").strip()
            if prompt.lower() == 'quit':
                break

            if not prompt:
                continue

            # Test the prompt
            prediction, tokens = test_model_completion(model, tokenizer, device, prompt)
            print(f"Model output: {prediction}")

            # Show expected if possible
            try:
                expected = calculate_expected_result(prompt)
                if "Invalid" not in expected and "Unknown" not in expected:
                    print(f"Expected:     {expected}")
            except:
                pass

        except KeyboardInterrupt:
            print("\nExiting...")
            break
        except Exception as e:
            print(f"Error: {e}")

# Main testing function
def main():
    """Run all tests"""
    try:
        # Load model
        model, tokenizer, device = load_trained_model()

        # Run comprehensive tests
        run_comprehensive_test(model, tokenizer, device)

        # Uncomment to test with generated examples (requires importing your functions)
        # test_with_generated_examples(model, tokenizer, device)

        # Run interactive test
        interactive_test(model, tokenizer, device)

    except Exception as e:
        print(f"Error during testing: {e}")

if __name__ == "__main__":
    main()

In [None]:

#minimal testing

import torch
import pickle
import os
import numpy as np

def evaluate_on_validation_set():
    """Evaluate model on validation set"""

    # Load metadata
    data_dir = 'data/number_commands'
    meta_path = os.path.join(data_dir, 'meta.pkl')
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)

    # Load model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    checkpoint_path = 'out_number_commands/ckpt.pt'
    checkpoint = torch.load(checkpoint_path, map_location=device)

    model_args = checkpoint['model_args']

    # Filter to only NumberCommandConfig parameters
    config_params = {
        'block_size': model_args['block_size'],
        'vocab_size': model_args['vocab_size'],
        'n_layer': model_args['n_layer'],
        'n_head': model_args['n_head'],
        'n_embd': model_args['n_embd'],
        'dropout': model_args['dropout'],
        'bias': model_args['bias']
    }

    config = NumberCommandConfig(**config_params)
    model = NumberCommandTransformer(config)

    # Remove compilation prefix if exists
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()

    # Load validation data
    val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    block_size = config.block_size

    print(f"Model loaded: {model.get_num_params()/1e6:.2f}M parameters")
    print(f"Validation data size: {len(val_data)} tokens")

    # Evaluate loss on validation set
    total_loss = 0
    num_batches = 0
    batch_size = 32

    with torch.no_grad():
        for i in range(0, len(val_data) - block_size, batch_size * block_size):
            # Get batch
            batch_end = min(i + batch_size * block_size, len(val_data) - block_size)
            batch_indices = range(i, batch_end, block_size)

            if len(batch_indices) == 0:
                break

            x = torch.stack([
                torch.from_numpy(val_data[j:j+block_size].astype(np.int64))
                for j in batch_indices
            ])
            y = torch.stack([
                torch.from_numpy(val_data[j+1:j+1+block_size].astype(np.int64))
                for j in batch_indices
            ])

            x, y = x.to(device), y.to(device)

            # Forward pass
            logits, loss = model(x, y)
            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches
    perplexity = torch.exp(torch.tensor(avg_loss))

    print(f"Validation Loss: {avg_loss:.4f}")
    print(f"Perplexity: {perplexity:.2f}")


In [None]:

evaluate_on_validation_set()