# ü§ñ CLaRa Training on Colab

[![Paper](https://img.shields.io/badge/Paper-Arxiv-green)](https://arxiv.org/abs/2511.18659)
[![GitHub](https://img.shields.io/badge/GitHub-ml--clara-blue)](https://github.com/apple/ml-clara)

**CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning**

This notebook provides a complete training pipeline for CLaRa model on Google Colab.

---

## üìã Training Pipeline

1. **Environment Setup** - Check GPU, install dependencies
2. **Code & Data Preparation** - Clone repository, prepare training data
3. **Stage 1: Compression Pretraining** - Train compressor with KPCP framework
4. **Stage 2: Instruction Tuning** - Fine-tune on instruction-following tasks
5. **Stage 3: End-to-End Training** - Joint training of reranker and generator
6. **Model Inference & Export** - Test model and save checkpoints

---

### ‚öôÔ∏è Configuration

**Hardware Requirements:**
- GPU: T4 (16GB) or better (A100 recommended)
- RAM: High RAM (25GB+)
- Runtime: GPU with High RAM

**Training Settings:**
- Base Model: `mistralai/Mistral-7B-Instruct-v0.2`
- Compression Rate: 32x
- Training Framework: OpenRLHF + DeepSpeed ZeRO-2
- Batch Size: Adaptive based on available GPU memory


---
## 1Ô∏è‚É£ Environment Setup

Check GPU availability and system information.

In [None]:
# Check GPU and CUDA
!nvidia-smi
print('\n' + '='*60)
import torch
print(f'PyTorch Version: {torch.__version__}')
print(f'CUDA Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA Version: {torch.version.cuda}')
    print(f'GPU Device: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB')
print('='*60)

---
## 2Ô∏è‚É£ Install Dependencies

Install required packages. This may take 5-10 minutes.

In [None]:
%%time
# Install core dependencies
print('üì¶ Installing core dependencies...')

# Install basic packages first
!pip install -q accelerate==1.10.1 transformers==4.56.2 datasets==3.2.0 \
    peft==0.17.1 einops==0.8.1 sentencepiece==0.2.0 tiktoken==0.11.0

print('‚úÖ Core packages installed')

# Fix fsspec version conflict with gcsfs
print('\nüì¶ Fixing fsspec version conflict...')
!pip install -q --upgrade fsspec==2025.3.0
print('‚úÖ fsspec upgraded to match gcsfs requirements')

# Install DeepSpeed (may fail on some systems)
print('\nüì¶ Installing DeepSpeed...')
try:
    !pip install -q deepspeed==0.18.1
    import deepspeed
    print(f'‚úÖ DeepSpeed {deepspeed.__version__} installed')
except Exception as e:
    print(f'‚ö†Ô∏è  DeepSpeed installation failed: {e}')
    print('   Training will continue without DeepSpeed optimizations')

# Install WandB (optional)
print('\nüì¶ Installing WandB (optional)...')
!pip install -q wandb==0.22.2
print('‚úÖ WandB installed')

print('\nüéâ Dependencies installation complete!')

### Optional: Install Flash Attention (Recommended for speed)

Flash Attention can speed up training by ~15%. Skip if installation fails.

In [None]:
# Option 1: Try precompiled version (fast)
# !pip install flash-attn --no-build-isolation

# Option 2: Skip flash-attn (training still works)
print('‚ö†Ô∏è Skipping flash-attn installation')
print('Training will use standard attention (slightly slower but fully functional)')
USE_FLASH_ATTN = False

---
## 3Ô∏è‚É£ Download Code and Data

Clone CLaRa repository and OpenRLHF framework.

**Note:** This notebook uses a custom fork with the following fixes:
- ‚úÖ Flash Attention made optional (with fallback implementations)
- ‚úÖ Dependency conflicts resolved (fsspec, gcsfs)
- ‚úÖ Colab-optimized training pipeline

The custom fork ensures training works smoothly without requiring flash_attn installation.

In [None]:
%%time
import os

# Clone CLaRa repository (with flash_attn fallback fixes and complete OpenRLHF integration)
if not os.path.exists('ml-clara'):
    # Use custom fork with fixes instead of official repo
    !git clone https://github.com/xucheng/ml-clara-rag.git ml-clara
    print('‚úÖ CLaRa repository cloned (with fixes)')
    print('   - Includes flash_attn fallback implementation')
    print('   - Includes complete OpenRLHF framework')
else:
    print('‚úÖ CLaRa repository already exists')
    # Pull latest changes if repository exists
    print('Pulling latest changes...')
    !cd ml-clara && git pull origin main
    print('‚úÖ Repository updated')

# Verify OpenRLHF is included
print('\nüì¶ Verifying OpenRLHF framework...')
if os.path.exists('ml-clara/openrlhf'):
    import subprocess
    file_count = subprocess.check_output(
        'find ml-clara/openrlhf -type f -name "*.py" | wc -l',
        shell=True
    ).decode().strip()
    print(f'‚úÖ OpenRLHF framework ready ({file_count} Python files)')
else:
    print('‚ùå OpenRLHF not found - please check repository')

# Change to project directory
%cd ml-clara
!pwd

---
## 4Ô∏è‚É£ Data Preparation

You have two options:
1. **Use example data** (provided in repository) - Quick start
2. **Upload your own data** - For custom training

### Data Format

**Stage 1 (Pretraining)**: `pretrain_data.jsonl`
```json
{"data_type": "qa", "question": ["Q1"], "answers": ["A1"], "docs": ["doc1"]}
```

**Stage 2 (Instruction Tuning)**: `instruction_data.jsonl`
```json
{"question": "Q1", "docs": ["doc1"], "gold_answer": "A1"}
```

**Stage 3 (End-to-End)**: `end_to_end_data.jsonl` (same as Stage 2)

In [None]:
# Check example data
!ls -lh example/*.jsonl
print('\nüìä Example data statistics:')
!wc -l example/*.jsonl

### Option A: Use Example Data (Recommended for first run)

The repository includes small example datasets for quick testing.

In [None]:
# Use example data (already in repository)
DATA_MODE = 'example'  # or 'custom'

if DATA_MODE == 'example':
    PRETRAIN_DATA = 'example/pretrain_data.jsonl'
    INSTRUCTION_DATA = 'example/instruction_data.jsonl'
    END_TO_END_DATA = 'example/end_to_end_data.jsonl'
    print('‚úÖ Using example data from repository')
    print(f'  - Pretraining: {PRETRAIN_DATA}')
    print(f'  - Instruction: {INSTRUCTION_DATA}')
    print(f'  - End-to-End: {END_TO_END_DATA}')

### Option B: Load from Google Drive

Mount Google Drive and use data files stored there.

**Example paths in Google Drive:**
- `/content/drive/MyDrive/Colab Notebooks/data/ml-clara/pretrain_data.jsonl`
- `/content/drive/MyDrive/data/ml-clara/instruction_data.jsonl`

Run the cell below to mount Drive and set paths.

In [None]:
import os

# Detect environment
try:
    from google.colab import drive
    IS_COLAB = True
except ImportError:
    IS_COLAB = False

if IS_COLAB:
    # Mount Google Drive
    print('üìÇ Mounting Google Drive...')
    drive.mount('/content/drive')
    print('‚úÖ Google Drive mounted at /content/drive')
    
    # Modify these paths to match your Drive folder structure
    # Example: If your files are in "Colab Notebooks/data/ml-clara/"
    DRIVE_BASE = '/content/drive/MyDrive/Colab Notebooks/data/ml-clara'
    
    PRETRAIN_DATA = f'{DRIVE_BASE}/pretrain_data.jsonl'
    INSTRUCTION_DATA = f'{DRIVE_BASE}/instruction_data.jsonl'
    END_TO_END_DATA = f'{DRIVE_BASE}/end_to_end_data.jsonl'
    
    print(f'\nüìÅ Looking for data in: {DRIVE_BASE}')
    
    # Verify files exist
    all_found = True
    for name, path in [('Pretrain', PRETRAIN_DATA),
                       ('Instruction', INSTRUCTION_DATA),
                       ('End-to-End', END_TO_END_DATA)]:
        if os.path.exists(path):
            file_size = os.path.getsize(path) / 1024  # KB
            print(f'‚úÖ {name}: {path} ({file_size:.1f} KB)')
        else:
            print(f'‚ùå {name}: {path} (NOT FOUND)')
            all_found = False
    
    if all_found:
        DATA_MODE = 'drive'
        print(f'\n‚úÖ All data files found in Google Drive')
    else:
        print(f'\n‚ö†Ô∏è  Some files not found. Please check:')
        print(f'   1. Files are uploaded to: {DRIVE_BASE}')
        print(f'   2. Folder path is correct (including spaces)')
        print(f'   3. File names match exactly')
        print(f'\nüí° To fix: Update DRIVE_BASE path in this cell')
else:
    print('‚ö†Ô∏è  Not in Google Colab environment')
    print('This cell is designed for Google Colab with Drive mounting')
    print('Use Option A (example data) or Option C (local paths) instead')

### Option C: Upload Files or Use Local Paths

This cell automatically detects your environment:

**In Google Colab:**
- Uses the upload widget (`files.upload()`)
- Simply run the cell and select files when prompted

**In Local/VS Code:**
- Uses file paths instead
- Modify the paths to point to your local data files

Run the cell to load custom data.

In [None]:
import os
import sys

# Detect environment
try:
    from google.colab import files
    IS_COLAB = True
except ImportError:
    IS_COLAB = False

print(f'Environment: {"Google Colab" if IS_COLAB else "Local/VS Code"}')

# Option 1: For Google Colab - Use upload widget
if IS_COLAB:
    print('\nüì§ Upload your custom data files:')
    
    print('\n1Ô∏è‚É£ Upload pretrain_data.jsonl:')
    uploaded = files.upload()
    
    print('\n2Ô∏è‚É£ Upload instruction_data.jsonl:')
    uploaded = files.upload()
    
    print('\n3Ô∏è‚É£ Upload end_to_end_data.jsonl:')
    uploaded = files.upload()
    
    # Move to data directory
    !mkdir -p data
    !mv pretrain_data.jsonl instruction_data.jsonl end_to_end_data.jsonl data/
    
    DATA_MODE = 'custom'
    PRETRAIN_DATA = 'data/pretrain_data.jsonl'
    INSTRUCTION_DATA = 'data/instruction_data.jsonl'
    END_TO_END_DATA = 'data/end_to_end_data.jsonl'
    print('\n‚úÖ Custom data uploaded')

# Option 2: For Local/VS Code - Specify file paths
else:
    print('\nüìÅ Using local file paths')
    print('Please modify the paths below to point to your data files:\n')
    
    # Modify these paths to your actual data locations
    PRETRAIN_DATA = 'example/pretrain_data.jsonl'  # Change this
    INSTRUCTION_DATA = 'example/instruction_data.jsonl'  # Change this
    END_TO_END_DATA = 'example/end_to_end_data.jsonl'  # Change this
    
    # Verify files exist
    missing_files = []
    for name, path in [('Pretrain', PRETRAIN_DATA), 
                       ('Instruction', INSTRUCTION_DATA), 
                       ('End-to-End', END_TO_END_DATA)]:
        if os.path.exists(path):
            print(f'‚úÖ {name}: {path}')
        else:
            print(f'‚ùå {name}: {path} (NOT FOUND)')
            missing_files.append(path)
    
    if missing_files:
        print(f'\n‚ö†Ô∏è  Warning: {len(missing_files)} file(s) not found')
        print('Please update the file paths in this cell or use example data')
    else:
        DATA_MODE = 'custom'
        print('\n‚úÖ All custom data files found')

---
## 5Ô∏è‚É£ Training Configuration

Set up training parameters. Adjust based on your GPU memory.

In [None]:
import torch

# Detect GPU memory and set batch sizes
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f'GPU Memory: {gpu_memory:.1f} GB')
    
    if gpu_memory < 20:  # T4 (16GB)
        TRAIN_BATCH_SIZE = 32
        MICRO_BATCH_SIZE = 1
        NUM_GPUS = 1
        MAX_SAMPLES = 200
        print('‚öôÔ∏è Using T4 config (16GB)')
    elif gpu_memory < 50:  # V100 or A100-40GB
        TRAIN_BATCH_SIZE = 64
        MICRO_BATCH_SIZE = 2
        NUM_GPUS = 1
        MAX_SAMPLES = 500
        print('‚öôÔ∏è Using V100/A100-40GB config')
    else:  # A100-80GB
        TRAIN_BATCH_SIZE = 128
        MICRO_BATCH_SIZE = 2
        NUM_GPUS = 1
        MAX_SAMPLES = 1000
        print('‚öôÔ∏è Using A100-80GB config')
else:
    raise RuntimeError('‚ùå No GPU available. Please enable GPU runtime.')

# Model and checkpoint paths
MODEL_PATH = 'mistralai/Mistral-7B-Instruct-v0.2'
CHECKPOINT_DIR = '/content/checkpoints'

# Training settings
LEARNING_RATE = 1e-4
MAX_EPOCHS = 1
COMPRESS_RATE = 32
DOC_MAX_LENGTH = 256
MAX_LEN = 2048

# Flash attention flag
FLASH_ATTN_FLAG = '--flash_attn' if USE_FLASH_ATTN else ''

print(f'\nüìù Training Configuration:')
print(f'  Model: {MODEL_PATH}')
print(f'  Batch Size: {TRAIN_BATCH_SIZE}')
print(f'  Micro Batch Size: {MICRO_BATCH_SIZE}')
print(f'  Max Samples: {MAX_SAMPLES}')
print(f'  Learning Rate: {LEARNING_RATE}')
print(f'  Compress Rate: {COMPRESS_RATE}x')
print(f'  Flash Attention: {USE_FLASH_ATTN}')

---
## 6Ô∏è‚É£ Stage 1: Compression Pretraining

Train the compressor using KPCP framework with QA pairs and paraphrases.

**What happens:**
- Compress documents into continuous latent representations
- Learn semantic compression through QA-based supervision
- Support compression rates from 1x to 256x

**Expected time:** 10-30 minutes (depends on data size and GPU)

In [None]:
%%time
# Stage 1: Compression Pretraining
print('üöÄ Starting Stage 1: Compression Pretraining\n')

!torchrun --nproc_per_node={NUM_GPUS} \
    --master_port=29500 \
    -m openrlhf.cli.train_sft \
    --max_len {MAX_LEN} \
    --dataset {PRETRAIN_DATA} \
    --pretrain {MODEL_PATH} \
    --train_batch_size {TRAIN_BATCH_SIZE} \
    --micro_train_batch_size {MICRO_BATCH_SIZE} \
    --max_samples {MAX_SAMPLES} \
    --save_path {CHECKPOINT_DIR}/clara_stage1 \
    --save_steps -2 \
    --logging_steps 5 \
    --eval_steps -1 \
    --zero_stage 2 \
    --max_epochs {MAX_EPOCHS} \
    --bf16 \
    {FLASH_ATTN_FLAG} \
    --learning_rate {LEARNING_RATE} \
    --stage stage1 \
    --generation_top_k 1 \
    --qa_loss \
    --doc_max_length {DOC_MAX_LENGTH} \
    --compress_rate {COMPRESS_RATE} \
    --mse_loss \
    --gradient_checkpointing

print('\n‚úÖ Stage 1 completed!')
print(f'Checkpoint saved to: {CHECKPOINT_DIR}/clara_stage1')

### Check Stage 1 Checkpoint

In [None]:
# Verify checkpoint
!ls -lh {CHECKPOINT_DIR}/clara_stage1/
!du -sh {CHECKPOINT_DIR}/clara_stage1/

---
## 7Ô∏è‚É£ Stage 2: Instruction Tuning

Fine-tune the compressor on instruction-following tasks.

**What happens:**
- Load Stage 1 checkpoint
- Fine-tune on downstream QA tasks
- Ensure compressed representations retain sufficient semantics

**Expected time:** 10-30 minutes

In [None]:
%%time
# Stage 2: Instruction Tuning
print('üöÄ Starting Stage 2: Instruction Tuning\n')

!torchrun --nproc_per_node={NUM_GPUS} \
    --master_port=29500 \
    -m openrlhf.cli.train_sft \
    --max_len {MAX_LEN} \
    --dataset {INSTRUCTION_DATA} \
    --pretrain {MODEL_PATH} \
    --ckpt_path {CHECKPOINT_DIR}/clara_stage1 \
    --train_batch_size {TRAIN_BATCH_SIZE} \
    --micro_train_batch_size {MICRO_BATCH_SIZE} \
    --max_samples {MAX_SAMPLES} \
    --save_path {CHECKPOINT_DIR}/clara_stage2 \
    --save_steps -2 \
    --logging_steps 5 \
    --eval_steps -1 \
    --zero_stage 2 \
    --max_epochs {MAX_EPOCHS} \
    --bf16 \
    {FLASH_ATTN_FLAG} \
    --learning_rate {LEARNING_RATE} \
    --stage stage2 \
    --generation_top_k 1 \
    --doc_max_length {DOC_MAX_LENGTH} \
    --compress_rate {COMPRESS_RATE} \
    --gradient_checkpointing

print('\n‚úÖ Stage 2 completed!')
print(f'Checkpoint saved to: {CHECKPOINT_DIR}/clara_stage2')

In [None]:
# Verify checkpoint
!ls -lh {CHECKPOINT_DIR}/clara_stage2/
!du -sh {CHECKPOINT_DIR}/clara_stage2/

---
## 8Ô∏è‚É£ Stage 3: End-to-End Fine-tuning

Jointly train reranker and generator.

**What happens:**
- Load Stage 2 checkpoint
- Unify retrieval and generation in shared continuous space
- Use differentiable top-k estimator
- Train via single language modeling loss

**Expected time:** 15-40 minutes

In [None]:
%%time
# Stage 3: End-to-End Training
print('üöÄ Starting Stage 3: End-to-End Fine-tuning\n')

!torchrun --nproc_per_node={NUM_GPUS} \
    --master_port=29500 \
    -m openrlhf.cli.train_sft \
    --max_len {MAX_LEN} \
    --dataset {END_TO_END_DATA} \
    --pretrain {MODEL_PATH} \
    --ckpt_path {CHECKPOINT_DIR}/clara_stage2 \
    --train_batch_size {TRAIN_BATCH_SIZE} \
    --micro_train_batch_size {MICRO_BATCH_SIZE} \
    --max_samples {MAX_SAMPLES} \
    --save_path {CHECKPOINT_DIR}/clara_stage3_final \
    --save_steps -2 \
    --logging_steps 5 \
    --eval_steps -1 \
    --zero_stage 2 \
    --max_epochs {MAX_EPOCHS} \
    --bf16 \
    {FLASH_ATTN_FLAG} \
    --learning_rate {LEARNING_RATE} \
    --stage stage3 \
    --generation_top_k 5 \
    --doc_max_length {DOC_MAX_LENGTH} \
    --compress_rate {COMPRESS_RATE} \
    --gradient_checkpointing

print('\n‚úÖ Stage 3 completed!')
print(f'Final model saved to: {CHECKPOINT_DIR}/clara_stage3_final')

In [None]:
# Verify final checkpoint
!ls -lh {CHECKPOINT_DIR}/clara_stage3_final/
!du -sh {CHECKPOINT_DIR}/clara_stage3_final/

print('\nüéâ Training pipeline completed!')
print('\nüìÅ All checkpoints:')
!ls -lh {CHECKPOINT_DIR}/

---
## 9Ô∏è‚É£ Model Inference Test

Test the trained model with a sample query.

In [None]:
# Load trained model for inference
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load model
model_path = f'{CHECKPOINT_DIR}/clara_stage3_final'
print(f'Loading model from: {model_path}')

try:
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map='auto'
    )
    print('‚úÖ Model loaded successfully')
    
    # Test inference
    test_query = "What is CLaRa?"
    test_doc = "CLaRa is a framework that bridges retrieval and generation with continuous latent reasoning."
    
    prompt = f"Document: {test_doc}\n\nQuestion: {test_query}\n\nAnswer:"
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        do_sample=True
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f'\nüìù Test Query: {test_query}')
    print(f'ü§ñ Model Response:')
    print(response)
    
except Exception as e:
    print(f'‚ùå Error loading model: {e}')
    print('This is expected if training was skipped or checkpoint format needs adjustment.')

---
## üîü Export and Save Model

Download your trained model to local machine or save to Google Drive.

### Option A: Download to Local Machine

In [None]:
# Create a zip archive of the final model
!apt-get install -y zip
!cd {CHECKPOINT_DIR} && zip -r clara_stage3_final.zip clara_stage3_final/

# Download
from google.colab import files
# files.download(f'{CHECKPOINT_DIR}/clara_stage3_final.zip')
print(f'Model archived to: {CHECKPOINT_DIR}/clara_stage3_final.zip')
print('Uncomment the download line above to download to your local machine')

### Option B: Save to Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
# drive.mount('/content/drive')

# Copy checkpoint to Drive
# !cp -r {CHECKPOINT_DIR}/clara_stage3_final /content/drive/MyDrive/
# print('‚úÖ Model saved to Google Drive')

print('Uncomment the lines above to save to Google Drive')

---
## ‚úÖ Training Summary

### Checkpoints Created:
1. **Stage 1**: `/content/checkpoints/clara_stage1` - Compression pretraining
2. **Stage 2**: `/content/checkpoints/clara_stage2` - Instruction tuning
3. **Stage 3**: `/content/checkpoints/clara_stage3_final` - Final end-to-end model

### Next Steps:
1. **Evaluation**: Use the evaluation scripts in `scripts/evaluation_*.sh`
2. **Fine-tuning**: Continue training with your own data
3. **Deployment**: Export model for inference

### Useful Resources:
- üìÑ [Paper](https://arxiv.org/abs/2511.18659)
- üíª [GitHub](https://github.com/apple/ml-clara)
- ü§ó [HuggingFace Models](https://huggingface.co/probejie)

---

### üìä Troubleshooting

**Out of Memory (OOM):**
- Reduce `TRAIN_BATCH_SIZE` and `MICRO_BATCH_SIZE`
- Decrease `MAX_SAMPLES`
- Use gradient checkpointing (already enabled)

**Training too slow:**
- Install flash-attn (see cell above)
- Use A100 GPU instead of T4
- Reduce data size for testing

**Checkpoint loading errors:**
- Verify checkpoint path exists
- Check disk space
- Ensure previous stage completed successfully

---

**Made with ‚ù§Ô∏è by the CLaRa Team**

If you use this code, please cite:
```bibtex
@article{zhao2024clara,
  title={CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning},
  author={Zhao, Zhihao and others},
  journal={arXiv preprint arXiv:2511.18659},
  year={2024}
}
```
