# LAB Multi-Phase Training with Training Hub

This notebook demonstrates how to perform LAB (Large-scale Alignment for chatBots) multi-phase training using the training_hub library. We'll walk through the two-phase LAB training process:

1. **Phase 1 - Knowledge Tuning (Phase07)**: Training on knowledge-heavy data to build foundational understanding
2. **Phase 2 - Skills + Replay Training (Phase10)**: Training on skills data with replay of both Phase07 knowledge data AND the base model's original instruction tuning data to maintain all capabilities

This LAB multi-phase approach is specifically designed for instruction tuning where you first establish additional knowledge foundations, then add task-specific skills while preventing knowledge forgetting and preserving the base model's original instruction-following capabilities through comprehensive replay mechanisms.

## Setup and Imports

First, let's import the necessary libraries and set up our training environment.

In [None]:
# Import training_hub for SFT training
from training_hub import sft

# Standard library imports
import os
import time
import logging
import sys
from datetime import datetime
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO

## Logging Configuration

Set up logging to prevent notebook crashes from excessive output while still showing essential progress and error information.

**Note:** While this notebook will walk you through a breakdown of all the steps and contains the end-to-end pipeline, we also provide an example script for any significantly long-running jobs for reproducibility, flexibility, and logging consistency in case of notebook disconnects. You can find the script at `scripts/lab_multiphase_training.py`.

**Quick script usage:**
```bash
python scripts/lab_multiphase_training.py \
  --base-model-path /path/to/model \
  --phase07-data-path /path/to/knowledge.jsonl \
  --phase10-data-path /path/to/skills_replay.jsonl \
  --ckpt-output-base-dir /path/to/checkpoints
```

In [None]:
# Configure logging to prevent notebook crashes from excessive output
# while still showing essential progress and error information

def setup_training_logging():
    """Set up logging configuration optimized for notebook environments."""
    # Reduce logging level for common noisy loggers
    logging.getLogger("transformers").setLevel(logging.WARNING)
    logging.getLogger("torch").setLevel(logging.WARNING)
    logging.getLogger("accelerate").setLevel(logging.WARNING)
    
    # Set up a custom logger that shows progress without overwhelming the notebook
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    
    print("✅ Logging configured for notebook environment")

def run_training_with_managed_output(training_func, description="Training"):
    """
    Run training with balanced output showing progress without overwhelming the notebook.
    Shows essential progress, errors, and key milestones while filtering excessive logs.
    """
    print(f"🚀 Starting {description}...")
    print("📝 Showing essential progress and key training milestones")
    print("⏳ This may take a while. Training progress will appear below:")
    print("-" * 60)
    
    start_time = time.time()
    
    try:
        # Run training with minimal output redirection to allow subprocess logs
        # but reduce verbosity of the most chatty components
        import os
        
        # Set environment variables to reduce some verbose output
        old_env = {}
        env_settings = {
            'TRANSFORMERS_VERBOSITY': 'warning',
            'TOKENIZERS_PARALLELISM': 'false',  # Reduces tokenizer warnings
        }
        
        for key, value in env_settings.items():
            old_env[key] = os.environ.get(key)
            os.environ[key] = value
        
        try:
            result = training_func()
        finally:
            # Restore environment
            for key, old_value in old_env.items():
                if old_value is None:
                    os.environ.pop(key, None)
                else:
                    os.environ[key] = old_value
        
        end_time = time.time()
        duration = end_time - start_time
        
        print("-" * 60)
        print(f"✅ {description} completed successfully!")
        print(f"⏱️  Duration: {duration/3600:.2f} hours")
        
        return result
        
    except Exception as e:
        end_time = time.time()
        duration = end_time - start_time
        
        print("-" * 60)
        print(f"❌ {description} failed after {duration/60:.1f} minutes")
        print(f"Error: {e}")
        print("\n💡 The error occurred in the distributed training subprocess.")
        print("   Check the training logs above for more context about the failure.")
        print("   Common issues include: data path problems, memory issues, or model loading errors.")
        
        raise

