# OSFT Multi-Phase Training Tutorial

This notebook demonstrates how to perform multi-phase training using the OSFT (Orthogonal Subspace Fine-Tuning) algorithm. The key innovation is that **OSFT eliminates the need for replay buffers** while still preserving all capabilities across training phases.

## The Two-Phase OSFT Process:

1. **Phase 1 - Knowledge Tuning (Phase07)**: Training on knowledge-heavy data to build foundational understanding
2. **Phase 2 - Skills Training (Phase10)**: Training on skills data with a **reduced unfreeze_rank_ratio** to preserve Phase 1 knowledge

## Key Advantages of OSFT Multi-Phase Training:

- ✅ **No replay buffers needed** - OSFT naturally preserves prior knowledge
- ✅ **Simpler data pipeline** - Just use your knowledge and skills data directly
- ✅ **Better preservation** - Reduce unfreeze_rank_ratio during skills training for optimal retention
- ✅ **No catastrophic forgetting** - Built into the algorithm
- ✅ **Replaces traditional LAB workflows** - More efficient than LAB multi-phase training

## The Unfreeze Ratio Strategy:

- **Phase 1**: Use standard ratio (e.g., 0.3) for knowledge acquisition
- **Phase 2**: Reduce by ~10% (e.g., 0.2) to preserve knowledge while adding skills

This progressive reduction ensures each phase builds upon the previous without overwriting.

## How OSFT Multi-Phase Replaces LAB Workflows:

Traditional LAB (Large-scale Alignment for chatBots) multi-phase training requires complex replay buffers and data mixing. **OSFT Multi-Phase training provides the same benefits with a much simpler approach** - no replay data needed, just progressive unfreeze ratio reduction.

## Setup and Imports

First, let's import the necessary libraries and set up our training environment.


In [None]:
# Import training_hub for OSFT training
from training_hub import osft

# Standard library imports
import os
import time
import logging
import sys
from datetime import datetime
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO


## Logging Configuration

Set up logging to track progress while preventing notebook crashes from excessive output.

**Note:** For production workflows or long-running jobs, we recommend using the script version at `scripts/osft_multiphase_training.py` for better logging consistency and resumption capabilities.

**Quick script usage:**
```bash
python scripts/osft_multiphase_training.py \
  --base-model-path /path/to/model \
  --phase07-data-path /path/to/knowledge.jsonl \
  --phase10-data-path /path/to/skills.jsonl \
  --ckpt-output-base-dir /path/to/checkpoints
```

In [None]:
# Configure logging to show only essential information
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)

# Suppress verbose logging from transformers and other libraries
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)

print("✅ Logging configured for notebook environment")


## Utility Functions

Let's define some helper functions for checkpoint management.


In [None]:
import glob

def find_most_recent_checkpoint(output_dir):
    """
    Find the most recent checkpoint in the training output directory.
    
    Args:
        output_dir (str): Training output directory containing hf_format/ subdirectory
        
    Returns:
        str: Path to the most recent checkpoint
        
    Raises:
        ValueError: If no checkpoints are found
    """
    # Get all checkpoint directories under hf_format
    checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0")
    checkpoint_dirs = glob.glob(checkpoint_pattern)
    
    if not checkpoint_dirs:
        raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}")
    
    # Find the most recently created checkpoint
    most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
    
    return most_recent_checkpoint

print("✅ Checkpoint utility functions defined")


## Understanding Multi-Phase Data Requirements

OSFT Multi-Phase training requires carefully curated datasets for each phase:

### Phase07 (Knowledge) Data
- Focus on factual knowledge, domain expertise, and foundational understanding
- Examples: technical documentation, educational content, reference materials
- Format: Standard JSONL with messages

### Phase10 (Skills) Data  
- Focus on task completion, instruction following, and practical applications
- Examples: coding tasks, problem-solving, conversational skills
- Format: Standard JSONL with messages

### The OSFT Advantage Over Traditional LAB Workflows
**With traditional LAB SFT**, Phase10 would need:
- Skills data
- \+ Phase07 knowledge data (replay)
- \+ Base model instruction data (replay)
- = Complex data mixing and large datasets

**With OSFT Multi-Phase**, Phase10 only needs:
- Skills data
- That's it! OSFT preserves prior knowledge automatically

