# Train Custom LLM from Scratch

This notebook trains a language model from scratch with configurable:
- **Model size** (125M to 7B parameters)
- **Model type** (Reasoning Agent, Code Assistant, General Purpose, etc.)
- **Dataset selection** (automatically matched to model type)

## Requirements
- **GPU**: A100 40GB+ (80GB recommended for 3B+)
- **Storage**: Google Drive for checkpoints
- **Time**: Varies by model size and type

---
## Step 0: Environment Setup

In [None]:
#@title ### 0.1 Mount Google Drive & Clone Repository
import os
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted")
else:
    print("Running locally")

In [None]:
#@title ### 0.2 Clone Repository & Install Dependencies
REPO_URL = "https://github.com/rmarnold/llm-training-pipeline.git"  #@param {type:"string"}
BRANCH = "main"  #@param {type:"string"}

REPO_DIR = "/content/llm-training-pipeline"

if IN_COLAB:
    if os.path.exists(REPO_DIR):
        %cd {REPO_DIR}
        !git pull origin {BRANCH}
    else:
        !git clone -b {BRANCH} {REPO_URL} {REPO_DIR}
        %cd {REPO_DIR}
    
    print("\nInstalling dependencies...")
    !pip install -q -e ".[colab]"
    !pip install -q flash-attn --no-build-isolation 2>/dev/null || true
    !pip install -q liger-kernel bitsandbytes
    
    # Install GPU acceleration libraries (RAPIDS + NeMo Curator)
    print("\nInstalling GPU acceleration libraries...")
    
    # RAPIDS cuDF for GPU text cleaning (100-150x faster)
    !pip install -q --extra-index-url=https://pypi.nvidia.com cudf-cu12 2>/dev/null || echo "RAPIDS cuDF not available"
    
    # NeMo Curator for GPU deduplication (16-107x faster)
    # Requires dask-cuda for GPU cluster management
    !pip install -q dask-cuda 2>/dev/null || echo "dask-cuda not available"
    !pip install -q nemo-curator 2>/dev/null || echo "NeMo Curator not available"
    
    # Verify installations
    print("\n" + "=" * 50)
    print("GPU Library Verification:")
    print("=" * 50)
    
    # Check RAPIDS cuDF
    try:
        import cudf
        print(f"✓ RAPIDS cuDF: OK (version {cudf.__version__})")
    except ImportError as e:
        print(f"✗ RAPIDS cuDF: NOT AVAILABLE ({e})")
    
    # Check NeMo Curator - discover and print actual API structure
    print("\nNeMo Curator API Discovery:")
    nemo_ok = False
    nemo_api = None
    
    try:
        import nemo_curator
        nemo_version = getattr(nemo_curator, '__version__', 'unknown')
        print(f"  Package version: {nemo_version}")
        
        # List top-level modules
        top_level = [x for x in dir(nemo_curator) if not x.startswith('_')]
        print(f"  Top-level: {top_level}")
        
        # Check stages module structure
        if hasattr(nemo_curator, 'stages'):
            import nemo_curator.stages as stages
            stages_contents = [x for x in dir(stages) if not x.startswith('_')]
            print(f"  stages: {stages_contents}")
            
            # Check for deduplication in stages
            if hasattr(stages, 'deduplication'):
                import nemo_curator.stages.deduplication as dedup
                dedup_contents = [x for x in dir(dedup) if not x.startswith('_')]
                print(f"  stages.deduplication: {dedup_contents}")
                
                # Check for fuzzy module
                if hasattr(dedup, 'fuzzy'):
                    import nemo_curator.stages.deduplication.fuzzy as fuzzy
                    fuzzy_contents = [x for x in dir(fuzzy) if not x.startswith('_')]
                    print(f"  stages.deduplication.fuzzy: {fuzzy_contents}")
                    
                    # Check for workflow
                    if hasattr(fuzzy, 'workflow'):
                        import nemo_curator.stages.deduplication.fuzzy.workflow as workflow
                        workflow_contents = [x for x in dir(workflow) if not x.startswith('_')]
                        print(f"  stages.deduplication.fuzzy.workflow: {workflow_contents}")
                        
                        if 'FuzzyDeduplicationWorkflow' in workflow_contents:
                            nemo_ok = True
                            nemo_api = "stages.deduplication.fuzzy.workflow.FuzzyDeduplicationWorkflow"
                            print(f"  ✓ Found: {nemo_api}")
            
            # Also check for text module (alternative path)
            if hasattr(stages, 'text') and not nemo_ok:
                import nemo_curator.stages.text as text
                text_contents = [x for x in dir(text) if not x.startswith('_')]
                print(f"  stages.text: {text_contents}")
        
        # Check tasks module
        if hasattr(nemo_curator, 'tasks') and not nemo_ok:
            import nemo_curator.tasks as tasks
            tasks_contents = [x for x in dir(tasks) if not x.startswith('_')]
            print(f"  tasks: {tasks_contents}")
        
        # Try direct import as fallback
        if not nemo_ok:
            try:
                from nemo_curator.stages.deduplication.fuzzy.workflow import FuzzyDeduplicationWorkflow
                nemo_ok = True
                nemo_api = "workflow (direct import)"
                print(f"  ✓ Direct import succeeded: FuzzyDeduplicationWorkflow")
            except ImportError as e:
                print(f"  ✗ Direct import failed: {e}")
                
    except ImportError as e:
        print(f"  ✗ Cannot import nemo_curator: {e}")
    except Exception as e:
        print(f"  ✗ Error during discovery: {e}")
    
    # Check dask-cuda
    try:
        from dask_cuda import LocalCUDACluster
        dask_cuda_ok = True
        print(f"\n✓ dask-cuda: OK")
    except ImportError:
        dask_cuda_ok = False
        print(f"\n✗ dask-cuda: NOT AVAILABLE")
    
    # Summary
    print("\n" + "=" * 50)
    if nemo_ok:
        print(f"✓ NeMo Curator: READY (API: {nemo_api})")
    else:
        print("✗ NeMo Curator: Deduplication API not found")
        print("  Will use CPU datasketch fallback for deduplication")
    print("=" * 50)
    
    PROJECT_ROOT = REPO_DIR