# Set up logging
setup_training_logging()

## Configuration

Let's define our training configuration. You'll need to adjust these paths to match your environment.

In [None]:
# LAB Multi-Phase Training Configuration
experiment_prefix = "lab_multiphase_training_demo"
ckpt_output_base_dir = "/path/to/your/checkpoints"  # Update this path

# Model and data paths - Update these to your actual paths
base_model_path = "/path/to/your/base/model"  # e.g., granite-3.1-8b-starter-v2.1
phase07_data_path = "/path/to/knowledge_data.jsonl"  # Knowledge/facts data for Phase07
phase10_data_path = "/path/to/skills_plus_replay_data.jsonl"  # Skills + replay data for Phase10
# Note: Phase10 data should include:
# - New skills/task data
# - Replay of Phase07 knowledge data  
# - Replay of base model's original instruction tuning data

# Training hyperparameters
max_tokens_per_gpu = 25_000  # Memory limit per GPU (reduce if hitting OOM errors)
max_seq_len = 20_000         # Maximum sequence length

# Distributed training setup (adjust for your hardware)
nproc_per_node = 8  # Number of GPUs per node
nnodes = 1          # Number of nodes
node_rank = 0       # This node's rank
rdzv_id = 47        # Rendezvous ID
rdzv_endpoint = "0.0.0.0:12345"  # Master endpoint

print(f"LAB Multi-Phase Experiment: {experiment_prefix}")
print(f"Output directory: {ckpt_output_base_dir}")
print(f"GPUs per node: {nproc_per_node}")
print(f"Max tokens per GPU: {max_tokens_per_gpu:,}")
print(f"\nData composition:")
print(f"  Phase07: Knowledge data only")
print(f"  Phase10: Skills + Phase07 replay + Base model instruction replay")
print(f"\n💡 Note: If you encounter OOM (Out of Memory) errors, reduce max_tokens_per_gpu")

## Phase 1: Knowledge Tuning (Phase07)

In Phase07, we train the model on knowledge-heavy data to establish foundational understanding. This phase focuses on factual information, domain knowledge, and core concepts that the model needs to master before learning specific skills.

In [None]:
# Phase07 (Knowledge Tuning) configuration
experiment_prefix_phase07 = experiment_prefix + "_phase07"
experiment_name_phase07 = experiment_prefix_phase07 + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")
phase07_ckpt_output_dir = os.path.join(ckpt_output_base_dir, experiment_prefix_phase07)

print(f"Phase07 (Knowledge Tuning) Configuration:")
print(f"  Experiment name: {experiment_name_phase07}")
print(f"  Input model: {base_model_path}")
print(f"  Data path: {phase07_data_path}")
print(f"  Output directory: {phase07_ckpt_output_dir}")

In [None]:
# Run Phase07 training with managed output
def phase07_training():
    """Execute Phase07 training with all parameters."""
    return sft(
        # Required parameters
        model_path=base_model_path,           # Path to the model to fine-tune
        data_path=phase07_data_path,          # Path to the training data
        ckpt_output_dir=phase07_ckpt_output_dir,  # Directory to save checkpoints
        
        # Core training parameters
        num_epochs=7,                         # Number of training epochs
        effective_batch_size=128,             # Effective batch size for training (smaller due to smaller knowledge dataset)
        learning_rate=2e-5,                   # Learning rate for training
        max_seq_len=max_seq_len,              # Maximum sequence length
        max_tokens_per_gpu=max_tokens_per_gpu, # Maximum tokens per GPU in a mini-batch (hard-cap for memory to avoid OOMs)
        
        # Data and checkpointing parameters
        data_output_dir="/dev/shm",           # Directory to save processed data (using RAM for faster data processing)
        warmup_steps=0,                       # Number of warmup steps
        save_samples=0,                       # Number of samples to save after training (0 disables saving based on sample count)
        checkpoint_at_epoch=True,             # Whether to checkpoint at each epoch (default value, shown for clarity)
        accelerate_full_state_at_epoch=False, # Whether to save full state at epoch for automatic checkpoint resumption (override default to save space)
        
        # Distributed training parameters
        nproc_per_node=nproc_per_node,        # Number of processes (GPUs) per node for distributed training
        nnodes=nnodes,                        # Total number of nodes for distributed training
        node_rank=node_rank,                  # Rank of this node (0 to nnodes-1) for distributed training
        rdzv_id=rdzv_id,                      # Unique job ID for rendezvous in distributed training
        rdzv_endpoint=rdzv_endpoint,          # Master node endpoint for multi-node training
    )