This makes OSFT Multi-Phase training a superior replacement for LAB workflows, eliminating the complexity of replay buffer management while providing the same multi-phase training benefits.

## Configuration: Model and Data Paths

Configure your base model and data paths for the two-phase training.


In [None]:
# =============================================================================
# MODEL AND DATA CONFIGURATION
# =============================================================================

# Base model configuration
BASE_MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"  # Or your preferred base model
# BASE_MODEL_PATH = "Qwen/Qwen2.5-7B-Instruct"
# BASE_MODEL_PATH = "microsoft/Phi-4-mini-instruct"

# Data paths for each phase
PHASE07_DATA_PATH = "/path/to/your/phase07_knowledge_data.jsonl"  # Knowledge data
PHASE10_DATA_PATH = "/path/to/your/phase10_skills_data.jsonl"     # Skills data ONLY (no replay needed!)

# Output configuration
CHECKPOINT_BASE_DIR = "/path/to/checkpoints"
EXPERIMENT_PREFIX = "osft_multiphase_experiment"

# Create timestamped experiment directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_name = f"{EXPERIMENT_PREFIX}_{timestamp}"

print("📋 OSFT Multi-Phase Configuration")
print("=" * 50)
print(f"Base Model: {BASE_MODEL_PATH}")
print(f"Phase07 Data: {PHASE07_DATA_PATH}")
print(f"Phase10 Data: {PHASE10_DATA_PATH}")
print(f"Output Directory: {CHECKPOINT_BASE_DIR}/{experiment_name}")
print()
print("✨ Key Difference from Traditional LAB SFT:")
print("  Phase10 only needs skills data - no replay buffers!")
print("  OSFT preserves Phase07 knowledge automatically.")
print("  This workflow replaces complex LAB multi-phase training.")

## OSFT-Specific Parameters: The Progressive Unfreeze Strategy

The key to successful multi-phase training with OSFT is progressively reducing the `unfreeze_rank_ratio` in each phase.


In [None]:
# =============================================================================
# OSFT PROGRESSIVE UNFREEZE STRATEGY
# =============================================================================

# Phase07: Initial knowledge acquisition
PHASE07_UNFREEZE_RATIO = 0.3  # Standard ratio for knowledge learning

# Phase10: Reduced ratio for better preservation
UNFREEZE_REDUCTION = 0.1  # Reduce by 10% for each subsequent phase
PHASE10_UNFREEZE_RATIO = max(0.1, PHASE07_UNFREEZE_RATIO - UNFREEZE_REDUCTION)

print("🎯 OSFT Progressive Unfreeze Strategy")
print("=" * 50)
print(f"Phase07 (Knowledge): unfreeze_rank_ratio = {PHASE07_UNFREEZE_RATIO}")
print(f"Phase10 (Skills):    unfreeze_rank_ratio = {PHASE10_UNFREEZE_RATIO}")
print(f"Reduction:           -{UNFREEZE_REDUCTION} per phase")
print()
print("📊 Strategy Explanation:")
print(f"  • Phase07 ({PHASE07_UNFREEZE_RATIO}): More freedom to acquire new knowledge")
print(f"  • Phase10 ({PHASE10_UNFREEZE_RATIO}): Reduced to preserve Phase07 learning")
print()
print("💡 Guidelines:")
print("  • Start with 0.25-0.35 for Phase07")
print("  • Reduce by 0.05-0.15 for each subsequent phase")
print("  • Never go below 0.1 (too restrictive)")
print("  • Adjust based on your preservation needs")


## Training Hyperparameters

Configure training parameters for both phases. Note that we can use similar settings for both phases since OSFT handles preservation automatically.


In [None]:
# =============================================================================
# TRAINING HYPERPARAMETERS
# =============================================================================

# Common parameters for both phases
MAX_SEQ_LEN = 8_192                 # Maximum sequence length
MAX_TOKENS_PER_GPU = 10_000         # Memory limit per GPU
NUM_EPOCHS = 2                      # Training epochs per phase
WARMUP_STEPS = 0                    # Warmup for Phase07
USE_LIGER = True                    # Enable Liger kernels for efficiency

# Phase07 specific parameters
PHASE07_BATCH_SIZE = 128            # Batch size for knowledge training
PHASE07_LEARNING_RATE = 5e-6        # Use low learning rate for better learning quality

