In [None]:
!pip install transformers accelerate peft bitsandbytes sentencepiece
!pip install qwen-vl-utils


Collecting bitsandbytes
  Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl (59.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.49.0
Collecting qwen-vl-utils
  Downloading qwen_vl_utils-0.0.14-py3-none-any.whl.metadata (9.0 kB)
Collecting av (from qwen-vl-utils)
  Downloading av-16.0.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.6 kB)
Downloading qwen_vl_utils-0.0.14-py3-none-any.whl (8.1 kB)
Downloading av-16.0.1-cp312-cp312-manylinux_2_28_x86_64.whl (40.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 MB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: av, qwen-vl-utils
Successfully installed av-16.0.1 qwen-vl-utils-0.0.14


In [None]:
!pip install transformers accelerate peft bitsandbytes qwen-vl-utils




In [None]:
!pip uninstall -y bitsandbytes
!pip install --upgrade pip
!pip install unsloth transformers accelerate pillow datasets


Found existing installation: bitsandbytes 0.49.0
Uninstalling bitsandbytes-0.49.0:
  Successfully uninstalled bitsandbytes-0.49.0
Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.3
Collecting unsloth
  Downloading unsloth-2025.12.5-py3-none-any.whl.metadata (65 kB)
Collecting unsloth_zoo>=2025.12.4 (from unsloth)
  Downloading unsloth_zoo-2025.12.4-py3-none-any.whl.metadata (32 kB)
Collecting tyro (from unsloth)
  Downloading tyro-1.0.1-py3-none-any.whl.metadata (11 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl.me

In [None]:
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"


In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from huggingface_hub import login
login("hugging_face_token")


LORA BASED GRPO TRAINING

In [None]:
# Patched GRPO for Qwen2-VL + LoRA (hybrid precision loader fixed)
# - Auto device_map then force vision modules to GPU
# - Larger LoRA coverage and rank for meaningful updates
# - Debug mode (run_few_steps) to verify gradients
#
# Edit PATHS at the top (json_path, IMAGE_ROOT, save_path) before running.

import os, json, time, gc, re, difflib
from collections import defaultdict
import torch
import torch.cuda.amp as amp
from torch.utils.data import Dataset, DataLoader
from PIL import Image

from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training
from safetensors.torch import save_file as safe_save

# -------------------------
# PATHS (EDIT)
# -------------------------
json_path = "/content/drive/MyDrive/plantfolder100/plant100.json"
IMAGE_ROOT = "/content/drive/MyDrive/plantfolder100/Images/train"
existing_lora = None   # set path to continue from an existing adapter if desired
save_path = "/content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed"
base_model_name = "Qwen/Qwen2-VL-2B-Instruct"

os.makedirs(save_path, exist_ok=True)

# -------------------------
# HYPERPARAMETERS (T4-friendly)
# -------------------------
train_batch_size = 1
gradient_accumulation_steps = 8
lr = 3e-5
num_epochs = 3
gen_max_tokens = 48
kl_coeff = 0.02
gamma_baseline = 0.99
max_grad_norm = 1.0
log_interval = 10
max_seq_length = 512

# LoRA config (bigger capacity)
LORA_R = 64
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

use_4bit = True
use_gradient_checkpointing = True
cache_vision_embeddings = True

use_amp = False
scaler = None

use_gpu = torch.cuda.is_available()
device = "cuda" if use_gpu else "cpu"

# DEBUG mode: run only a few steps to verify gradients (set False for full training)
run_few_steps = False
few_steps = 5

# -------------------------
# UTILITIES
# -------------------------
def clear_cache():
    if use_gpu:
        try: torch.cuda.empty_cache()
        except: pass
    gc.collect()

def extract_tag(text, tag):
    if not text: return ""
    m = re.search(fr"<{tag}>(.*?)</{tag}>", text, re.S | re.I)
    return m.group(1).strip() if m else ""

def similarity(a,b):
    a = (a or "").strip().lower(); b=(b or "").strip().lower()
    if not a or not b: return 0.0
    return difflib.SequenceMatcher(None,a,b).ratio()

# -------------------------
# LOAD DATA + PROCESSOR
# -------------------------
with open(json_path, "r") as f:
    raw_data = json.load(f)

processor = AutoProcessor.from_pretrained(base_model_name, trust_remote_code=True)
# reduce image size for memory
try:
    if hasattr(processor, "image_processor"):
        processor.image_processor.size = {"height": 448, "width": 448}
except Exception:
    pass

# -------------------------
# DATASET (PlantDataset)
# -------------------------
class PlantDataset(Dataset):
    def __init__(self,data,processor,root,max_length=512):
        self.data = data
        self.processor = processor
        self.root = root
        self.max_length = max_length
        self.tokenized_cache = {}
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        ex = self.data[idx]
        img_name = os.path.basename(ex.get("image",""))
        if idx in self.tokenized_cache:
            cached = self.tokenized_cache[idx].copy(); cached["img_name"]=img_name; return cached
        img_path = os.path.join(self.root, img_name)
        try:
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            raise RuntimeError(f"Failed to open {img_path}: {e}")
        convs = ex.get("conversations", [])
        if len(convs) < 2:
            user = ex.get("question", "<image>\nQuestion: (no question found)")
            assistant = ex.get("answer", "<answer>(no answer)</answer>")
        else:
            user = convs[0].get("value","")
            assistant = convs[1].get("value","")
        gold_perception = extract_tag(assistant, "visual_perception")
        gold_answer = extract_tag(assistant, "answer")
        if "<image>" not in user:
            user = "<image>\n" + user
        messages = [
            {"role":"user","content":[{"type":"text","text":user},{"type":"image","image":img}]},
            {"role":"assistant","content":[{"type":"text","text":assistant}]}
        ]
        encoded = self.processor.apply_chat_template(
            messages, tokenize=True, return_tensors="pt",
            add_generation_prompt=False, return_dict=True, max_length=self.max_length, truncation=True
        )
        input_ids = encoded["input_ids"]
        labels = input_ids.clone(); labels[:] = -100
        eos = self.processor.tokenizer.eos_token_id
        pos = (input_ids==eos).nonzero(as_tuple=True)
        if pos[0].numel()>0:
            start = pos[1][0].item()+1
            labels[0,start:] = input_ids[0,start:]
        else:
            labels[:] = -100
        encoded["labels"]=labels
        encoded["user_text"]=user; encoded["gold_answer"]=gold_answer; encoded["gold_perception"]=gold_perception; encoded["img_name"]=img_name
        result = {k:(v.squeeze(0) if torch.is_tensor(v) else v) for k,v in encoded.items()}
        cache_entry = {k:(v.detach().cpu() if torch.is_tensor(v) else v) for k,v in result.items() if k!="img_name"}
        self.tokenized_cache[idx] = cache_entry
        return result

def collate_fn(batch):
    out={}
    keys=batch[0].keys()
    for k in keys:
        vals=[b[k] for b in batch]
        if torch.is_tensor(vals[0]):
            out[k]=torch.stack(vals)
        else:
            out[k]=vals
    return out

train_dataset = PlantDataset(raw_data, processor, IMAGE_ROOT, max_seq_length)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=True)

# -------------------------
# TARGET MODULES - broadened
# -------------------------
target_modules = [
    "q_proj","k_proj","v_proj","o_proj",
    "down_proj","up_proj","dense","linear",
    "proj_in","proj_out","wq","wk","wv","wo",
    "gated_act_proj","mlp_dense_h_to_4h","mlp_dense_4h_to_h"
]
target_modules = list(dict.fromkeys(target_modules))
print("Target modules:", target_modules)

# -------------------------
# HYBRID MODEL LOADING: FIXED DEVICE MAP (auto -> force vision GPU)
# -------------------------
print("STEP 1: Loading model with automatic device_map to build a safe map...")

if use_4bit:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        llm_int8_enable_fp32_cpu_offload=True,
    )
