# OSFT Dataset Scaling Guide: Adapting Hyperparameters to Data Size

This notebook demonstrates how to adapt OSFT (Orthogonal Subspace Fine-Tuning) hyperparameters based on your dataset size. As OSFT is suitable for training on datasets of any scale - from small domain-specific sets to large-scale instruction tuning datasets - understanding how to properly scale your hyperparameters is crucial for optimal performance.

## Key Principle: Scale Your Batch Size with Your Data

One of the most important hyperparameters to adjust based on dataset size is the **batch size**. Larger datasets generally benefit from larger batch sizes for several reasons:

1. **Training Efficiency**: Larger batches better utilize GPU resources with bigger datasets
2. **Gradient Stability**: More samples per update provide more stable gradient estimates
3. **Convergence**: Appropriate batch sizes help the model converge effectively

**‚ö†Ô∏è Important**: The configurations shown here are **illustrative examples**, not prescriptions. Finding the optimal hyperparameters for your specific use case requires experimentation and iterative refinement.


## Why Batch Size Matters at Different Scales

### Small Datasets (1K samples)
- **Risk**: Large batches might see the entire dataset in just a few steps
- **Solution**: Use smaller batch sizes (e.g., 16) for more gradient updates per epoch
- **Benefit**: Model gets more opportunities to learn from limited data

### Medium Datasets (10K samples)  
- **Balance**: Need efficiency without overfitting
- **Solution**: Moderate batch sizes (e.g., 128) provide good gradient estimates
- **Benefit**: Efficient training with stable convergence

### Large Datasets (100K+ samples)
- **Challenge**: Training time becomes a major factor
- **Solution**: Large batch sizes (e.g., 1024) maximize throughput
- **Benefit**: Faster training with stable gradients from diverse samples


## Common Configuration

First, let's define the parameters that remain constant across different dataset sizes.


In [1]:
# Model configuration - using Llama 3.1 8B Instruct
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B-Instruct"

# OSFT-specific parameters (constant across dataset sizes)
UNFREEZE_RANK_RATIO = 0.3  # Balanced preservation vs adaptation
MAX_SEQ_LEN = 8192         # Llama 3.1 supports up to 128K, but 8K is practical
LEARNING_RATE = 5e-6       # Standard learning rate for fine-tuning
NUM_EPOCHS = 3             # Adjust based on convergence

# Hardware constraints (adjust based on your GPU)
MAX_TOKENS_PER_GPU = 10000  # For A100 40GB or similar

# Distributed training setup (single node, 8 GPUs)
NPROC_PER_NODE = 8

print("Common Configuration:")
print(f"  Model: {MODEL_PATH}")
print(f"  OSFT Unfreeze Ratio: {UNFREEZE_RANK_RATIO}")
print(f"  Max Sequence Length: {MAX_SEQ_LEN:,}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Training Epochs: {NUM_EPOCHS}")
print(f"  GPUs: {NPROC_PER_NODE}")


Common Configuration:
  Model: meta-llama/Meta-Llama-3.1-8B-Instruct
  OSFT Unfreeze Ratio: 0.3
  Max Sequence Length: 8,192
  Learning Rate: 5e-06
  Training Epochs: 3
  GPUs: 8


## Example 1: Small Dataset (1K samples)

For small, specialized datasets, we use smaller batch sizes to ensure the model sees sufficient gradient updates.


In [2]:
# Configuration for 1K sample dataset
small_dataset_config = {
    "dataset_size": "1K samples",
    "data_path": "/path/to/your/small_dataset_1k_samples.jsonl",  # Replace with your path
    "effective_batch_size": 16,  # Small batch size for more gradient updates
    "warmup_steps": 50,          # Quick warmup for small dataset
    "use_case": "Domain-specific terminology or specialized knowledge"
}

# Calculate training dynamics
steps_per_epoch_1k = 1000 // small_dataset_config["effective_batch_size"]
total_steps_1k = steps_per_epoch_1k * NUM_EPOCHS

print("üî¨ Small Dataset Configuration (1K samples):")
print(f"  Effective Batch Size: {small_dataset_config['effective_batch_size']}")
print(f"  Steps per Epoch: ~{steps_per_epoch_1k}")
print(f"  Total Training Steps: ~{total_steps_1k}")
print(f"  Use Case: {small_dataset_config['use_case']}")
print()
print("üí° Rationale: Small batch size ensures sufficient gradient updates")
print("   despite limited data, helping the model learn nuanced patterns.")