# Phase10 specific parameters  
PHASE10_BATCH_SIZE = 128            # Can use same batch size (no replay data!)
PHASE10_LEARNING_RATE = 5e-6        # Use low learning rate for better learning quality
PHASE10_WARMUP_STEPS = 0            # No warmup

# Distributed training configuration
NPROC_PER_NODE = 8                  # Number of GPUs per node
NNODES = 1                          # Number of nodes
NODE_RANK = 0                       # Rank of this node
RDZV_ID = 47                        # Unique job ID
RDZV_ENDPOINT = "127.0.0.1:29500"   # Rendezvous endpoint

print("⚙️  Training Hyperparameters")
print("=" * 50)
print(f"Max Sequence Length: {MAX_SEQ_LEN:,}")
print(f"Max Tokens per GPU: {MAX_TOKENS_PER_GPU:,}")
print(f"Epochs per Phase: {NUM_EPOCHS}")
print()
print("Phase07 (Knowledge):")
print(f"  • Batch Size: {PHASE07_BATCH_SIZE}")
print(f"  • Learning Rate: {PHASE07_LEARNING_RATE}")
print(f"  • Warmup Steps: {WARMUP_STEPS}")
print()
print("Phase10 (Skills):")
print(f"  • Batch Size: {PHASE10_BATCH_SIZE}")
print(f"  • Learning Rate: {PHASE10_LEARNING_RATE} (reduced for preservation)")
print(f"  • Warmup Steps: {PHASE10_WARMUP_STEPS}")
print()
print(f"Distributed: {NPROC_PER_NODE} GPUs × {NNODES} nodes = {NPROC_PER_NODE * NNODES} total GPUs")


## Phase 1: Knowledge Training (Phase07)

First, we train on knowledge data to build foundational understanding. This phase uses the standard unfreeze_rank_ratio.


In [None]:
# =============================================================================
# PHASE 1 (PHASE07): KNOWLEDGE TRAINING
# =============================================================================

phase07_output_dir = os.path.join(CHECKPOINT_BASE_DIR, f"{experiment_name}_phase07")

print("📚 Phase 1: Knowledge Training with OSFT")
print("=" * 60)
print(f"Starting from: {BASE_MODEL_PATH}")
print(f"Training data: {PHASE07_DATA_PATH}")
print(f"Output directory: {phase07_output_dir}")
print(f"Unfreeze ratio: {PHASE07_UNFREEZE_RATIO}")
print()

# Capture output to prevent notebook crashes
output_buffer = StringIO()
error_buffer = StringIO()

phase07_start_time = time.time()

try:
    with redirect_stdout(output_buffer), redirect_stderr(error_buffer):
        # Phase07 OSFT training
        phase07_result = osft(
            # Model and data
            model_path=BASE_MODEL_PATH,
            data_path=PHASE07_DATA_PATH,
            ckpt_output_dir=phase07_output_dir,
            
            # OSFT-specific
            unfreeze_rank_ratio=PHASE07_UNFREEZE_RATIO,
            
            # Training parameters
            num_epochs=NUM_EPOCHS,
            effective_batch_size=PHASE07_BATCH_SIZE,
            learning_rate=PHASE07_LEARNING_RATE,
            max_seq_len=MAX_SEQ_LEN,
            max_tokens_per_gpu=MAX_TOKENS_PER_GPU,
            
            # Data processing
            data_output_dir=os.path.join(phase07_output_dir, "data_processing"),
            warmup_steps=WARMUP_STEPS,
            
            # Optimization
            use_liger=USE_LIGER,
            seed=42,
            lr_scheduler="cosine",
            
            # Checkpointing
            checkpoint_at_epoch=True,
            save_final_checkpoint=True,
            
            # Distributed training
            nproc_per_node=NPROC_PER_NODE,
            nnodes=NNODES,
            node_rank=NODE_RANK,
            rdzv_id=RDZV_ID,
            rdzv_endpoint=RDZV_ENDPOINT,
        )
    
    phase07_duration = time.time() - phase07_start_time
    
    print(f"✅ Phase07 completed successfully in {phase07_duration/3600:.2f} hours!")
    print(f"📁 Checkpoint saved to: {phase07_output_dir}")
    print()
    print("📊 Phase07 Achievements:")
    print("  • Base model capabilities: ✅ Preserved")
    print("  • New knowledge integrated: ✅ Complete")
    print("  • Ready for Phase10: ✅ Yes")
    
    # Find the most recent checkpoint for Phase10
    PHASE07_CHECKPOINT = find_most_recent_checkpoint(phase07_output_dir)
    print(f"📁 Found most recent Phase07 checkpoint: {PHASE07_CHECKPOINT}")
    print(f"📁 Ready for Phase10 training!")
    
