# Tutorial 1: Basic Model Training

This tutorial demonstrates how to train a G-code fingerprinting model from scratch.

**What you'll learn:**
- Load and prepare preprocessed data
- Configure model hyperparameters
- Train a multi-head transformer
- Evaluate model performance
- Save and load checkpoints

**Prerequisites:**
- Preprocessed data in `data/preprocessed/`
- Python environment with all dependencies installed

## 1. Setup and Imports

In [None]:
import os
import sys
import json
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / 'src'))

from miracle.data.dataset import GCodeDataset
from miracle.data.augmentation import DataAugmenter
from miracle.model.multihead_transformer import MultiHeadGCodeTransformer
from miracle.training.trainer import Trainer
from miracle.utils.target_utils import TargetDecomposer
from miracle.tokenization.gcode_tokenizer import GCodeTokenizer, TokenizerConfig

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

## 2. Configure Training

Set hyperparameters for training. These are reasonable defaults for quick experimentation.

In [None]:
# Training configuration
config = {
    # Data
    "data_dir": "../data/preprocessed",
    "batch_size": 8,
    "num_workers": 4,
    
    # Model architecture
    "d_model": 128,
    "nhead": 8,
    "num_encoder_layers": 2,
    "num_decoder_layers": 2,
    "dim_feedforward": 512,
    "dropout": 0.1,
    
    # Training
    "epochs": 30,
    "learning_rate": 0.001,
    "weight_decay": 0.01,
    "warmup_epochs": 5,
    
    # Data augmentation
    "augmentation": True,
    "oversampling_factor": 3,
    "noise_level": 0.02,
    
    # Loss weights
    "type_weight": 1.0,
    "command_weight": 3.0,
    "param_type_weight": 2.0,
    "param_value_weight": 2.0,
    
    # Output
    "output_dir": "outputs/tutorial_01",
    "save_every": 5,
}

# Create output directory
os.makedirs(config["output_dir"], exist_ok=True)

# Save config
with open(f"{config['output_dir']}/config.json", 'w') as f:
    json.dump(config, f, indent=2)

print("✓ Configuration set")
print(json.dumps(config, indent=2))

## 3. Load Tokenizer and Vocabulary

In [None]:
# Load tokenizer
tokenizer_config = TokenizerConfig(bucket_digits=2)  # 2-digit bucketing
tokenizer = GCodeTokenizer(tokenizer_config)

# Load vocabulary from preprocessed data
vocab_path = Path(config["data_dir"]) / "vocabulary_v2.json"
tokenizer.load(vocab_path)

vocab_size = len(tokenizer.vocab)
print(f"✓ Loaded vocabulary: {vocab_size} tokens")
print(f"  Bucket digits: {tokenizer.config.bucket_digits}")
print(f"  Special tokens: PAD={tokenizer.pad_token_id}, UNK={tokenizer.unk_token_id}")

# Show sample tokens
sample_tokens = list(tokenizer.vocab.keys())[:10]
print(f"\nSample tokens: {sample_tokens}")

## 4. Create Data Loaders

In [None]:
# Initialize target decomposer
decomposer = TargetDecomposer(tokenizer.vocab)

# Create datasets
train_dataset = GCodeDataset(
    data_dir=config["data_dir"],
    split="train",
    vocab=tokenizer.vocab,
    decomposer=decomposer,
)

val_dataset = GCodeDataset(
    data_dir=config["data_dir"],
    split="val",
    vocab=tokenizer.vocab,
    decomposer=decomposer,
)

test_dataset = GCodeDataset(
    data_dir=config["data_dir"],
    split="test",
    vocab=tokenizer.vocab,
    decomposer=decomposer,
)

print(f"✓ Datasets loaded:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val:   {len(val_dataset)} samples")
print(f"  Test:  {len(test_dataset)} samples")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=config["num_workers"],
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"],
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"],
)

print(f"\n✓ Data loaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")
print(f"  Test batches:  {len(test_loader)}")

## 5. Inspect Sample Batch

In [None]:
# Get sample batch
batch = next(iter(train_loader))

print("Sample batch contents:")
print(f"  continuous shape: {batch['continuous'].shape}")
print(f"  categorical shape: {batch['categorical'].shape}")
print(f"  tokens shape: {batch['tokens'].shape}")
print(f"  type_ids shape: {batch['type_ids'].shape}")
print(f"  command_ids shape: {batch['command_ids'].shape}")
print(f"  param_type_ids shape: {batch['param_type_ids'].shape}")
print(f"  param_value_ids shape: {batch['param_value_ids'].shape}")

# Visualize first sample's sensor data
continuous_sample = batch['continuous'][0].numpy()  # Shape: [T, 135]