else:
    bnb_config = None

# 1) Temporary load to get an auto device map
tmp_model = Qwen2VLForConditionalGeneration.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
)
auto_map = getattr(tmp_model, "hf_device_map", None) or getattr(tmp_model, "device_map", None) or {}
del tmp_model
clear_cache()

# 2) Force certain vision modules to cuda if present in auto_map keys
forced_gpu_keys = [
    "vision_tower",
    "vision_tower.vision_model",
    "vision_tower.vision_model.encoder",
    "vision_tower.vision_model.embeddings",
    "vision_proj",
    "multi_modal_projector",
    "multi_modal_projector.proj",
]

for key in list(auto_map.keys()):
    for fk in forced_gpu_keys:
        if fk in key:
            auto_map[key] = "cuda"

print("Corrected device_map (sample):")
for k,v in list(auto_map.items())[:30]:
    print(k, "->", v)
print("Loading final model with corrected device_map...")

policy_base = Qwen2VLForConditionalGeneration.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map=auto_map,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
)

# Prepare for kbit training if using 4-bit
if use_4bit:
    try:
        policy_base = prepare_model_for_kbit_training(policy_base)
    except Exception:
        pass

# Ensure use_cache False and enable gradient checkpointing
try:
    policy_base.config.use_cache = False
    if use_gradient_checkpointing:
        policy_base.gradient_checkpointing_enable()
        print("Enabled gradient checkpointing")
except Exception:
    pass

clear_cache()

# -------------------------
# Attach LoRA (higher capacity)
# -------------------------
print("Attaching higher-capacity LoRA...")
peft_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=target_modules,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)

policy_model = get_peft_model(policy_base, peft_cfg)

# Optionally continue from existing adapter
if existing_lora and os.path.isdir(existing_lora):
    try:
        policy_model = PeftModel.from_pretrained(policy_model, existing_lora, is_trainable=True)
        print("Loaded existing LoRA adapter for continued training.")
    except Exception as e:
        print("Could not load existing LoRA (continuing from scratch):", e)

# Count trainable params
trainable_params = sum(p.numel() for p in policy_model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in policy_model.parameters())
print(f"Trainable params: {trainable_params:,} || All params: {all_params:,} || Trainable%: {100.0*trainable_params/all_params:.6f}%")

# -------------------------
# Optimizer
# -------------------------
try:
    import bitsandbytes as bnb
    optimizer = bnb.optim.AdamW8bit([p for p in policy_model.parameters() if p.requires_grad], lr=lr, betas=(0.9,0.95))
    print("Using 8-bit AdamW")
except Exception:
    optimizer = torch.optim.AdamW([p for p in policy_model.parameters() if p.requires_grad], lr=lr)
    print("Using standard AdamW")

clear_cache()

# -------------------------
# Reference model for KL
# -------------------------
ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map=auto_map,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
)
ref_model.eval()
for p in ref_model.parameters(): p.requires_grad=False

# -------------------------
# Vision cache (optional)
# -------------------------
vision_cache = {}
def precompute_vision_embeddings():
    policy_model.eval()
    unique_images = set(os.path.basename(ex.get("image","")) for ex in raw_data)
    print(f"Caching {len(unique_images)} images...")
    for i, img_name in enumerate(unique_images,1):
        p = os.path.join(IMAGE_ROOT, img_name)
        try:
            im = Image.open(p).convert("RGB")
        except Exception:
            continue
        msg = [{"role":"user","content":[{"type":"text","text":"<image>"},{"type":"image","image":im}]}]
        enc = processor.apply_chat_template(msg, tokenize=True, return_tensors="pt", add_generation_prompt=False, return_dict=True)
        with torch.no_grad():
            pv = enc.get("pixel_values"); ig = enc.get("image_grid_thw")
            vision_cache[img_name] = {"pixel_values": pv.detach().cpu(), "image_grid_thw": ig.detach().cpu() if ig is not None else None}
        if i % 100 == 0: clear_cache()
    print("Cached images:", len(vision_cache))
    clear_cache()

if cache_vision_embeddings:
    precompute_vision_embeddings()

# -------------------------
# Fast generate helper
# -------------------------
@torch.inference_mode()
def fast_generate(model, input_ids, attention_mask, pixel_values, image_grid_thw, max_new_tokens):
    return model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        pixel_values=pixel_values,
        image_grid_thw=image_grid_thw,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
    )

# -------------------------
# Reward helpers
# -------------------------
def rouge_l_simple(a, b):
    if not a or not b: return 0.0
    a,b=a.strip().lower(),b.strip().lower()
    la,lb=len(a),len(b)
    dp=[[0]*(lb+1) for _ in range(la+1)]
    for i in range(1,la+1):
        for j in range(1,lb+1):
            if a[i-1]==b[j-1]:
                dp[i][j]=dp[i-1][j-1]+1
            else:
                dp[i][j]=max(dp[i-1][j], dp[i][j-1])
    lcs = dp[la][lb]
    return (2.0*lcs)/(la+lb+1e-12)

def jaccard(a,b):
    a_set=set((a or "").lower().split()); b_set=set((b or "").lower().split())
    if not a_set and not b_set: return 0.0
    return len(a_set & b_set)/max(1,len(a_set|b_set))

# -------------------------
# TRAIN LOOP (GRPO)
# -------------------------
running_baseline = 0.0
baseline_init = False
global_step = 0
start_time = time.time()
optimizer.zero_grad()

print("\nSTARTING IMPROVED HYBRID GRPO (FIXED LOADER)\n")

