In [1]:
# train_rag_unsloth.py
import os
import random
from dataclasses import dataclass
from typing import Dict, Any, List

from datasets import load_dataset
from transformers import (
    TrainingArguments,
    TrainerCallback,
    TrainerState,
    TrainerControl,
)
from transformers.trainer_utils import get_last_checkpoint
import numpy as np
import torch
import wandb

from unsloth import FastModel, is_bfloat16_supported
from trl import SFTTrainer, SFTConfig
from unsloth.chat_templates import get_chat_template

from google.cloud import storage


# =========================
# CONFIG
# =========================
SEED = 3407
# set random seed for reproducibility

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

MODEL_NAME = "unsloth/gemma-3-4b-it"  # Base model
DATASET_NAME = "FreedomIntelligence/RAG-Instruct"  # HF dataset
OUTPUT_DIR = "./outputs/gemma3-4b-rag"

BUCKET_NAME = "model-finetune-1"  # GCP bucket for finetune checkpoints
GCS_PREFIX = "checkpoints/gemma3-4b-rag"  # optional path inside bucket
ENABLE_GCS_SYNC = True

MAX_SEQ_LENGTH = 2048  # safe for 4B on L4 with packing
MAX_DOCS = 2  # number of documents per sample
MAX_DOC_CHARS = 4000  # truncate context to avoid huge prompts

# Budget controls â€“ tune these first
MAX_TRAIN_SAMPLES = None  # set to None for full dataset
NUM_TRAIN_EPOCHS = 1
LEARNING_RATE = 2e-4

BATCH_SIZE = 1  # per-device batch size
GRAD_ACCUM_STEPS = 4  # effective batch size = BATCH_SIZE * GRAD_ACCUM_STEPS

USE_WANDB = True  # set True and export WANDB_PROJECT if you want logging
WANDB_PROJECT = "instruct-rag-finetune"


if USE_WANDB:
    run = wandb.init(
        # Set the wandb entity where your project will be logged (generally your team name).
        entity="pareek-ml-personal",
        # Set the wandb project where this run will be logged.
        project="instruct-rag-finetune",
        # Track hyperparameters and run metadata.
        config={
            "model_name": MODEL_NAME,
            "dataset_name": DATASET_NAME,
            "max_train_samples": MAX_TRAIN_SAMPLES,
            "num_train_epochs": NUM_TRAIN_EPOCHS,
            "learning_rate": LEARNING_RATE,
            "batch_size": BATCH_SIZE,
            "grad_accum_steps": GRAD_ACCUM_STEPS,
        },
    )

  from .autonotebook import tqdm as notebook_tqdm

Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastModel, is_bfloat16_supported


ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