üî¨ Small Dataset Configuration (1K samples):
  Effective Batch Size: 16
  Steps per Epoch: ~62
  Total Training Steps: ~186
  Use Case: Domain-specific terminology or specialized knowledge

üí° Rationale: Small batch size ensures sufficient gradient updates
   despite limited data, helping the model learn nuanced patterns.


## Example 2: Medium Dataset (10K samples)

For medium-sized datasets, we increase the batch size to balance training efficiency with learning effectiveness.


In [3]:
# Configuration for 10K sample dataset
medium_dataset_config = {
    "dataset_size": "10K samples",
    "data_path": "/path/to/your/medium_dataset_10k_samples.jsonl",  # Replace with your path
    "effective_batch_size": 128,  # Moderate batch size for efficiency
    "warmup_steps": 100,          # Standard warmup
    "use_case": "Domain adaptation or moderate-scale instruction tuning"
}

# Calculate training dynamics
steps_per_epoch_10k = 10000 // medium_dataset_config["effective_batch_size"]
total_steps_10k = steps_per_epoch_10k * NUM_EPOCHS

print("üìä Medium Dataset Configuration (10K samples):")
print(f"  Effective Batch Size: {medium_dataset_config['effective_batch_size']}")
print(f"  Steps per Epoch: ~{steps_per_epoch_10k}")
print(f"  Total Training Steps: ~{total_steps_10k}")
print(f"  Use Case: {medium_dataset_config['use_case']}")
print()
print("üí° Rationale: Moderate batch size balances training efficiency")
print("   with gradient quality, suitable for most domain adaptation tasks.")


üìä Medium Dataset Configuration (10K samples):
  Effective Batch Size: 128
  Steps per Epoch: ~78
  Total Training Steps: ~234
  Use Case: Domain adaptation or moderate-scale instruction tuning

üí° Rationale: Moderate batch size balances training efficiency
   with gradient quality, suitable for most domain adaptation tasks.


## Example 3: Large Dataset (100K samples)

For large-scale datasets, we use larger batch sizes to maximize training efficiency and throughput.


In [4]:
# Configuration for 100K sample dataset
large_dataset_config = {
    "dataset_size": "100K samples",
    "data_path": "/path/to/your/large_dataset_100k_samples.jsonl",  # Replace with your path
    "effective_batch_size": 1024,  # Large batch size for efficiency
    "warmup_steps": 500,           # Extended warmup for large batch
    "use_case": "Large-scale instruction tuning or comprehensive domain coverage"
}

# Calculate training dynamics
steps_per_epoch_100k = 100000 // large_dataset_config["effective_batch_size"]
total_steps_100k = steps_per_epoch_100k * NUM_EPOCHS

print("üìà Large Dataset Configuration (100K samples):")
print(f"  Effective Batch Size: {large_dataset_config['effective_batch_size']}")
print(f"  Steps per Epoch: ~{steps_per_epoch_100k}")
print(f"  Total Training Steps: ~{total_steps_100k}")
print(f"  Use Case: {large_dataset_config['use_case']}")
print()
print("üí° Rationale: Large batch size maximizes GPU utilization and")
print("   training throughput while maintaining stable gradients.")


üìà Large Dataset Configuration (100K samples):
  Effective Batch Size: 1024
  Steps per Epoch: ~97
  Total Training Steps: ~291
  Use Case: Large-scale instruction tuning or comprehensive domain coverage

üí° Rationale: Large batch size maximizes GPU utilization and
   training throughput while maintaining stable gradients.


## Batch Size Scaling Summary

Here's a visual summary of how batch size scales with dataset size:


In [5]:
print("üìä Batch Size Scaling Summary:")
print("="*60)
print(f"{'Dataset Size':<15} {'Batch Size':<12} {'Steps/Epoch':<12} {'Total Steps':<12}")
print("="*60)
print(f"{'1K samples':<15} {16:<12} {steps_per_epoch_1k:<12} {total_steps_1k:<12}")
print(f"{'10K samples':<15} {128:<12} {steps_per_epoch_10k:<12} {total_steps_10k:<12}")
print(f"{'100K samples':<15} {1024:<12} {steps_per_epoch_100k:<12} {total_steps_100k:<12}")
print("="*60)
print()
print("üìà Scaling Pattern:")
print("   As dataset size increases 10x ‚Üí batch size increases ~8x")
print("   This maintains a reasonable number of gradient updates")


üìä Batch Size Scaling Summary:
Dataset Size    Batch Size   Steps/Epoch  Total Steps 
1K samples      16           62           186         
10K samples     128          78           234         
100K samples    1024         97           291         