for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(train_loader):
        global_step += 1
        tensors = {k:v for k,v in batch.items() if torch.is_tensor(v)}
        gold_answer = batch["gold_answer"][0]
        gold_perception = batch["gold_perception"][0]
        img_name = batch["img_name"][0]
        user_text = batch.get("user_text",[None])[0]

        # vision inputs (from cache or on-the-fly)
        if img_name in vision_cache:
            vis = vision_cache[img_name]
            pixel_values = vis["pixel_values"].to(device)
            image_grid_thw = vis.get("image_grid_thw")
            if image_grid_thw is not None: image_grid_thw = image_grid_thw.to(device)
        else:
            pixel_values = tensors["pixel_values"].to(device)
            image_grid_thw = tensors.get("image_grid_thw", None)
            if image_grid_thw is not None: image_grid_thw = image_grid_thw.to(device)

        input_ids = tensors["input_ids"].to(device)
        attn = tensors["attention_mask"].to(device)

        # ========== ROLLOUT PHASE ==========
        policy_model.eval()
        with torch.no_grad(), torch.amp.autocast(device_type='cuda' if use_gpu else 'cpu', enabled=use_amp):
            gen_ids = fast_generate(policy_model, input_ids, attn, pixel_values, image_grid_thw, gen_max_tokens)

        gen_text = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)[0]
        pred_perception = extract_tag(gen_text, "visual_perception")
        pred_answer1 = extract_tag(gen_text, "answer")

        # self-rollout (no image) using only perception
        if user_text:
            qm = re.search(r"Question\s*:\s*(.*)", user_text, re.S | re.I)
            question_text = qm.group(1).strip() if qm else user_text
        else:
            prompt_text = processor.tokenizer.decode(input_ids[0], skip_special_tokens=True)
            qm = re.search(r"Question\s*:\s*(.*)", prompt_text, re.S | re.I)
            question_text = qm.group(1).strip() if qm else prompt_text

        perception_prompt = (
            "You previously saw the image and described it as:\n"
            f"<visual_perception>{pred_perception}</visual_perception>\n\n"
            "Using ONLY the above perception, answer the question accurately. Do NOT invent new perception.\n"
            f"Question: {question_text}\n"
            "Give final answer inside <answer> tags."
        )

        msg2 = [{"role":"user","content":[{"type":"text","text":perception_prompt}]}]
        enc2 = processor.apply_chat_template(msg2, tokenize=True, return_tensors="pt", add_generation_prompt=True, return_dict=True, max_length=max_seq_length, truncation=True)
        input_ids2 = enc2["input_ids"].to(device)
        attn2 = enc2["attention_mask"].to(device)

        with torch.no_grad(), torch.amp.autocast(device_type='cuda' if use_gpu else 'cpu', enabled=use_amp):
            gen2_ids = policy_model.generate(input_ids=input_ids2, attention_mask=attn2, max_new_tokens=gen_max_tokens, do_sample=False, num_beams=1, pad_token_id=processor.tokenizer.pad_token_id)

        gen2_text = processor.tokenizer.batch_decode(gen2_ids, skip_special_tokens=True)[0]
        pred_answer2 = extract_tag(gen2_text, "answer")

        # ------------------ COMPUTE RICH REWARD ------------------
        ans1_score = 0.6 * rouge_l_simple(pred_answer1, gold_answer) + 0.4 * jaccard(pred_answer1, gold_answer)
        ans2_score = 0.75 * rouge_l_simple(pred_answer2, gold_answer) + 0.25 * jaccard(pred_answer2, gold_answer)
        perc_score = 0.7 * rouge_l_simple(pred_perception, gold_perception) + 0.3 * jaccard(pred_perception, gold_perception)
        len_penalty = -0.02 * max(0, len(pred_answer1.split()) - len(gold_answer.split()) - 10)

        reward = (0.25*ans1_score + 0.55*ans2_score + 0.18*perc_score + 0.02*len_penalty)
        reward = float(max(0.0, min(1.0, reward)))

        # baseline & advantage
        if not baseline_init:
            running_baseline = reward; baseline_init=True
        else:
            running_baseline = gamma_baseline * running_baseline + (1.0 - gamma_baseline) * reward
        advantage = float(max(min(reward - running_baseline, 10.0), -10.0))

        # ========== COMPUTE POLICY LOSS ==========
        prompt_len = input_ids.shape[1]
        cont_len = gen_ids.shape[1] - prompt_len
        if cont_len <= 0:
            print(f"[Warning] cont_len <= 0; skipping step.")
            continue

        labels_pol = gen_ids.clone().to(device)
        labels_pol[0, :prompt_len] = -100

        pad_id = processor.tokenizer.pad_token_id
        attn_full = (gen_ids != pad_id).long().to(device)

        policy_model.train()
        # IMPORTANT: enable autocast as appropriate
        with torch.amp.autocast(device_type='cuda' if use_gpu else 'cpu', enabled=use_amp):
            out_pol = policy_model(input_ids=gen_ids.to(device), attention_mask=attn_full, pixel_values=pixel_values if pixel_values is not None else None, image_grid_thw=image_grid_thw if image_grid_thw is not None else None, labels=labels_pol)
            avg_neglog_pol = out_pol.loss

            with torch.no_grad():
                labels_ref = gen_ids.clone().to(device); labels_ref[0, :prompt_len] = -100
                out_ref = ref_model(input_ids=gen_ids.to(device), attention_mask=attn_full, pixel_values=pixel_values if pixel_values is not None else None, image_grid_thw=image_grid_thw if image_grid_thw is not None else None, labels=labels_ref)
                avg_neglog_ref = out_ref.loss

            sum_logprob_pol = -avg_neglog_pol * cont_len
            policy_loss = -(advantage * sum_logprob_pol)
            kl_pen = kl_coeff * (avg_neglog_pol - avg_neglog_ref)
            total_loss = (policy_loss + kl_pen) / gradient_accumulation_steps

        # backward
        total_loss.backward()

        # gradient accumulation step
        if (batch_idx + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_([p for p in policy_model.parameters() if p.requires_grad], max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

        # logging
        if global_step % log_interval == 0:
            elapsed = time.time() - start_time
            try:
                gpu_mem = torch.cuda.memory_allocated()/1e9 if use_gpu else 0
            except Exception:
                gpu_mem = 0
            print(f"[Step {global_step}] reward={reward:.3f} ans1={ans1_score:.3f} self={ans2_score:.3f} perc={perc_score:.3f} adv={advantage:.3f} loss={(total_loss.item()*gradient_accumulation_steps):.4f} | GPU {gpu_mem:.2f}GB")

        if global_step % 50 == 0: clear_cache()

        # debug quick exit
        if run_few_steps and global_step >= few_steps:
            print("Completed debug few steps, exiting training loop.")
            break

    if run_few_steps and global_step >= few_steps:
        break

    # epoch-end leftover grads
    any_grads = any((p.grad is not None) for p in policy_model.parameters() if p.requires_grad)
    if any_grads:
        torch.nn.utils.clip_grad_norm_([p for p in policy_model.parameters() if p.requires_grad], max_grad_norm)
        optimizer.step(); optimizer.zero_grad()

# -------------------------
# SAVE LoRA (safetensors) + processor
# -------------------------
print("Saving LoRA adapter and processor...")
adapter_state = {k:v.cpu() for k,v in policy_model.state_dict().items() if "lora" in k or "alpha" in k}
adapter_config = {"r": LORA_R, "lora_alpha": LORA_ALPHA, "target_modules": target_modules, "lora_dropout": LORA_DROPOUT}
safe_save(adapter_state, os.path.join(save_path, "adapter_model.safetensors"))
with open(os.path.join(save_path, "adapter_config.json"), "w") as f:
    json.dump(adapter_config, f, indent=2)
processor.save_pretrained(save_path)
print("Saved to:", save_path)

print("Training finished.")


Target modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'up_proj', 'dense', 'linear', 'proj_in', 'proj_out', 'wq', 'wk', 'wv', 'wo', 'gated_act_proj', 'mlp_dense_h_to_4h', 'mlp_dense_4h_to_h']
STEP 1: Loading model with automatic device_map to build a safe map...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Corrected device_map (sample):
 -> 0
Loading final model with corrected device_map...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Enabled gradient checkpointing
Attaching higher-capacity LoRA...
Trainable params: 55,050,240 || All params: 1,277,325,824 || Trainable%: 4.309804%
Using 8-bit AdamW


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Caching 100 images...
Cached images: 100

STARTING IMPROVED HYBRID GRPO (FIXED LOADER)

[Step 10] reward=0.482 ans1=1.000 self=0.094 perc=1.000 adv=-0.050 loss=-0.9517 | GPU 7.36GB
[Step 20] reward=0.430 ans1=1.000 self=0.000 perc=1.000 adv=-0.095 loss=-1.9090 | GPU 7.35GB
[Step 30] reward=0.494 ans1=1.000 self=0.117 perc=1.000 adv=-0.025 loss=-0.6106 | GPU 7.39GB
[Step 40] reward=0.503 ans1=1.000 self=0.132 perc=1.000 adv=-0.011 loss=-0.1897 | GPU 7.13GB
[Step 50] reward=0.430 ans1=1.000 self=0.000 perc=1.000 adv=-0.079 loss=-0.7617 | GPU 7.36GB
[Step 60] reward=0.430 ans1=1.000 self=0.000 perc=1.000 adv=-0.074 loss=-2.0508 | GPU 7.37GB
[Step 70] reward=0.477 ans1=1.000 self=0.085 perc=1.000 adv=-0.022 loss=-0.5734 | GPU 7.39GB
[Step 80] reward=0.430 ans1=1.000 self=0.000 perc=1.000 adv=-0.065 loss=-0.5852 | GPU 7.14GB
[Step 90] reward=0.430 ans1=1.000 self=0.000 perc=1.000 adv=-0.061 loss=-2.0972 | GPU 7.36GB
[Step 100] reward=0.488 ans1=1.000 self=0.106 perc=1.000 adv=-0.001 loss=-0

In [None]:
from transformers import Qwen2VLForConditionalGeneration, BitsAndBytesConfig

base_model_name = "Qwen/Qwen2-VL-2B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16",
    bnb_4bit_use_double_quant=True,
)