except Exception as e:
    print(f"❌ Phase07 training failed: {e}")
    print("\nError details:")
    print(error_buffer.getvalue())
    raise

## Phase 2: Skills Training (Phase10)

Now we train on skills data with a **reduced unfreeze_rank_ratio** to preserve Phase07 knowledge while adding new capabilities.


In [None]:
# =============================================================================
# PHASE 2 (PHASE10): SKILLS TRAINING
# =============================================================================

phase10_output_dir = os.path.join(CHECKPOINT_BASE_DIR, f"{experiment_name}_phase10")

print("🎯 Phase 2: Skills Training with OSFT")
print("=" * 60)
print(f"Starting from: {PHASE07_CHECKPOINT} (Phase07 checkpoint)")
print(f"Training data: {PHASE10_DATA_PATH}")
print(f"Output directory: {phase10_output_dir}")
print(f"Unfreeze ratio: {PHASE10_UNFREEZE_RATIO} (reduced from {PHASE07_UNFREEZE_RATIO})")
print()
print("💡 Key Innovation:")
print("  NO replay buffer needed! OSFT preserves Phase07 knowledge automatically.")
print("  The reduced unfreeze_rank_ratio ensures better preservation.")
print()

# Capture output to prevent notebook crashes
output_buffer = StringIO()
error_buffer = StringIO()

phase10_start_time = time.time()

try:
    with redirect_stdout(output_buffer), redirect_stderr(error_buffer):
        # Phase10 OSFT training
        phase10_result = osft(
            # Start from Phase07 checkpoint
            model_path=PHASE07_CHECKPOINT,
            data_path=PHASE10_DATA_PATH,
            ckpt_output_dir=phase10_output_dir,
            
            # OSFT-specific: REDUCED ratio for preservation
            unfreeze_rank_ratio=PHASE10_UNFREEZE_RATIO,
            
            # Training parameters
            num_epochs=NUM_EPOCHS,
            effective_batch_size=PHASE10_BATCH_SIZE,
            learning_rate=PHASE10_LEARNING_RATE,  # Lower LR for preservation
            max_seq_len=MAX_SEQ_LEN,
            max_tokens_per_gpu=MAX_TOKENS_PER_GPU,
            
            # Data processing
            data_output_dir=os.path.join(phase10_output_dir, "data_processing"),
            warmup_steps=PHASE10_WARMUP_STEPS,  # Shorter warmup
            
            # Optimization
            use_liger=USE_LIGER,
            seed=42,  
            lr_scheduler="cosine",
            
            # Checkpointing
            checkpoint_at_epoch=True,
            save_final_checkpoint=True,
            
            # Distributed training
            nproc_per_node=NPROC_PER_NODE,
            nnodes=NNODES,
            node_rank=NODE_RANK,
            rdzv_id=RDZV_ID + 1,  # Different ID for Phase10
            rdzv_endpoint=RDZV_ENDPOINT,
        )
    
    phase10_duration = time.time() - phase10_start_time
    
    print(f"✅ Phase10 completed successfully in {phase10_duration/3600:.2f} hours!")
    print(f"📁 Final checkpoint saved to: {phase10_output_dir}")
    print()
    print("📊 Phase10 Achievements:")
    print("  • Base model capabilities: ✅ Preserved")
    print("  • Phase07 knowledge: ✅ Retained")  
    print("  • New skills integrated: ✅ Complete")
    
    # Find the most recent checkpoint from Phase10 training
    FINAL_CHECKPOINT = find_most_recent_checkpoint(phase10_output_dir)
    print(f"📁 Final model checkpoint: {FINAL_CHECKPOINT}")
    
except Exception as e:
    print(f"❌ Phase10 training failed: {e}")
    print("\nError details:")
    print(error_buffer.getvalue())
    raise