[34m[1mwandb[0m: Currently logged in as: [33mpareek-ml[0m ([33mpareek-ml-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


In [2]:
# =========================
# GCS SYNC CALLBACK
# =========================

class GCSSyncCallback(TrainerCallback):
    def __init__(self, local_dir: str, bucket_name: str, prefix: str):
        self.local_dir = os.path.abspath(local_dir)
        self.bucket_name = bucket_name
        self.prefix = prefix.rstrip("/") if prefix else ""
        self.client = storage.Client()
        self.bucket = self.client.bucket(bucket_name)

    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        # Sync only the checkpoint that was just written
        if not os.path.exists(self.local_dir):
            return

        print(f"[GCSSync] Syncing {self.local_dir} to gs://{self.bucket_name}/{self.prefix} ...")
        self._sync_directory(self.local_dir, self.prefix)
        print("[GCSSync] Sync complete.")
        return control

    def _sync_directory(self, local_dir: str, gcs_prefix: str):
        for root, dirs, files in os.walk(local_dir):
            for fname in files:
                local_path = os.path.join(root, fname)
                rel_path = os.path.relpath(local_path, local_dir)
                blob_name = f"{gcs_prefix}/{rel_path}" if gcs_prefix else rel_path
                blob = self.bucket.blob(blob_name)
                blob.upload_from_filename(local_path)


In [3]:
# =========================
# MODEL & TOKENIZER
# =========================

def load_model_and_tokenizer():
    # If you need HF token for the model, set: os.environ["HF_TOKEN"] = "hf_xxx"
    print("Loading base model and tokenizer...")
    model, tokenizer = FastModel.from_pretrained(
        model_name=MODEL_NAME,
        max_seq_length=MAX_SEQ_LENGTH,
        load_in_4bit=True,
        load_in_8bit=False,
        full_finetuning=False,
    )

    print("Applying LoRA (QLoRA) configuration...")
    model = FastModel.get_peft_model(
        model,
        finetune_vision_layers=False,
        finetune_language_layers=True,
        finetune_attention_modules=True,
        finetune_mlp_modules=True,
        r=8,            # LoRA rank
        lora_alpha=8,   # usually >= r
        lora_dropout=0,
        bias="none",
        random_state=3407,
    )

    # Prepare for training (gradient checkpointing etc.)
    FastModel.for_training(model)

    # Attach Gemma-3 chat template
    tokenizer = get_chat_template(
        tokenizer,
        chat_template="gemma-3",
    )

    return model, tokenizer




In [4]:
# =========================
# DATASET + PROMPTING
# =========================

def join_docs(docs: List[str]) -> str:
    """Join top-k docs and truncate to avoid crazy-long contexts."""
    if not docs:
        return ""
    text = "\n\n".join(docs[:MAX_DOCS])
    return text[:MAX_DOC_CHARS]


def make_prompt(question: str, docs: str) -> str:
    """User prompt: question + retrieved documents."""
    return (
        "You are a helpful assistant. Use ONLY the provided documents to answer the question.\n\n"
        "QUESTION:\n"
        f"{question}\n\n"
        "DOCUMENTS:\n"
        f"{docs}\n"
    )


def formatting_single_example(example: Dict[str, Any], tokenizer) -> Dict[str, str]:
    question = example["question"]
    answer = example["answer"]
    docs = join_docs(example["documents"])

    user_prompt = make_prompt(question, docs)

    messages = [
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": answer},
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )

    bos = tokenizer.bos_token or ""
    if bos and text.startswith(bos):
        text = text[len(bos):]

    return {"text": text}


def load_and_prepare_dataset(tokenizer):
    print(f"Loading dataset: {DATASET_NAME}")
    dataset = load_dataset(DATASET_NAME, split="train")

    if MAX_TRAIN_SAMPLES is not None:
        n = min(MAX_TRAIN_SAMPLES, len(dataset))
        print(f"Subsampling dataset to {n} examples for budget.")
        dataset = dataset.select(range(n))

    def _map_fn(batch):
        questions = batch["question"]
        answers = batch["answer"]
        documents = batch["documents"]

        texts = []
        for q, a, docs in zip(questions, answers, documents):
            ex = {"question": q, "answer": a, "documents": docs}
            out = formatting_single_example(ex, tokenizer)
            texts.append(out["text"])

        return {"text": texts}

    print("Formatting dataset into chat-style text...")
    dataset = dataset.map(
        _map_fn,
        batched=True,
        remove_columns=dataset.column_names,
        desc="Formatting prompts",
    )

    return dataset




In [5]:
model, tokenizer = load_model_and_tokenizer()
dataset = load_and_prepare_dataset(tokenizer)

# Only use 80% of data for training, rest for eval
train_size = int(0.8 * len(dataset))
train_dataset = dataset.select(range(train_size))
eval_dataset = dataset.select(range(train_size, len(dataset)))

Loading base model and tokenizer...
==((====))==  Unsloth 2025.11.3: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.034 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.
Applying LoRA (QLoRA) configuration...
Unsloth: Making `model.base_model.model.model.language_model` require gradients
Loading dataset: FreedomIntelligence/RAG-Instruct
Formatting dataset into chat-style text...


Formatting prompts: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 40541/40541 [00:05<00:00, 7632.88 examples/s]


In [8]:
os.makedirs(OUTPUT_DIR, exist_ok=True)
bf16 = is_bfloat16_supported()
print(f"bfloat16 supported: {bf16}")
report_to = "wandb" if USE_WANDB else "none"

bfloat16 supported: True


In [9]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LEARNING_RATE,
    logging_steps=200,
    save_steps=100,
    save_total_limit=3,
    bf16=bf16,
    fp16=not bf16,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    report_to=report_to,
)

callbacks = []
if ENABLE_GCS_SYNC:
    gcs_callback = GCSSyncCallback(
        local_dir=OUTPUT_DIR,
        bucket_name=BUCKET_NAME,
        prefix=GCS_PREFIX,
    )
    callbacks.append(gcs_callback)
    print("GCS sync callback enabled.")

print("Starting SFT training...")
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    packing=True,  # packs multiple samples per sequence -> good for throughput
    args=training_args,
    callbacks=callbacks,
)

trainer.train()

GCS sync callback enabled.
Starting SFT training...


The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 32,432 | Num Epochs = 1 | Total steps = 8,108
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 14,901,248 of 4,314,980,720 (0.35% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
200,2.6737
400,2.2438
600,2.2174
800,2.1947
1000,2.1914
1200,2.1659
1400,2.1486
1600,2.1688
1800,2.1456
2000,2.1543


[GCSSync] Syncing /home/yashpareek_workmail/unsloth_finetuning/notebooks/outputs/gemma3-4b-rag to gs://model-finetune-1/checkpoints/gemma3-4b-rag ...
[GCSSync] Sync complete.
[GCSSync] Syncing /home/yashpareek_workmail/unsloth_finetuning/notebooks/outputs/gemma3-4b-rag to gs://model-finetune-1/checkpoints/gemma3-4b-rag ...
[GCSSync] Sync complete.
[GCSSync] Syncing /home/yashpareek_workmail/unsloth_finetuning/notebooks/outputs/gemma3-4b-rag to gs://model-finetune-1/checkpoints/gemma3-4b-rag ...
[GCSSync] Sync complete.
[GCSSync] Syncing /home/yashpareek_workmail/unsloth_finetuning/notebooks/outputs/gemma3-4b-rag to gs://model-finetune-1/checkpoints/gemma3-4b-rag ...
[GCSSync] Sync complete.
[GCSSync] Syncing /home/yashpareek_workmail/unsloth_finetuning/notebooks/outputs/gemma3-4b-rag to gs://model-finetune-1/checkpoints/gemma3-4b-rag ...
[GCSSync] Sync complete.
[GCSSync] Syncing /home/yashpareek_workmail/unsloth_finetuning/notebooks/outputs/gemma3-4b-rag to gs://model-finetune-1/check

TrainOutput(global_step=8108, training_loss=2.1408763635199706, metrics={'train_runtime': 27333.9796, 'train_samples_per_second': 1.187, 'train_steps_per_second': 0.297, 'total_flos': 3.1683510571566246e+17, 'train_loss': 2.1408763635199706, 'epoch': 1.0})

In [10]:
from pathlib import Path

save_dir = Path("rag-instruct-gemma-3-finetuned")

save_dir.mkdir(exist_ok=True)

model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)

print("Saved to", save_dir.resolve())

Saved to /home/yashpareek_workmail/unsloth_finetuning/notebooks/rag-instruct-gemma-3-finetuned


In [None]:
from unsloth import FastModel
from transformers import AutoTokenizer

base_model_name = "unsloth/gemma-3-4b-it"
adapter_dir = "rag-instruct-gemma-3-finetuned"

base_model, base_tokenizer = FastModel.from_pretrained(
    base_model_name,
    max_seq_length=2048,
    load_in_4bit=True,
    load_in_8bit=False,
    full_finetuning=False,
)

from peft import PeftModel
model = PeftModel.from_pretrained(base_model, adapter_dir)
tokenizer = base_tokenizer

## To get checkpoint to train again

In [None]:
last_ckpt = get_last_checkpoint(OUTPUT_DIR)
if last_ckpt is not None:
    print(f"Resuming from checkpoint: {last_ckpt}")
    trainer.train(resume_from_checkpoint=last_ckpt)
else:
    print("No checkpoint found. Starting fresh.")
    trainer.train()

print("Saving final adapter + tokenizer...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)