# Execute Phase07 training with managed output to prevent notebook crashes
try:
    result = run_training_with_managed_output(phase07_training, "Phase07 (Knowledge Tuning)")
    print("🎯 Phase07 training completed successfully!")
except Exception as e:
    print(f"💥 Phase07 training failed: {e}")
    print("🔍 Check the error details above for troubleshooting")
    raise

## Checkpoint Discovery

After Phase07 completes, we need to find the most recent checkpoint to use as input for Phase10.

In [None]:
# Find the most recent checkpoint from Phase07
phase07_checkpoint_location = f"{phase07_ckpt_output_dir}/hf_format"

print(f"Looking for Phase07 checkpoints in: {phase07_checkpoint_location}")

if not os.path.exists(phase07_checkpoint_location):
    print(f"❌ Checkpoint directory not found: {phase07_checkpoint_location}")
    print("   Make sure Phase07 completed successfully")
else:
    phase07_checkpoints = os.listdir(phase07_checkpoint_location)
    
    if not phase07_checkpoints:
        print(f"❌ No checkpoints found in {phase07_checkpoint_location}")
    else:
        print(f"Found {len(phase07_checkpoints)} checkpoint(s):")
        for ckpt in phase07_checkpoints:
            print(f"  - {ckpt}")
        
        # Find the most recent checkpoint
        most_recent_checkpoint, most_recent_time = None, 0
        
        for checkpoint in phase07_checkpoints:
            full_ckpt_path = f"{phase07_checkpoint_location}/{checkpoint}"
            if os.path.isdir(full_ckpt_path):
                ckpt_time = os.stat(full_ckpt_path).st_ctime
                if ckpt_time > most_recent_time:
                    most_recent_checkpoint = full_ckpt_path
                    most_recent_time = ckpt_time
        
        if most_recent_checkpoint:
            print(f"\n✅ Most recent Phase07 checkpoint: {most_recent_checkpoint}")
            print(f"   Created: {datetime.fromtimestamp(most_recent_time)}")
        else:
            print("❌ No valid checkpoint directories found")

## Phase 2: Skills + Replay Training (Phase10)

In Phase10, we continue training from the Phase07 checkpoint using a comprehensive dataset that includes:

1. **New Skills Data**: Task instructions, problem-solving examples, and specific capabilities
2. **Phase07 Knowledge Replay**: Replay of the knowledge data from Phase07 to prevent knowledge forgetting  
3. **Base Model Instruction Replay**: Replay of the base model's original instruction tuning data to preserve foundational instruction-following capabilities

This comprehensive replay strategy ensures that the model maintains both its original instruction-following abilities and the newly acquired knowledge from Phase07, while learning new skills in Phase10.

In [None]:
# Phase10 (Skills + Replay Training) configuration
if 'most_recent_checkpoint' not in locals() or most_recent_checkpoint is None:
    print("❌ Cannot proceed with Phase10: No checkpoint from Phase07")
    print("   Please ensure Phase07 completed successfully")