üìà Scaling Pattern:
   As dataset size increases 10x ‚Üí batch size increases ~8x
   This maintains a reasonable number of gradient updates


## Other Hyperparameters to Consider

While batch size is a primary scaling factor, other hyperparameters may also need adjustment:

### 1. Learning Rate
- **Small datasets**: Consider slightly lower LR (e.g., 1e-5) to avoid overfitting
- **Large datasets**: Can use standard or slightly higher LR (e.g., 2e-5 to 5e-5)
- **With large batches**: May need to scale LR with sqrt(batch_size) or linear scaling

### 2. Number of Epochs
- **Small datasets**: More epochs (3-5) might be beneficial
- **Large datasets**: Fewer epochs (1-3) often sufficient
- Monitor validation metrics to avoid overfitting

### 3. Warmup Steps
- Scale with batch size: larger batches often benefit from longer warmup
- Rule of thumb: 5-10% of total training steps

### 4. OSFT-Specific: unfreeze_rank_ratio
- Generally consistent across dataset sizes (0.25-0.35)
- Depends more on task complexity than dataset size
- May slightly increase for very large, diverse datasets



## Important Considerations

### ‚ö†Ô∏è These Are Starting Points, Not Rules

The configurations shown are **illustrative examples** to demonstrate the scaling principle. Your optimal settings will depend on:

1. **Data Characteristics**:
   - Domain complexity
   - Sample diversity
   - Average sequence length
   - Task difficulty

2. **Model Factors**:
   - Model size (7B, 13B, 70B, etc.)
   - Pre-training quality
   - Current capabilities

3. **Hardware Constraints**:
   - GPU memory
   - Number of GPUs
   - Training time budget

4. **Quality Requirements**:
   - Acceptable performance threshold
   - Preservation vs adaptation balance
   - Downstream task needs

### üî¨ Experimentation is Key

Always validate your hyperparameters through:
- Small-scale experiments first
- Monitoring training metrics
- Validation set performance
- Downstream task evaluation


## Practical Experimentation Strategy

Here's a suggested approach for finding optimal hyperparameters:


üî¨ **Hyperparameter Tuning Strategy:**

1. **Start with a subset of your data (10-20%)**
2. **Try 3 different batch sizes:**
   - Conservative: `dataset_size / 100`
   - Moderate: `dataset_size / 50`
   - Aggressive: `dataset_size / 25`

3. **For each batch size, monitor:**
   - Training loss curve
   - Validation performance
   - Training time per epoch
   - GPU memory utilization

4. **Select the configuration that balances:**
   - Best validation performance
   - Reasonable training time
   - Stable convergence

5. **Scale to full dataset with chosen parameters**

üí° *Pro tip: Create a simple grid search script to automate this process!*

## Conclusion

OSFT's ability to handle datasets of any scale makes it a versatile replacement for traditional LAB multiphase approaches. The key insight is that **batch size should scale with dataset size** to maintain training efficiency and model quality.

### Key Takeaways:

1. **Small datasets (1K)**: Use small batch sizes (16-32) for sufficient gradient updates
2. **Medium datasets (10K)**: Use moderate batch sizes (128-256) for balanced training
3. **Large datasets (100K+)**: Use large batch sizes (512-2048) for efficiency

### Remember:

- These are **starting points**, not prescriptions
- **Experimentation** is essential for finding optimal settings
- **Monitor metrics** to guide your hyperparameter choices
- **Document what works** for future reference

### Next Steps:

1. Prepare your dataset in the required JSONL format
2. Start with the suggested batch size for your data scale
3. Run initial experiments with a data subset
4. Iterate and refine based on results
5. Scale to full training once parameters are optimized

Happy training with OSFT! üöÄ


### Note on `effective_batch_size` vs `max_tokens_per_gpu`

- `effective_batch_size`: controls how many samples are aggregated per optimization step (including any gradient accumulation). This directly impacts the number of updates per epoch and training dynamics.
- `max_tokens_per_gpu`: a hardware-capacity setting that limits how many tokens are placed on a single GPU at once to avoid OOM. It constrains per-step memory usage but does not change the `effective_batch_size`.

Put simply: adjusting `max_tokens_per_gpu` helps you fit training into memory; it does not increase or decrease the `effective_batch_size`. If you need a larger `effective_batch_size`, increase the batch size and/or use gradient accumulation; if you hit memory limits, reduce `max_tokens_per_gpu`, the per-device micro-batch, or sequence length.

