# Training SmolLM2-135M

This notebook demonstrates how to train the custom `SmolLM2` model (imported from `model.py`) on the Wikitext dataset. It includes:
1. Training for 5000 steps.
2. Generating text every 500 steps.
3. Saving a checkpoint.
4. Resuming training from the checkpoint.

In [None]:
# Install PyTorch with CUDA support for Windows (assuming CUDA 12.1)
# If this fails or you have a different CUDA version, check https://pytorch.org/get-started/locally/
!pip uninstall -y torch torchvision
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip install transformers datasets

> **IMPORTANT**: After running the cell above, you **MUST** restart the Jupyter Kernel for the changes to take effect. Go to **Kernel > Restart Kernel** in the menu.

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoConfig
from datasets import load_dataset
from model import SmolLM2  # Import our custom model
import os

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cpu":
    print("WARNING: You are running on CPU. Training will be very slow. Please ensure you have a GPU and the correct PyTorch version installed.")

## 1. Load Model and Tokenizer

In [None]:
# Load Custom Tokenizer
tokenizer_path = "./custom_tokenizer"

# Check if custom tokenizer exists
if os.path.exists(tokenizer_path):
    print(f"Loading custom tokenizer from {tokenizer_path}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
    print(f"Custom tokenizer not found at {tokenizer_path}. Please run train_tokenizer.ipynb first.")
    # Fallback to default for safety, or raise error. 
    # Assuming user wants custom, we should probably raise error or warn heavily.
    # For this flow, let's fallback but warn.
    print("Falling back to default tokenizer (HuggingFaceTB/SmolLM2-135M)...")
    model_id = "HuggingFaceTB/SmolLM2-135M"
    tokenizer = AutoTokenizer.from_pretrained(model_id)

# Ensure pad_token is set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Padding token set to: {tokenizer.pad_token}")

model_id = "HuggingFaceTB/SmolLM2-135M"
config = AutoConfig.from_pretrained(model_id)

# Update config vocab size to match tokenizer
config.vocab_size = len(tokenizer)
print(f"Model vocab size updated to: {config.vocab_size}")

# Initialize model from scratch
model = SmolLM2(config).to(device)
print("Model initialized.")

## 2. Prepare Dataset (Chunked)
We concatenate text and split into chunks to allow the model to learn context across lines.

In [None]:
# Load dataset by reading the full file to preserve newlines and formatting
with open("input-1.txt", "r", encoding="utf-8") as f:
    full_text = f.read()

from datasets import Dataset
dataset = Dataset.from_dict({"text": [full_text]})

# Double check tokenizer padding
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

block_size = 256 # Context window size

def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
        
    # Split by chunks of max_len
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

def tokenize_function(examples):
    return tokenizer(examples["text"])

# Tokenize all text
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Group into chunks
lm_dataset = tokenized_dataset.map(
    group_texts,
    batched=True,
    batch_size=1000,
    # num_proc=4, # Disabled to avoid Windows multiprocessing 'spawn' issues with global variables
)

lm_dataset = lm_dataset.with_format("torch")

# Create dataloader
train_dataloader = DataLoader(lm_dataset, batch_size=4, shuffle=True)
print(f"Dataset prepared. Number of chunks: {len(lm_dataset)}")

## 3. Training Loop

In [None]:
# Optimization: Enable TF32 for faster matrix multiplications on Ampere+ GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Optimization: Use torch.compile (PyTorch 2.0+)
# Windows support for torch.compile can be tricky, so we wrap it in a try-except or check os
if os.name != 'nt': # torch.compile often has issues on Windows currently, skipping for safety or try 'inductor'
     print("Compiling model with torch.compile...")
     model = torch.compile(model)
else:
    print("Skipping torch.compile on Windows to avoid potential compatibility issues.")

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=True if torch.cuda.is_available() else False)
loss_fn = torch.nn.CrossEntropyLoss()

# Optimization: Mixed Precision Training
scaler = torch.cuda.amp.GradScaler()

def generate_text(model, tokenizer, prompt="The meaning of life is", max_new_tokens=50, temperature=0.7, top_k=50):
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs.input_ids
    
    for _ in range(max_new_tokens):
        with torch.no_grad():
            # Auto-cast is not strictly necessary for inference but can speed it up
            with torch.cuda.amp.autocast():
                logits = model(input_ids)
            next_token_logits = logits[:, -1, :] / temperature
            
            # Top-k sampling
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
            probs = torch.nn.functional.softmax(top_k_logits, dim=-1)
            next_token_index = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices.gather(-1, next_token_index)
            
            input_ids = torch.cat([input_ids, next_token], dim=1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
            
    print(f"Generated: {tokenizer.decode(input_ids[0], skip_special_tokens=True)}")
    model.train()

steps = 0
max_steps = 5000
save_path = "checkpoint_5000.pt"

model.train()
print("Starting training...")

# Loop indefinitely until max_steps is reached
while steps < max_steps:
    for batch in train_dataloader:
        if steps >= max_steps:
            break
            
        input_ids = batch["input_ids"].to(device)
        # Shift labels for causal LM
        labels = input_ids.clone()
        
        optimizer.zero_grad()
        
        # Optimization: Mixed Precision Context
        with torch.cuda.amp.autocast():
            logits = model(input_ids)
            
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            loss = loss_fn(shift_logits.view(-1, config.vocab_size), shift_labels.view(-1))
        
        # Optimization: Scaled Backward Pass
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        steps += 1
        
        if steps % 100 == 0:
            print(f"Step {steps}: Loss {loss.item()}")
            
        if steps % 500 == 0:
            print(f"\n--- Step {steps} Generation ---")
            generate_text(model, tokenizer)
            print("-----------------------------\n")

# Save Checkpoint
torch.save(model.state_dict(), save_path)
print(f"Checkpoint saved to {save_path}")

## 4. Resume Training
Now we simulate stopping and resuming by loading the checkpoint and training for 50 more steps.

In [None]:
print("Resuming training...")
model = SmolLM2(config).to(device)
model.load_state_dict(torch.load(save_path))
model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

extra_steps = 50
current_step = 0

# Re-create dataloader for resume (in real scenario you'd want to skip already seen data)
# Note: In this simple example, we restart the dataloader from the beginning.
# In a real resume scenario, you'd want to save the dataloader state or skip 'steps' batches.
train_dataloader_resume = DataLoader(lm_dataset, batch_size=4, shuffle=True)

for batch in train_dataloader_resume:
    if current_step >= extra_steps:
        break
        
    input_ids = batch["input_ids"].to(device)
    labels = input_ids.clone()
    
    optimizer.zero_grad()
    logits = model(input_ids)
    
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    loss = loss_fn(shift_logits.view(-1, config.vocab_size), shift_labels.view(-1))
    loss.backward()
    optimizer.step()
    
    current_step += 1
    if current_step % 10 == 0:
        print(f"Resume Step {current_step}: Loss {loss.item()}")

print(f"\n--- Step {steps + extra_steps} Generation ---")
generate_text(model, tokenizer, prompt="First Citizen", max_new_tokens=1024)
print("-----------------------------\n")

print("Resumed training completed.")