policy_base = Qwen2VLForConditionalGeneration.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
count = 0
for name, module in policy_base.named_modules():
    print(name)
    count += 1
    if count >= 300:
        break



model
model.visual
model.visual.patch_embed
model.visual.patch_embed.proj
model.visual.rotary_pos_emb
model.visual.blocks
model.visual.blocks.0
model.visual.blocks.0.norm1
model.visual.blocks.0.norm2
model.visual.blocks.0.attn
model.visual.blocks.0.attn.qkv
model.visual.blocks.0.attn.proj
model.visual.blocks.0.mlp
model.visual.blocks.0.mlp.fc1
model.visual.blocks.0.mlp.act
model.visual.blocks.0.mlp.fc2
model.visual.blocks.1
model.visual.blocks.1.norm1
model.visual.blocks.1.norm2
model.visual.blocks.1.attn
model.visual.blocks.1.attn.qkv
model.visual.blocks.1.attn.proj
model.visual.blocks.1.mlp
model.visual.blocks.1.mlp.fc1
model.visual.blocks.1.mlp.act
model.visual.blocks.1.mlp.fc2
model.visual.blocks.2
model.visual.blocks.2.norm1
model.visual.blocks.2.norm2
model.visual.blocks.2.attn
model.visual.blocks.2.attn.qkv
model.visual.blocks.2.attn.proj
model.visual.blocks.2.mlp
model.visual.blocks.2.mlp.fc1
model.visual.blocks.2.mlp.act
model.visual.blocks.2.mlp.fc2
model.visual.blocks.3
mod

In [None]:
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import safetensors.torch as st
from PIL import Image
from google.colab import files


# ---------------------------------------------------------
# PATHS
# ---------------------------------------------------------
base_model_name = "Qwen/Qwen2-VL-2B-Instruct"
lora_path = "/content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed"
lora_weights = f"{lora_path}/adapter_model.safetensors"   # Auto-saved by your script


# ---------------------------------------------------------
# LOAD PROCESSOR
# ---------------------------------------------------------
processor = AutoProcessor.from_pretrained(base_model_name, trust_remote_code=True)


# ---------------------------------------------------------
# 4-BIT QUANTIZATION CONFIG (T4 Friendly)
# ---------------------------------------------------------
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)


# ---------------------------------------------------------
# BASE MODEL LOADER
# ---------------------------------------------------------
def load_base():
    print("\nLoading BASE model...")
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )
    model.eval()
    print("Base model loaded.\n")
    return model


# ---------------------------------------------------------
# MANUAL LORA LOADER (FIX FOR MISSING peft_type IN CONFIG)
# ---------------------------------------------------------
def load_lora():
    print("\nLoading LoRA model (manual load)...")

    # 1) Load base model again
    base = Qwen2VLForConditionalGeneration.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )

    # 2) Recreate EXACT LoRA config used in your training
    lora_cfg = LoraConfig(
        peft_type="LORA",
        task_type="CAUSAL_LM",
        r=64,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=[
            "q_proj","k_proj","v_proj","o_proj",
            "down_proj","up_proj","dense","linear",
            "proj_in","proj_out","wq","wk","wv","wo",
            "gated_act_proj","mlp_dense_h_to_4h","mlp_dense_4h_to_h"
        ]
    )

    model = get_peft_model(base, lora_cfg)

    # 3) Load LoRA weights manually
    print("Loading LoRA weights from:", lora_weights)
    weights = st.load_file(lora_weights, device="cpu")

    missing, unexpected = model.load_state_dict(weights, strict=False)
    print("Loaded LoRA weights.")
    print("Missing keys:", len(missing))
    print("Unexpected keys:", len(unexpected))

    model.eval()
    return model


# ---------------------------------------------------------
# PROMPT TEMPLATE FOR REASONING
# ---------------------------------------------------------
def build_prompt(question):
    return (
        "Give a short reasoning summary (not full chain-of-thought).\n"
        "Then give the final answer inside <answer> tags.\n\n"
        f"Question: {question}\n"
    )


# ---------------------------------------------------------
# ASK FUNCTION
# ---------------------------------------------------------
def ask(model, img, question):
    prompt = build_prompt(question)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    enc = processor.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
    )

    # Ensure tensors have batch dimension
    def ensure_batch(x):
        if x is None:
            return None
        if isinstance(x, list):
            x = x[0]
        if x.dim() == 1:
            x = x.unsqueeze(0)
        return x.to(model.device)

    input_ids = ensure_batch(enc["input_ids"])
    attention_mask = ensure_batch(enc["attention_mask"])
    pixel_values = ensure_batch(enc.get("pixel_values"))
    image_grid_thw = ensure_batch(enc.get("image_grid_thw"))

    output_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        pixel_values=pixel_values,
        image_grid_thw=image_grid_thw,
        max_new_tokens=1024,
    )

    return processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)


# ---------------------------------------------------------
# UPLOAD IMAGE
# ---------------------------------------------------------
uploaded = files.upload()
image_path = list(uploaded.keys())[0]
img = Image.open(image_path).convert("RGB")
print("\nImage loaded:", image_path)


question = "What pest is present in the image? Explain reasoning and prevention."


# ---------------------------------------------------------
# RUN BASE MODEL
# ---------------------------------------------------------
base_model = load_base()
base_output = ask(base_model, img, question)

print("\n================ BASE MODEL OUTPUT ================\n")
print(base_output)


# ---------------------------------------------------------
# RUN LORA MODEL
# ---------------------------------------------------------
lora_model = load_lora()
lora_output = ask(lora_model, img, question)

print("\n================ LORA MODEL OUTPUT ================\n")
print(lora_output)


Saving plant4.jpg to plant4.jpg

Image loaded: plant4.jpg

Loading BASE model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Base model loaded.



system
You are a helpful assistant.
user
Give a short reasoning summary (not full chain-of-thought).
Then give the final answer inside <answer> tags.

Question: What pest is present in the image? Explain reasoning and prevention.

assistant
The pest present in the image is likely a leaf miner. Leaf miners are insects that feed on the leaves of plants, causing them to become distorted and eventually die. To prevent this pest, it is important to monitor the plant for signs of damage and to take appropriate measures to control the pest. This may include using insecticides or other pest control methods to prevent the pest from spreading to other plants.

Loading LoRA model (manual load)...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading LoRA weights from: /content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed/adapter_model.safetensors
Loaded LoRA weights.
Missing keys: 730
Unexpected keys: 314


system
You are a helpful assistant.
user
Give a short reasoning summary (not full chain-of-thought).
Then give the final answer inside <answer> tags.

Question: What pest is present in the image? Explain reasoning and prevention.

assistant
## Answer:

## Image Description:

The image shows a close-up of a plant leaf with several brown spots scattered across it. The leaves appear healthy, but the spots indicate a potential pest infestation.

## Analysis:

### Pest Identification:

1. **Observation**: The brown spots on the leaves are indicative of a pest infestation. These spots are often caused by insects or diseases that affect plant health.

2. **Possible Pest**: Given the presence of brown spots, it is likely that a pest such as a aphid or a scale insect is present. These pests can cause significant damage 

In [None]:
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import safetensors.torch as st
from PIL import Image
from google.colab import files