## Final Analysis and Summary

Let's analyze the complete two-phase training results and understand what we've achieved with OSFT.


In [None]:
# =============================================================================
# FINAL ANALYSIS AND SUMMARY
# =============================================================================

total_duration = (phase07_duration + phase10_duration) / 3600

print("🎉 OSFT Multi-Phase Training Complete!")
print("=" * 60)
print(f"Total training time: {total_duration:.2f} hours")
print(f"Final model: {FINAL_CHECKPOINT}")
print()

# Training summary
print("📊 Training Summary:")
print("-" * 50)
print("Phase 1 (Knowledge - Phase07):")
print(f"  • Duration: {phase07_duration/3600:.2f} hours")
print(f"  • Unfreeze ratio: {PHASE07_UNFREEZE_RATIO}")
print(f"  • Batch size: {PHASE07_BATCH_SIZE}")
print(f"  • Learning rate: {PHASE07_LEARNING_RATE}")
print(f"  • Checkpoint: {PHASE07_CHECKPOINT}")
print()
print("Phase 2 (Skills - Phase10):")
print(f"  • Duration: {phase10_duration/3600:.2f} hours")
print(f"  • Unfreeze ratio: {PHASE10_UNFREEZE_RATIO} (reduced by {UNFREEZE_REDUCTION})")
print(f"  • Batch size: {PHASE10_BATCH_SIZE}")
print(f"  • Learning rate: {PHASE10_LEARNING_RATE}")
print(f"  • Checkpoint: {FINAL_CHECKPOINT}")
print()

# Model capabilities
print("🚀 Your Model Now Has:")
print("-" * 50)
print("1. ✅ Original base model capabilities (preserved)")
print("2. ✅ New knowledge from Phase07 (integrated)")
print("3. ✅ Task-specific skills from Phase10 (acquired)")
print("4. ❌ Catastrophic forgetting (none!)")
print()

# How to use the model
print("💻 Using Your Trained Model:")
print("-" * 50)
print("```python")
print("from transformers import AutoModelForCausalLM, AutoTokenizer")
print("")
print("# Load your OSFT Multi-Phase trained model")
print(f"model = AutoModelForCausalLM.from_pretrained('{FINAL_CHECKPOINT}')")
print(f"tokenizer = AutoTokenizer.from_pretrained('{FINAL_CHECKPOINT}')")
print("")
print("# The model now excels at:")
print("# 1. General instruction following (preserved from base)")
print("# 2. Domain knowledge (from Phase07)")
print("# 3. Specific skills (from Phase10)")
print("")
print("# Test it:")
print("prompt = 'Your domain-specific question here'")
print("inputs = tokenizer(prompt, return_tensors='pt')")
print("outputs = model.generate(**inputs, max_new_tokens=200)")
print("response = tokenizer.decode(outputs[0], skip_special_tokens=True)")
print("print(response)")
print("```")

## OSFT Multi-Phase vs LAB Multi-Phase: Key Differences

Understanding how OSFT Multi-Phase training simplifies and improves upon traditional LAB multi-phase training workflows.

In [None]:
# =============================================================================
# OSFT MULTI-PHASE vs LAB MULTI-PHASE COMPARISON
# =============================================================================

print("📊 OSFT Multi-Phase vs LAB Multi-Phase Comparison")
print("=" * 70)
print()
print("| Aspect                  | LAB Multi-Phase (SFT)      | OSFT Multi-Phase           |")
print("|-------------------------|-----------------------------|-----------------------------|")
print("| **Phase07 Data**        | Knowledge data only         | Knowledge data only         |")
print("| **Phase10 Data**        | Skills + Phase07 replay     | Skills data ONLY ✨         |")
print("|                         | + Base model replay         |                             |")
print("| **Data Complexity**     | Complex mixing ratios       | Simple, direct              |")
print("| **Data Storage**        | 3x larger (with replays)    | 1x (no replays needed)      |")
print("| **Preservation Method** | Data replay buffers         | Algorithm (unfreeze ratio)  |")
print("| **Configuration**       | Complex replay ratios       | Simple ratio reduction      |")
print("| **Forgetting Risk**     | If replay ratios wrong      | Minimal by design           |")
print("| **Training Time**       | Longer (more data)          | Shorter (less data)         |")
print("| **Task Performance**    | Full model fitting          | Similar (slight trade-off)  |")
print("| **Capability Preserv.** | Depends on replay quality   | Guaranteed by algorithm     |")
print()
print("✨ Key OSFT Multi-Phase Advantages:")
print("  1. No need to store or manage replay buffers")
print("  2. Simpler data pipeline - just your new data")
print("  3. Progressive unfreeze strategy ensures preservation")
print("  4. Reduced training time (less data to process)")
print("  5. Guaranteed preservation through algorithm design")
print("  6. Direct replacement for LAB workflows with better efficiency")
print()
print("⚖️ Trade-offs:")
print("  • OSFT may achieve slightly lower task-specific performance")
print("  • But preserves other capabilities that SFT would degrade")
print("  • Overall: Better capability preservation vs. task fitting balance")

