# 1. Setup and Installation

In [None]:
%%capture
!pip install unsloth==2025.7.3 
!pip install unsloth_zoo==2025.7.3
!pip install --upgrade transformers timm peft

# 2. Importing Libraries

In [None]:
import unsloth
from unsloth import FastVisionModel
import torch
from datasets import Dataset
from transformers import TrainingArguments, Trainer
from PIL import Image
import re
from huggingface_hub import login, HfApi

import os
import json
import time
import logging
import random
import numpy as np
from getpass import getpass
from pathlib import Path

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seed for reproducibility
def set_seed(seed=3407):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(3407)

# 3. Configuration

In [None]:
model_id = "unsloth/gemma-3n-e2b-it-unsloth-bnb-4bit" # Using Unsloth's pre-quantized version for efficiency
output_dir = "./gemma-qlora-finetuned-cddm"

hub_model_id = "shreyansh24/gemma3n-e2b-cddm-finetune"
hub_private_repo = False # Set to True if you want your model repository to be private on the Hub.

# Hugging Face Hub repository ID for the UNIVERSAL DATASET
hub_dataset_name = "shreyansh24/Crop-Disease-VQA" # Your new dataset name

# Secure token handling
hf_token = os.getenv("HF_TOKEN") or getpass("Enter Hugging Face token: ")
if not hf_token:
    raise ValueError("HF_TOKEN environment variable or token input required")

login(token=hf_token) # Login to Hugging Face Hub

# 4. Data Preparation

In [None]:
# Make paths work both on Kaggle and locally
KAGGLE_INPUT_DIR = "/kaggle/input/crop-disease-data/"
if not os.path.exists(KAGGLE_INPUT_DIR):
    KAGGLE_INPUT_DIR = "./crop-disease-data/"
    logger.info(f"Using local data directory: {KAGGLE_INPUT_DIR}")

dataset_json_path = os.path.join(KAGGLE_INPUT_DIR, "dataset/Crop_Disease_train_qwenvl.json")
system_message = "You are a helpful AI assistant specialized in crop disease diagnosis. Provide concise and accurate information."

In [None]:
# Helper function to extract PIL Images from the TRL-formatted messages
# This is used by the `collate_fn`.
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    """Extract PIL Images from TRL-formatted messages with error handling"""
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        for element in content:
            if isinstance(element, dict) and element.get("type") == "image":
                if "image" in element and isinstance(element["image"], Image.Image):
                    try:
                        image_inputs.append(element["image"].convert("RGB"))
                    except Exception as e:
                        logger.warning(f"Failed to convert image to RGB: {e}")
    return image_inputs

The raw CDDM dataset uses a specific format with `<img>` tags inside the text. The `format_data_for_trl` function is a custom parser that transforms this into the standard TRL multimodal chat format. It extracts the image path, loads the PIL Image object, and constructs the conversation turns that Unsloth and Gemma can understand.