# ---------------------------------------------------------
# PATHS
# ---------------------------------------------------------
base_model_name = "Qwen/Qwen2-VL-2B-Instruct"
lora_path = "/content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed"
lora_weights = f"{lora_path}/adapter_model.safetensors"   # Auto-saved by your script


# ---------------------------------------------------------
# LOAD PROCESSOR
# ---------------------------------------------------------
processor = AutoProcessor.from_pretrained(base_model_name, trust_remote_code=True)


# ---------------------------------------------------------
# 4-BIT QUANTIZATION CONFIG (T4 Friendly)
# ---------------------------------------------------------
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)


# ---------------------------------------------------------
# BASE MODEL LOADER
# ---------------------------------------------------------
def load_base():
    print("\nLoading BASE model...")
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )
    model.eval()
    print("Base model loaded.\n")
    return model


# ---------------------------------------------------------
# MANUAL LORA LOADER (FIX FOR MISSING peft_type IN CONFIG)
# ---------------------------------------------------------
def load_lora():
    print("\nLoading LoRA model (manual load)...")

    # 1) Load base model again
    base = Qwen2VLForConditionalGeneration.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )

    # 2) Recreate EXACT LoRA config used in your training
    lora_cfg = LoraConfig(
        peft_type="LORA",
        task_type="CAUSAL_LM",
        r=64,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=[
            "q_proj","k_proj","v_proj","o_proj",
            "down_proj","up_proj","dense","linear",
            "proj_in","proj_out","wq","wk","wv","wo",
            "gated_act_proj","mlp_dense_h_to_4h","mlp_dense_4h_to_h"
        ]
    )

    model = get_peft_model(base, lora_cfg)

    # 3) Load LoRA weights manually
    print("Loading LoRA weights from:", lora_weights)
    weights = st.load_file(lora_weights, device="cpu")

    missing, unexpected = model.load_state_dict(weights, strict=False)
    print("Loaded LoRA weights.")
    print("Missing keys:", len(missing))
    print("Unexpected keys:", len(unexpected))

    model.eval()
    return model


# ---------------------------------------------------------
# PROMPT TEMPLATE FOR REASONING
# ---------------------------------------------------------
def build_prompt(question):
    return (
        "Give a short reasoning summary (not full chain-of-thought).\n"
        "Then give the final answer inside <answer> tags.\n\n"
        f"Question: {question}\n"
    )


# ---------------------------------------------------------
# ASK FUNCTION
# ---------------------------------------------------------
def ask(model, img, question):
    prompt = build_prompt(question)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    enc = processor.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
    )

    # Ensure tensors have batch dimension
    def ensure_batch(x):
        if x is None:
            return None
        if isinstance(x, list):
            x = x[0]
        if x.dim() == 1:
            x = x.unsqueeze(0)
        return x.to(model.device)

    input_ids = ensure_batch(enc["input_ids"])
    attention_mask = ensure_batch(enc["attention_mask"])
    pixel_values = ensure_batch(enc.get("pixel_values"))
    image_grid_thw = ensure_batch(enc.get("image_grid_thw"))

    output_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        pixel_values=pixel_values,
        image_grid_thw=image_grid_thw,
        max_new_tokens=128,
    )

    return processor.tokenizer.decode(output_ids[0], skip_special_tokens=True)


# ---------------------------------------------------------
# UPLOAD IMAGE
# ---------------------------------------------------------
uploaded = files.upload()
image_path = list(uploaded.keys())[0]
img = Image.open(image_path).convert("RGB")
print("\nImage loaded:", image_path)


question = "What pest is present in the image? Explain reasoning and prevention."


# ---------------------------------------------------------
# RUN BASE MODEL
# ---------------------------------------------------------
base_model = load_base()
base_output = ask(base_model, img, question)

print("\n================ BASE MODEL OUTPUT ================\n")
print(base_output)


# ---------------------------------------------------------
# RUN LORA MODEL
# ---------------------------------------------------------
lora_model = load_lora()
lora_output = ask(lora_model, img, question)

print("\n================ LORA MODEL OUTPUT ================\n")
print(lora_output)


Saving pest4.jpg to pest4.jpg

Image loaded: pest4.jpg

Loading BASE model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Base model loaded.



system
You are a helpful assistant.
user
Give a short reasoning summary (not full chain-of-thought).
Then give the final answer inside <answer> tags.

Question: What pest is present in the image? Explain reasoning and prevention.

assistant
The pest present in the image is a leaf miner. Leaf miners are insects that feed on the leaves of plants, causing damage to the plant's structure and reducing its overall health. To prevent leaf miner damage, it is important to monitor the plant for signs of damage and to take appropriate measures to control the pest. This may include using insecticides or other pest control methods to eliminate the pest and prevent further damage to the plant.

Loading LoRA model (manual load)...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading LoRA weights from: /content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed/adapter_model.safetensors
Loaded LoRA weights.
Missing keys: 730
Unexpected keys: 314


system
You are a helpful assistant.
user
Give a short reasoning summary (not full chain-of-thought).
Then give the final answer inside <answer> tags.

Question: What pest is present in the image? Explain reasoning and prevention.

assistant
## Answer:

## Image Description:

The image shows a close-up of a leaf with a pest on it. The pest appears to be a small, brown insect with a distinct shape and size. The leaf is green and appears to be healthy, but the presence of the pest indicates potential damage to the plant.

## Pest Identification:

The pest identified in the image is likely a type of aphid, which is a common pest on various plants. Aphids are small insects that feed on plant sap and can cause significant damage to plants by sucking the nutrients from the leaves. They are often found on the underside

COMPARISON BETWEEN BASE MODEL AND FINE TUNED MODEL

In [None]:
import torch
import re
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import safetensors.torch as st
from PIL import Image
from google.colab import files

# =========================================================
# PATHS
# =========================================================
BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"
LORA_PATH = "/content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed"
LORA_WEIGHTS = f"{LORA_PATH}/adapter_model.safetensors"

# =========================================================
# LOAD PROCESSOR
# =========================================================
processor = AutoProcessor.from_pretrained(
    BASE_MODEL,
    trust_remote_code=True
)

# =========================================================
# 4-BIT CONFIG (T4 SAFE)
# =========================================================
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# =========================================================
# LOAD BASE MODEL
# =========================================================
def load_base():
    print("\nLoading BASE model...")
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )
    model.eval()
    return model

# =========================================================
# LOAD BASE + LoRA MODEL
# =========================================================
def load_lora():
    print("\nLoading BASE + LoRA model...")

    base = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )

    lora_cfg = LoraConfig(
        peft_type="LORA",
        task_type="CAUSAL_LM",
        r=64,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=[
            "q_proj","k_proj","v_proj","o_proj",
            "down_proj","up_proj","dense","linear",
            "proj_in","proj_out","wq","wk","wv","wo",
            "gated_act_proj","mlp_dense_h_to_4h","mlp_dense_4h_to_h"
        ],
    )

    model = get_peft_model(base, lora_cfg)

    weights = st.load_file(LORA_WEIGHTS, device="cpu")
    model.load_state_dict(weights, strict=False)

    model.eval()
    return model

# =========================================================
# CLEAN OUTPUT (ABSOLUTE GUARANTEE)
# =========================================================
def clean_text(text):
    text = text.split("assistant")[-1]

    # Remove markdown headings / tags / bullets
    text = re.sub(r"<.*?>", "", text)
    text = re.sub(r"^#+.*$", "", text, flags=re.MULTILINE)
    text = re.sub(r"[-•*]\s*", "", text)

    # Remove repeated labels
    text = re.sub(
        r"(?i)(image description|pest description|pest identification|prevention methods|analysis).*?:",
        "",
        text
    )

    # Normalize spaces
    text = re.sub(r"\s+", " ", text).strip()

    return text

