# OSFT Continual Learning Demo

Fine-tuning language models is hard—you need good data, lots of resources, and even small changes can cause problems. This makes it tough to add new abilities to a model. This problem is called **continual learning** and is what our new training technique, orthogonal subspace fine-tuning (OSFT), solves.

This notebook presents a hands-on example where we enhance `meta-llama/Meta-Llama-3-8B-Instruct` by teaching it to only produce JSON output when requested.

By the end of this notebook, you will learn:
- ✅ How Llama can be fine-tuned without destroying its existing capabilities
- ✅ How to enhance your own LLMs with OSFT
- ✅ Best practices when fine-tuning models
- ❌ OSFT does NOT kill your existing model when trained on new data


## Setup Paths and Directories

In [None]:
from pathlib import Path


WORKSPACE = Path.cwd().parent  # Path to the workspace directory

OUTPUT_DIR = WORKSPACE / "output" / "step_04"

OUTPUT_DIR.mkdir(
    parents=True, exist_ok=True
)  # Create output directory if it doesn't exist

KNOWLEDGE_MIXED_DATASET_PATH = WORKSPACE / "output" / "step_03" / "training_mix"

################################################################################
# 🤖 Model + Data Paths                                                        #
################################################################################
BASE_MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
DATASET_PATH = KNOWLEDGE_MIXED_DATASET_PATH / "combined_cut_5x.jsonl" # Path to the Training Dataset from Step 03
CHECKPOINTS_PATH = OUTPUT_DIR / "checkpoint"
DATA_OUTPUT_PATH = OUTPUT_DIR / "dev" / "shm"  # for quicker multi-process loading of datasets



BASE_MODEL_PATH = OUTPUT_DIR / "base_model" / BASE_MODEL_NAME.split("/")[-1]


# Authenticate to Hugging Face if required to pull your base model
# from huggingface_hub import login
# HF_TOKEN = "" # Insert your API Token
# login(HF_TOKEN)

In [None]:
# SAVE THE MODEL LOCALLY

if not BASE_MODEL_PATH.exists():
    print("Model not available locally, Downloading the model locally ")
    from transformers import AutoModelForCausalLM, AutoTokenizer

    # Save the model
    print(f"Loading model {BASE_MODEL_NAME}")
    model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_NAME)
    model.save_pretrained(BASE_MODEL_PATH)
    print(f"Model saved to {BASE_MODEL_PATH}")

    # Save the tokenizer
    print(f"Loading tokenizer {BASE_MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
    tokenizer.save_pretrained(BASE_MODEL_PATH)
    print(f"Tokenizer saved to {BASE_MODEL_PATH}")
else:
    print(f"Model Available locally : {BASE_MODEL_PATH}")

## Setup and Imports

First, let's import and configure everything that we need to run this notebook.

In [None]:
# IMPORTANT: Set these env variables so we can properly clear the memory after training
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# We have to enable this so we can properly clear the model from memory after inferencing
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:4096,expandable_segments:True"


In [None]:
# Import training_hub for OSFT training
from training_hub import osft

# Standard library imports
import os
import time
import logging
import sys
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
import shutil


## Logging Configuration

Set up logging to track progress while preventing notebook crashes from excessive output.
<!-- 
**Note:** For production workflows or long-running jobs, we recommend using the script version at `scripts/osft-training.py` for better logging consistency and resumption capabilities.

**Quick script usage:**
```bash
python scripts/lab_multiphase_osft_training.py \
  --base-model-path /path/to/model \
  --phase07-data-path /path/to/knowledge.jsonl \
  --phase10-data-path /path/to/skills.jsonl \
  --ckpt-output-base-dir /path/to/checkpoints -->
```


In [None]:
# Configure logging to show only essential information
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)

# Suppress verbose logging from transformers and other libraries
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)

print("✅ Logging configured for notebook environment")


## Utility Functions

Let's define some helper functions for checkpoint management.

In [None]:
import glob


def find_most_recent_checkpoint(output_dir):
    """
    Find the most recent checkpoint in the training output directory.
    
    Args:
        output_dir (str): Training output directory containing hf_format/ subdirectory
        
    Returns:
        str: Path to the most recent checkpoint
        
    Raises:
        ValueError: If no checkpoints are found
    """
    # Get all checkpoint directories under hf_format
    checkpoint_pattern = os.path.join(output_dir, "hf_format", "samples_*.0")
    checkpoint_dirs = glob.glob(checkpoint_pattern)
    
    if not checkpoint_dirs:
        raise ValueError(f"No checkpoints found in {os.path.join(output_dir, 'hf_format')}")
    
    # Find the most recently created checkpoint
    most_recent_checkpoint = max(checkpoint_dirs, key=os.path.getctime)
    
    return most_recent_checkpoint