else:
    phase10_input_model = most_recent_checkpoint
    experiment_prefix_phase10 = experiment_prefix + "_phase10"
    experiment_name_phase10 = experiment_prefix_phase10 + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")
    phase10_ckpt_output_dir = os.path.join(ckpt_output_base_dir, experiment_prefix_phase10)
    
    print(f"Phase10 (Skills + Replay Training) Configuration:")
    print(f"  Experiment name: {experiment_name_phase10}")
    print(f"  Input model (from Phase07): {phase10_input_model}")
    print(f"  Data path: {phase10_data_path}")
    print(f"  Output directory: {phase10_ckpt_output_dir}")
    print(f"  Training on skills + comprehensive replay data...")
    print(f"  ↳ Skills data + Phase07 knowledge replay + Base model instruction replay")

In [None]:
# Run Phase10 training with managed output
if 'most_recent_checkpoint' not in locals() or most_recent_checkpoint is None:
    print("❌ Cannot proceed with Phase10: No checkpoint from Phase07")
    print("   Please ensure Phase07 completed successfully")
else:
    def phase10_training():
        """Execute Phase10 training with all parameters."""
        return sft(
            # Required parameters
            model_path=phase10_input_model,       # Path to the model to fine-tune (from Phase07 checkpoint)
            data_path=phase10_data_path,          # Path to the training data (skills + replay data)
            ckpt_output_dir=phase10_ckpt_output_dir,  # Directory to save checkpoints
            
            # Core training parameters
            num_epochs=7,                         # Number of training epochs
            effective_batch_size=3840,            # Effective batch size for training (larger due to larger skills + replay dataset)
            learning_rate=2e-5,                   # Learning rate for training
            max_seq_len=max_seq_len,              # Maximum sequence length
            max_tokens_per_gpu=max_tokens_per_gpu, # Maximum tokens per GPU in a mini-batch (hard-cap for memory to avoid OOMs)
            
            # Data and checkpointing parameters
            data_output_dir="/dev/shm",           # Directory to save processed data (using RAM for faster data processing)
            warmup_steps=0,                       # Number of warmup steps
            save_samples=0,                       # Number of samples to save after training (0 disables saving based on sample count)
            checkpoint_at_epoch=True,             # Whether to checkpoint at each epoch (default value, shown for clarity)
            accelerate_full_state_at_epoch=True,  # Whether to save full state at epoch for automatic checkpoint resumption (default value, enable for final model)
            
            # Distributed training parameters
            nproc_per_node=nproc_per_node,        # Number of processes (GPUs) per node for distributed training
            nnodes=nnodes,                        # Total number of nodes for distributed training
            node_rank=node_rank,                  # Rank of this node (0 to nnodes-1) for distributed training
            rdzv_id=rdzv_id,                      # Unique job ID for rendezvous in distributed training
            rdzv_endpoint=rdzv_endpoint,          # Master node endpoint for multi-node training
        )

    # Execute Phase10 training with managed output to prevent notebook crashes
    try:
        result = run_training_with_managed_output(phase10_training, "Phase10 (Skills + Replay Training)")
        print("🎯 Phase10 training completed successfully!")
    except Exception as e:
        print(f"💥 Phase10 training failed: {e}")
        print("🔍 Check the error details above for troubleshooting")
        raise

## Training Summary

Let's summarize what we accomplished and where to find the final model.

In [None]:
print("🎉 LAB Multi-Phase Training Summary")
print("=" * 50)

if 'phase07_ckpt_output_dir' in locals():
    print(f"📁 Phase07 (Knowledge Tuning) Output: {phase07_ckpt_output_dir}")
    
if 'phase10_ckpt_output_dir' in locals():
    print(f"📁 Phase10 (Skills + Replay) Output: {phase10_ckpt_output_dir}")
    print(f"\n🎯 Final trained model location:")
    print(f"   {phase10_ckpt_output_dir}/hf_format/[latest_checkpoint]")
    
    # List final checkpoints if available
    final_ckpt_dir = f"{phase10_ckpt_output_dir}/hf_format"
    if os.path.exists(final_ckpt_dir):
        final_checkpoints = [d for d in os.listdir(final_ckpt_dir) if os.path.isdir(os.path.join(final_ckpt_dir, d))]
        if final_checkpoints:
            print(f"\n📋 Available final checkpoints:")
            for ckpt in sorted(final_checkpoints):
                print(f"   - {ckpt}")

