In [21]:
import os
import torch
import importlib
import model as model_module

importlib.reload(model_module)
from model import SimpleGPTPredictor, device

print(f"Using model from {model_module.__file__}")


Using model from /data/home/ayumu/dev/llm_scratch/model.py


In [22]:
# Load vocab dictionaries (MUST match training order from main.py!)
with open('inputLearnText.txt', 'r', encoding='utf-8') as f:
    text = f.read()
chars = sorted(list(set(text)))  # SORTED for consistency with training!
char_to_id = {ch: i for i, ch in enumerate(chars)}
id_to_char = {i: ch for i, ch in enumerate(chars)}

def text_to_ids(text):
    return [char_to_id[ch] for ch in text]

def ids_to_text(ids):
    return ''.join([id_to_char[i] for i in ids])

print(f"Vocab size: {len(chars)}")

Vocab size: 120


In [23]:
# Model hyperparams (must match main.py training config)
EMBED_SIZE = 32
NUM_HEADS = 4
MAX_LEN = 100

# Pick the checkpoint to load
MODEL_PATH = "model/model_11.pth"

# Load weights first so we can infer number of layers
state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=True)

def _infer_num_layers(sd, prefix="encoder.layers."):
    layer_ids = {int(k.split('.')[2]) for k in sd if k.startswith(prefix)}
    return (max(layer_ids) + 1) if layer_ids else 0

NUM_LAYERS = _infer_num_layers(state_dict)
print(f"Inferred NUM_LAYERS = {NUM_LAYERS}")

# Instantiate model with same architecture as training
model = SimpleGPTPredictor(
    vocab_size=len(chars),
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    max_len=MAX_LEN,
    num_layers=NUM_LAYERS,
)

# Load weights (strict=False allows old checkpoints without "pe")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
    print(f"Missing keys: {missing}")
if unexpected:
    print(f"Unexpected keys: {unexpected}")

model.to(device)
model.eval()

print(f"Model loaded from {MODEL_PATH} on {device}")


Inferred NUM_LAYERS = 10
Model loaded from model/model_11.pth on cuda


In [24]:
def test_prediction(model: SimpleGPTPredictor, input_text, temperature=1.0):
    """
    Predict next character with temperature control.
    
    Args:
        model: The model to use for prediction
        input_text: Input text string
        temperature: Controls randomness. Higher = more random, Lower = more deterministic
                    temperature=1.0 is neutral, >1.0 is more random, <1.0 is more focused
    """
    input_ids = text_to_ids(input_text)
    input_tensor = torch.tensor([input_ids], device=device)

    with torch.no_grad():
        output = model(input_tensor, input_tensor)
        last_char_probs = output[0, -1, :]
        
        # Apply temperature scaling
        last_char_probs = last_char_probs / temperature
        probs = torch.softmax(last_char_probs, dim=-1)

        # Sample from the distribution (instead of always picking top-1)
        if temperature > 0:
            char_id = torch.multinomial(probs, num_samples=1).item()
        else:
            # If temperature is 0, use greedy (deterministic)
            char_id = torch.argmax(probs).item()
            
        predicted_char = id_to_char[char_id]

        return predicted_char

def generateSeq(model, text, max_length=20, temperature=1.0):
    """
    Generate sequence with temperature control.
    
    Args:
        model: The model to use
        text: Starting text
        max_length: Maximum number of tokens to generate
        temperature: Controls randomness (default 1.0)
    """
    generated = text
    for _ in range(max_length):
        nextSingleToken = test_prediction(model, generated, temperature=temperature)
        generated += nextSingleToken
    return generated

In [25]:
prompt = "But since what may prove "

# Try different temperatures
temperatures = [0.5, 0.8, 1.0, 1.2, 1.5]

for temp in temperatures:
    completion = generateSeq(model, prompt, max_length=50, temperature=temp)
    print(f"\n=== Temperature: {temp} ===")
    print(f"Input:  {prompt}")
    print(f"Output: {completion}")


=== Temperature: 0.5 ===
Input:  But since what may prove 
Output: But since what may prove  e       i o  t d  d t  etto    e ee rn  et     l 

=== Temperature: 0.8 ===
Input:  But since what may prove 
Output: But since what may prove yhy surns  e ih ,no a.au  smie.c;sfct o
 en ld loa

=== Temperature: 1.0 ===
Input:  But since what may prove 
Output: But since what may prove eatBamteh”em rph4eeaepοne;t2.eone ehifo;ioeibyAhhs

=== Temperature: 1.2 ===
Input:  But since what may prove 
Output: But since what may prove sdPrlc bllμ wwoe νgreᾶia7e iyngx) kat]έ
pa7lott wa

=== Temperature: 1.5 ===
Input:  But since what may prove 
Output: But since what may prove wve)kώεōvח,αCtECκe kiuh”wuunsgc  oc ehJiSihlςόydat


## Model Architecture

### Key Changes:
1. **batch_first=True** - Transformer layers now use (batch, seq, embed) format
2. **No transpose operations** - Simpler code, more efficient
3. **Vocab Consistency** - Both training and inference use sorted() vocab for consistent token IDs
4. **Correct Mask Dimensions** - Uses `tgt.size(1)` for batch-first format

### Architecture:
- Encoder-Decoder Transformer
- Embedding size: 32
- Attention heads: 4
- Layers: 2 (encoder) + 2 (decoder)
- Causal masking for autoregressive generation