else:
    PROJECT_ROOT = os.getcwd()

os.chdir(PROJECT_ROOT)
print(f"\nProject root: {PROJECT_ROOT}")

In [None]:
#@title ### 0.3 Choose Model Type & Size { run: "auto" }
#@markdown ---
#@markdown ### Model Configuration

model_type = "reasoning_agent"  #@param ["reasoning_agent", "code_assistant", "general_assistant", "chat_model"]
#@markdown **Model Types:**
#@markdown - `reasoning_agent`: Math, logic, function calling, tool use
#@markdown - `code_assistant`: Code generation, debugging, explanation
#@markdown - `general_assistant`: Balanced instruction following
#@markdown - `chat_model`: Conversational, helpful responses

model_size = "1b"  #@param ["125m", "350m", "1b", "3b", "7b"]
#@markdown **Model Sizes:**
#@markdown - `125m`: Fast training, testing (~2-4 hours pretrain)
#@markdown - `350m`: Small but capable (~4-8 hours)
#@markdown - `1b`: Good balance (~12-16 hours)
#@markdown - `3b`: Strong performance (~30-40 hours)
#@markdown - `7b`: Full capability (~60-80 hours)

#@markdown ---
#@markdown ### Training Parameters
context_length = 2048  #@param {type:"integer"}
pretrain_tokens_billions = 50  #@param {type:"number"}
#@markdown *Recommended: 20x model params (1B model = 20B tokens minimum)*

#@markdown ---
#@markdown ### Data Preparation Settings
data_prep_speed = "fast"  #@param ["fast", "balanced", "thorough"]
#@markdown **Data Prep Quality Modes:**
#@markdown - `fast`: Skip quality filters, just clean + dedup (~5 min for 12M docs)
#@markdown - `balanced`: Basic quality filter, no toxicity (~30 min)
#@markdown - `thorough`: All quality filters + toxicity (~2-3 hours)

use_gpu_data_prep = True  #@param {type:"boolean"}
#@markdown **GPU Data Prep (Recommended for A100/H100):**
#@markdown - `True`: Use RAPIDS cuDF + NeMo Curator (10-30x faster)
#@markdown - `False`: Use CPU-only pipeline (slower but more compatible)

# ============================================================
# MODEL TYPE CONFIGURATIONS
# ============================================================

