In [None]:

# Installation Cell - Run this FIRST
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers<0.0.27 trl<0.9.0 peft accelerate bitsandbytes
!pip install torch torchvision pillow datasets

In [None]:
!pip install qwen-vl-utils

In [None]:
# =======================
# Load dataset from Drive
# =======================
from google.colab import drive
drive.mount('/content/drive')

# Update these paths to match your Drive structure
DATASET_JSON = "/content/drive/MyDrive/PlantVillageVQA_subset_1000/PlantVillageVQA_1000.json"
IMAGE_ROOT   = "/content/drive/MyDrive/PlantVillageVQA_subset_1000/Images"   # folder where images/ exist


import json
from datasets import Dataset

# Load dict from JSON
with open(DATASET_JSON, "r") as f:
    raw_data = json.load(f)

expanded = []

# Loop through each image entry
for img_key, entry in raw_data.items():
    image_path = entry["image_path"]  # e.g., "images/train/xxx.jpg"

    for qa in entry["questions"]:
        expanded.append({
            "image": image_path,
            "question": qa["question"],
            "answer": qa["answer"]
        })

print("Total QA pairs:", len(expanded))
print(expanded[0])

dataset = Dataset.from_list(expanded)
print(dataset)



In [None]:
# ------------------------------------------------------------
# STEP 1: Load from Google Drive
# ------------------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')

import os, json
from datasets import Dataset
from PIL import Image

BASE_PATH = "/content/drive/MyDrive/PlantVillageVQA_subset_1000/"  # root folder
JSON_PATH = BASE_PATH + "PlantVillageVQA_1000.json"

print("Loading JSON:", JSON_PATH)

# ------------------------------------------------------------
# STEP 2: Load raw JSON (dictionary)
# ------------------------------------------------------------
with open(JSON_PATH, "r") as f:
    raw_data = json.load(f)

print("Top-level entries:", len(raw_data))


# ------------------------------------------------------------
# STEP 3: Expand into list of {image, question, answer}
# ------------------------------------------------------------
expanded = []

for img_key, entry in raw_data.items():
    image_path = entry["image_path"]                     # e.g., images/train/image_0001.JPG
    for qa in entry["questions"]:
        expanded.append({
            "image": image_path,
            "question": qa["question"],
            "answer": qa["answer"],
        })

print("Total expanded QA pairs:", len(expanded))


# ------------------------------------------------------------
# STEP 4: Cache images so they load only once (10–20× faster)
# ------------------------------------------------------------
image_cache = {}

def load_image_once(rel_path):
    if rel_path in image_cache:
        return image_cache[rel_path]

    # Fix path: JSON uses "images/...", actual folder is "Images/train/..."
    corrected_path = rel_path
    corrected_path = corrected_path.replace("images/train", "Images/train")
    corrected_path = corrected_path.replace("images/test", "Images/train")
    corrected_path = corrected_path.replace("images/val", "Images/train")

    full_path = BASE_PATH + corrected_path

    # Fix .jpg vs .JPG mismatch
    if not os.path.exists(full_path):
        full_path = full_path.replace(".JPG", ".jpg")
    if not os.path.exists(full_path):
        full_path = full_path.replace(".jpg", ".JPG")

    if not os.path.exists(full_path):
        raise FileNotFoundError(f"Cannot find image: {full_path}")

    # Load once → store in cache
    img = Image.open(full_path).convert("RGB")
    image_cache[rel_path] = img
    return img


# ------------------------------------------------------------
# STEP 5: Convert each sample into Qwen2-VL "messages" structure
# ------------------------------------------------------------
def convert_to_messages(sample):
    image = load_image_once(sample["image"])

    return {
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text",  "text": sample["question"]},
                    {"type": "image", "image": image},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text",  "text": str(sample["answer"])},
                ],
            },
        ]
    }


# ------------------------------------------------------------
# STEP 6: Apply conversion
# ------------------------------------------------------------
converted_dataset = [convert_to_messages(s) for s in expanded]

print("Converted dataset size:", len(converted_dataset))
print("Example:", converted_dataset[0])


In [None]:
"""
FULL SCRIPT: Qwen2-VL 2B Fine-tuning on PlantVillageVQA_subset_1000
Optimized for Google Colab T4 using Unsloth + QLoRA
"""

# =====================================================================
# STEP 1: Install packages (RUN ONE TIME, THEN RESTART RUNTIME)
# =====================================================================
# !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes
# !pip install torch torchvision pillow datasets qwen-vl-utils

# After installing: Runtime -> Restart runtime


# =====================================================================
# STEP 2: Imports
# =====================================================================
import torch
from unsloth import FastVisionModel, is_bfloat16_supported
from unsloth.trainer import UnslothVisionDataCollator

from PIL import Image
import json, os
from datasets import Dataset
from trl import SFTConfig, SFTTrainer


# =====================================================================
# STEP 3: Load Qwen2-VL model (4-bit, T4 friendly)
# =====================================================================
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Qwen2-VL-2B-Instruct",
    load_in_4bit=True,
    use_gradient_checkpointing="unsloth",
)

