# 🤖 CLaRa Training on Colab

[![Paper](https://img.shields.io/badge/Paper-Arxiv-green)](https://arxiv.org/abs/2511.18659)
[![GitHub](https://img.shields.io/badge/GitHub-ml--clara-blue)](https://github.com/apple/ml-clara)

**CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning**

This notebook provides a complete training pipeline for CLaRa model on Google Colab.

---

## 📋 Training Pipeline

1. **Environment Setup** - Check GPU, install dependencies
2. **Code & Data Preparation** - Clone repository, prepare training data
3. **Stage 1: Compression Pretraining** - Train compressor with KPCP framework
4. **Stage 2: Instruction Tuning** - Fine-tune on instruction-following tasks
5. **Stage 3: End-to-End Training** - Joint training of reranker and generator
6. **Model Inference & Export** - Test model and save checkpoints

---

### ⚙️ Configuration

**Hardware Requirements:**
- GPU: T4 (16GB) or better (A100 recommended)
- RAM: High RAM (25GB+)
- Runtime: GPU with High RAM

**Training Settings:**
- Base Model: `mistralai/Mistral-7B-Instruct-v0.2`
- Compression Rate: 32x
- Training Framework: OpenRLHF + DeepSpeed ZeRO-2
- Batch Size: Adaptive based on available GPU memory


---
## 1️⃣ Environment Setup

Check GPU availability and system information.

In [None]:
# Check GPU and CUDA
!nvidia-smi
print('\n' + '='*60)
import torch
print(f'PyTorch Version: {torch.__version__}')
print(f'CUDA Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'CUDA Version: {torch.version.cuda}')
    print(f'GPU Device: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB')
print('='*60)

---
## 2️⃣ Install Dependencies

Install required packages. This may take 5-10 minutes.

In [None]:
%%time
# Install core dependencies
!pip install -q accelerate==1.10.1 transformers==4.56.2 datasets==3.2.0 \
    peft==0.17.1 deepspeed==0.18.1 wandb==0.22.2 \
    einops==0.8.1 sentencepiece==0.2.0 tiktoken==0.11.0

print('✅ Core dependencies installed')

# Check DeepSpeed
import deepspeed
print(f'DeepSpeed version: {deepspeed.__version__}')

### Optional: Install Flash Attention (Recommended for speed)

Flash Attention can speed up training by ~15%. Skip if installation fails.

In [None]:
# Option 1: Try precompiled version (fast)
# !pip install flash-attn --no-build-isolation

# Option 2: Skip flash-attn (training still works)
print('⚠️ Skipping flash-attn installation')
print('Training will use standard attention (slightly slower but fully functional)')
USE_FLASH_ATTN = False

---
## 3️⃣ Download Code and Data

Clone CLaRa repository and OpenRLHF framework.

In [None]:
%%time
import os

# Clone CLaRa repository
if not os.path.exists('ml-clara'):
    !git clone https://github.com/apple/ml-clara.git
    print('✅ CLaRa repository cloned')
else:
    print('✅ CLaRa repository already exists')

# Clone OpenRLHF
if not os.path.exists('OpenRLHF_repo'):
    !git clone https://github.com/OpenRLHF/OpenRLHF.git OpenRLHF_repo
    !cp -R -n OpenRLHF_repo/openrlhf/* ml-clara/openrlhf/
    print('✅ OpenRLHF framework integrated')
else:
    print('✅ OpenRLHF already integrated')

# Change to project directory
%cd ml-clara
!pwd

---
## 4️⃣ Data Preparation

You have two options:
1. **Use example data** (provided in repository) - Quick start
2. **Upload your own data** - For custom training

### Data Format

**Stage 1 (Pretraining)**: `pretrain_data.jsonl`
```json
{"data_type": "qa", "question": ["Q1"], "answers": ["A1"], "docs": ["doc1"]}
```

**Stage 2 (Instruction Tuning)**: `instruction_data.jsonl`
```json
{"question": "Q1", "docs": ["doc1"], "gold_answer": "A1"}
```

**Stage 3 (End-to-End)**: `end_to_end_data.jsonl` (same as Stage 2)

In [None]:
# Check example data
!ls -lh example/*.jsonl
print('\n📊 Example data statistics:')
!wc -l example/*.jsonl

### Option A: Use Example Data (Recommended for first run)

The repository includes small example datasets for quick testing.

In [None]:
# Use example data (already in repository)
DATA_MODE = 'example'  # or 'custom'

if DATA_MODE == 'example':
    PRETRAIN_DATA = 'example/pretrain_data.jsonl'
    INSTRUCTION_DATA = 'example/instruction_data.jsonl'
    END_TO_END_DATA = 'example/end_to_end_data.jsonl'
    print('✅ Using example data from repository')
    print(f'  - Pretraining: {PRETRAIN_DATA}')
    print(f'  - Instruction: {INSTRUCTION_DATA}')
    print(f'  - End-to-End: {END_TO_END_DATA}')

### Option B: Upload Your Own Data

Upload your custom training data files. Uncomment and run if needed.

In [None]:
# Uncomment to upload custom data
# from google.colab import files
# 
# print('Upload pretrain_data.jsonl:')
# uploaded = files.upload()
# !mkdir -p data
# !mv pretrain_data.jsonl data/
# 
# print('Upload instruction_data.jsonl:')
# uploaded = files.upload()
# !mv instruction_data.jsonl data/
# 
# print('Upload end_to_end_data.jsonl:')
# uploaded = files.upload()
# !mv end_to_end_data.jsonl data/
# 
# DATA_MODE = 'custom'
# PRETRAIN_DATA = 'data/pretrain_data.jsonl'
# INSTRUCTION_DATA = 'data/instruction_data.jsonl'
# END_TO_END_DATA = 'data/end_to_end_data.jsonl'
# print('✅ Custom data uploaded')

---
## 5️⃣ Training Configuration

Set up training parameters. Adjust based on your GPU memory.

In [None]:
import torch

# Detect GPU memory and set batch sizes
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f'GPU Memory: {gpu_memory:.1f} GB')
    
    if gpu_memory < 20:  # T4 (16GB)
        TRAIN_BATCH_SIZE = 32
        MICRO_BATCH_SIZE = 1
        NUM_GPUS = 1
        MAX_SAMPLES = 200
        print('⚙️ Using T4 config (16GB)')
    elif gpu_memory < 50:  # V100 or A100-40GB
        TRAIN_BATCH_SIZE = 64
        MICRO_BATCH_SIZE = 2
        NUM_GPUS = 1
        MAX_SAMPLES = 500
        print('⚙️ Using V100/A100-40GB config')
    else:  # A100-80GB
        TRAIN_BATCH_SIZE = 128
        MICRO_BATCH_SIZE = 2
        NUM_GPUS = 1
        MAX_SAMPLES = 1000
        print('⚙️ Using A100-80GB config')
else:
    raise RuntimeError('❌ No GPU available. Please enable GPU runtime.')

# Model and checkpoint paths
MODEL_PATH = 'mistralai/Mistral-7B-Instruct-v0.2'
CHECKPOINT_DIR = '/content/checkpoints'

# Training settings
LEARNING_RATE = 1e-4
MAX_EPOCHS = 1
COMPRESS_RATE = 32
DOC_MAX_LENGTH = 256
MAX_LEN = 2048

# Flash attention flag
FLASH_ATTN_FLAG = '--flash_attn' if USE_FLASH_ATTN else ''

print(f'\n📝 Training Configuration:')
print(f'  Model: {MODEL_PATH}')
print(f'  Batch Size: {TRAIN_BATCH_SIZE}')
print(f'  Micro Batch Size: {MICRO_BATCH_SIZE}')
print(f'  Max Samples: {MAX_SAMPLES}')
print(f'  Learning Rate: {LEARNING_RATE}')
print(f'  Compress Rate: {COMPRESS_RATE}x')
print(f'  Flash Attention: {USE_FLASH_ATTN}')

---
## 6️⃣ Stage 1: Compression Pretraining

Train the compressor using KPCP framework with QA pairs and paraphrases.

**What happens:**
- Compress documents into continuous latent representations
- Learn semantic compression through QA-based supervision
- Support compression rates from 1x to 256x

**Expected time:** 10-30 minutes (depends on data size and GPU)

In [None]:
%%time
# Stage 1: Compression Pretraining
print('🚀 Starting Stage 1: Compression Pretraining\n')

!torchrun --nproc_per_node={NUM_GPUS} \
    --master_port=29500 \
    -m openrlhf.cli.train_sft \
    --max_len {MAX_LEN} \
    --dataset {PRETRAIN_DATA} \
    --pretrain {MODEL_PATH} \
    --train_batch_size {TRAIN_BATCH_SIZE} \
    --micro_train_batch_size {MICRO_BATCH_SIZE} \
    --max_samples {MAX_SAMPLES} \
    --save_path {CHECKPOINT_DIR}/clara_stage1 \
    --save_steps -2 \
    --logging_steps 5 \
    --eval_steps -1 \
    --zero_stage 2 \
    --max_epochs {MAX_EPOCHS} \
    --bf16 \
    {FLASH_ATTN_FLAG} \
    --learning_rate {LEARNING_RATE} \
    --stage stage1 \
    --generation_top_k 1 \
    --qa_loss \
    --doc_max_length {DOC_MAX_LENGTH} \
    --compress_rate {COMPRESS_RATE} \
    --mse_loss \
    --gradient_checkpointing

print('\n✅ Stage 1 completed!')
print(f'Checkpoint saved to: {CHECKPOINT_DIR}/clara_stage1')

### Check Stage 1 Checkpoint

In [None]:
# Verify checkpoint
!ls -lh {CHECKPOINT_DIR}/clara_stage1/
!du -sh {CHECKPOINT_DIR}/clara_stage1/

---
## 7️⃣ Stage 2: Instruction Tuning

Fine-tune the compressor on instruction-following tasks.

**What happens:**
- Load Stage 1 checkpoint
- Fine-tune on downstream QA tasks
- Ensure compressed representations retain sufficient semantics

**Expected time:** 10-30 minutes

In [None]:
%%time
# Stage 2: Instruction Tuning
print('🚀 Starting Stage 2: Instruction Tuning\n')

!torchrun --nproc_per_node={NUM_GPUS} \
    --master_port=29500 \
    -m openrlhf.cli.train_sft \
    --max_len {MAX_LEN} \
    --dataset {INSTRUCTION_DATA} \
    --pretrain {MODEL_PATH} \
    --ckpt_path {CHECKPOINT_DIR}/clara_stage1 \
    --train_batch_size {TRAIN_BATCH_SIZE} \
    --micro_train_batch_size {MICRO_BATCH_SIZE} \
    --max_samples {MAX_SAMPLES} \
    --save_path {CHECKPOINT_DIR}/clara_stage2 \
    --save_steps -2 \
    --logging_steps 5 \
    --eval_steps -1 \
    --zero_stage 2 \
    --max_epochs {MAX_EPOCHS} \
    --bf16 \
    {FLASH_ATTN_FLAG} \
    --learning_rate {LEARNING_RATE} \
    --stage stage2 \
    --generation_top_k 1 \
    --doc_max_length {DOC_MAX_LENGTH} \
    --compress_rate {COMPRESS_RATE} \
    --gradient_checkpointing

print('\n✅ Stage 2 completed!')
print(f'Checkpoint saved to: {CHECKPOINT_DIR}/clara_stage2')

In [None]:
# Verify checkpoint
!ls -lh {CHECKPOINT_DIR}/clara_stage2/
!du -sh {CHECKPOINT_DIR}/clara_stage2/

---
## 8️⃣ Stage 3: End-to-End Fine-tuning

Jointly train reranker and generator.

**What happens:**
- Load Stage 2 checkpoint
- Unify retrieval and generation in shared continuous space
- Use differentiable top-k estimator
- Train via single language modeling loss

**Expected time:** 15-40 minutes

In [None]:
%%time
# Stage 3: End-to-End Training
print('🚀 Starting Stage 3: End-to-End Fine-tuning\n')

!torchrun --nproc_per_node={NUM_GPUS} \
    --master_port=29500 \
    -m openrlhf.cli.train_sft \
    --max_len {MAX_LEN} \
    --dataset {END_TO_END_DATA} \
    --pretrain {MODEL_PATH} \
    --ckpt_path {CHECKPOINT_DIR}/clara_stage2 \
    --train_batch_size {TRAIN_BATCH_SIZE} \
    --micro_train_batch_size {MICRO_BATCH_SIZE} \
    --max_samples {MAX_SAMPLES} \
    --save_path {CHECKPOINT_DIR}/clara_stage3_final \
    --save_steps -2 \
    --logging_steps 5 \
    --eval_steps -1 \
    --zero_stage 2 \
    --max_epochs {MAX_EPOCHS} \
    --bf16 \
    {FLASH_ATTN_FLAG} \
    --learning_rate {LEARNING_RATE} \
    --stage stage3 \
    --generation_top_k 5 \
    --doc_max_length {DOC_MAX_LENGTH} \
    --compress_rate {COMPRESS_RATE} \
    --gradient_checkpointing

print('\n✅ Stage 3 completed!')
print(f'Final model saved to: {CHECKPOINT_DIR}/clara_stage3_final')

In [None]:
# Verify final checkpoint
!ls -lh {CHECKPOINT_DIR}/clara_stage3_final/
!du -sh {CHECKPOINT_DIR}/clara_stage3_final/

print('\n🎉 Training pipeline completed!')
print('\n📁 All checkpoints:')
!ls -lh {CHECKPOINT_DIR}/

---
## 9️⃣ Model Inference Test

Test the trained model with a sample query.

In [None]:
# Load trained model for inference
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load model
model_path = f'{CHECKPOINT_DIR}/clara_stage3_final'
print(f'Loading model from: {model_path}')

try:
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map='auto'
    )
    print('✅ Model loaded successfully')
    
    # Test inference
    test_query = "What is CLaRa?"
    test_doc = "CLaRa is a framework that bridges retrieval and generation with continuous latent reasoning."
    
    prompt = f"Document: {test_doc}\n\nQuestion: {test_query}\n\nAnswer:"
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        temperature=0.7,
        do_sample=True
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f'\n📝 Test Query: {test_query}')
    print(f'🤖 Model Response:')
    print(response)
    
except Exception as e:
    print(f'❌ Error loading model: {e}')
    print('This is expected if training was skipped or checkpoint format needs adjustment.')

---
## 🔟 Export and Save Model

Download your trained model to local machine or save to Google Drive.

### Option A: Download to Local Machine

In [None]:
# Create a zip archive of the final model
!apt-get install -y zip
!cd {CHECKPOINT_DIR} && zip -r clara_stage3_final.zip clara_stage3_final/

# Download
from google.colab import files
# files.download(f'{CHECKPOINT_DIR}/clara_stage3_final.zip')
print(f'Model archived to: {CHECKPOINT_DIR}/clara_stage3_final.zip')
print('Uncomment the download line above to download to your local machine')

### Option B: Save to Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
# drive.mount('/content/drive')

# Copy checkpoint to Drive
# !cp -r {CHECKPOINT_DIR}/clara_stage3_final /content/drive/MyDrive/
# print('✅ Model saved to Google Drive')

print('Uncomment the lines above to save to Google Drive')

---
## ✅ Training Summary

### Checkpoints Created:
1. **Stage 1**: `/content/checkpoints/clara_stage1` - Compression pretraining
2. **Stage 2**: `/content/checkpoints/clara_stage2` - Instruction tuning
3. **Stage 3**: `/content/checkpoints/clara_stage3_final` - Final end-to-end model

### Next Steps:
1. **Evaluation**: Use the evaluation scripts in `scripts/evaluation_*.sh`
2. **Fine-tuning**: Continue training with your own data
3. **Deployment**: Export model for inference

### Useful Resources:
- 📄 [Paper](https://arxiv.org/abs/2511.18659)
- 💻 [GitHub](https://github.com/apple/ml-clara)
- 🤗 [HuggingFace Models](https://huggingface.co/probejie)

---

### 📊 Troubleshooting

**Out of Memory (OOM):**
- Reduce `TRAIN_BATCH_SIZE` and `MICRO_BATCH_SIZE`
- Decrease `MAX_SAMPLES`
- Use gradient checkpointing (already enabled)

**Training too slow:**
- Install flash-attn (see cell above)
- Use A100 GPU instead of T4
- Reduce data size for testing

**Checkpoint loading errors:**
- Verify checkpoint path exists
- Check disk space
- Ensure previous stage completed successfully

---

**Made with ❤️ by the CLaRa Team**

If you use this code, please cite:
```bibtex
@article{zhao2024clara,
  title={CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning},
  author={Zhao, Zhihao and others},
  journal={arXiv preprint arXiv:2511.18659},
  year={2024}
}
```