MODEL_TYPE_CONFIGS = {
    "reasoning_agent": {
        "description": "Optimized for math, logic, and tool use",
        "pretraining_datasets": [
            "slimpajama", "wikipedia", "the-stack-python", 
            "openwebtext", "arxiv", "stackexchange"
        ],
        "sft_datasets": [
            "gsm8k", "orca-math", "openorca", "cot-collection", "metamath",
            "glaive-function-calling", "hermes-function-calling", "toolbench",
            "logiqa"
        ],
        "dpo_datasets": ["hh-rlhf", "ultrafeedback"],
        "data_prep_script": "06_prepare_reasoning_data.py",
        "eval_benchmarks": ["gsm8k", "arc", "function_calling", "safety"],
        "sft_focus": "reasoning",
    },
    "code_assistant": {
        "description": "Optimized for code generation and understanding",
        "pretraining_datasets": [
            "the-stack", "slimpajama", "wikipedia", "arxiv"
        ],
        "sft_datasets": [
            "code-alpaca", "python-code-instructions", "evol-instruct-code",
            "glaive-function-calling", "oasst1"
        ],
        "dpo_datasets": ["hh-rlhf"],
        "data_prep_script": "prepare_lora_data.py",
        "eval_benchmarks": ["humaneval", "mbpp", "safety"],
        "sft_focus": "code",
    },
    "general_assistant": {
        "description": "Balanced instruction following",
        "pretraining_datasets": [
            "slimpajama", "wikipedia", "openwebtext", "pg19"
        ],
        "sft_datasets": [
            "oasst1", "dolly-15k", "alpaca-cleaned", "openorca"
        ],
        "dpo_datasets": ["hh-rlhf", "ultrafeedback"],
        "data_prep_script": "06_prepare_sft_data.py",
        "eval_benchmarks": ["mmlu", "hellaswag", "safety"],
        "sft_focus": "instruction",
    },
    "chat_model": {
        "description": "Conversational and helpful",
        "pretraining_datasets": [
            "slimpajama", "wikipedia", "openwebtext"
        ],
        "sft_datasets": [
            "oasst1", "dolly-15k", "sharegpt"
        ],
        "dpo_datasets": ["hh-rlhf", "ultrafeedback"],
        "data_prep_script": "06_prepare_sft_data.py",
        "eval_benchmarks": ["mt-bench", "safety"],
        "sft_focus": "chat",
    },
}

# Data prep speed configurations
DATA_PREP_CONFIGS = {
    "fast": {
        "flags": "--native-pipeline --fast-quality --no-toxicity --fresh",
        "gpu_flags": "--skip-quality --no-toxicity",
        "description": "Fastest: Skip quality filters, just clean + dedup (~5 min)",
    },
    "balanced": {
        "flags": "--native-pipeline --fast-quality --fresh",
        "gpu_flags": "--fast-quality --no-toxicity",
        "description": "Balanced: Basic quality filter, no toxicity (~30 min)",
    },
    "thorough": {
        "flags": "--native-pipeline --fresh",
        "gpu_flags": "",
        "description": "Thorough: All quality filters + toxicity (~2-3 hours)",
    },
}

# ============================================================
# OPTIMIZED SIZE CONFIGURATIONS (A100 40GB/80GB)
# ============================================================
SIZE_CONFIGS = {
    "125m": {
        "batch_size": 64,
        "grad_accum": 1,
        "learning_rate": 1e-4,
        "warmup_ratio": 0.05,
        "use_torch_compile": False,
        "use_8bit_optim": False,
        "dataloader_workers": 16,
        "pretrain_hours": 3,
        "sft_hours": 0.5,
        "dpo_hours": 0.25,
        "expected_throughput": "8-12 it/s",
    },
    "350m": {
        "batch_size": 32,
        "grad_accum": 2,
        "learning_rate": 1e-4,
        "warmup_ratio": 0.05,
        "use_torch_compile": False,
        "use_8bit_optim": False,
        "dataloader_workers": 16,
        "pretrain_hours": 6,
        "sft_hours": 1,
        "dpo_hours": 0.5,
        "expected_throughput": "5-8 it/s",
    },
    "1b": {
        "batch_size": 16,
        "grad_accum": 4,
        "learning_rate": 3e-4,
        "warmup_ratio": 0.03,
        "use_torch_compile": True,
        "use_8bit_optim": True,
        "dataloader_workers": 12,
        "pretrain_hours": 15,
        "sft_hours": 4,
        "dpo_hours": 1.5,
        "expected_throughput": "3-5 it/s",
    },
    "3b": {
        "batch_size": 8,
        "grad_accum": 8,
        "learning_rate": 3e-4,
        "warmup_ratio": 0.03,
        "use_torch_compile": True,
        "use_8bit_optim": True,
        "dataloader_workers": 8,
        "pretrain_hours": 35,
        "sft_hours": 10,
        "dpo_hours": 4,
        "expected_throughput": "1.5-2.5 it/s",
    },
    "7b": {
        "batch_size": 4,
        "grad_accum": 16,
        "learning_rate": 3e-4,
        "warmup_ratio": 0.03,
        "use_torch_compile": True,
        "use_8bit_optim": True,
        "dataloader_workers": 8,
        "pretrain_hours": 70,
        "sft_hours": 20,
        "dpo_hours": 8,
        "expected_throughput": "0.8-1.2 it/s",
    },
}