In [None]:
# Function to transform CDDM format to TRL multimodal chat format (Gemma style)
def format_data_for_trl(sample):
    """Transform CDDM format to TRL multimodal chat format with error handling"""
    trl_messages = []

    try:
        # 1. Add system message at the beginning of each conversation
        trl_messages.append({
            "role": "system",
            "content": [{"type": "text", "text": system_message}]
        })

        # Process each turn in the conversation
        for i, turn in enumerate(sample['conversations']):
            # Map Qwen-VL roles to Gemma roles
            current_role = "user" if turn["from"] == "user" else "assistant"
            content_value = turn["value"]

            content_parts = []

            img_pattern = r'<img>(.*?)</img>'
            img_match = re.search(img_pattern, content_value)

            # Handle the case where an image is present in the user's turn
            if img_match and current_role == "user":
                image_path_in_text = img_match.group(1)

                # Construct the ABSOLUTE image path from Kaggle Input
                cleaned_image_path_in_text = image_path_in_text.lstrip('/')
                corrected_image_path = os.path.join(KAGGLE_INPUT_DIR, cleaned_image_path_in_text)

                loaded_image = None
                # Use the correctly constructed path to check if the file exists
                if os.path.exists(corrected_image_path):
                    try:
                        # Try opening the image file
                        loaded_image = Image.open(corrected_image_path).convert("RGB")
                        content_parts.append({"type": "image", "image": loaded_image})
                    except Exception as e:
                        logger.warning(f"Could not load image {corrected_image_path}: {e}")
                        content_parts.append({"type": "text", "text": f"Error loading image: {corrected_image_path}"})
                else:
                    logger.warning(f"Image file not found at {corrected_image_path}")
                    content_parts.append({"type": "text", "text": f"Image not found: {corrected_image_path}"})

                # Extract text after removing the <img> tag and "Picture X:" prefix
                text_part = re.sub(img_pattern, "", content_value)
                text_part = re.sub(r'^(Picture\s+\d+:\s*)', '', text_part).strip()

                if text_part:
                    content_parts.append({"type": "text", "text": text_part})
                elif not content_parts and loaded_image is None:
                     content_parts.append({"type": "text", "text": ""})

            else:
                # All other turns are text-only
                cleaned_text = turn["value"].strip()
                if cleaned_text:
                    content_parts.append({"type": "text", "text": cleaned_text})
                elif current_role == "assistant":
                    content_parts.append({"type": "text", "text": ""})

            if content_parts:
                trl_messages.append({"role": current_role, "content": content_parts})

        # Ensure the messages list is not empty
        if not trl_messages:
            return {"messages": [{"role": "system", "content": [{"type": "text", "text": system_message}]}]}

        return {"messages": trl_messages}
    
    except Exception as e:
        logger.error(f"Error processing sample {sample.get('id', 'N/A')}: {e}")
        return {"messages": [{"role": "system", "content": [{"type": "text", "text": system_message}]}]}

In [None]:
# Load the dataset from the JSON file with error handling
try:
    with open(dataset_json_path, 'r') as f:
        data = json.load(f)
    logger.info(f"Successfully loaded {len(data)} samples from dataset")
except FileNotFoundError:
    logger.error(f"Dataset file not found: {dataset_json_path}")
    raise
except json.JSONDecodeError:
    logger.error(f"Invalid JSON format in: {dataset_json_path}")
    raise

In [None]:
# Optional: Use subset for testing (comment out for full training)
# start_subset=110000 
# end_subset = 120000
# data=data[start_subset:end_subset]

# Format the dataset
logger.info("Formatting dataset for training...")
train_dataset = [format_data_for_trl(sample) for sample in data]
logger.info(f"Formatted {len(train_dataset)} training samples")

# Show first sample
if train_dataset:
    logger.info("Sample training data structure:")
    print(train_dataset[0])

In [None]:
# Find a sample with an image in a user turn and print its content
print(f"\nExample of image loading for a user turn:")
image_sample_found = False
for sample in train_dataset:
    for msg in sample['messages']:
        if msg['role'] == 'user':
            for content_elem in msg['content']:
                if content_elem['type'] == 'image':
                    print(f"  Image element found. PIL Image object: {content_elem['image']}")
                    image_sample_found = True
                    break
        if image_sample_found:
            break
    if image_sample_found:
        break
if not image_sample_found:
    print("No image found in a user turn in the first few samples to demonstrate PIL object loading.")

# 5. Model Loading and QLoRA Configuration

In [None]:
logger.info("Loading base model...")
try:
    model, processor = FastVisionModel.from_pretrained(
        model_name = model_id,
        max_seq_length = 2048,
        dtype = torch.float16,
        load_in_4bit = True, 
        trust_remote_code=True
    )
    logger.info("Base model loaded successfully")
except Exception as e:
    logger.error(f"Failed to load base model: {e}")
    raise

# Get tokenizer for later use
tokenizer = processor.tokenizer

In [None]:
logger.info("Applying QLoRA configuration...")
try:
    model = FastVisionModel.get_peft_model(
        model,
        target_modules = "all-linear",
        r = 16,
        lora_alpha = 32,
        bias = "none",
        use_gradient_checkpointing = True,
        random_state = 3407,
        max_seq_length = 2048,
    )
    logger.info("QLoRA configuration applied successfully")
except Exception as e:
    logger.error(f"Failed to apply QLoRA: {e}")
    raise

In [None]:
logger.info("Trainable parameters:")
model.print_trainable_parameters()

# 6. Training

The `collate_fn` is the heart of the data preparation for training. It takes a batch of formatted samples and does two critical things.</br>
Processes Data: It uses the `processor` to tokenize the text and handle the images for the entire batch at once.</br>
Creates Labels: It intelligently masks the input text (user questions and system prompts) so that the model only learns to predict the assistant's answers. This is the correct way to fine-tune a conversational model.

