<a href="https://colab.research.google.com/github/rajveer43/unsloth_notebooks/blob/master/Unsloth_AI_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Step 1

In [1]:
import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"


## Step 2

In [10]:
import logging
import torch

torch._inductor.config.debug = True
torch._logging.set_logs(
    dynamo=logging.DEBUG,
    inductor=logging.DEBUG,
    graph_breaks=True,
    recompiles=True,
    recompiles_verbose=True,
    compiled_autograd_verbose=True,
)

## Step 3

In [15]:
torch_compile_options = {
    "epilogue_fusion": True,
    "max_autotune": True,
    "shape_padding": True,
    "trace.enabled": True,
    "triton.cudagraphs": True,
    "fullgraph": True,
}


## Step 4 : MLP COmpilation

In [17]:
import torch
import transformers.models.llama.modeling_llama

# Configure torch.compile options
torch_compile_options = {
    "fullgraph": True,  # Ensure the entire function is compiled into a single graph
    "dynamic": True,    # Enable dynamic shapes support
}

# Enable additional optimizations
torch._inductor.config.max_autotune = True
torch._inductor.config.triton.cudagraphs = True

# Define the compiled function
@torch.compile(**torch_compile_options)
def compiled_llama_mlp(self, x):
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    return down_proj

# Monkey-patch the LlamaMLP.forward method
transformers.models.llama.modeling_llama.LlamaMLP.forward = compiled_llama_mlp

In [18]:
@torch.compile(**torch_compile_options)
def compiled_llama_attention(self, hidden_states, attention_mask=None, **kwargs):
    return self.attn(hidden_states, attention_mask, **kwargs)

transformers.models.llama.modeling_llama.LlamaAttention.forward = compiled_llama_attention

@torch.compile(**torch_compile_options)
def compiled_layernorm(self, x):
    return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias)

torch.nn.LayerNorm.forward = compiled_layernorm

In [19]:
def compile_region(module):
    for name, submodule in module.named_children():
        if isinstance(submodule, (torch.nn.Linear, torch.nn.LayerNorm)):
            module._modules[name] = torch.compile(submodule, fullgraph=True, dynamic=True, options=torch_compile_options)
        else:
            compile_region(submodule)

compile_region(model)


NameError: name 'model' is not defined

---
---
---
<a name="COMPILE"></a>
## C) Make `torch.compile` work without graph breaks for QLoRA [Difficulty: Easy to Medium] [Max points: 9]

1. Goal: Write a single Python script like task B), except the goal is to `torch.compile` all modules if possible.

2. There must NOT be graph breaks, and excessive re-compilations should not be seen.

3. You should have say max 30 compilations. Over 60 is definitely wrong.

4. The loss must match with the non compiled module.

5. Utilize patching as much as possible.

6. Think about which areas might need disabling for compilation. Think about regional compilation. How do we compile sections efficiently?

7. Log memory / VRAM usage, and monitor speedups as well.

8. Must work for QLoRA.

We provided a script below, and showcased how to detect if graph breaks are seen. We also torch compiled the MLP for Llama:

In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m32.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading 

In [3]:
!pip install trl

Collecting trl
  Downloading trl-0.15.1-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate>=0.34.0->trl)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate>=0.34.0->trl)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate>=0.34.0->trl)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->accelerate>=0.34.0->trl)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->accelerate>=0.34.0->trl)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Colle