# =========================================================
# ASK FUNCTION (FINAL, STABLE)
# =========================================================
def ask(model, img, question):

    prompt = (
        "Answer in ONE short paragraph.\n"
        "Use plain text only.\n"
        "Do not use headings, bullet points, markdown, or labels.\n"
        "Stop after completing the answer.\n\n"
        f"Question: {question}"
    )

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )

    inputs = {k: v.to(model.device) for k, v in inputs.items() if v is not None}

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=False,
            temperature=0.0,
            repetition_penalty=1.15,   # PREVENT HEADER LOOPS
            length_penalty=1.1,
            early_stopping=True,
        )

    decoded = processor.tokenizer.decode(
        output_ids[0],
        skip_special_tokens=True
    )

    return clean_text(decoded)

# =========================================================
# UPLOAD IMAGE
# =========================================================
uploaded = files.upload()
image_path = list(uploaded.keys())[0]
img = Image.open(image_path).convert("RGB")

# =========================================================
# QUESTION
# =========================================================
question = "What pest is present in the image and how can it be prevented?"

# =========================================================
# RUN BASE MODEL
# =========================================================
base_model = load_base()
print("\n🟦 BASE MODEL OUTPUT")
print(ask(base_model, img, question))

# =========================================================
# RUN LoRA MODEL
# =========================================================
lora_model = load_lora()
print("\n🟩 LORA MODEL OUTPUT")
print(ask(lora_model, img, question))


Saving pest1.jpg to pest1.jpg

Loading BASE model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


🟦 BASE MODEL OUTPUT
The pest shown in the image appears to be aphids. To prevent aphid infestations, you should regularly inspect your plants for signs of damage such as small holes or leaves with brown spots. If you find aphids on your plants, try using an insecticide specifically designed for controlling aphids. Additionally, maintaining healthy soil conditions by fertilizing properly and providing adequate water will help reduce the risk of aphid infestation.

Loading BASE + LoRA model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


🟩 LORA MODEL OUTPUT
The image shows several black aphids on a green leaf. These insects are commonly known as aphids due to their small size and distinctive appearance. They are often found on plants and other plantbased organisms, including vegetables like tomatoes and cucumbers. 1. Control Aphid Populations: Regularly inspect your garden for signs of aphid infestation. If you see them, apply insecticides specifically designed for controlling aphids. Ensure that these products are applied according to the manufacturer's instructions. 2. Planting Variety Resistance: Select varieties of crops that naturally resist aphid damage. This includes using plants with natural resistance mechanisms such as those containing compounds called "aphicide" which help protect against aph


INFERENCE CODE

In [None]:
import torch
import re
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import safetensors.torch as st
from PIL import Image
from google.colab import files

# =========================================================
# PATHS
# =========================================================
BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"
LORA_PATH = "/content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed"
LORA_WEIGHTS = f"{LORA_PATH}/adapter_model.safetensors"

# =========================================================
# LOAD PROCESSOR
# =========================================================
processor = AutoProcessor.from_pretrained(
    BASE_MODEL,
    trust_remote_code=True
)

# =========================================================
# 4-BIT CONFIG (T4 SAFE)
# =========================================================
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# =========================================================
# LOAD BASE + LoRA MODEL (ONLY MODEL USED)
# =========================================================
def load_lora_model():
    print("\nLoading YOUR LoRA-trained model...")

    base = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )

    lora_cfg = LoraConfig(
        peft_type="LORA",
        task_type="CAUSAL_LM",
        r=64,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=[
            "q_proj","k_proj","v_proj","o_proj",
            "down_proj","up_proj","dense","linear",
            "proj_in","proj_out","wq","wk","wv","wo",
            "gated_act_proj","mlp_dense_h_to_4h","mlp_dense_4h_to_h"
        ],
    )

    model = get_peft_model(base, lora_cfg)

    weights = st.load_file(LORA_WEIGHTS, device="cpu")
    model.load_state_dict(weights, strict=False)

    model.eval()
    return model

# =========================================================
# CLEAN OUTPUT
# =========================================================
def clean_text(text):
    text = text.split("assistant")[-1]
    text = re.sub(r"<.*?>", "", text)
    text = re.sub(r"^#+.*$", "", text, flags=re.MULTILINE)
    text = re.sub(r"[-•*]\s*", "", text)
    text = re.sub(
        r"(?i)(image description|pest description|pest identification|prevention methods|analysis).*?:",
        "",
        text
    )
    text = re.sub(r"\s+", " ", text).strip()
    return text

# =========================================================
# ASK FUNCTION
# =========================================================
def ask(model, img, question):

    prompt = (
        "Answer in ONE short paragraph.\n"
        "Use plain text only.\n"
        "Do not use headings, bullet points, markdown, or labels.\n"
        "Stop after completing the answer.\n\n"
        f"Question: {question}"
    )

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": prompt},
            ],
        }
    ]

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )

    inputs = {k: v.to(model.device) for k, v in inputs.items() if v is not None}

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=False,
            temperature=0.0,
            repetition_penalty=1.15,
            length_penalty=1.1,
            early_stopping=True,
        )

    decoded = processor.tokenizer.decode(
        output_ids[0],
        skip_special_tokens=True
    )

    return clean_text(decoded)

# =========================================================
# UPLOAD IMAGE
# =========================================================
uploaded = files.upload()
image_path = list(uploaded.keys())[0]
img = Image.open(image_path).convert("RGB")

# =========================================================
# QUESTION
# =========================================================
question = "What pest is present in the image and how can it be prevented?"

# =========================================================
# RUN ONLY YOUR MODEL
# =========================================================
model = load_lora_model()

print(ask(model, img, question))


Saving pest5.jpg to pest5.jpg

Loading YOUR LoRA-trained model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k', 'early_stopping', 'length_penalty']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



🟩 YOUR LoRA MODEL OUTPUT
The pest present in the image is a Colorado potato beetle (Leptinotarsa decemlineata). To prevent its spread, you should: 1. Control Crop Rotation: Rotate crops to avoid the same crop being grown year after year on the same land. This helps reduce the population of pests like Colorado potato beetles that thrive on specific plant species. 2. Crop Rotation with NonPesticidal Methods: Use nonchemical methods such as biological control agents, physical barriers, and cultural practices to manage pests without using pesticides. 3. Insecticide Application: Apply insecticides specifically designed for controlling Colorado potato beetles if necessary but only when absolutely necessary due to their environmental impact. 4. Educate Farmers:


EVALUATION

In [None]:
import json
import re
from PIL import Image
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import f1_score
import numpy as np

# --------------------------------------------------
# LOAD SBERT FOR SEMANTIC SIMILARITY
# --------------------------------------------------
sbert = SentenceTransformer("all-MiniLM-L6-v2")

# --------------------------------------------------
# CLEAN TEXT (GROUND TRUTH + PREDICTION)
# --------------------------------------------------
def clean_text(text):
    text = re.sub(r"<.*?>", "", text)   # remove tags
    text = re.sub(r"\s+", " ", text)
    return text.strip().lower()

# --------------------------------------------------
# EXTRACT GROUND TRUTH FROM DATASET
# --------------------------------------------------
def extract_ground_truth(conversations):
    for turn in conversations:
        if turn["from"] == "assistant":
            return clean_text(turn["value"])
    return "unknown"

# --------------------------------------------------
# EXTRACT PREDICTED LABEL (KEYWORD MATCH)
# --------------------------------------------------
LABELS = [
    "aphid", "leaf miner", "whitefly", "thrips",
    "mite", "beetle", "caterpillar",
    "tomato yellow leaf curl virus", "powdery mildew",
    "leaf spot", "rust", "blight", "mosaic virus"
]

def extract_label(text):
    text = text.lower()
    for label in LABELS:
        if label in text:
            return label
    return "unknown"