## Best Practices and Recommendations

Key guidelines for successful OSFT Multi-Phase training that replaces traditional LAB workflows.

In [None]:
# =============================================================================
# BEST PRACTICES AND RECOMMENDATIONS
# =============================================================================

print("📚 Best Practices for OSFT Multi-Phase Training")
print("=" * 60)
print()
print("1️⃣  **Unfreeze Ratio Strategy:**")
print("   • Start with 0.25-0.35 for Phase07")
print("   • Reduce by 0.05-0.15 for Phase10")
print("   • Never go below 0.1 (too restrictive)")
print("   • If seeing forgetting, reduce the ratio further")
print()
print("2️⃣  **Data Quality:**")
print("   • Phase07: Focus on high-quality knowledge/facts")
print("   • Phase10: Focus on diverse skills and tasks")
print("   • No need for replay data - OSFT handles preservation!")
print()
print("3️⃣  **Learning Rate Strategy:**")
print("   • Phase07: Standard LR (e.g., 5e-6)")
print("   • Phase10: Standard LR (e.g., 5e-6) for stability")
print("   • Use cosine scheduler for smooth convergence")
print("   • If model isn't learning enough, try increasing LR or number of epochs;")
print("     the optimal settings will vary by model and data")
print()
print("4️⃣  **Monitoring:**")
print("   • Track loss curves for both phases")
print("   • Test preservation after each phase")
print("   • Evaluate on held-out sets for each capability")

print("🎯 Next Steps:")
print("-" * 50)
print("1. Test your model on evaluation sets")
print("2. Compare with base model to verify improvements")
print("3. Fine-tune unfreeze ratios if needed")
print("4. Deploy with confidence - no capability regression expected!")
print()
print("📝 For production use, remember to use the script:")
print("   scripts/osft_multiphase_training.py")
print()
print("🔄 Replacing LAB Workflows:")
print("   This OSFT Multi-Phase approach can directly replace")
print("   traditional LAB multi-phase training with better efficiency!")
print()
print("⚖️ Performance Expectations:")
print("   • Task-specific performance: Similar to LAB (may be slightly lower)")
print("   • Capability preservation: Superior to SFT (no degradation)")
print("   • OSFT trades some task fitting for capability preservation")
print()
print("Happy training! 🚀")

## Key Concepts Explained

### OSFT Multi-Phase Training Benefits:

1. **Knowledge → Skills with Progressive Unfreeze**: Phase07 builds foundational knowledge, Phase10 adds task-specific capabilities using reduced unfreeze ratio
2. **No Replay Buffers Needed**: OSFT's orthogonal subspace approach preserves prior learning algorithmically
3. **Simple Data Pipeline**: Just your knowledge and skills data - no complex mixing or replay datasets
4. **Progressive Preservation**: Reducing unfreeze_rank_ratio in each phase ensures cumulative learning
5. **Guaranteed Non-Forgetting**: Mathematical guarantee through orthogonal weight updates
6. **LAB Workflow Replacement**: Provides all benefits of LAB multi-phase training with superior efficiency

### OSFT Algorithm Advantages:

1. **Algorithmic Preservation**: Prior knowledge preserved through math, not data replay
2. **Simplified Training**: No need to manage replay buffers or mixing ratios
3. **Reduced Storage**: No need to store ~370k sample replay buffers
4. **Streamlined Training**: No replay data processing required
5. **Predictable Behavior**: Preservation is guaranteed by the algorithm
6. **Superior to LAB**: Replaces complex LAB workflows with simpler, more efficient approach