fig, axes = plt.subplots(2, 1, figsize=(14, 6))

# Plot first 20 channels
axes[0].plot(continuous_sample[:, :20])
axes[0].set_title("Sample Sensor Data (First 20 Channels)")
axes[0].set_xlabel("Time Step")
axes[0].set_ylabel("Normalized Value")
axes[0].legend([f"Ch{i}" for i in range(20)], ncol=10, loc='upper right', fontsize=8)

# Plot all channels as heatmap
im = axes[1].imshow(continuous_sample.T, aspect='auto', cmap='viridis')
axes[1].set_title("All Sensor Channels (Heatmap)")
axes[1].set_xlabel("Time Step")
axes[1].set_ylabel("Channel")
plt.colorbar(im, ax=axes[1])

plt.tight_layout()
plt.show()

print("\n✓ Sample batch inspected")

## 6. Create Model

In [None]:
# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model
model = MultiHeadGCodeTransformer(
    vocab_size=vocab_size,
    d_model=config["d_model"],
    nhead=config["nhead"],
    num_encoder_layers=config["num_encoder_layers"],
    num_decoder_layers=config["num_decoder_layers"],
    dim_feedforward=config["dim_feedforward"],
    dropout=config["dropout"],
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✓ Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: {total_params * 4 / (1024**2):.2f} MB (FP32)")

# Print model summary
print("\nModel architecture:")
print(model)

## 7. Setup Optimizer and Scheduler

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config["learning_rate"],
    weight_decay=config["weight_decay"],
)

# Learning rate scheduler (cosine annealing with warmup)
from torch.optim.lr_scheduler import CosineAnnealingLR

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=config["epochs"] - config["warmup_epochs"],
)

print("✓ Optimizer and scheduler created")
print(f"  Optimizer: AdamW")
print(f"  Learning rate: {config['learning_rate']}")
print(f"  Weight decay: {config['weight_decay']}")
print(f"  Scheduler: CosineAnnealing")
print(f"  Warmup epochs: {config['warmup_epochs']}")

## 8. Define Loss Function

In [None]:
def compute_loss(logits, targets, weights):
    """
    Compute weighted multi-head loss.
    
    Args:
        logits: Tuple of (type_logits, command_logits, param_type_logits, param_value_logits)
        targets: Tuple of (type_ids, command_ids, param_type_ids, param_value_ids)
        weights: Dictionary of loss weights
    
    Returns:
        Total loss, dict of individual losses
    """
    type_logits, command_logits, param_type_logits, param_value_logits = logits
    type_ids, command_ids, param_type_ids, param_value_ids = targets
    
    criterion = nn.CrossEntropyLoss()
    
    # Reshape for loss computation
    B, T = type_ids.shape
    
    type_loss = criterion(
        type_logits.view(B * T, -1),
        type_ids.view(B * T)
    )
    
    command_loss = criterion(
        command_logits.view(B * T, -1),
        command_ids.view(B * T)
    )
    
    param_type_loss = criterion(
        param_type_logits.view(B * T, -1),
        param_type_ids.view(B * T)
    )
    
    param_value_loss = criterion(
        param_value_logits.view(B * T, -1),
        param_value_ids.view(B * T)
    )
    
    # Weighted sum
    total_loss = (
        weights['type_weight'] * type_loss +
        weights['command_weight'] * command_loss +
        weights['param_type_weight'] * param_type_loss +
        weights['param_value_weight'] * param_value_loss
    )
    
    losses = {
        'total': total_loss.item(),
        'type': type_loss.item(),
        'command': command_loss.item(),
        'param_type': param_type_loss.item(),
        'param_value': param_value_loss.item(),
    }
    
    return total_loss, losses

print("✓ Loss function defined")

## 9. Training Loop

Train the model for the specified number of epochs.

In [None]:
from tqdm.notebook import tqdm

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'learning_rate': [],
}

best_val_loss = float('inf')

# Loss weights
loss_weights = {
    'type_weight': config['type_weight'],
    'command_weight': config['command_weight'],
    'param_type_weight': config['param_type_weight'],
    'param_value_weight': config['param_value_weight'],
}

print("Starting training...\n")