# Build configuration
type_config = MODEL_TYPE_CONFIGS[model_type]
size_config = SIZE_CONFIGS[model_size]
data_prep_config = DATA_PREP_CONFIGS[data_prep_speed]

CONFIG = {
    'model_type': model_type,
    'model_size': model_size,
    'context_length': context_length,
    'pretrain_tokens_b': pretrain_tokens_billions,
    'batch_size': size_config['batch_size'],
    'grad_accum': size_config['grad_accum'],
    'learning_rate': size_config['learning_rate'],
    'warmup_ratio': size_config['warmup_ratio'],
    'use_torch_compile': size_config['use_torch_compile'],
    'use_8bit_optim': size_config['use_8bit_optim'],
    'dataloader_workers': size_config['dataloader_workers'],
    'pretraining_datasets': type_config['pretraining_datasets'],
    'sft_datasets': type_config['sft_datasets'],
    'dpo_datasets': type_config['dpo_datasets'],
    'data_prep_script': type_config['data_prep_script'],
    'eval_benchmarks': type_config['eval_benchmarks'],
    'sft_focus': type_config['sft_focus'],
    'data_prep_flags': data_prep_config['flags'],
    'gpu_data_prep_flags': data_prep_config['gpu_flags'],
    'data_prep_speed': data_prep_speed,
    'use_gpu_data_prep': use_gpu_data_prep,
}

# Calculate training steps
effective_batch = CONFIG['batch_size'] * CONFIG['grad_accum']
tokens_per_step = effective_batch * context_length
total_tokens = int(pretrain_tokens_billions * 1e9)
pretrain_steps = total_tokens // tokens_per_step
CONFIG['pretrain_steps'] = pretrain_steps
CONFIG['effective_batch_size'] = effective_batch

# Estimate time
time_scale = pretrain_tokens_billions / 50
pretrain_hours = size_config['pretrain_hours'] * time_scale
total_hours = pretrain_hours + size_config['sft_hours'] + size_config['dpo_hours']

DRIVE_BASE = f"/content/drive/MyDrive/llm-{model_size}-{model_type.replace('_', '-')}"

print("=" * 60)
print("MODEL CONFIGURATION")
print("=" * 60)
print(f"\nType: {model_type.upper().replace('_', ' ')}")
print(f"Size: {model_size.upper()}, Context: {context_length} tokens")
print(f"\nData Prep: {data_prep_speed.upper()} - {data_prep_config['description']}")
print(f"GPU Acceleration: {'ENABLED' if use_gpu_data_prep else 'DISABLED'}")
print(f"\nPretraining: {pretrain_tokens_billions}B tokens, {pretrain_steps:,} steps")
print(f"Expected time: ~{total_hours:.0f} hours ({total_hours/24:.1f} days)")
print("=" * 60)

In [None]:
#@title ### 0.4 Set Up Persistent Storage