def save_best_model(source_directory):
    """This function copies the final model from the checkpoints dir to the base output dir for easier access.

    Args:
        source_directory: Path to the recent checkpoint
    """
    FINAL_FINE_TUNED_MODEL_PATH = OUTPUT_DIR / "fine_tuned_model" / BASE_MODEL_NAME.split('/')[-1]
    FINAL_FINE_TUNED_MODEL_PATH.mkdir(exist_ok=True,parents=True) # Create the directory if not available

    # Iterate through all files/folders in source and copy them
    for item in os.listdir(source_directory):
        src_path = os.path.join(source_directory, item)
        dst_path = os.path.join(FINAL_FINE_TUNED_MODEL_PATH, item)

        # Copy directories or files appropriately
        if os.path.isdir(src_path):
            shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
        else:
            shutil.copy2(src_path, dst_path)

    print(f"✅ Final finetuned model copied\n\t Path :{FINAL_FINE_TUNED_MODEL_PATH}")
    return FINAL_FINE_TUNED_MODEL_PATH

def cleanup_model_memory(*objects):
    """
    Clean up GPU memory by deleting arbitrary objects and clearing CUDA cache.
    
    Args:
        *objects: Variable number of objects to clean up (models, tokenizers, etc.)
    """
    import torch
    import gc
    
    # Delete all provided objects
    # Delete objects from global namespace if they exist there
    for obj in objects:
        if obj is not None:
            # Find the variable name in globals that references this object
            for var_name, var_obj in list(globals().items()):
                if var_obj is obj:
                    print(f"🗑️ Deleting global variable: {var_name}")
                    del globals()[var_name]
                    break
            # Also delete the local reference
            del obj
    
    # Force garbage collection
    gc.collect()
    
    # Clear CUDA cache if CUDA is available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    print("✅ Model memory cleaned up and CUDA cache cleared")


print("✅ Checkpoint utility functions defined")



## Configuration

For this example, I will be running training on an 8xA100 box, but these hyperparameters can be adjusted for any machine; provided it is capable of running OSFT.


🚨 Ensure you have configured the number of GPUs available on the system

In [None]:
################################################################################
# 🏋️‍♀️ Training Hyperparameters                                                  #
################################################################################
# Important for OSFT
UNFREEZE_RANK_RATIO = 0.25 

# Standard parameters
BATCH_SIZE = 128
LEARNING_RATE = 5e-6
NUM_EPOCHS=2
LR_SCHEDULER="cosine"
WARMUP_STEPS=0
SEED=42


################################################################################
# 🏎️ Performance Hyperparameters                                               #
################################################################################
USE_LIGER = True
MAX_TOKENS_PER_GPU=10_000
MAX_SEQ_LEN=8192

################################################################################
# 💾 Checkpointing Settings                                                    #
################################################################################
# Here we only want to save the very last checkpoint
SAVE_FINAL_CHECKPOINT = True
CHECKPOINT_AT_EPOCH = False 

################################################################################
# 🔥 TORCHRUN SETTINGS                                                         #
################################################################################
NUM_GPUS=8
NUM_NODES=1
NODE_RANK=0
RDZV_ID=23
RDZV_ENDPOINT='localhost:1738'


print("⚙️  Training Hyperparameters")
print("=" * 50)
print(f"Base Model: {BASE_MODEL_NAME}")
print(f"Dataset Path: {DATASET_PATH}")
print(f"Checkpoints Path: {CHECKPOINTS_PATH}")
print(f"Data Output Path: {DATA_OUTPUT_PATH}")
print()
print(f"Unfreeze Rank Ratio: {UNFREEZE_RANK_RATIO}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Number of Epochs: {NUM_EPOCHS}")
print(f"LR Scheduler: {LR_SCHEDULER}")
print(f"Warmup Steps: {WARMUP_STEPS}")
print(f"Seed: {SEED}")
print()
print(f"Use Liger: {USE_LIGER}")
print(f"Max Tokens per GPU: {MAX_TOKENS_PER_GPU:,}")
print(f"Max Sequence Length: {MAX_SEQ_LEN:,}")
print()
print(f"Save Final Checkpoint: {SAVE_FINAL_CHECKPOINT}")
print(f"Checkpoint at Epoch: {CHECKPOINT_AT_EPOCH}")
print()
print(f"Distributed: {NUM_GPUS} GPUs × {NUM_NODES} nodes = {NUM_GPUS * NUM_NODES} total GPUs")
print(f"Node Rank: {NODE_RANK}")
print(f"RDZV ID: {RDZV_ID}")
print(f"RDZV Endpoint: {RDZV_ENDPOINT}")

## Clean up GPU Memory