In [4]:
!pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Downloading bitsandbytes-0.45.2-py3-none-manylinux_2_24_x86_64.whl (69.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.2


In [10]:
!pip install -U transformers

Collecting transformers
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.49.0-py3-none-any.whl (10.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m49.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.48.3
    Uninstalling transformers-4.48.3:
      Successfully uninstalled transformers-4.48.3
Successfully installed transformers-4.49.0


In [18]:
#!/usr/bin/env python
"""
Torch Compile QLoRA Script
--------------------------
This script loads a QLoRA model, patches key modules with torch.compile (using regional compilation),
monitors VRAM usage and speedups, and confirms that the compiled model’s loss matches the non‐compiled version.
It is meant to avoid graph breaks and limit total compilations (max ~30) while working for QLoRA.
"""

import os
import time
import logging
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig

# -----------------------------------------------------------------------------
# Environment Setup & Logging
# -----------------------------------------------------------------------------
# Enable detailed dynamo/inductor logging to catch graph breaks and recompilations.
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

torch._inductor.config.debug = True
torch._logging.set_logs(
    dynamo=logging.DEBUG,
    inductor=logging.DEBUG,
    graph_breaks=True,
    recompiles=True,
    recompiles_verbose=True,
    compiled_autograd_verbose=True,
)
torch._dynamo.config.verbose = True
torch._dynamo.config.suppress_errors = False

# -----------------------------------------------------------------------------
# Torch Compile Options (regional compilation)
# -----------------------------------------------------------------------------
# torch_compile_options = {
#     "epilogue_fusion": True,
#     "max_autotune": True,
#     "shape_padding": True,
#     "trace.enabled": True,
#     # Triton cudagraphs can sometimes trigger extra recompilations;
#     # set to True if your environment is stable with it.
#     "triton.cudagraphs": True,
#     "dynamic_shapes": True,
#     "fullgraph": True,
# }

torch_compile_options = {
    "fullgraph": True,  # Ensure the entire function is compiled into a single graph
    "dynamic": True,    # Enable dynamic shapes support
}

# Enable additional optimizations
torch._inductor.config.max_autotune = True
torch._inductor.config.triton.cudagraphs = True

# -----------------------------------------------------------------------------
# Patching: Compile key modules for QLoRA (using patching)
# -----------------------------------------------------------------------------
# Import LLaMA model components from transformers
import transformers.models.llama.modeling_llama as llama_mod

# Patch LlamaMLP forward to compile its inner operations.
@torch.compile(**torch_compile_options)
def compiled_llama_mlp(self, x):
    # Combine gate, up, and down projections with activation
    # (this should be equivalent to the original LlamaMLP.forward)
    act = self.act_fn(self.gate_proj(x))
    up = self.up_proj(x)
    down = self.down_proj(act * up)
    return down

llama_mod.LlamaMLP.forward = compiled_llama_mlp
original_llama_attention_forward = llama_mod.LlamaAttention.forward


# Patch LlamaAttention forward if available.
if hasattr(llama_mod, "LlamaAttention"):
    @torch.compile(**torch_compile_options)
    def compiled_llama_attention(self, hidden_states, attention_mask=None, **kwargs):
        # Call the original attention function (usually self.attn)
        return original_llama_attention_forward(self, hidden_states, attention_mask, **kwargs)
    llama_mod.LlamaAttention.forward = compiled_llama_attention

# Patch torch.nn.LayerNorm to avoid graph breaks in normalization
@torch.compile(**torch_compile_options)
def compiled_layernorm(self, x):
    return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias)
torch.nn.LayerNorm.forward = compiled_layernorm

# Optionally, if other submodules (like q_proj, k_proj, etc.) cause issues,
# you can patch their forward methods similarly. For brevity, we assume that patching the MLP,
# attention, and layernorm is sufficient for our current QLoRA model.

# -----------------------------------------------------------------------------
# Model & Tokenizer Setup (QLoRA)
# -----------------------------------------------------------------------------
# Set visible GPUs and default dtype
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
torch.set_default_dtype(torch.float16)

max_seq_length = 1024
model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"
dtype = torch.float16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)

# Load the base model (with quantization configuration)
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    attn_implementation="sdpa",
    quantization_config=bnb_config,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

# -----------------------------------------------------------------------------
# Setup LoRA with PEFT
# -----------------------------------------------------------------------------
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

print("Wrapping model with LoRA...")
model = get_peft_model(model, lora_config)

# Enable gradients only for LoRA parameters; freeze base parameters.
with torch.no_grad():
    for name, param in model.named_parameters():
        if ".lora_A." in name or ".lora_B." in name:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)

# Disable gradient checkpointing (it interferes with torch.compile)
# model.gradient_checkpointing_enable()  <-- NOT enabled for compilation
model.enable_input_require_grads()

# -----------------------------------------------------------------------------
# Dataset Setup
# -----------------------------------------------------------------------------
# Load a small portion of the dataset for testing


Loading base model...




Wrapping model with LoRA...


In [2]:
def compile_region(module):
    for name, submodule in module.named_children():
        if isinstance(submodule, (torch.nn.Linear, torch.nn.LayerNorm)):
            module._modules[name] = torch.compile(submodule,**torch_compile_options)
        else:
            compile_region(submodule)

compile_region(model)

In [3]:
import torch
import transformers.models.llama.modeling_llama as llama_mod

# Save the original forward method before patching
original_llama_attention_forward = llama_mod.LlamaAttention.forward

@torch.compile(**torch_compile_options)
def compiled_llama_attention(self, hidden_states, attention_mask=None, **kwargs):
    # Call the original forward method instead of a non-existent self.attn
    return original_llama_attention_forward(self, hidden_states, attention_mask, **kwargs)

# Patch the LlamaAttention forward with the compiled version
llama_mod.LlamaAttention.forward = compiled_llama_attention


In [22]:
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
print("Loading dataset...")
dataset = load_dataset("json", data_files={"train": url}, split="train[:10%]")

import torch._dynamo
torch._dynamo.config.suppress_errors = True