In [None]:
# Create a data collator to encode text and image pairs
def collate_fn(examples):
    """Custom collate function for training with error handling"""
    try:
        # 1. Extract all texts and all images from the batch of examples.
        all_texts = []
        all_images = []
        for ex in examples:
            # Get the list of PIL images for the current example.
            images = process_vision_info(ex["messages"])
            all_images.append(images)

            # Apply the chat template to convert the messages list into a single string.
            text = processor.apply_chat_template(
                ex["messages"], add_generation_prompt=False, tokenize=False
            ).strip()
            all_texts.append(text)

        # 2. Call the processor ONCE on the entire batch.
        batch = processor(
            text=all_texts,
            images=all_images,
            return_tensors="pt",
            padding=True
        )

        # 3. Create the labels tensor for calculating loss.
        labels = batch["input_ids"].clone()

        # Get the token IDs for the start of a model's turn and the end of a turn.
        model_prompt_start_tokens = processor.tokenizer.encode("<start_of_turn>model\n", add_special_tokens=False)
        end_of_turn_id = processor.tokenizer.eos_token_id

        # Iterate through each sequence in the batch to apply the loss mask.
        for i in range(len(labels)):
            # By default, we want to ignore all tokens when calculating the loss.
            ignore_mask = torch.ones_like(labels[i], dtype=torch.bool)
            
            current_sequence_list = labels[i].tolist()

            # Find all occurrences of the model's turn start sequence.
            for j in range(len(current_sequence_list)):
                if current_sequence_list[j : j + len(model_prompt_start_tokens)] == model_prompt_start_tokens:
                    # We've found the start of an assistant's response.
                    response_start_index = j + len(model_prompt_start_tokens)
                    
                    # Now, find the end of this response.
                    try:
                        # Search for the next end-of-turn token *after* the response starts.
                        response_end_index = current_sequence_list.index(end_of_turn_id, response_start_index)
                    except ValueError:
                        # If no EOS token is found (e.g., it's the end of the sequence), go to the end.
                        response_end_index = len(current_sequence_list) - 1
                    
                    # Unmask the region corresponding to the assistant's response.
                    ignore_mask[response_start_index : response_end_index + 1] = False

            # Apply the final mask. All tokens marked True in ignore_mask will be set to -100.
            labels[i][ignore_mask] = -100

        # Add the correctly masked labels to our batch dictionary.
        batch["labels"] = labels
        return batch
    
    except Exception as e:
        logger.error(f"Error in collate_fn: {e}")
        raise

In [None]:
# Set up the training arguments
logger.info("Setting up training arguments...")
args = TrainingArguments(
    output_dir=output_dir, # directory to save and repository id
    num_train_epochs=1, # number of training epochs
    per_device_train_batch_size=1, # batch size per device during training
    gradient_accumulation_steps=8, # number of steps before performing a backward/update pass
    max_steps=14250,
    gradient_checkpointing=True, # use gradient checkpointing to save memory
    optim="adamw_8bit", # Use 8-bit AdamW for memory efficiency with Unsloth
    logging_steps=25, # log every 25 steps
    save_strategy="steps", # Save checkpoints every `save_steps`
    save_steps=250, # Number of steps between saves
    save_total_limit=2, # Limit the total number of checkpoints to save
    save_only_model=False,
    learning_rate=2e-4, # learning rate, based on QLoRA paper
    fp16=True, # Use fp16 for training if bf16 is not supported or desired
    bf16=False, # Explicitly set bf16 to False
    max_grad_norm=0.3, # max gradient norm based on QLoRA paper
    warmup_ratio=0.03, # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant", # use constant learning rate scheduler
    push_to_hub=True, # Push to HF Hub
    hub_model_id=hub_model_id, # Your Hugging Face Hub repository name
    hub_strategy="checkpoint",
    hub_token=hf_token,
    hub_private_repo=hub_private_repo, # Set to True for a private repo
    report_to="tensorboard", # report metrics to tensorboard
    gradient_checkpointing_kwargs={"use_reentrant": False},  # use reentrant checkpointing
    remove_unused_columns=False,
)

# Initialize the Trainer
logger.info("Initializing Trainer...")
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset, 
    data_collator=collate_fn,  
)

