In [None]:
import argparse
import json
import os
import time
from typing import List, Optional, Dict, Any, Union

import torch
import torch.nn.functional as F
import numpy as np

from transformer_blocks import TransformerLM, softmax
from bpe_main import BPETokenizer

def load_model_and_tokenizer(checkpoint_path: str, config_path: Optional[str] = None, 
                           tokenizer_vocab_path: str = None, tokenizer_merges_path: str = None,
                           special_tokens: List[str] = None) -> tuple:
    """
    Load trained model and tokenizer from checkpoint and config files.
    
    Args:
        checkpoint_path: Path to model checkpoint
        config_path: Path to training config (optional, will try to infer)
        tokenizer_vocab_path: Path to tokenizer vocab file
        tokenizer_merges_path: Path to tokenizer merges file
        special_tokens: List of special tokens
        
    Returns:
        Tuple of (model, tokenizer, config)
    """
    # Load checkpoint
    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Load config
    if config_path is None:
        # Try to find config in same directory as checkpoint
        checkpoint_dir = os.path.dirname(checkpoint_path)
        config_path = os.path.join(checkpoint_dir, "config.json")
    
    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            config = json.load(f)
        print(f"Loaded config from {config_path}")
    else:
        # Fallback default config
        print("Warning: No config file found, using default parameters")
        config = {
            'vocab_size': 50257,
            'context_length': 1024,
            'd_model': 768,
            'num_heads': 12,
            'num_layers': 12,
            'rope_theta': 10000.0
        }
    
    # Initialize model
    print("Initializing model...")
    model = TransformerLM(
        vocab_size=config['vocab_size'],
        context_length=config['context_length'],
        d_model=config['d_model'],
        num_heads=config['num_heads'],
        num_layers=config['num_layers'],
        rope_theta=config.get('rope_theta', 10000.0)
    )
    
    # Load model weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model loaded: {total_params:,} parameters")
    
    # Load tokenizer
    if tokenizer_vocab_path and tokenizer_merges_path:
        print("Loading BPE tokenizer...")
        tokenizer = BPETokenizer.from_files(
            vocab_path=tokenizer_vocab_path,
            merges_path=tokenizer_merges_path,
            special_tokens=special_tokens or ["<|endoftext|>"]
        )
        print(f"Tokenizer loaded with vocab size: {len(tokenizer.vocab)}")
    else:
        print("Warning: No tokenizer paths provided, tokenizer will be None")
        tokenizer = None
    
    return model, tokenizer, config

In [2]:
checkpoint_path = "checkpoints/final.pt"
config_path = "checkpoints/config.json"
tokenizer_vocab_path = "../artifacts/ts_train/vocab.json"
tokenizer_merges_path = "../artifacts/ts_train/merges.json"
special_tokens = None

In [3]:
model, tokenizer, config = load_model_and_tokenizer(
    checkpoint_path=checkpoint_path,
    config_path=config_path,
    tokenizer_vocab_path=tokenizer_vocab_path,
    tokenizer_merges_path=tokenizer_merges_path,
    special_tokens=special_tokens
)

Loading checkpoint from checkpoints/final.pt...
Loaded config from checkpoints/config.json
Initializing model...
Model loaded: 100,313,856 parameters
Loading BPE tokenizer...
Loading vocab from ../artifacts/ts_train/vocab.json...
Vocab loaded in 0.00s - 10000 entries
Merges loaded in 0.00s - 9743 rules
Tokenizer loaded with vocab size: 10000


In [61]:
prompt = ("In France there are a").strip()

In [62]:
# Tokenize prompt
if prompt:
    input_ids = tokenizer.encode(prompt)
    if not input_ids:
        print("Warning: Empty prompt after tokenization")
        input_ids = []
else:
    input_ids = []

In [63]:
input_ids

[1463, 3532, 1638, 401, 483, 259]

In [64]:
if input_ids:
    tokens = torch.tensor([input_ids], dtype=torch.long, device='cpu')
else:
    # Start with empty sequence - model should handle this
    tokens = torch.empty((1, 0), dtype=torch.long, device='cpu')

In [65]:
stop_token: str = "<|endoftext|>"

stop_token_id = None
if stop_token:
    stop_tokens = tokenizer.encode(stop_token)
    if stop_tokens:
        stop_token_id = stop_tokens[0]
stop_token_id

0

In [66]:
generated_tokens = []
# start_time = time.time()

In [21]:
model.eval() 
with torch.no_grad(): 
    logits = model(tokens, return_logits=True)

logits

tensor([[[-0.1438,  0.3650, -0.1430,  ...,  0.5692, -0.3380,  0.3288],
         [-0.5015,  0.0610, -0.0039,  ...,  0.5530, -0.1796,  0.0252],
         [-0.3390,  0.3696,  0.2635,  ...,  0.4191, -0.2637, -0.3357],
         ...,
         [-0.2940,  0.0408,  0.3533,  ...,  0.4481, -0.4130, -0.7079],
         [-0.5766, -0.3995,  0.0820,  ...,  0.7355, -0.3799, -0.4130],
         [-0.3247,  0.2963,  0.5471,  ...,  0.5115, -0.3120, -0.0507]]])

In [22]:
logits.shape

torch.Size([1, 9, 10000])

In [24]:
next_token_logits = logits[:, -1, :]

next_token_logits.shape

torch.Size([1, 10000])

In [37]:
next_token = torch.argmax(next_token_logits)
next_token = next_token.unsqueeze(-1)

In [38]:
tokens = torch.cat([tokens, next_token.unsqueeze(-1)], dim=-1)
generated_tokens.append(next_token.item())

In [39]:
generated_tokens

[6031]

In [41]:
generated_text = tokenizer.decode(generated_tokens)
generated_text

' cuts'

In [42]:
tokens

tensor([[  81,  820, 1010,   59,  285,  496, 1311,  911, 1854, 6031]])

In [None]:
def softmax_with_temperature(logits, temperature):
    
    scaled_logits = logits / temperature
    
    probs = F.softmax(scaled_logits, dim=-1)
    
    return probs

In [94]:
def top_p_sampling(probs, top_p):
    pass

In [88]:
def sample_next_token(logits: torch.Tensor, temperature: float = 1.0, 
                     top_p: float = 1.0, deterministic: bool = False) -> torch.Tensor:
    """
    Sample next token from model logits using temperature and top-p sampling.
    
    Args:
        logits: Model output logits of shape (..., vocab_size)
        temperature: Temperature for scaling
        top_p: Top-p threshold for nucleus sampling
        deterministic: If True, always select most probable token (greedy)
    
    Returns:
        Sampled token indices
    """
    if deterministic:
        return torch.argmax(logits, dim=-1)
    
    # Apply temperature scaling
    probs = softmax_with_temperature(logits, temperature)
    
    # Apply top-p filtering
    if top_p < 1.0:
        probs = top_p_sampling(probs, top_p)
    
    # Sample from the distribution
    next_token = torch.multinomial(probs, num_samples=1)
    
    return next_token.squeeze(-1)

In [92]:
# generated_tokens = []
# start_time = time.time()

# model.eval() 
with torch.no_grad(): 
    logits = model(tokens, return_logits=True)

next_token_logits = logits[:, -1, :]
next_token = sample_next_token(next_token_logits, deterministic=True)

tokens = torch.cat([tokens, next_token.unsqueeze(-1)], dim=-1)
generated_tokens.append(next_token.item())

generated_text = tokenizer.decode(generated_tokens)
generated_text

' big Bun also Sally floating unh unh unh trip a big a big a big, ".\n<|endoftext|>\n<|endoftext|>\n<|endoftext|>'