Make sure to clean up all excess memory used earlier so we can train without OOM errors.

In [None]:
import torch
import gc
gc.collect()
print("💾 Clearing CUDA cache...")
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()


## Fine-tuning the model using OSFT

Since prompting the model doesn't work, our next step is modifying the model.
Normally, this wouldn't be a great solution since many methods can cause the model to forget important capabilities. 
However; OSFT allows us to adjust the non-critical pieces of the model while keeping the crucial parts intact -- perfect for our use-case 😃.


## Preparing to train

We have to unset the memory settings we enabled earlier so that we do not run into issues when training.

In [None]:
# IMPORTANT: Set these env variables so we can properly clear the memory after training
import os
# Unset the environment variables
if "CUDA_LAUNCH_BLOCKING" in os.environ:
    del os.environ["CUDA_LAUNCH_BLOCKING"]

if "PYTORCH_CUDA_ALLOC_CONF" in os.environ:
    del os.environ["PYTORCH_CUDA_ALLOC_CONF"]


## Training with OSFT

With our hyperparameters configured, now we launch a training job and sit back while it enhances our new model 😎🍿

In [None]:
print("🚀 Starting OSFT Continual Learning Training")
print("=" * 60)
print(f"Starting from: {BASE_MODEL_NAME}")
print(f"Training data: {DATASET_PATH}")
print(f"Output directory: {CHECKPOINTS_PATH}")
print(f"Unfreeze ratio: {UNFREEZE_RANK_RATIO}")
print()

# Capture output to prevent notebook crashes
output_buffer = StringIO()
error_buffer = StringIO()

training_start_time = time.time()

try:
    with redirect_stdout(output_buffer), redirect_stderr(error_buffer):
        # OSFT training
        training_result = osft(
            # Model and data
            model_path=str(BASE_MODEL_PATH),
            data_path=str(DATASET_PATH),
            ckpt_output_dir=str(CHECKPOINTS_PATH),
            
            # OSFT-specific
            unfreeze_rank_ratio=UNFREEZE_RANK_RATIO,
            
            # Training parameters
            num_epochs=NUM_EPOCHS,
            effective_batch_size=BATCH_SIZE,
            learning_rate=LEARNING_RATE,
            max_seq_len=MAX_SEQ_LEN,
            max_tokens_per_gpu=MAX_TOKENS_PER_GPU,
            
            # Data processing
            data_output_dir=str(DATA_OUTPUT_PATH),
            warmup_steps=WARMUP_STEPS,
            
            # Optimization
            use_liger=USE_LIGER,
            seed=SEED,
            lr_scheduler=LR_SCHEDULER,
            
            # Checkpointing
            checkpoint_at_epoch=CHECKPOINT_AT_EPOCH,
            save_final_checkpoint=SAVE_FINAL_CHECKPOINT,
            
            # Distributed training
            nproc_per_node=NUM_GPUS,
            nnodes=NUM_NODES,
            node_rank=NODE_RANK,
            rdzv_id=RDZV_ID,
            rdzv_endpoint=RDZV_ENDPOINT,
        )
    
    training_duration = time.time() - training_start_time

    # Find the most recent checkpoint from Phase10 training
    final_checkpoint = find_most_recent_checkpoint(CHECKPOINTS_PATH)
    print("\n\n\n✅ Model Training Completed")
    print("="*60)
    print(f"📁 Final model checkpoint: {final_checkpoint}")

    fine_tuned_model_path = save_best_model(final_checkpoint)
    
    print(f"✅ OSFT training completed  in {training_duration/3600:.2f} hours!")
    print()
    print("📊 Training Achievements:")
    print("  • Base model capabilities: ✅ Preserved")
    print("  • New knowledge integrated: ✅ Complete")
    print("  • Continual learning: ✅ Success")
    
except Exception as e:
    print(f"❌ OSFT training failed: {e}")
    print("\nError details:")
    print(error_buffer.getvalue())
    raise


## Best Practices: Choosing Your Unfreeze Rank Ratio

The `unfreeze_rank_ratio` is your key control for balancing learning vs. preservation. Here's how to choose the right value for your use case.


In [None]:
# =============================================================================
# PRACTICAL GUIDE: UNFREEZE RANK RATIO SELECTION
# =============================================================================

print("📚 Choosing the Right Unfreeze Rank Ratio")
print("=" * 60)
print()

# Scenario 1: Small behavior tweaks
print("1️⃣  **Small Behavior Tweaks**")
print("   When to use: Adjusting specific model behaviors without major changes")
print("   Examples: Output formatting, response style, minor corrections")
print()
print("   🎯 Strategy: Start SMALL")
print("   • unfreeze_rank_ratio = 0.1 - 0.15")
print("   • Why: Minimal modification preserves most model behavior")
print("   • Result: Targeted changes without broad impact")
print()

