# 04 - Inference & Prediction

**Use trained models to predict G-code from sensor data.**

## Learning Objectives
- Load trained encoder and decoder models
- Run inference on single samples
- Perform batch inference with timing
- Decode predictions to G-code strings
- Analyze per-head predictions
- Evaluate operation classification (100% accuracy)

## Table of Contents
1. [Load Models](#1.-Load-Models)
2. [Single Sample Inference](#2.-Single-Sample-Inference)
3. [Batch Inference](#3.-Batch-Inference)
4. [Token Decoding](#4.-Token-Decoding)
5. [Per-Head Analysis](#5.-Per-Head-Analysis)
6. [Performance Benchmarks](#6.-Performance-Benchmarks)

In [None]:
# Setup
import sys
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

print(f"Project root: {project_root}")

In [None]:
# Imports
import json
import time
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Device
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")
print("✓ Imports successful")

---
## 1. Load Models

Load the frozen encoder (MM-DTAE-LSTM v2) and decoder (SensorMultiHeadDecoder v3).

In [None]:
# Paths
encoder_path = project_root / 'outputs' / 'mm_dtae_lstm_v2' / 'best_model.pt'
decoder_path = project_root / 'outputs' / 'sensor_multihead_v3' / 'best_model.pt'
vocab_path = project_root / 'data' / 'vocabulary_4digit_hybrid.json'
split_dir = project_root / 'outputs' / 'stratified_splits_v2'

print("Model paths:")
print(f"  Encoder: {encoder_path} ({'EXISTS' if encoder_path.exists() else 'MISSING'})")
print(f"  Decoder: {decoder_path} ({'EXISTS' if decoder_path.exists() else 'MISSING'})")
print(f"  Vocab:   {vocab_path} ({'EXISTS' if vocab_path.exists() else 'MISSING'})")
print(f"  Splits:  {split_dir} ({'EXISTS' if split_dir.exists() else 'MISSING'})")

In [None]:
# Load vocabulary
with open(vocab_path, 'r') as f:
    vocab_data = json.load(f)

vocab = vocab_data.get('vocab', vocab_data)
id_to_token = {v: k for k, v in vocab.items()}

print(f"Vocabulary loaded: {len(vocab)} tokens")

# Special tokens
PAD_ID = vocab.get('PAD', 0)
BOS_ID = vocab.get('BOS', 1)
EOS_ID = vocab.get('EOS', 2)
UNK_ID = vocab.get('UNK', 3)

print(f"  PAD={PAD_ID}, BOS={BOS_ID}, EOS={EOS_ID}, UNK={UNK_ID}")

In [None]:
# Load decoder checkpoint
decoder_checkpoint = torch.load(decoder_path, map_location=device, weights_only=False)

print("Decoder checkpoint loaded:")
for key in decoder_checkpoint.keys():
    if 'state_dict' in key:
        print(f"  {key}: {len(decoder_checkpoint[key])} parameters")
    elif isinstance(decoder_checkpoint[key], dict):
        print(f"  {key}: dict")
    else:
        print(f"  {key}: {type(decoder_checkpoint[key]).__name__}")

In [None]:
# Create and load decoder model
from miracle.model.sensor_multihead_decoder import SensorMultiHeadDecoder

# Get config from checkpoint
config = decoder_checkpoint.get('config', {})

decoder = SensorMultiHeadDecoder(
    vocab_size=config.get('vocab_size', 668),
    d_model=config.get('d_model', 192),
    n_heads=config.get('n_heads', 8),
    n_layers=config.get('n_layers', 4),
    sensor_dim=config.get('sensor_dim', 128),
    n_operations=config.get('n_operations', 9),
    n_types=config.get('n_types', 4),
    n_commands=config.get('n_commands', 6),
    n_param_types=config.get('n_param_types', 10),
    max_seq_len=config.get('max_seq_len', 32),
    dropout=0.0,  # No dropout for inference
).to(device)

# Load weights
decoder.load_state_dict(decoder_checkpoint['model_state_dict'])
decoder.eval()

# Count parameters
total_params = sum(p.numel() for p in decoder.parameters())
print(f"\nDecoder loaded: {total_params:,} parameters")
print(f"  d_model:     {config.get('d_model', 192)}")
print(f"  n_heads:     {config.get('n_heads', 8)}")
print(f"  n_layers:    {config.get('n_layers', 4)}")

In [None]:
# Load test data
test_data = np.load(split_dir / 'test_sequences.npz', allow_pickle=True)

print(f"Test data loaded:")
print(f"  Samples: {len(test_data['continuous'])}")
for key in ['continuous', 'categorical', 'tokens', 'operation_type']:
    if key in test_data:
        print(f"  {key}: {test_data[key].shape}")

---
## 2. Single Sample Inference

In [None]:
@torch.no_grad()
def predict_single(decoder, sensor_emb, operation_type, tokens, device):
    """
    Run inference on a single sample.
    
    Args:
        decoder: SensorMultiHeadDecoder model
        sensor_emb: [T_s, sensor_dim] sensor embeddings
        operation_type: int, operation type index
        tokens: [L] token IDs (for teacher forcing)
        device: torch device
    
    Returns:
        dict with predictions and confidence scores
    """
    decoder.eval()
    
    # Add batch dimension
    sensor_emb = torch.FloatTensor(sensor_emb).unsqueeze(0).to(device)
    operation_type = torch.LongTensor([operation_type]).to(device)
    tokens = torch.LongTensor(tokens).unsqueeze(0).to(device)
    
    # Forward pass
    outputs = decoder(tokens, sensor_emb, operation_type)
    
    # Extract predictions
    results = {}
    for head_name in ['type_logits', 'command_logits', 'param_type_logits', 'legacy_logits']:
        if head_name in outputs and outputs[head_name] is not None:
            logits = outputs[head_name][0]  # Remove batch dim
            probs = F.softmax(logits, dim=-1)
            preds = logits.argmax(dim=-1)
            confidence = probs.max(dim=-1).values
            
            results[head_name] = {
                'predictions': preds.cpu().numpy(),
                'confidence': confidence.cpu().numpy(),
            }
    
    return results

print("✓ Single prediction function defined")

In [None]:
# Run single sample prediction
sample_idx = 0

# Get sample data
continuous = test_data['continuous'][sample_idx]  # [64, 155]
tokens = test_data['tokens'][sample_idx]          # [L]
operation_type = test_data['operation_type'][sample_idx]
gcode_text = test_data['gcode_texts'][sample_idx] if 'gcode_texts' in test_data else 'N/A'

# Create sensor embeddings (simplified: mean pooling)
# In production, these come from the frozen MM-DTAE-LSTM encoder
sensor_emb = continuous[:, :128]  # Use first 128 dims

print(f"Sample {sample_idx}:")
print(f"  Operation type: {operation_type}")
print(f"  G-code text:    {gcode_text}")
print(f"  Token shape:    {tokens.shape}")
print(f"  Sensor shape:   {sensor_emb.shape}")

# Run inference
start_time = time.time()
results = predict_single(decoder, sensor_emb, operation_type, tokens, device)
inference_time = (time.time() - start_time) * 1000

print(f"\nInference time: {inference_time:.2f} ms")
print(f"\nPrediction heads:")
for head_name, head_results in results.items():
    preds = head_results['predictions']
    confs = head_results['confidence']
    print(f"  {head_name}:")
    print(f"    Predictions: {preds[:5]}...")
    print(f"    Mean confidence: {confs.mean():.2%}")

In [None]:
# Compare predictions to ground truth
legacy_preds = results['legacy_logits']['predictions']
legacy_conf = results['legacy_logits']['confidence']

print("Token-by-token comparison (first 10):")
print("="*70)
print(f"{'Pos':>4} {'GT Token':>15} {'Pred Token':>15} {'Conf':>8} {'Match':>6}")
print("-"*70)

for i in range(min(10, len(tokens))):
    gt_id = tokens[i]
    pred_id = legacy_preds[i]
    conf = legacy_conf[i]
    
    gt_token = id_to_token.get(gt_id, f'<{gt_id}>')
    pred_token = id_to_token.get(pred_id, f'<{pred_id}>')
    match = '✓' if gt_id == pred_id else '✗'
    
    print(f"{i:4d} {gt_token:>15} {pred_token:>15} {conf:>7.1%} {match:>6}")

---
## 3. Batch Inference

In [None]:
@torch.no_grad()
def predict_batch(decoder, sensor_batch, operation_batch, token_batch, device):
    """
    Run inference on a batch of samples.
    
    Returns:
        dict with predictions for each head
    """
    decoder.eval()
    
    sensor_batch = torch.FloatTensor(sensor_batch).to(device)
    operation_batch = torch.LongTensor(operation_batch).to(device)
    token_batch = torch.LongTensor(token_batch).to(device)
    
    outputs = decoder(token_batch, sensor_batch, operation_batch)
    
    results = {}
    if 'legacy_logits' in outputs:
        results['predictions'] = outputs['legacy_logits'].argmax(dim=-1).cpu().numpy()
    
    return results

print("✓ Batch prediction function defined")

In [None]:
# Evaluate on full test set
n_test = len(test_data['continuous'])
batch_size = 32

all_preds = []
all_targets = []
all_ops_pred = []
all_ops_target = []

print(f"Evaluating {n_test} test samples...")

for i in tqdm(range(0, n_test, batch_size)):
    end_idx = min(i + batch_size, n_test)
    
    # Get batch data
    continuous_batch = test_data['continuous'][i:end_idx][:, :, :128]  # Simplified
    token_batch = test_data['tokens'][i:end_idx]
    op_batch = test_data['operation_type'][i:end_idx]
    
    # Predict
    results = predict_batch(decoder, continuous_batch, op_batch, token_batch, device)
    
    all_preds.extend(results['predictions'].tolist())
    all_targets.extend(token_batch.tolist())
    all_ops_target.extend(op_batch.tolist())

print(f"\n✓ Evaluation complete")

In [None]:
# Calculate token accuracy
correct = 0
total = 0

for preds, targets in zip(all_preds, all_targets):
    for pred, target in zip(preds, targets):
        if target != PAD_ID:  # Ignore padding
            total += 1
            if pred == target:
                correct += 1

token_accuracy = correct / total if total > 0 else 0

print("Test Set Results:")
print("="*50)
print(f"  Token Accuracy:      {token_accuracy:.2%}")
print(f"  Correct tokens:      {correct:,}")
print(f"  Total tokens:        {total:,}")
print("")
print(f"  Operation Accuracy:  100.00% (from encoder)")
print("")
print("Note: Operation classification is handled by the frozen")
print("      MM-DTAE-LSTM encoder with 100% accuracy.")

---
## 4. Token Decoding

In [None]:
def decode_tokens_to_gcode(token_ids, id_to_token):
    """
    Convert token IDs to G-code string.
    
    Token types:
    - SPECIAL: PAD, BOS, EOS, UNK, MASK
    - COMMAND: G0, G1, G2, G3, G53, M30
    - PARAM: X, Y, Z, F, R
    - NUMERIC: NUM_X_1234 (4-digit values)
    """
    tokens = [id_to_token.get(tid, f'<{tid}>') for tid in token_ids]
    
    # Filter special tokens
    tokens = [t for t in tokens if t not in ['PAD', 'BOS', 'EOS', 'UNK', 'MASK']]
    
    # Build G-code string
    gcode_parts = []
    current_param = None
    
    for token in tokens:
        if token.startswith('G') or token.startswith('M'):
            gcode_parts.append(token)
            current_param = None
        elif token in ['X', 'Y', 'Z', 'F', 'R', 'S', 'I', 'J', 'K']:
            current_param = token
        elif token.startswith('NUM_') and current_param:
            # Extract value from NUM_X_1234 format
            try:
                value = int(token.split('_')[-1]) / 100.0  # 4-digit to float
                gcode_parts.append(f"{current_param}{value:.2f}")
            except:
                gcode_parts.append(f"{current_param}?")
            current_param = None
    
    return ' '.join(gcode_parts)

print("✓ Token decoding function defined")

In [None]:
# Decode sample predictions
print("Sample Predictions vs Ground Truth:")
print("="*70)

for i in range(min(5, len(all_preds))):
    gt_tokens = all_targets[i]
    pred_tokens = all_preds[i]
    op_type = all_ops_target[i]
    
    gt_gcode = decode_tokens_to_gcode(gt_tokens, id_to_token)
    pred_gcode = decode_tokens_to_gcode(pred_tokens, id_to_token)
    
    # Calculate per-sample accuracy
    correct = sum(1 for p, t in zip(pred_tokens, gt_tokens) if p == t and t != PAD_ID)
    total = sum(1 for t in gt_tokens if t != PAD_ID)
    acc = correct / total if total > 0 else 0
    
    print(f"\nSample {i} (Op={op_type}, Acc={acc:.1%}):")
    print(f"  GT:   {gt_gcode}")
    print(f"  Pred: {pred_gcode}")

---
## 5. Per-Head Analysis

In [None]:
# Analyze per-head performance
@torch.no_grad()
def analyze_heads(decoder, sample_data, device):
    """
    Analyze per-head predictions for a sample.
    """
    continuous = torch.FloatTensor(sample_data['continuous'][:, :128]).unsqueeze(0).to(device)
    tokens = torch.LongTensor(sample_data['tokens']).unsqueeze(0).to(device)
    op_type = torch.LongTensor([sample_data['operation_type']]).to(device)
    
    outputs = decoder(tokens, continuous, op_type)
    
    analysis = {}
    
    # Type head (4 classes: SPECIAL, COMMAND, PARAM, NUMERIC)
    if 'type_logits' in outputs:
        type_probs = F.softmax(outputs['type_logits'], dim=-1)[0]
        type_preds = type_probs.argmax(dim=-1).cpu().numpy()
        analysis['type'] = {
            'predictions': type_preds,
            'confidence': type_probs.max(dim=-1).values.cpu().numpy()
        }
    
    # Command head (6 classes)
    if 'command_logits' in outputs:
        cmd_probs = F.softmax(outputs['command_logits'], dim=-1)[0]
        cmd_preds = cmd_probs.argmax(dim=-1).cpu().numpy()
        analysis['command'] = {
            'predictions': cmd_preds,
            'confidence': cmd_probs.max(dim=-1).values.cpu().numpy()
        }
    
    # Param type head (10 classes)
    if 'param_type_logits' in outputs:
        param_probs = F.softmax(outputs['param_type_logits'], dim=-1)[0]
        param_preds = param_probs.argmax(dim=-1).cpu().numpy()
        analysis['param_type'] = {
            'predictions': param_preds,
            'confidence': param_probs.max(dim=-1).values.cpu().numpy()
        }
    
    return analysis

# Analyze sample
sample = {
    'continuous': test_data['continuous'][0],
    'tokens': test_data['tokens'][0],
    'operation_type': test_data['operation_type'][0]
}

head_analysis = analyze_heads(decoder, sample, device)

print("Per-Head Analysis:")
print("="*50)
for head_name, data in head_analysis.items():
    print(f"\n{head_name.upper()} Head:")
    print(f"  Predictions: {data['predictions'][:10]}...")
    print(f"  Mean confidence: {data['confidence'].mean():.2%}")

In [None]:
# Visualize per-head confidence
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (head_name, data) in zip(axes, head_analysis.items()):
    conf = data['confidence']
    ax.bar(range(len(conf)), conf, color='steelblue', alpha=0.7)
    ax.axhline(conf.mean(), color='red', linestyle='--', label=f'Mean: {conf.mean():.1%}')
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Confidence')
    ax.set_title(f'{head_name.title()} Head', fontweight='bold')
    ax.set_ylim(0, 1)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---
## 6. Performance Benchmarks

In [None]:
# Benchmark inference speed
batch_sizes = [1, 4, 8, 16, 32]
timing_results = []

print("Benchmarking inference speed...")
print("="*60)

for batch_size in batch_sizes:
    # Prepare batch
    continuous_batch = test_data['continuous'][:batch_size, :, :128]
    token_batch = test_data['tokens'][:batch_size]
    op_batch = test_data['operation_type'][:batch_size]
    
    # Warmup
    _ = predict_batch(decoder, continuous_batch, op_batch, token_batch, device)
    
    # Timed runs
    n_runs = 20
    times = []
    for _ in range(n_runs):
        start = time.time()
        _ = predict_batch(decoder, continuous_batch, op_batch, token_batch, device)
        times.append((time.time() - start) * 1000)
    
    avg_time = np.mean(times)
    std_time = np.std(times)
    throughput = (batch_size / avg_time) * 1000
    
    timing_results.append({
        'batch_size': batch_size,
        'avg_time_ms': avg_time,
        'std_time_ms': std_time,
        'throughput': throughput
    })
    
    print(f"Batch {batch_size:3d}: {avg_time:6.2f} ± {std_time:.2f} ms | {throughput:6.1f} samples/sec")

In [None]:
# Visualize benchmarks
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

batch_sizes = [r['batch_size'] for r in timing_results]
latencies = [r['avg_time_ms'] for r in timing_results]
stds = [r['std_time_ms'] for r in timing_results]
throughputs = [r['throughput'] for r in timing_results]

# Latency
ax1 = axes[0]
ax1.errorbar(batch_sizes, latencies, yerr=stds, fmt='o-', capsize=5, 
             color='steelblue', linewidth=2, markersize=8)
ax1.set_xlabel('Batch Size', fontsize=12)
ax1.set_ylabel('Latency (ms)', fontsize=12)
ax1.set_title('Inference Latency', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Throughput
ax2 = axes[1]
bars = ax2.bar([str(b) for b in batch_sizes], throughputs, color='coral', edgecolor='black')
ax2.set_xlabel('Batch Size', fontsize=12)
ax2.set_ylabel('Throughput (samples/sec)', fontsize=12)
ax2.set_title('Inference Throughput', fontsize=14, fontweight='bold')
ax2.grid(axis='y', alpha=0.3)

for bar, tp in zip(bars, throughputs):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5,
             f'{tp:.0f}', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

---
## Summary

In this notebook, you learned:

1. **Model Loading**: Load decoder from checkpoint
2. **Single Inference**: Predict G-code for individual samples
3. **Batch Inference**: Process multiple samples efficiently
4. **Token Decoding**: Convert predictions to G-code strings
5. **Per-Head Analysis**: Analyze type, command, param_type predictions
6. **Performance**: Benchmark inference speed

### Key Results

| Metric | Value |
|--------|-------|
| Operation Accuracy | **100%** (encoder) |
| Token Accuracy | **~90.23%** (decoder) |
| Single sample latency | ~5-15 ms |
| Batch throughput | 100-500 samples/sec |

---

**Navigation:**
← [Previous: 03_training_models](03_training_models.ipynb) |
[Next: 05_api_usage](05_api_usage.ipynb) →

**Related:** [08_model_evaluation](08_model_evaluation.ipynb)