if IN_COLAB:
    print(f"Setting up storage at: {DRIVE_BASE}")
    
    # Create Drive directories
    for subdir in ['checkpoints', 'data', 'data/raw', 'data/packed', 'data/sft', 'data/dpo', 'logs', 'evals']:
        os.makedirs(os.path.join(DRIVE_BASE, subdir), exist_ok=True)
    
    # Create symlinks
    for dir_name in ['checkpoints', 'data', 'logs', 'evals']:
        local_path = os.path.join(PROJECT_ROOT, dir_name)
        drive_path = os.path.join(DRIVE_BASE, dir_name)
        
        if os.path.exists(local_path) and not os.path.islink(local_path):
            !cp -r {local_path}/* {drive_path}/ 2>/dev/null || true
            !rm -rf {local_path}
        elif os.path.islink(local_path):
            os.unlink(local_path)
        
        os.symlink(drive_path, local_path)
        print(f"  {dir_name} -> Drive")
    
    print("\nStorage ready!")
else:
    for d in ['checkpoints', 'data', 'logs', 'evals']:
        os.makedirs(d, exist_ok=True)

In [None]:
#@title ### 0.5 Copy Data to Local SSD (Faster I/O)
#@markdown Copies data from Google Drive to local NVMe SSD for 5-10x faster I/O during training.
#@markdown This significantly speeds up data prep and training.

import shutil

# Local SSD paths (much faster than Drive)
LOCAL_DATA = "/content/local_data"
LOCAL_RAW = f"{LOCAL_DATA}/raw"
LOCAL_PROCESSED = f"{LOCAL_DATA}/processed"
LOCAL_PACKED = f"{LOCAL_DATA}/packed"
LOCAL_CACHE = "/content/.gpu_cache"
LOCAL_CHECKPOINTS = "/content/local_checkpoints"

# Drive paths (persistent)
DRIVE_RAW = f"{DRIVE_BASE}/data/raw"
DRIVE_PACKED = f"{DRIVE_BASE}/data/packed"
DRIVE_CHECKPOINTS = f"{DRIVE_BASE}/checkpoints"

if IN_COLAB:
    print("Setting up local SSD storage for faster I/O...")
    print("=" * 50)
    
    # Create local directories
    for d in [LOCAL_DATA, LOCAL_RAW, LOCAL_PROCESSED, LOCAL_PACKED, LOCAL_CACHE, LOCAL_CHECKPOINTS]:
        os.makedirs(d, exist_ok=True)
    
    # Check if we have existing data on Drive to copy
    drive_raw_files = []
    if os.path.exists(DRIVE_RAW):
        drive_raw_files = [f for f in os.listdir(DRIVE_RAW) if f.endswith('.parquet')]
    
    if drive_raw_files:
        print(f"\nFound {len(drive_raw_files)} raw data files on Drive")
        print("Copying to local SSD for faster processing...")
        for f in drive_raw_files:
            src = os.path.join(DRIVE_RAW, f)
            dst = os.path.join(LOCAL_RAW, f)
            if not os.path.exists(dst):
                size_mb = os.path.getsize(src) / (1024*1024)
                print(f"  Copying {f} ({size_mb:.0f} MB)...")
                shutil.copy2(src, dst)
            else:
                print(f"  {f} already on local SSD")
        print("Raw data copied!")
    else:
        print("\nNo existing raw data on Drive (will download fresh)")
    
    # Check for existing packed data
    drive_packed_files = []
    if os.path.exists(DRIVE_PACKED):
        drive_packed_files = list(os.listdir(DRIVE_PACKED))
    
    if drive_packed_files:
        print(f"\nFound packed data on Drive - copying to local SSD...")
        !cp -r {DRIVE_PACKED}/* {LOCAL_PACKED}/ 2>/dev/null || true
        print("Packed data copied!")
    
    # Store paths in CONFIG for later use
    CONFIG['local_raw'] = LOCAL_RAW
    CONFIG['local_processed'] = LOCAL_PROCESSED
    CONFIG['local_packed'] = LOCAL_PACKED
    CONFIG['local_cache'] = LOCAL_CACHE
    CONFIG['local_checkpoints'] = LOCAL_CHECKPOINTS
    CONFIG['drive_raw'] = DRIVE_RAW
    CONFIG['drive_packed'] = DRIVE_PACKED
    CONFIG['drive_checkpoints'] = DRIVE_CHECKPOINTS
    CONFIG['use_local_ssd'] = True
    
    # Check local disk space
    import subprocess
    result = subprocess.run(['df', '-h', '/content'], capture_output=True, text=True)
    print(f"\nLocal SSD status:")
    print(result.stdout.split('\n')[1])
    
    print("\nLocal SSD paths configured:")
    print(f"  Raw data: {LOCAL_RAW}")
    print(f"  Processed: {LOCAL_PROCESSED}")
    print(f"  Packed: {LOCAL_PACKED}")
    print(f"  Cache: {LOCAL_CACHE}")
    print("=" * 50)
else:
    CONFIG['use_local_ssd'] = False
    print("Running locally - no SSD optimization needed")

In [None]:
#@title ### 0.5 Check GPU
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    capability = torch.cuda.get_device_capability()
    
    CONFIG['use_fp8'] = capability[0] >= 9
    
    print(f"GPU: {gpu_name} ({gpu_memory:.0f} GB)")
    print(f"FP8: {'Available' if CONFIG['use_fp8'] else 'Not available'}")
    
    # Memory check for model size
    min_memory = {'125m': 8, '350m': 16, '1b': 24, '3b': 40, '7b': 70}
    if gpu_memory < min_memory[CONFIG['model_size']]:
        print(f"\nWARNING: {CONFIG['model_size']} may need {min_memory[CONFIG['model_size']]}+ GB")
else:
    print("No GPU detected!")
    CONFIG['use_fp8'] = False

---
## Step 1: Download Data

Downloads datasets matched to your model type.

In [None]:
#@title ### 1.1 Download Pretraining Data
#@markdown Downloads pretraining corpora for your model type.

print(f"Downloading pretraining data for: {CONFIG['model_type']}")
print(f"Datasets: {CONFIG['pretraining_datasets']}")
print("="*50)

!python scripts/01_download_data.py --phases pretraining

---
## Step 2: Prepare Data

In [None]:
#@title ### 2.1 Clean & Tokenize Pretraining Data
#@markdown Uses the data prep speed setting from configuration above.
#@markdown GPU mode uses RAPIDS cuDF (150x faster text cleaning) + NeMo Curator (16x faster dedup).
#@markdown All processing happens on local SSD with incremental backup to Drive.

use_gpu = CONFIG.get('use_gpu_data_prep', False)
data_prep_speed = CONFIG['data_prep_speed']
use_local_ssd = CONFIG.get('use_local_ssd', False)

print(f"Data prep mode: {data_prep_speed.upper()}")
print(f"GPU Acceleration: {'ENABLED' if use_gpu else 'DISABLED'}")
print(f"Local SSD: {'ENABLED (5-10x faster I/O)' if use_local_ssd else 'DISABLED'}")
print("=" * 50)

if use_local_ssd:
    local_raw = CONFIG['local_raw']
    local_processed = CONFIG['local_processed']
    local_packed = CONFIG['local_packed']
    local_cache = CONFIG['local_cache']
    drive_packed = CONFIG['drive_packed']
    
    # Create backup dir on Drive for incremental checkpoint
    drive_cache = f"{DRIVE_BASE}/data/.gpu_cache"
    os.makedirs(drive_cache, exist_ok=True)
    CONFIG['drive_cache'] = drive_cache
    
    if use_gpu:
        # GPU-accelerated pipeline on local SSD with Drive backup
        gpu_flags = CONFIG['gpu_data_prep_flags']
        print(f"Using GPU pipeline on local SSD...")
        print(f"  Input: {local_raw}")
        print(f"  Output: {local_processed}")
        print(f"  Cache: {local_cache}")
        print(f"  Backup: {drive_cache} (incremental sync)")
        !python scripts/02_gpu_clean_deduplicate.py \
            --input {local_raw} \
            --output {local_processed} \
            --cache {local_cache} \
            --backup-dir {drive_cache} \
            {gpu_flags}
    else:
        # CPU pipeline on local SSD
        data_prep_flags = CONFIG['data_prep_flags']
        print(f"Using CPU pipeline on local SSD...")
        !python scripts/02_clean_deduplicate_optimized.py \
            --input {local_raw} \
            --output {local_processed} \
            {data_prep_flags}
    
    # Tokenize and pack on local SSD
    print(f"\nTokenizing and packing...")
    !python scripts/03_tokenize_and_pack.py \
        --input-dir {local_processed} \
        --output-dir {local_packed}
    
    # Backup packed data to Drive for persistence
    print(f"\nBacking up packed data to Drive...")
    !cp -r {local_packed}/* {drive_packed}/ 2>/dev/null || true
    print("Backup complete!")
    
else:
    # Original paths (Drive-linked)
    if use_gpu:
        gpu_flags = CONFIG['gpu_data_prep_flags']
        print("Using GPU pipeline...")
        !python scripts/02_gpu_clean_deduplicate.py --input data/raw --output data/processed {gpu_flags}
    else:
        data_prep_flags = CONFIG['data_prep_flags']
        print(f"Using CPU pipeline with flags: {data_prep_flags}")
        !python scripts/02_clean_deduplicate_optimized.py {data_prep_flags}
    
    !python scripts/03_tokenize_and_pack.py

In [None]:
#@title ### 1.2 Download SFT & DPO Data
#@markdown Downloads fine-tuning data matched to your model type.

print(f"Downloading SFT data for: {CONFIG['model_type']}")
print(f"SFT datasets: {CONFIG['sft_datasets']}")
print(f"DPO datasets: {CONFIG['dpo_datasets']}")
print("="*50)

# Download based on model type
if CONFIG['sft_focus'] == 'reasoning':
    !python scripts/01_download_data.py --phases reasoning function_calling logic instruction_tuning preference_data
elif CONFIG['sft_focus'] == 'code':
    !python scripts/01_download_data.py --phases instruction_tuning preference_data
    # Code data handled by prepare_lora_data.py
else:
    !python scripts/01_download_data.py --phases instruction_tuning preference_data

In [None]:
#@title ### 2.2 Prepare SFT Data
#@markdown Uses the appropriate script for your model type.

prep_script = CONFIG['data_prep_script']
print(f"Using: {prep_script} (optimized for {CONFIG['model_type']})")
print("="*50)

!python scripts/{prep_script}

In [None]:
#@title ### 2.3 Prepare DPO Data

!python scripts/08_prepare_dpo_data.py

---
## Step 3: Initialize Model

In [None]:
#@title ### 3.1 Initialize Model

model_size = CONFIG['model_size']
context = CONFIG['context_length']

print(f"Initializing {model_size.upper()} model with {context} context...")

!python scripts/04_init_model.py --size {model_size} --context-length {context}

In [None]:
#@title ### 3.2 Verify Setup
import os

checks = [
    ('Model', 'checkpoints/init'),
    ('Tokenizer', 'configs/tokenizer'),
    ('Pretrain data', 'data/packed'),
    ('SFT data', 'data/sft/train'),
    ('DPO data', 'data/dpo/train'),
]

print("Setup verification:")
all_ok = True
for name, path in checks:
    ok = os.path.exists(path)
    print(f"  {name}: {'OK' if ok else 'MISSING'}")
    all_ok = all_ok and ok

if all_ok:
    print("\nReady for training!")
else:
    print("\nFix missing items before proceeding.")

---
## Step 4: Pretraining

In [None]:
#@title ### 4.1 Start Pretraining
#@markdown Uses auto-optimized parameters based on model size.
#@markdown Training uses local SSD for data loading (5-10x faster I/O).
#@markdown Checkpoints are saved locally and backed up to Drive periodically.

steps = CONFIG['pretrain_steps']
use_fp8 = CONFIG.get('use_fp8', False)
use_local_ssd = CONFIG.get('use_local_ssd', False)

# Build optimized command with size-specific parameters
cmd = f"python scripts/05_pretrain.py --max_steps {steps}"

# Local SSD paths for faster I/O
if use_local_ssd:
    local_packed = CONFIG['local_packed']
    local_checkpoints = CONFIG['local_checkpoints']
    cmd += f" --train_data_path {local_packed}"
    cmd += f" --output_dir {local_checkpoints}/pretrain"

# Always use Liger Kernel (compatible with all sizes)
cmd += " --use-liger-kernel"

# Size-specific optimizations
cmd += f" --per_device_train_batch_size {CONFIG['batch_size']}"
cmd += f" --gradient_accumulation_steps {CONFIG['grad_accum']}"
cmd += f" --learning_rate {CONFIG['learning_rate']}"
cmd += f" --warmup_ratio {CONFIG['warmup_ratio']}"
cmd += f" --dataloader_num_workers {CONFIG['dataloader_workers']}"

# torch.compile: Enable only for larger models where benefit > overhead
if not CONFIG['use_torch_compile']:
    cmd += " --no-torch-compile"

# 8-bit optimizer: Enable for larger models to save memory
if CONFIG['use_8bit_optim']:
    cmd += " --optim adamw_bnb_8bit"
else:
    cmd += " --optim adamw_torch_fused"

# FP8 for H100
if use_fp8:
    cmd += " --fp8"

# OOM recovery (always useful)
cmd += " --enable-oom-recovery"

print(f"Model: {CONFIG['model_size'].upper()} {CONFIG['model_type']}")
print(f"Steps: {steps:,}")
print(f"\nOptimized Parameters:")
print(f"  Batch: {CONFIG['batch_size']} x {CONFIG['grad_accum']} = {CONFIG['effective_batch_size']}")
print(f"  Learning rate: {CONFIG['learning_rate']:.0e}")
print(f"  torch.compile: {'ON' if CONFIG['use_torch_compile'] else 'OFF'}")
print(f"  8-bit optimizer: {'ON' if CONFIG['use_8bit_optim'] else 'OFF'}")
if use_local_ssd:
    print(f"\nLocal SSD I/O:")
    print(f"  Data: {CONFIG['local_packed']}")
    print(f"  Checkpoints: {CONFIG['local_checkpoints']}/pretrain")
print(f"\nCommand: {cmd}")
print("=" * 50)

!{cmd}

# Backup checkpoints to Drive after training
if use_local_ssd:
    drive_checkpoints = CONFIG['drive_checkpoints']
    local_checkpoints = CONFIG['local_checkpoints']
    print(f"\nBacking up checkpoints to Drive...")
    !cp -r {local_checkpoints}/pretrain/* {drive_checkpoints}/pretrain/ 2>/dev/null || true
    print("Checkpoint backup complete!")

---
## Step 5: SFT Training

In [None]:
#@title ### 5.1 Start SFT

cmd = "python scripts/07_sft.py --use-liger-kernel --enable-oom-recovery"
if CONFIG.get('use_fp8', False):
    cmd += " --fp8"

print(f"SFT focus: {CONFIG['sft_focus']}")
print(f"Datasets: {CONFIG['sft_datasets']}")
print("="*50)

!{cmd}

---
## Step 6: DPO Alignment

In [None]:
#@title ### 6.1 Start DPO

cmd = "python scripts/09_dpo.py --enable-oom-recovery"
if CONFIG.get('use_fp8', False):
    cmd += " --fp8"

!{cmd}

---
## Step 7: Evaluation

In [None]:
#@title ### 7.1 Run Evaluation

print(f"Benchmarks for {CONFIG['model_type']}: {CONFIG['eval_benchmarks']}")
print("="*50)

!python scripts/11_evaluate.py checkpoints/dpo_final

In [None]:
#@title ### 7.2 Check Gates

!python scripts/12_check_gates.py dpo

---
## Step 8: Test Model

In [None]:
#@title ### 8.1 Load Model
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

MODEL_PATH = "checkpoints/dpo_final"

tokenizer = AutoTokenizer.from_pretrained("configs/tokenizer")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
print("Model loaded!")

In [None]:
#@title ### 8.2 Generate Response

def generate(prompt, max_tokens=512):
    formatted = f"<|user|>\n{prompt}\n<|assistant|>\n"
    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_tokens, 
                                 temperature=0.7, do_sample=True, top_p=0.9)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test prompts based on model type
TEST_PROMPTS = {
    'reasoning_agent': "Solve step by step: If a train travels 60 mph for 2.5 hours, how far does it go?",
    'code_assistant': "Write a Python function to find the nth Fibonacci number.",
    'general_assistant': "Explain the difference between machine learning and deep learning.",
    'chat_model': "Hello! How are you today? What can you help me with?",
}

prompt = TEST_PROMPTS.get(CONFIG['model_type'], "Hello, how are you?")
print(f"Test for {CONFIG['model_type']}:")
print(f"Prompt: {prompt}")
print("\nResponse:")
print(generate(prompt))

In [None]:
#@title ### 8.3 Custom Prompt

CUSTOM_PROMPT = """Your prompt here"""  #@param {type:"string"}

print(generate(CUSTOM_PROMPT))

---
## Training Complete!

Your model is trained and ready to use. Check the evaluation results above to see how it performs on relevant benchmarks.