### OSFT Multi-Phase Training Strategy:

- **Phase07 Focus**: Knowledge acquisition with standard unfreeze ratio (0.25-0.35)
- **Phase10 Focus**: Skills training with reduced ratio (0.15-0.25) for preservation
- **Progressive Reduction**: Each phase reduces unfreeze_rank_ratio by 0.05-0.15
- **Memory Management**: Same `max_tokens_per_gpu` strategy as SFT
- **Fast Data Loading**: Using `/dev/shm` for data processing
- **Simple Checkpointing**: Standard checkpointing, no special requirements

### Data Requirements Comparison:

**Traditional LAB Multi-Phase SFT**:
- Phase07: Knowledge data
- Phase10: Skills data + Knowledge replay + ~370k sample replay buffer

**OSFT Multi-Phase (Our Approach)**:
- Phase07: Knowledge data only  
- Phase10: Skills data only
- Result: Eliminates the replay buffer requirement entirely!

**Why OSFT Multi-Phase is Superior**: Provides the same multi-phase training benefits as LAB workflows but with dramatically reduced complexity, storage requirements, and training time.

In [None]:
# Visual comparison of training approaches
print("📊 Training Approach Comparison")
print("=" * 80)
print()
print("Traditional LAB Multi-Phase SFT:")
print("  Phase07: [Knowledge Data] → Model_v1")
print("  Phase10: [Skills Data] + [Knowledge Replay] + [~370k Replay Buffer] → Model_v2")
print("           ~~~~~~~~~~~~~   ^^^^^^^^^^^^^^^^^^   ^^^^^^^^^^^^^^^^^^^^^")
print("                               Requires additional replay buffers")
print()
print("OSFT Multi-Phase (This Notebook):")
print("  Phase07: [Knowledge Data] → Model_v1 (unfreeze_ratio=0.3)")
print("  Phase10: [Skills Data] → Model_v2 (unfreeze_ratio=0.2)")
print("           ~~~~~~~~~~~~~")
print("        No replay buffers needed!")
print()
print("✨ Key Benefit: Eliminates the ~370k sample replay buffer requirement!")
print("🔄 This OSFT Multi-Phase workflow directly replaces LAB multi-phase training!")

## Troubleshooting

### Common Issues:

1. **Out of Memory (OOM)**:
   - Reduce `max_tokens_per_gpu`
   - Set `use_liger` to True
   - Reduce `unfreeze_rank_ratio`
   - Check GPU memory usage with `nvidia-smi`

2. **Model Not Learning Well**:
   - Check if `unfreeze_rank_ratio` is too low (try increasing slightly)
   - Verify data quality and format
   - Consider increasing learning rate or epochs
   - Ensure warmup steps are appropriate for your dataset size

3. **Knowledge Forgetting in Phase10**:
   - Reduce Phase10's `unfreeze_rank_ratio` further (e.g., from 0.2 to 0.15)
   - Consider using a lower learning rate in Phase10
   - Verify you're loading the correct Phase07 checkpoint

4. **Checkpoint Not Found**:
   - Verify Phase07 completed successfully
   - Check `ckpt_output_dir` permissions
   - Look for error messages in training logs
   - Ensure sufficient disk space

5. **Distributed Training Issues**:
   - Verify network connectivity between nodes
   - Check `rdzv_endpoint` accessibility
   - Ensure consistent environment across nodes
   - Try with single node first to isolate issues

6. **Data Loading Errors**:
   - Verify JSONL format (each line must be valid JSON with 'messages' field)
   - Check file paths and permissions
   - Ensure sufficient disk space in `/dev/shm`
   - Validate no corrupted entries in data

### OSFT Multi-Phase Specific Tips:

- **Finding Optimal Unfreeze Ratios**: Start with 0.3 for Phase07, reduce by 0.1 for Phase10. Adjust based on your preservation needs.
- **Balancing Learning vs Preservation**: Higher ratios = more learning, lower ratios = more preservation
- **Multi-Phase Beyond Two**: For 3+ phases, continue reducing ratio by 0.05-0.1 per phase, but don't go below 0.1
- **Comparing with LAB**: Expect similar or better results than LAB multi-phase training, but with much simpler setup