# --------------------------------------------------
# SEMANTIC SIMILARITY
# --------------------------------------------------
def semantic_similarity(pred, gt):
    e1 = sbert.encode(pred, convert_to_tensor=True)
    e2 = sbert.encode(gt, convert_to_tensor=True)
    return util.cos_sim(e1, e2).item()

# --------------------------------------------------
# MAIN EVALUATION FUNCTION
# --------------------------------------------------
def evaluate_model(model, dataset):
    y_true, y_pred = [], []
    sim_scores = []

    for sample in dataset:
        image_path = sample["image"]
        conversations = sample["conversations"]

        # Extract GT
        gt_text = extract_ground_truth(conversations)
        gt_label = extract_label(gt_text)

        # Extract question
        user_turn = conversations[0]["value"]
        question = re.sub(r"<image>\s*Question:\s*", "", user_turn)

        # Load image
        img = Image.open(image_path).convert("RGB")

        # Model inference
        pred_text = ask(model, img, question)
        pred_text_clean = clean_text(pred_text)
        pred_label = extract_label(pred_text_clean)

        # Collect
        y_true.append(gt_label)
        y_pred.append(pred_label)
        sim_scores.append(semantic_similarity(pred_text_clean, gt_text))

    # Compute F1
    labels = sorted(set(y_true + y_pred))
    f1 = f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0)

    return {
        "f1_score": f1,
        "semantic_similarity": float(np.mean(sim_scores))
    }

# --------------------------------------------------
# LOAD DATASET
# --------------------------------------------------
with open("/content/drive/MyDrive/PlantVillage_TestSet/test_dataset50.json") as f:
    test_data = json.load(f)

# --------------------------------------------------
# RUN BASE VS LORA
# --------------------------------------------------
print("Evaluating BASE model...")
base_results = evaluate_model(base_model, test_data)

print("Evaluating LoRA model...")
lora_results = evaluate_model(lora_model, test_data)

print("\n===== RESULTS =====")
print("BASE:", base_results)
print("LORA:", lora_results)


Evaluating BASE model...
Evaluating LoRA model...

===== RESULTS =====
BASE: {'f1_score': 0.10660018993352327, 'semantic_similarity': 0.2564465381205082}
LORA: {'f1_score': 0.09666666666666666, 'semantic_similarity': 0.2596837501972914}


In [None]:
import re
import json
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import PeftModel
import nltk

nltk.download("punkt")

# ---------------------------------------------------------
# PATHS
# ---------------------------------------------------------
BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"
LORA_PATH = "/content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed"
TEST_JSON = "/content/drive/MyDrive/PlantVillage_TestSet/test_dataset50.json"

DEVICE = "cuda"

# ---------------------------------------------------------
# LOAD PROCESSOR
# ---------------------------------------------------------
processor = AutoProcessor.from_pretrained(BASE_MODEL, trust_remote_code=True)

# ---------------------------------------------------------
# QUANT CONFIG
# ---------------------------------------------------------
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# ---------------------------------------------------------
# LOAD MODELS
# ---------------------------------------------------------
def load_base():
    print("Loading BASE model...")
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        device_map="auto",
        quantization_config=bnb_config,
        trust_remote_code=True,
    )
    model.eval()
    return model


def load_lora():
    print("Loading LoRA model...")
    base = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        device_map="auto",
        quantization_config=bnb_config,
        trust_remote_code=True,
    )
    model = PeftModel.from_pretrained(base, LORA_PATH)
    model.eval()
    return model


# ---------------------------------------------------------
# ASK FUNCTION (CLEAN OUTPUT)
# ---------------------------------------------------------
def ask(model, image, question):
    prompt = (
        "Answer based only on what is visible in the image.\n"
        "Answer the question.\n"
        "Give a short, factual answer.\n\n"
        f"Question: {question}"
    )

    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": prompt},
        ]
    }]

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )

    inputs = {k: v.to(model.device) for k, v in inputs.items() if v is not None}

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=False
        )

    text = processor.tokenizer.decode(output[0], skip_special_tokens=True)
    return text.split("assistant")[-1].strip()


# ---------------------------------------------------------
# FAITHFULNESS SCORER
# ---------------------------------------------------------
HALLUCINATION_TERMS = [
    "definitely", "clearly", "confirmed", "diagnosis is",
    "virus", "fungal", "bacterial", "tylcv", "blight"
]

HEDGING_TERMS = [
    "appears", "likely", "may", "could", "suggests", "based on visible"
]


def faithfulness_score(answer):
    answer = answer.lower()

    strong_claims = sum(t in answer for t in HALLUCINATION_TERMS)
    hedges = sum(t in answer for t in HEDGING_TERMS)

    if strong_claims >= 2 and hedges == 0:
        return 0
    elif strong_claims >= 1 and hedges == 0:
        return 1
    elif strong_claims >= 1 and hedges >= 1:
        return 2
    else:
        return 3


# ---------------------------------------------------------
# LOAD TEST DATA
# ---------------------------------------------------------
with open(TEST_JSON) as f:
    test_data = json.load(f)

QUESTION = "What is affecting the plant in the image?"

# ---------------------------------------------------------
# EVALUATION
# ---------------------------------------------------------
base_model = load_base()
lora_model = load_lora()

base_scores = []
lora_scores = []

for item in tqdm(test_data[:50]):
    image_path = item["image"]
    image = Image.open(image_path).convert("RGB")

    base_ans = ask(base_model, image, QUESTION)
    lora_ans = ask(lora_model, image, QUESTION)

    base_scores.append(faithfulness_score(base_ans))
    lora_scores.append(faithfulness_score(lora_ans))

# ---------------------------------------------------------
# RESULTS
# ---------------------------------------------------------
def summarize(scores):
    return {
        "average": sum(scores) / len(scores),
        "distribution": {i: scores.count(i) for i in range(4)}
    }

print("\n===== FAITHFULNESS RESULTS =====")
print("BASE:", summarize(base_scores))
print("LORA:", summarize(lora_scores))


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Loading BASE model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading LoRA model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 50/50 [02:27<00:00,  2.95s/it]


===== FAITHFULNESS RESULTS =====
BASE: {'average': 2.96, 'distribution': {0: 0, 1: 1, 2: 0, 3: 49}}
LORA: {'average': 2.96, 'distribution': {0: 0, 1: 1, 2: 0, 3: 49}}





In [None]:
import os
import json
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import safetensors.torch as st
import re

# ---------------------------------------------------------
# PATHS
# ---------------------------------------------------------
BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"
LORA_PATH = "/content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed"
LORA_WEIGHTS = f"{LORA_PATH}/adapter_model.safetensors"
TEST_JSON = "/content/drive/MyDrive/PlantVillage_TestSet/test_dataset50.json"

# ---------------------------------------------------------
# HALLUCINATION TERMS (NOT DIRECTLY VISIBLE)
# ---------------------------------------------------------
HALLUCINATION_TERMS = [
    "tree", "branch", "soil", "field", "farm",
    "fungal", "bacterial", "viral",
    "tylcv", "mosaic", "blight", "rust",
    "chlorosis", "pathogen", "infection"
]

# ---------------------------------------------------------
# LOAD PROCESSOR
# ---------------------------------------------------------
processor = AutoProcessor.from_pretrained(BASE_MODEL, trust_remote_code=True)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# ---------------------------------------------------------
# LOAD BASE MODEL
# ---------------------------------------------------------
def load_base():
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.float16,
        trust_remote_code=True,
    )
    model.eval()
    return model

