[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/IXZZZ9/GradES/blob/main/examples/unsloth_lora_grades.ipynb)

# 🦙 Unsloth + GradES: LoRA Fine-tuning with 40-50% Speed Up

This notebook demonstrates how to use **GradES** (Gradient-based Early Stopping) with **Unsloth** for ultra-fast LoRA fine-tuning of large language models.

## Key Benefits:
- 🚀 **40-50% computational savings** compared to standard fine-tuning
- 🎯 **Maintains or improves** model performance
- 🔧 **Easy integration** with existing Unsloth workflows
- 📊 **Real-time monitoring** with WandB integration

## Paper: 
📖 [GradES: Significantly Faster Training in Transformers with Gradient-Based Early Stopping](https://arxiv.org/abs/2509.01842)

## 🔧 Installation

In [None]:
# Install GradES and Unsloth
!pip install grades
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "trl<0.9.0" peft accelerate bitsandbytes

## 📚 Setup Model and Tokenizer

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4x faster downloads and no OOMs with max_seq_length=2048!
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # Rank stabilized LoRA
    loftq_config = None, # LoftQ
)

## 📊 Setup Dataset

In [None]:
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }

from datasets import load_dataset
dataset = load_dataset("yahma/alpaca-cleaned", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True,)

## 🎯 Setup GradES Callback

In [None]:
from grades import GradEarlyStoppingCallback, GradEarlyStoppingConfig

# Configure GradES for LoRA fine-tuning
config = GradEarlyStoppingConfig(
    tau=1e-10,                   # Convergence threshold (very low for LoRA)
    alpha=0.1,                   # Allow freezing after 10% of training
    enable_wandb_logging=True,   # Enable WandB logging for monitoring
    log_interval=5,              # Log every 5 steps
    save_stats=True,             # Save component statistics
)

# Create the callback
grades_callback = GradEarlyStoppingCallback(config)

print("🎯 GradES configured successfully!")
print(f"   Convergence threshold (tau): {config.tau}")
print(f"   Minimum training progress (alpha): {config.alpha}")
print(f"   WandB logging: {config.enable_wandb_logging}")

## 🚀 Training with GradES

In [None]:
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    callbacks = [grades_callback],  # 🎯 Add GradES callback here!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 16,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_ratio = 0.05,
        max_steps = 60,              # Reduced steps for demo
        learning_rate = 2e-4,        # Higher LR for LoRA
        logging_steps = 1,
        optim = "adamw_torch",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        report_to = "wandb",         # Use this for WandB etc
        # Memory optimizations
        gradient_checkpointing=True,  # Enable gradient checkpointing
        dataloader_drop_last=True,    # Drop incomplete batches
        bf16=True,
        dataloader_pin_memory=True,
        dataloader_num_workers=0,
        remove_unused_columns=False,
    ),
)

print("🚀 Starting training with GradES...")
print("   Watch for component freezing messages!")
print("   Expected 40-50% speedup compared to standard training.")

In [None]:
# Start training!
trainer_stats = trainer.train()

## 📊 Training Results Analysis

In [None]:
# Print training statistics
print("🎉 Training completed!")
print(f"Total training time: {trainer_stats.metrics['train_runtime']:.2f} seconds")
print(f"Training samples per second: {trainer_stats.metrics['train_samples_per_second']:.2f}")

# Access GradES statistics if available
if hasattr(grades_callback, 'component_stats'):
    print("\n📈 GradES Component Statistics:")
    total_components = len(grades_callback.component_stats)
    frozen_components = sum(1 for stats in grades_callback.component_stats.values() if stats.is_frozen)
    print(f"   Total components: {total_components}")
    print(f"   Frozen components: {frozen_components}")
    print(f"   Freeze ratio: {frozen_components/total_components*100:.1f}%")
    
    if frozen_components > 0:
        print(f"\n🎯 GradES successfully froze {frozen_components} components!")
        print("   This contributes to the computational savings.")

## 🧪 Test the Model

In [None]:
# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    alpaca_prompt.format(
        "Continue the fibonacci sequence.", # instruction
        "1, 1, 2, 3, 5, 8", # input
        "", # output - leave this blank for generation!
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
tokenizer.batch_decode(outputs)

## 💾 Save the Model

In [None]:
# Save the model
model.save_pretrained("lora_model") # Local saving
tokenizer.save_pretrained("lora_model")

# Save to 16bit for VRAM
model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)

print("💾 Model saved successfully!")
print("   LoRA model: ./lora_model")
print("   Merged model (16bit): ./model")

## 🎉 Summary

Congratulations! You've successfully fine-tuned a language model using **GradES + Unsloth** with the following benefits:

- 🚀 **40-50% faster training** through gradient-based component freezing
- 🎯 **Maintained performance** with intelligent convergence detection
- 📊 **Real-time monitoring** of component freezing via WandB
- 🔧 **Easy integration** with existing Unsloth workflows

### Next Steps:
1. **Experiment with hyperparameters**: Try different `tau` and `alpha` values
2. **Scale up**: Use larger models or longer training runs
3. **Monitor results**: Check WandB for detailed component freezing patterns
4. **Share your results**: We'd love to hear about your experience!

### Resources:
- 📖 **Paper**: [arXiv:2509.01842](https://arxiv.org/abs/2509.01842)
- 🐙 **GitHub**: [IXZZZ9/GradES](https://github.com/IXZZZ9/GradES)
- 📦 **PyPI**: `pip install grades`

---
**Made with ❤️ by the GradES Team**