for epoch in range(config['epochs']):
    # Training phase
    model.train()
    train_losses = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [Train]")
    
    for batch in pbar:
        continuous = batch['continuous'].to(device)
        categorical = batch['categorical'].to(device)
        
        targets = (
            batch['type_ids'].to(device),
            batch['command_ids'].to(device),
            batch['param_type_ids'].to(device),
            batch['param_value_ids'].to(device),
        )
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(continuous, categorical)
        
        # Compute loss
        loss, loss_dict = compute_loss(logits, targets, loss_weights)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    # Validation phase
    model.eval()
    val_losses = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [Val]  "):
            continuous = batch['continuous'].to(device)
            categorical = batch['categorical'].to(device)
            
            targets = (
                batch['type_ids'].to(device),
                batch['command_ids'].to(device),
                batch['param_type_ids'].to(device),
                batch['param_value_ids'].to(device),
            )
            
            logits = model(continuous, categorical)
            loss, _ = compute_loss(logits, targets, loss_weights)
            val_losses.append(loss.item())
    
    # Update learning rate
    if epoch >= config['warmup_epochs']:
        scheduler.step()
    
    # Compute epoch metrics
    train_loss = np.mean(train_losses)
    val_loss = np.mean(val_losses)
    lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['learning_rate'].append(lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{config['epochs']}:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  LR:         {lr:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = {
            'epoch': epoch,
            'model': model,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'val_loss': val_loss,
            'config': config,
        }
        torch.save(checkpoint, f"{config['output_dir']}/checkpoint_best.pt")
        print(f"  ✓ Saved best model (val_loss={val_loss:.4f})")
    
    # Save periodic checkpoint
    if (epoch + 1) % config['save_every'] == 0:
        checkpoint = {
            'epoch': epoch,
            'model': model,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'val_loss': val_loss,
            'config': config,
        }
        torch.save(checkpoint, f"{config['output_dir']}/checkpoint_epoch{epoch+1}.pt")

print("\n✓ Training complete!")
print(f"  Best validation loss: {best_val_loss:.4f}")

## 10. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
epochs = range(1, len(history['train_loss']) + 1)
axes[0].plot(epochs, history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(epochs, history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Learning rate
axes[1].plot(epochs, history['learning_rate'], color='green', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].grid(True)

plt.tight_layout()
plt.savefig(f"{config['output_dir']}/training_history.png", dpi=150)
plt.show()

print("✓ Training curves saved")

## 11. Evaluate on Test Set

In [None]:
# Load best checkpoint
checkpoint = torch.load(f"{config['output_dir']}/checkpoint_best.pt")
model = checkpoint['model']
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Validation loss: {checkpoint['val_loss']:.4f}\n")

# Evaluate on test set
test_losses = []
all_type_preds = []
all_type_targets = []
all_command_preds = []
all_command_targets = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating on test set"):
        continuous = batch['continuous'].to(device)
        categorical = batch['categorical'].to(device)
        
        targets = (
            batch['type_ids'].to(device),
            batch['command_ids'].to(device),
            batch['param_type_ids'].to(device),
            batch['param_value_ids'].to(device),
        )
        
        # Forward pass
        logits = model(continuous, categorical)
        loss, _ = compute_loss(logits, targets, loss_weights)
        test_losses.append(loss.item())
        
        # Collect predictions
        type_preds = torch.argmax(logits[0], dim=-1)
        command_preds = torch.argmax(logits[1], dim=-1)
        
        all_type_preds.append(type_preds.cpu())
        all_type_targets.append(targets[0].cpu())
        all_command_preds.append(command_preds.cpu())
        all_command_targets.append(targets[1].cpu())

# Compute test metrics
test_loss = np.mean(test_losses)

all_type_preds = torch.cat(all_type_preds).flatten()
all_type_targets = torch.cat(all_type_targets).flatten()
all_command_preds = torch.cat(all_command_preds).flatten()
all_command_targets = torch.cat(all_command_targets).flatten()

type_accuracy = (all_type_preds == all_type_targets).float().mean() * 100
command_accuracy = (all_command_preds == all_command_targets).float().mean() * 100

print("\n" + "="*50)
print("TEST SET RESULTS")
print("="*50)
print(f"Test Loss:         {test_loss:.4f}")
print(f"Type Accuracy:     {type_accuracy:.2f}%")
print(f"Command Accuracy:  {command_accuracy:.2f}%")
print("="*50)

## 12. Summary

In this tutorial, you learned how to:

✓ Load preprocessed sensor data and G-code tokens  
✓ Configure a multi-head transformer model  
✓ Train with weighted multi-task loss  
✓ Monitor training progress and save checkpoints  
✓ Evaluate model performance on test set  

### Next Steps

- **Tutorial 2**: Create custom datasets
- **Tutorial 3**: Advanced data augmentation
- **Tutorial 4**: Export models to ONNX
- **Tutorial 5**: Deploy with Docker

### Improving Results

To improve model performance:
1. Increase model size (`d_model=256`, more layers)
2. Train for more epochs (100+)
3. Run hyperparameter sweeps (see HYPERPARAMETER_TUNING.md)
4. Use better vocabulary (3-digit bucketing)
5. Tune loss weights for parameter heads

See the documentation for detailed guides!