# ---------------------------------------------------------
# LOAD LORA MODEL
# ---------------------------------------------------------
def load_lora():
    base = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.float16,
        trust_remote_code=True,
    )

    lora_cfg = LoraConfig(
        peft_type="LORA",
        task_type="CAUSAL_LM",
        r=64,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=[
            "q_proj","k_proj","v_proj","o_proj",
            "down_proj","up_proj","dense","linear",
            "proj_in","proj_out","wq","wk","wv","wo",
            "gated_act_proj","mlp_dense_h_to_4h","mlp_dense_4h_to_h"
        ]
    )

    model = get_peft_model(base, lora_cfg)
    weights = st.load_file(LORA_WEIGHTS, device="cpu")
    model.load_state_dict(weights, strict=False)
    model.eval()
    return model

# ---------------------------------------------------------
# ASK FUNCTION (CLEAN, DETERMINISTIC)
# ---------------------------------------------------------
def ask(model, image, question):
    prompt = (
        "Answer only based on what is visible in the image. "
        "Do not infer causes or name diseases unless directly visible.\n\n"
        f"Question: {question}"
    )

    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": prompt}
        ]
    }]

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True
    )

    inputs = {k: v.to(model.device) for k, v in inputs.items() if v is not None}

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=80,
            do_sample=False
        )

    text = processor.tokenizer.decode(output[0], skip_special_tokens=True)
    return text.split("assistant")[-1].strip().lower()

# ---------------------------------------------------------
# HALLUCINATION CHECK
# ---------------------------------------------------------
def is_hallucinated(answer):
    return any(term in answer for term in HALLUCINATION_TERMS)

# ---------------------------------------------------------
# LOAD TEST DATA
# ---------------------------------------------------------
with open(TEST_JSON, "r") as f:
    test_data = json.load(f)

# ---------------------------------------------------------
# EVALUATION
# ---------------------------------------------------------
base_model = load_base()
lora_model = load_lora()

base_hallucinations = 0
lora_hallucinations = 0

QUESTION = "What pest or disease is present in the image?"

for sample in tqdm(test_data[:50]):
    image_path = sample["image"]
    image = Image.open(image_path).convert("RGB")

    base_ans = ask(base_model, image, QUESTION)
    lora_ans = ask(lora_model, image, QUESTION)

    if is_hallucinated(base_ans):
        base_hallucinations += 1
    if is_hallucinated(lora_ans):
        lora_hallucinations += 1

# ---------------------------------------------------------
# RESULTS
# ---------------------------------------------------------
n = 50
print("\n===== VISUAL HALLUCINATION RATE =====")
print(f"BASE  : {base_hallucinations/n:.3f}")
print(f"LORA  : {lora_hallucinations/n:.3f}")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 50/50 [11:26<00:00, 13.72s/it]


===== VISUAL HALLUCINATION RATE =====
BASE  : 0.240
LORA  : 0.180





In [None]:
import torch, json, re
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import safetensors.torch as st

# =====================================================
# PATHS
# =====================================================
BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"
LORA_PATH = "/content/drive/MyDrive/qwen2vl_grpo_lora_improved_hybrid_fixed"
LORA_WEIGHTS = f"{LORA_PATH}/adapter_model.safetensors"
TEST_JSON = "/content/drive/MyDrive/PlantVillage_TestSet/test_dataset50.json"

# =====================================================
# PROCESSOR
# =====================================================
processor = AutoProcessor.from_pretrained(BASE_MODEL, trust_remote_code=True)

# =====================================================
# QUANT CONFIG (T4 SAFE)
# =====================================================
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# =====================================================
# LOAD BASE MODEL
# =====================================================
def load_base():
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    model.eval()
    return model

# =====================================================
# LOAD LORA MODEL
# =====================================================
def load_lora():
    base = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )

    lora_cfg = LoraConfig(
        peft_type="LORA",
        task_type="CAUSAL_LM",
        r=64,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=[
            "q_proj","k_proj","v_proj","o_proj",
            "down_proj","up_proj","dense","linear",
            "proj_in","proj_out","wq","wk","wv","wo",
            "gated_act_proj","mlp_dense_h_to_4h","mlp_dense_4h_to_h"
        ],
    )

    model = get_peft_model(base, lora_cfg)
    weights = st.load_file(LORA_WEIGHTS, device="cpu")
    model.load_state_dict(weights, strict=False)
    model.eval()
    return model

# =====================================================
# QUESTION EXTRACTION
# =====================================================
def extract_question(conversations):
    for turn in conversations:
        if turn["from"] == "user":
            text = turn["value"]
            text = text.replace("<image>", "").strip()
            text = re.sub(r"Question:\s*", "", text)
            return text.strip()
    return ""

# =====================================================
# ANSWER GENERATION
# =====================================================
def ask(model, img, question):
    prompt = (
        "Answer in one short paragraph.\n"
        "If the image does not provide enough visual evidence, say so clearly.\n\n"
        f"Question: {question}"
    )

    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": img},
            {"type": "text", "text": prompt},
        ],
    }]

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True
    )

    inputs = {k: v.to(model.device) for k, v in inputs.items() if v is not None}

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=120,
            do_sample=False,
            temperature=0.0
        )

    return processor.tokenizer.decode(output[0], skip_special_tokens=True).lower()

# =====================================================
# METRICS
# =====================================================
UNSUPPORTED = ["virus", "bacteria", "fungal", "caused by", "tylcv", "tmv"]
REFUSALS = [
    "cannot determine",
    "not enough visual evidence",
    "cannot be identified",
    "unclear from the image"
]
SPECIFIC = ["virus", "bacteria", "fungus", "tylcv", "tmv", "leaf miner", "aphid"]

def visual_grounded(ans):
    return not any(t in ans for t in UNSUPPORTED)

def is_refusal(ans):
    return any(t in ans for t in REFUSALS)

def osp(ans):
    words = ans.split()
    return sum(t in ans for t in SPECIFIC) / max(len(words), 1)

# =====================================================
# AMBIGUITY HEURISTIC
# =====================================================
def is_ambiguous(question):
    q = question.lower()
    return any(
        kw in q for kw in [
            "which agent",
            "trace the cause",
            "diagnosis",
            "disease",
            "virus"
        ]
    )

# =====================================================
# EVALUATION LOOP
# =====================================================
def evaluate(model, dataset):
    vga_score, osp_score = 0, 0
    refusal_hits, ambiguous_count = 0, 0

    for item in tqdm(dataset):
        img = Image.open(item["image"]).convert("RGB")
        question = extract_question(item["conversations"])

        ans = ask(model, img, question)

        vga_score += visual_grounded(ans)
        osp_score += osp(ans)

        if is_ambiguous(question):
            ambiguous_count += 1
            refusal_hits += is_refusal(ans)

    return {
        "VGA": vga_score / len(dataset),
        "OSP": osp_score / len(dataset),
        "RCR": refusal_hits / max(ambiguous_count, 1)
    }

# =====================================================
# RUN EVALUATION
# =====================================================
data = json.load(open(TEST_JSON))

print("\nLoading BASE model...")
base_model = load_base()

print("\nLoading LoRA model...")
lora_model = load_lora()

print("\nEvaluating BASE model...")
base_metrics = evaluate(base_model, data)

print("\nEvaluating LoRA model...")
lora_metrics = evaluate(lora_model, data)

print("\n=========== FINAL RESULTS ===========")
for k in base_metrics:
    print(f"{k:<4} | BASE: {base_metrics[k]:.3f} | LORA: {lora_metrics[k]:.3f}")



Loading BASE model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Loading LoRA model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Evaluating BASE model...


100%|██████████| 50/50 [01:25<00:00,  1.70s/it]



Evaluating LoRA model...


100%|██████████| 50/50 [01:43<00:00,  2.08s/it]


VGA  | BASE: 0.780 | LORA: 0.900
OSP  | BASE: 0.003 | LORA: 0.001
RCR  | BASE: 0.000 | LORA: 0.000