# Scenario 2: Major new capabilities
print("2️⃣  **Major New Capabilities**")
print("   When to use: Adding entirely new skills or knowledge domains")
print("   Examples: New language, coding ability, domain expertise")
print()
print("   🎯 Strategy: Start STANDARD")
print("   • unfreeze_rank_ratio = 0.3 - 0.35")
print("   • Why: More freedom to learn complex new patterns")
print("   • Result: Robust new capabilities while preserving base model")
print()

# Scenario 3: Sequential task learning
print("3️⃣  **Sequential Task Learning (Task 1 → Task 2)**")
print("   When to use: Training on multiple tasks in sequence")
print("   Examples: Knowledge → Skills, General → Specialized")
print()
print("   🎯 Strategy: PROGRESSIVELY REDUCE")
print("   • Task 1: unfreeze_rank_ratio = 0.35")
print("   • Task 2: unfreeze_rank_ratio = 0.30 (reduce by 0.05)")
print("   • Task 3: unfreeze_rank_ratio = 0.25 (reduce by 0.05)")
print("   • Why: Each reduction preserves previous learning")
print("   • Result: Accumulate capabilities without forgetting")
print()

# Golden rules
print("🌟 Golden Rules:")
print("-" * 50)
print("• Never go below 0.1 (too restrictive for learning)")
print("• Never go above 0.5 (risks forgetting)")
print("• When in doubt, start smaller - you can always increase")
print("• Test preservation after each training phase")
print()

# Quick reference table
print("📊 Quick Reference:")
print("-" * 50)
print("| Use Case                | Recommended Ratio | Notes                    |")
print("|------------------------|-------------------|--------------------------|")
print("| Format tweaks          | 0.10 - 0.15       | Minimal changes          |")
print("| Style adjustments      | 0.15 - 0.20       | Moderate refinement      |")
print("| New domain knowledge   | 0.25 - 0.35       | Major capability         |")
print("| New task type          | 0.30 - 0.35       | Significant learning     |")
print("| Sequential phase 2+    | Previous - 0.05   | Preserve prior phases    |")


## Next Steps: Apply OSFT to Your Use Case


In [None]:
# =============================================================================
# NEXT STEPS AND PRACTICAL APPLICATION
# =============================================================================

print("🚀 Ready to Use OSFT for Your Own Tasks!")
print("=" * 60)
print()

# Code template for your use case
print("💻 Quick Start Template:")
print("-" * 50)
print("```python")
print("from training_hub import osft")
print("")
print("# Your OSFT training configuration")
print("result = osft(")
print("    # Model and data")
print("    model_path='meta-llama/Meta-Llama-3-8B-Instruct',")
print("    data_path='your_task_data.jsonl',")
print("    ckpt_output_dir='./checkpoints/your_experiment',")
print("    ")
print("    # Choose based on your use case:")
print("    # - Small tweaks: 0.10-0.15")
print("    # - Major capability: 0.30-0.35") 
print("    # - Sequential training: reduce by 0.05 each phase")
print("    unfreeze_rank_ratio=0.2,  # Adjust based on guidance above")
print("    ")
print("    # Standard training parameters")
print("    num_epochs=1,")
print("    effective_batch_size=64,")
print("    learning_rate=5e-6,")
print("    max_seq_len=8192,")
print("    max_tokens_per_gpu=10000,")
print(")")
print("```")
print()

# Testing preservation
print("🧪 Testing Model Preservation:")
print("-" * 50)
print("After training, always test that original capabilities remain:")
print()
print("1. Test on original model's strong areas (general knowledge, reasoning)")
print("2. Test on your newly trained capability")
print("3. Compare outputs to ensure both work well")
print()

# Common use cases
print("💡 Common OSFT Use Cases:")
print("-" * 50)
print("• Adding structured output formats (JSON, XML, tables)")
print("• Teaching domain-specific knowledge without losing general ability")
print("• Adding new language support while preserving others")
print("• Sequential skill building (basic → intermediate → advanced)")
print("• Customizing response style without breaking functionality")
print()

# Final recommendations
print("🎯 Final Recommendations:")
print("-" * 50)
print("1. Start with our recommended ratios - they work well")
print("2. Use small datasets first to test your approach")
print("3. Always evaluate preservation alongside new capabilities")
print("4. For production: use the script version for better control")
print("5. For sequential training - produce multiple candidate models and advance only the best performing one to the next phase")
print()
print("Happy training with OSFT - where your model learns without forgetting! 🚀")


## Next Steps

Model training has been completed successfully!
You can now proceed to the [Evaluation](../05_Evaluation/Evaluation.ipynb) notebook to assess the performance of your fine-tuned model.