In [None]:
# Start the fine-tuning process
latest_checkpoint = None
if os.path.isdir(args.output_dir):
    checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith("checkpoint-")]
    if checkpoints:
        latest_checkpoint = os.path.join(
            args.output_dir, max(checkpoints, key=lambda x: int(x.split('-')[-1]))
        )
        logger.info(f"Resuming training from latest checkpoint: {latest_checkpoint}")

# Start the fine-tuning process
logger.info("Starting QLoRA fine-tuning...")
try:
    trainer_stats = trainer.train(resume_from_checkpoint=latest_checkpoint)
    logger.info("Training completed successfully")
    
    # Print training metrics
    if hasattr(trainer_stats, 'training_loss'):
        logger.info(f"Final training loss: {trainer_stats.training_loss[-1]:.4f}")
    if hasattr(trainer_stats, 'total_flos'):
        logger.info(f"Total training time: {trainer_stats.total_flos/1e12:.2f} TFLOPS")
        
except Exception as e:
    logger.error(f"Training failed: {e}")
    raise

# 7. Save the Final Model

In [None]:
logger.info("Saving final model...")
final_folder = "/kaggle/working/full_checkpoint"
os.makedirs(final_folder, exist_ok=True)

try:
    # Save model/adapters/tokenizer/trainer state
    trainer.save_model(final_folder)
    tokenizer.save_pretrained(final_folder)
    trainer.save_state()

    # Save optimizer & scheduler state dicts
    torch.save(trainer.optimizer.state_dict(), os.path.join(final_folder, "optimizer_state.pt"))
    torch.save(trainer.lr_scheduler.state_dict(), os.path.join(final_folder, "scheduler_state.pt"))
    
    logger.info("Model saved successfully")
except Exception as e:
    logger.error(f"Failed to save model: {e}")
    raise

In [None]:
# Upload to Hugging Face Hub
logger.info(f"Uploading model to Hugging Face Hub: {hub_model_id}")
try:
    api = HfApi()
    api.upload_folder(
        folder_path="/kaggle/working/full_checkpoint",
        repo_id=hub_model_id,
        repo_type="model",
        path_in_repo="full_checkpoint"
    )
    logger.info("Model uploaded successfully to Hugging Face Hub")
except Exception as e:
    logger.error(f"Failed to upload model to Hub: {e}")
    raise

# 8. Inference
After fine-tuning, let's test our model's new capabilities. This section loads the fully trained model and runs a sample inference on an image from our dataset to verify its diagnostic abilities.

In [None]:
logger.info("Setting up inference...")
try:
    from unsloth import FastVisionModel
    import torch
    from transformers import TextStreamer

    # Enable fast vision inference
    FastVisionModel.for_inference(model)

    # Prepare your single example
    img = process_vision_info(train_dataset[0]["messages"])
    
    # Build the messages chat structure:
    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant specialized in crop disease diagnosis."}]},
        {"role": "user", "content": [
            {"type": "image", "image": img},
            {"type": "text",  "text": "what disease does this leaf has?"}
        ]}
    ]

    # Process inputs
    inputs = processor(
        text=processor.apply_chat_template(messages, add_generation_prompt=True),
        images=[img],
        return_tensors="pt",
    ).to(model.device)

    # Generate output
    logger.info("Generating inference...")
    streamer = TextStreamer(processor.tokenizer)
    _ = model.generate(
        **inputs,
        streamer=streamer,
        max_new_tokens=128,
        temperature=1.0,
        top_k=64,
        top_p=0.95
    )
    
    logger.info("Inference completed successfully")
    
except Exception as e:
    logger.error(f"Inference failed: {e}")
    raise

In [None]:
# Display sample image for verification
logger.info("Displaying sample image...")
try:
    import matplotlib.pyplot as plt
    from PIL import Image
    import numpy as np

    img = process_vision_info(train_dataset[0]["messages"])
    if img:
        img_array = np.squeeze(img)
        plt.figure(figsize=(8, 6))
        plt.imshow(img_array)
        plt.title("Sample Training Image")
        plt.axis('off')
        plt.show()
        logger.info("Sample image displayed successfully")
    else:
        logger.warning("No image found to display")
        
except Exception as e:
    logger.error(f"Failed to display image: {e}")

logger.info("🎉 Fine-tuning pipeline completed successfully!")