model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers=True,
    finetune_language_layers=True,
    finetune_attention_modules=True,
    finetune_mlp_modules=True,
    r=16,
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    random_state=3407,
)

FastVisionModel.for_training(model)


# =====================================================================
# STEP 4: LOAD PLANTVILLAGE DATASET
# =====================================================================
from google.colab import drive
drive.mount('/content/drive')

BASE_PATH = "/content/drive/MyDrive/PlantVillageVQA_subset_1000/"
JSON_PATH = BASE_PATH + "PlantVillageVQA_1000.json"

print("Loading JSON:", JSON_PATH)

with open(JSON_PATH, "r") as f:
    raw_data = json.load(f)

print("Top-level entries:", len(raw_data))

# Expand into flat list of dicts
expanded = []
for img_key, entry in raw_data.items():
    image_path = entry["image_path"]
    for qa in entry["questions"]:
        expanded.append({
            "image": image_path,
            "question": qa["question"],
            "answer": qa["answer"],
        })

print("Total expanded QA pairs:", len(expanded))


# =====================================================================
# STEP 5: CONVERT TO QWEN2-VL FORMAT (Optimized)
# =====================================================================

# ----------- image cache (20× faster) ------------ #
image_cache = {}

def load_image_once(rel_path):
    if rel_path in image_cache:
        return image_cache[rel_path]

    corrected = rel_path
    corrected = corrected.replace("images/train", "Images/train")
    corrected = corrected.replace("images/test",  "Images/train")
    corrected = corrected.replace("images/val",   "Images/train")

    full_path = BASE_PATH + corrected

    # JPG/jpg fix
    if not os.path.exists(full_path):
        full_path = full_path.replace(".JPG", ".jpg")
    if not os.path.exists(full_path):
        full_path = full_path.replace(".jpg", ".JPG")

    if not os.path.exists(full_path):
        raise FileNotFoundError("Missing image: " + full_path)

    img = Image.open(full_path).convert("RGB")
    image_cache[rel_path] = img
    return img


def convert_to_conversation(sample):
    image = load_image_once(sample["image"])

    return {
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": sample["question"]},
                    {"type": "image", "image": image},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": str(sample["answer"])},
                ],
            },
        ]
    }


print("Converting PlantVillage dataset...")
converted_dataset = [convert_to_conversation(s) for s in expanded]

print("Converted dataset size:", len(converted_dataset))
print("Example:", converted_dataset[0])


# =====================================================================
# STEP 6: TRAINING CONFIGURATION
# =====================================================================
training_args = SFTConfig(
    output_dir="./qwen2vl-finetuned-plantvillage",
    report_to="none",
    logging_steps=5,

    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,   # effective batch = 8
    num_train_epochs=1,

    learning_rate=2e-4,
    warmup_ratio=0.1,
    weight_decay=0.01,
    optim="adamw_8bit",

    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),

    max_seq_length=512,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},

    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,

    seed=3407,

    remove_unused_columns=False,          # REQUIRED
    dataset_text_field="",                # unused
    dataset_kwargs={"skip_prepare_dataset": True},
)


# =====================================================================
# STEP 7: CREATE TRAINER
# =====================================================================
data_collator = UnslothVisionDataCollator(model, tokenizer)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=converted_dataset,
    data_collator=data_collator,
)


# =====================================================================
# STEP 8: TRAIN
# =====================================================================
print("\n" + "="*60)
print("Starting training...")
print("Samples:", len(converted_dataset))
print("GPU:", torch.cuda.get_device_name(0))
print("="*60, "\n")

trainer_stats = trainer.train()

print("\nTraining completed.")
print(trainer_stats.metrics)


# =====================================================================
# STEP 9: SAVE MODEL
# =====================================================================
model.save_pretrained("qwen2vl_lora_plantvillage")
tokenizer.save_pretrained("qwen2vl_lora_plantvillage")
print("✓ Saved at: qwen2vl_lora_plantvillage/")


# =====================================================================
# STEP 10: INFERENCE
# =====================================================================
from qwen_vl_utils import process_vision_info
FastVisionModel.for_inference(model)

def test_inference(image, question):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": question},
            ],
        }
    ]

    image_inputs, video_inputs = process_vision_info(messages)

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(
        text=[prompt],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=128,
            temperature=0.7,
            do_sample=True,
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)


In [None]:
!cp -r qwen2vl_lora_plantvillage "/content/drive/MyDrive/"


In [None]:
# 🟩 Upload image
from google.colab import files
uploaded = files.upload()
image_path = list(uploaded.keys())[0]

# 🟩 Load the image
from PIL import Image
image = Image.open(image_path).convert("RGB")

# 🟩 Load model + LoRA
from unsloth import FastVisionModel

model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Qwen2-VL-2B-Instruct",
    load_in_4bit=True,
)
model.load_adapter("/content/drive/MyDrive/qwen2vl_lora_plantvillage")
FastVisionModel.for_inference(model)

# 🟩 Test inference
question = "What pest or disease is affecting this leaf and how can the infestation be reduced?"
answer = test_inference(image, question)

print("Answer:", answer)
