# 🌱 AgriVision-Gemma: Vision‑LLM Fine‑Tuning Notebook  
**Fine‑tuning the Gemma‑3n‑e2b‑it model for crop disease diagnosis**  

## 🔖 Table of Contents
1. 🧰 [Setup and Installation](#1-setup-and-installation)  
2. 📦 [Importing Libraries](#2-importing-libraries)  
3. ⚙️ [Configuration](#3-configuration)  
4. 🗂️ [Data Preparation](#4-data-preparation)  
5. 🤖 [Model + QLoRA Configuration](#5-model-loading-and-qlora-configuration)  
6. 🚀 [Training](#6-training)  
7. 💾 [Saving Model](#7-saving-model)  
8. 🧪 [Inference](#8-inference)

<a id="1-setup-and-installation"></a>
# 1. Setup and Installation

In [None]:
%%capture
!pip install --upgrade unsloth
!pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu124
!pip install transformers==4.53.0
!pip install --no-deps --upgrade timm

<a id="2-importing-libraries"></a>

# 2. Importing Libraries

In [2]:
import unsloth
from unsloth import FastModel
import os
import zipfile
import json
import torch
from datasets import Dataset 
from transformers import TrainingArguments 
from transformers import AutoProcessor, AutoModelForImageTextToText 
from PIL import Image
import re
import requests
from huggingface_hub import login

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


2025-07-12 02:50:19.776341: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752288619.962496      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752288620.014335      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🦥 Unsloth Zoo will now patch everything to make training faster!


<a id="3-configuration"></a>

# 3. Configuration


In [None]:
model_id = "unsloth/gemma-3n-e2b-it-unsloth-bnb-4bit"
output_dir = "/kaggle/working/gemma-qlora-finetuned-cddm"

hub_model_id = "shreyansh24/gemma3n-e2b-cddm-finetune"
hub_private_repo = False

hf_token="YOUR_HF_TOKEN"
login(hf_token)

<a id="4-data-preparation"></a>

# 4. Data Preparation

In [4]:
KAGGLE_INPUT_DIR = "/kaggle/input/crop-disease-data/"
dataset_json_path = os.path.join(KAGGLE_INPUT_DIR, "dataset/Crop_Disease_train_qwenvl.json")
system_message = "You are a helpful AI assistant specialized in crop disease diagnosis. Provide concise and accurate information."

In [5]:
# Helper function to extract PIL Images from the TRL-formatted messages
# This is used by the `collate_fn`.
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        for element in content:
            if isinstance(element, dict) and element.get("type") == "image":
                if "image" in element and isinstance(element["image"], Image.Image): # Check if it's actually a PIL Image object
                    image_inputs.append(element["image"].convert("RGB"))
    return image_inputs

In [6]:
# Function to transform CDDM format to TRL multimodal chat format (Gemma style)
def format_data_for_trl(sample):
    trl_messages = []

    # 1. Add system message at the beginning of each conversation (Keep as is)
    trl_messages.append({
        "role": "system",
        "content": [{"type": "text", "text": system_message}]
    })

    # Process each turn in the conversation
    for i, turn in enumerate(sample['conversations']):
        # Map Qwen-VL roles to Gemma roles (Keep as is)
        current_role = "user" if turn["from"] == "user" else "assistant"
        content_value = turn["value"]

        content_parts = []

        img_pattern = r'<img>(.*?)</img>'
        img_match = re.search(img_pattern, content_value)

        # Handle the case where an image is present in the user's turn
        if img_match and current_role == "user": # An image should only appear in a user turn
            image_path_in_text = img_match.group(1) # Path inside <img> tag (e.g., "/dataset/images/Rice,Blast/plant_121175.jpg")

            # --- CRITICAL CORRECTION: Construct the ABSOLUTE image path from Kaggle Input ---
            # Remove leading slash from image_path_in_text if present
            # This ensures that os.path.join works correctly.
            cleaned_image_path_in_text = image_path_in_text.lstrip('/')
            # Join the Kaggle input root with the cleaned relative path
            corrected_image_path = os.path.join(KAGGLE_INPUT_DIR, cleaned_image_path_in_text)


            loaded_image = None
            # Use the correctly constructed path to check if the file exists
            if os.path.exists(corrected_image_path):
                try:
                    # Try opening the image file
                    loaded_image = Image.open(corrected_image_path).convert("RGB")
                    content_parts.append({"type": "image", "image": loaded_image}) # Store the PIL Image object
                except Exception as e:
                    # This warning indicates the file exists but couldn't be opened as an image
                    print(f"Warning: Could not load image {corrected_image_path}: {e}. Adding path as text for sample ID: {sample.get('id', 'N/A')}")
                    content_parts.append({"type": "text", "text": f"Error loading image: {corrected_image_path}"})
            else:
                # This warning indicates the file was not found at the expected path
                print(f"Warning: Image file not found at {corrected_image_path}. Adding path as text for sample ID: {sample.get('id', 'N/A')}")
                content_parts.append({"type": "text", "text": f"Image not found: {corrected_image_path}"})

            # Extract text after removing the <img> tag and "Picture X:" prefix (Keep as is)
            text_part = re.sub(img_pattern, "", content_value)
            text_part = re.sub(r'^(Picture\s+\d+:\s*)', '', text_part).strip()

            if text_part:
                content_parts.append({"type": "text", "text": text_part})
            # If no text and image failed, ensure content_parts is not empty (Keep as is)
            elif not content_parts and loaded_image is None:
                 content_parts.append({"type": "text", "text": ""})


        else:
            # All other turns are text-only (Keep as is)
            cleaned_text = turn["value"].strip()
            if cleaned_text:
                content_parts.append({"type": "text", "text": cleaned_text})
            elif current_role == "assistant":
                content_parts.append({"type": "text", "text": ""})


        if content_parts:
            trl_messages.append({"role": current_role, "content": content_parts})

    # Ensure the messages list is not empty (at least system message should be there) (Keep as is)
    if not trl_messages:
        return {"messages": [{"role": "system", "content": [{"type": "text", "text": system_message}]}]}

    return {"messages": trl_messages}

In [7]:
# Load the dataset from the JSON file
with open(dataset_json_path, 'r') as f:
    data = json.load(f)

In [None]:
start_subset=66000 
end_subset = 76000

train_dataset = [format_data_for_trl(sample) for sample in data[start_subset:end_subset]]
train_dataset[0]

{'messages': [{'role': 'system',
   'content': [{'type': 'text',
     'text': 'You are a helpful AI assistant specialized in crop disease diagnosis. Provide concise and accurate information.'}]},
  {'role': 'user',
   'content': [{'type': 'image',
     'image': <PIL.Image.Image image mode=RGB size=256x256>},
    {'type': 'text', 'text': 'What are the features of this picture?'}]},
  {'role': 'assistant',
   'content': [{'type': 'text',
     'text': 'This image shows a tomato leaf that exhibits symptoms of Yellow Leaf Curl Virus, characterized by yellowing and curling of the leaves.'}]},
  {'role': 'user',
   'content': [{'type': 'text', 'text': "What plant's leaf is this?"}]},
  {'role': 'assistant',
   'content': [{'type': 'text', 'text': 'This is a tomato leaf.'}]},
  {'role': 'user',
   'content': [{'type': 'text', 'text': 'Is this crop diseased?'}]},
  {'role': 'assistant',
   'content': [{'type': 'text',
     'text': 'Yes, this tomato leaf is afflicted with Yellow Leaf Curl Virus.

In [9]:
# Find a sample with an image in a user turn and print its content
print(f"\nExample of image loading for a user turn:")
image_sample_found = False
for sample in train_dataset:
    for msg in sample['messages']:
        if msg['role'] == 'user':
            for content_elem in msg['content']:
                if content_elem['type'] == 'image':
                    print(f"  Image element found. PIL Image object: {content_elem['image']}")
                    image_sample_found = True
                    break
        if image_sample_found:
            break
    if image_sample_found:
        break
if not image_sample_found:
    print("No image found in a user turn in the first few samples to demonstrate PIL object loading.")


Example of image loading for a user turn:
  Image element found. PIL Image object: <PIL.Image.Image image mode=RGB size=256x256 at 0x7EA0B6C9DB10>


<a id="5-model-loading-and-qlora-configuration"></a>

# 5. Model Loading and QLoRA Configuration

In [10]:
# --- Unsloth Model Loading ---
# Unsloth simplifies model loading and quantization.
# It automatically sets up 4-bit quantization (QLoRA) and prepares the model for training.
model, processor = FastModel.from_pretrained(
    model_name = model_id,
    max_seq_length = 2048,
    dtype = torch.float16, # Use bfloat16 for computation
    load_in_4bit = True,    # Enables QLoRA
    # processor_class = AutoProcessor, # Unsloth can infer this for common models
    # model_class = AutoModelForImageTextToText, # Unsloth can infer this for common models
    trust_remote_code=True
)

Are you certain you want to do remote code execution?
==((====))==  Unsloth 2025.7.3: Fast Gemma3N patching. Transformers: 4.53.0.
   \\   /|    Tesla T4. Num GPUs = 2. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/2.65G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/469M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

In [11]:
# Unsloth handles modules_to_save internally for `lm_head` and `embed_tokens` for multimodal models.
model = FastModel.get_peft_model(
    model,
    # finetune_vision_layers     = True,  # Turn off for just text!
    # finetune_language_layers   = True,  # Should leave on!
    # finetune_attention_modules = True,  # Attention good for GRPO
    # finetune_mlp_modules       = True,  # SHould leave on always!
    target_modules = "all-linear",
    r = 16,
    lora_alpha = 32,
    bias = "none",
    use_gradient_checkpointing = True,
    random_state = 3407,
    max_seq_length = 2048,
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


In [12]:
model.print_trainable_parameters()  

trainable params: 21,135,360 || all params: 5,460,573,632 || trainable%: 0.3871


<a id="6-training"></a>

# 6. Training

In [1]:
# from huggingface_hub import snapshot_download
# import os

# base = snapshot_download(
#     repo_id=hub_model_id,               # e.g. "sdvsvs/gemma3n-e2b-cddm-finetune"
#     allow_patterns=["last-checkpoint/*"], 
#     token=hf_token
# )
# src = os.path.join(base, "last-checkpoint")
# dst = "/kaggle/working/gemma-qlora-finetuned-cddm/checkpoint-5500"

# import shutil
# if os.path.isdir(dst):
#     shutil.rmtree(dst)
# shutil.copytree(src, dst)

In [13]:
# Create a data collator to encode text and image pairs
def collate_fn(examples):
    # 1. Extract all texts and all images from the batch of examples.
    all_texts = []
    all_images = []
    for ex in examples:
        # Get the list of PIL images for the current example.
        images = process_vision_info(ex["messages"])
        all_images.append(images)

        # Apply the chat template to convert the messages list into a single string.
        text = processor.apply_chat_template(
            ex["messages"], add_generation_prompt=False, tokenize=False
        ).strip()
        all_texts.append(text)

    # 2. Call the processor ONCE on the entire batch.
    # This correctly handles tokenization, padding, and image processing for all items at once.
    # This is the key fix that resolves the TypeError.
    batch = processor(
        text=all_texts,
        images=all_images,
        return_tensors="pt",
        padding=True
    )

    # 3. Create the labels tensor for calculating loss.
    labels = batch["input_ids"].clone()

    # Get the token IDs for the start of a model's turn and the end of a turn.
    # This helps us identify exactly which parts of the sequence to mask.
    model_prompt_start_tokens = processor.tokenizer.encode("<start_of_turn>model\n", add_special_tokens=False)
    end_of_turn_id = processor.tokenizer.eos_token_id

    # Iterate through each sequence in the batch to apply the loss mask.
    for i in range(len(labels)):
        # By default, we want to ignore all tokens when calculating the loss.
        ignore_mask = torch.ones_like(labels[i], dtype=torch.bool)
        
        current_sequence_list = labels[i].tolist()

        # Find all occurrences of the model's turn start sequence.
        for j in range(len(current_sequence_list)):
            if current_sequence_list[j : j + len(model_prompt_start_tokens)] == model_prompt_start_tokens:
                # We've found the start of an assistant's response.
                response_start_index = j + len(model_prompt_start_tokens)
                
                # Now, find the end of this response.
                try:
                    # Search for the next end-of-turn token *after* the response starts.
                    response_end_index = current_sequence_list.index(end_of_turn_id, response_start_index)
                except ValueError:
                    # If no EOS token is found (e.g., it's the end of the sequence), go to the end.
                    response_end_index = len(current_sequence_list) - 1
                
                # Unmask the region corresponding to the assistant's response.
                # We calculate loss from the start of the response content to the end token (inclusive).
                ignore_mask[response_start_index : response_end_index + 1] = False

        # Apply the final mask. All tokens marked True in ignore_mask will be set to -100.
        labels[i][ignore_mask] = -100

    # Add the correctly masked labels to our batch dictionary.
    batch["labels"] = labels
    
    return batch

In [14]:
# Set up the training arguments using transformers.TrainingArguments
# This replaces trl.SFTConfig, as Unsloth enhances the standard Trainer.
args = TrainingArguments(
    output_dir=output_dir, # directory to save and repository id
    num_train_epochs=1, # number of training epochs (Previous: 3)
    per_device_train_batch_size=1, # batch size per device during training (Previous: 1)
    gradient_accumulation_steps=8, # number of steps before performing a backward/update pass (Previous: 4)
    max_steps=8750,
    gradient_checkpointing=True, # use gradient checkpointing to save memory (Previous: True)
    optim="adamw_8bit", # Use 8-bit AdamW for memory efficiency with Unsloth (Previous: adamw_torch_fused)
    logging_steps=25, # log every 10 steps (Previous: 10)
    save_strategy="steps", # NEW: Save checkpoints every `save_steps` (Previous: "epoch")
    save_steps=250, # NEW: Number of steps between saves (adjust based on dataset size/training time)
    save_total_limit=2, # NEW: Limit the total number of checkpoints to save
    save_only_model=False,
    learning_rate=2e-4, # learning rate, based on QLoRA paper (Previous: 2e-4)
    fp16=True, # Use fp16 for training if bf16 is not supported or desired
    bf16=False, # Explicitly set bf16 to False
    max_grad_norm=0.3, # max gradient norm based on QLoRA paper (Previous: 0.3)
    warmup_ratio=0.03, # warmup ratio based on QLoRA paper (Previous: 0.03)
    lr_scheduler_type="constant", # use constant learning rate scheduler (Previous: "constant")
    push_to_hub=True, # Change to True if you want to push to HF Hub (Previous: False)
    hub_model_id=hub_model_id, # NEW: Your Hugging Face Hub repository name
    hub_strategy="checkpoint",
    hub_token=hf_token,
    hub_private_repo=hub_private_repo, # NEW: Set to True for a private repo
    report_to="tensorboard", # report metrics to tensorboard (Previous: "tensorboard")
    gradient_checkpointing_kwargs={"use_reentrant": False},  # use reentrant checkpointing (Previous: same)
    remove_unused_columns=False,
)

# Initialize the Trainer (from transformers)
from transformers import Trainer # Explicitly import Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset, 
    data_collator=collate_fn,  
)

In [18]:
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = Tesla T4. Max memory = 14.741 GB.
10.988 GB of memory reserved.


In [15]:
# Start the fine-tuning process
print("\nStarting QLoRA fine-tuning...")
trainer_stats = trainer.train(resume_from_checkpoint=True)


Starting QLoRA fine-tuning...


	per_device_train_batch_size: 1 (from args) != 0 (from trainer_state.json)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
7525,0.0581
7550,0.1107
7575,0.085
7600,0.204
7625,0.1066
7650,0.139
7675,0.0398
7700,0.0725
7725,0.1074
7750,0.092



<a id="7-saving-model"></a>
# 7. Save the Final Model

In [None]:
final_folder = "/kaggle/working/full_checkpoint_22k_sample"
os.makedirs(final_folder, exist_ok=True)

# 4a) Save model/adapters/tokenizer/trainer state
trainer.save_model(final_folder)
# tokenizer.save_pretrained(final_folder)
trainer.save_state()

# 4b) Optionally save optimizer & scheduler state dicts
torch.save(trainer.optimizer.state_dict(), os.path.join(final_folder, "optimizer_state.pt"))
torch.save(trainer.lr_scheduler.state_dict(), os.path.join(final_folder, "scheduler_state.pt"))

In [None]:
from huggingface_hub import HfApi

api = HfApi()
api.upload_folder(
    folder_path="/kaggle/working/full_checkpoint_22k_sample",
    repo_id=hub_model_id,
    repo_type="model",      
    path_in_repo="full_checkpoint_22k_sample"  
)


<a id="8-inference"></a>
# 8. Inference


In [None]:
import torch
from transformers import TextStreamer
from unsloth import FastVisionModel

# 1️⃣ Enable fast vision inference
FastVisionModel.for_inference(model)

# 2️⃣ Prepare your single example
# Load the image from disk (or reuse one from your dataset)
from PIL import Image
img=process_vision_info(train_dataset[0]["messages"])
# Build the messages chat structure:
messages = [
    {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant specialized in crop disease diagnosis."}]},
    {"role": "user", "content": [
        {"type": "image", "image": img},
        {"type": "text",  "text": "what diesis does this leaf has?"}
    ]}
]

# 3️⃣ Process inputs
inputs = processor(
    text=processor.apply_chat_template(messages, add_generation_prompt=True),
    images=[ img ],
    return_tensors="pt",
).to(model.device)

# 4️⃣ Generate output
streamer = TextStreamer(processor.tokenizer)
_ = model.generate(
    **inputs,
    streamer=streamer,
    max_new_tokens=128,
    temperature=1.0,
    top_k=64,
    top_p=0.95
)


In [None]:
train_dataset[0]

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

img=process_vision_info(train_dataset[0]["messages"])
img=np.squeeze(img)
plt.imshow(img)