# -----------------------------------------------------------------------------
# Loss Verification: Non-Compiled vs. Compiled
# -----------------------------------------------------------------------------
def compute_loss(model_instance, input_ids):
    outputs = model_instance(input_ids=input_ids, labels=input_ids)
    return outputs.loss

# Prepare a small test batch
sample_text = "This is a test input to check loss consistency."
input_ids = tokenizer(sample_text, return_tensors="pt").input_ids.to("cuda")

model.eval()
with torch.no_grad():
    loss_noncompiled = compute_loss(model, input_ids).item()

# Now compile the whole model (regional compilation already patched key submodules)
print("Compiling full model...")
compiled_model = torch.compile(model, **torch_compile_options)
compiled_model.eval()
with torch.no_grad():
    loss_compiled = compute_loss(compiled_model, input_ids).item()

print(f"Non-compiled loss: {loss_noncompiled:.6f}")
print(f"Compiled loss:     {loss_compiled:.6f}")
assert abs(loss_noncompiled - loss_compiled) < 1e-4, "Loss mismatch between compiled and non-compiled models!"

# -----------------------------------------------------------------------------
# VRAM & Speedup Logging
# -----------------------------------------------------------------------------
# Reset GPU memory stats
torch.cuda.reset_peak_memory_stats()
start_mem = torch.cuda.memory_allocated()

# Time non-compiled inference
n_runs = 10
print("Timing non-compiled inference...")
start_time = time.time()
for _ in range(n_runs):
    _ = model(input_ids=input_ids)
noncompiled_time = time.time() - start_time

# Time compiled inference
print("Timing compiled inference...")
start_time = time.time()
for _ in range(n_runs):
    _ = compiled_model(input_ids=input_ids)
compiled_time = time.time() - start_time

end_mem = torch.cuda.memory_allocated()
peak_mem = torch.cuda.max_memory_allocated()

print(f"Non-compiled inference time over {n_runs} runs: {noncompiled_time:.4f} sec")
print(f"Compiled inference time over {n_runs} runs:     {compiled_time:.4f} sec")
print(f"Speedup: {noncompiled_time / compiled_time:.2f}x")
print(f"Memory usage delta: {end_mem - start_mem} bytes; Peak memory: {peak_mem} bytes")

# -----------------------------------------------------------------------------
# Training Setup using SFTTrainer
# -----------------------------------------------------------------------------
trainer = SFTTrainer(
    model=compiled_model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        warmup_steps=1,
        max_steps=10,
        logging_steps=1,
        output_dir="outputs",
        seed=3407,
        max_seq_length=max_seq_length,
        fp16=True,
        report_to="none",  # disable external logging (e.g. W&B)
        dataset_num_proc=4,
    ),
)

print("Starting training...")
trainer.train()

print("Training finished.")


Loading dataset...


V0223 10:50:42.861000 5897 torch/_dynamo/convert_frame.py:864] [6/2] torchdynamo start compiling compiled_llama_attention <ipython-input-18-855f0b57ff5a>:85, stack (elided 4 frames):
V0223 10:50:42.861000 5897 torch/_dynamo/convert_frame.py:864] [6/2]   File "<frozen runpy>", line 198, in _run_module_as_main
V0223 10:50:42.861000 5897 torch/_dynamo/convert_frame.py:864] [6/2]   File "<frozen runpy>", line 88, in _run_code
V0223 10:50:42.861000 5897 torch/_dynamo/convert_frame.py:864] [6/2]   File "/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py", line 37, in <module>
V0223 10:50:42.861000 5897 torch/_dynamo/convert_frame.py:864] [6/2]     ColabKernelApp.launch_instance()
V0223 10:50:42.861000 5897 torch/_dynamo/convert_frame.py:864] [6/2]   File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 992, in launch_instance
V0223 10:50:42.861000 5897 torch/_dynamo/convert_frame.py:864] [6/2]     app.start()
V0223 10:50:42.861000 5897 torch/_dyna

Unsupported: call_method BuiltinVariable(object) __getattribute__ [UnspecializedNNModuleVariable(LlamaAttention), ConstantVariable()] {}

from user code:
   File "<ipython-input-18-855f0b57ff5a>", line 88, in compiled_llama_attention
    return original_llama_attention_forward(self, hidden_states, attention_mask, **kwargs)
  File "<ipython-input-16-546be4492c69>", line 96, in compiled_llama_attention
    return call_orig_forward(self, hidden_states, attention_mask, **kwargs)
  File "<ipython-input-16-546be4492c69>", line 90, in call_orig_forward
    return object.__getattribute__(self, "_orig_forward")(hidden_states, attention_mask, **kwargs)


In [9]:
import transformers

# Get the version of bitsandbytes
print(f"bitsandbytes version: {transformers.__version__}")


bitsandbytes version: 4.48.3