print(f"\n🔧 LAB Training Configuration Used:")
print(f"   - Max tokens per GPU: {max_tokens_per_gpu:,}")
print(f"   - Max sequence length: {max_seq_len:,}")
print(f"   - GPUs per node: {nproc_per_node}")
print(f"   - Phase07 batch size: 128 (smaller knowledge dataset)")
print(f"   - Phase10 batch size: 3840 (larger skills + replay dataset)")
print(f"   - Learning rate: 2e-5")
print(f"   - Epochs per phase: 7")

print(f"\n📊 Data Composition:")
print(f"   - Phase07: Knowledge data only")
print(f"   - Phase10: Skills + Phase07 replay + Base model instruction replay")

print(f"\n💡 Next Steps:")
print(f"   1. Evaluate your model on relevant benchmarks")
print(f"   2. Test with sample prompts to verify training quality")
print(f"   3. Check knowledge retention from Phase07")
print(f"   4. Verify new skills acquisition from Phase10")
print(f"   5. Confirm base model instruction-following capabilities are preserved")
print(f"   6. Deploy for inference using your preferred serving framework")

## Key Concepts Explained

### LAB Multi-Phase Training Benefits:

1. **Knowledge → Skills + Comprehensive Replay**: Phase07 builds foundational knowledge, Phase10 adds task-specific capabilities while replaying both Phase07 knowledge AND base model instruction data
2. **Comprehensive Replay Strategy**: Prevents both knowledge forgetting (Phase07) and capability regression (base model instruction-following)
3. **Dataset-Appropriate Batch Sizes**: Smaller batch size for smaller knowledge datasets, larger batch size for larger skills + replay datasets
4. **Checkpoint Continuity**: Seamlessly continue from Phase07 results into Phase10
5. **Multi-Level Preservation**: Maintains original instruction capabilities, Phase07 knowledge, and adds new Phase10 skills

### Training_Hub Advantages:

1. **Simplified Interface**: Single `sft()` function instead of separate argument objects
2. **Clear Parameter Organization**: Logical grouping of training, distributed, and advanced options
3. **Backend Flexibility**: Easy to switch between different training backends
4. **Better Documentation**: Clear parameter names like `max_tokens_per_gpu`

### LAB Training Strategy:

- **Phase07 Focus**: Knowledge acquisition on typically smaller, focused knowledge datasets
- **Phase10 Focus**: Skills training with comprehensive replay on larger combined datasets
- **Dual Replay Mechanism**: Prevents both knowledge drift and instruction capability loss
- **Memory Management**: `max_tokens_per_gpu` prevents OOM while maintaining throughput
- **Fast Data Loading**: Using `/dev/shm` for data processing
- **Checkpointing Strategy**: Different strategies for intermediate vs final models

### Data Composition Details:

**Phase07**: Pure knowledge data for focused learning (typically smaller datasets)
**Phase10**: Carefully balanced mixture of (typically much larger combined dataset):
- New skills/task data (primary learning objective)
- Phase07 knowledge replay (prevents knowledge forgetting)
- Base model instruction replay (preserves original capabilities)

## Troubleshooting

### Common Issues:

1. **Out of Memory (OOM)**:
   - Reduce `max_tokens_per_gpu`
   - Reduce `effective_batch_size`
   - Check GPU memory usage

2. **Checkpoint Not Found**:
   - Verify Phase 1 completed successfully
   - Check `ckpt_output_dir` permissions
   - Look for error messages in training logs

3. **Distributed Training Issues**:
   - Verify network connectivity between nodes
   - Check `rdzv_endpoint` accessibility
   - Ensure consistent environment across nodes

4. **Data Loading Errors**:
   - Verify data file paths exist
   - Check JSONL format validity
   - Ensure sufficient disk space in `/dev/shm`