# VQA Generation and Finetuning using ABO Dataset
We have used Amazon's dataset to generate a Visual Question Answering via prompting an LLM that accept an image as input. We have experimented with Gemini API as well as on device ones using llava models. Based on the compute available and the results generated, we decided to move with the Gemini API to generate the dataset while handling exceptions during API calls. We then evaluated a pretrained BLIP and fine-tuned it on this dataset using Low Rank Adaptation method (LoRA). More details and results can be found in the following [report](https://drive.google.com/file/d/1WTs4sVgsXIqaiwg7BRb3cV2hOJgegmXe/view?usp=sharing)
.

## Dataset generation

In [None]:
!wget https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-images-small.tar
!tar -xf abo-images-small.tar

In [None]:
# get abo-listings.tar
!wget https://amazon-berkeley-objects.s3.amazonaws.com/archives/abo-listings.tar
!tar -xf abo-listings.tar

### Unzipping

In [None]:
# extract the images/metadata/images.csv.gz file
!gzip -d images/metadata/images.csv.gz

In [None]:
# extract listings/metadata/*.json.gz
!gzip -d listings/metadata/*.json.gz

In [None]:
# display images.csv
!head images/metadata/images.csv

### Preprocessing

In [None]:
import json

# Load the file content

with open('/content/listings/metadata/listings_0.json', 'r') as f:
    file_content = f.read()

# Attempt to decode JSON objects iteratively
data = []
decoder = json.JSONDecoder()
while file_content:
    try:
        obj, index = decoder.raw_decode(file_content)
        data.append(obj)
        file_content = file_content[index:].lstrip()  # Remove processed data and leading whitespace
    except json.JSONDecodeError as e:
        # Handle potential errors, e.g., log them or break the loop
        print(f"JSONDecodeError: {e}")
    break  # Or handle differently based on your needs

# 'data' now contains a list of decoded JSON objects from the file
for item in data:
    print(json.dumps(item, indent=4))

In [None]:
import pandas as pd

metadata = pd.read_csv('images/metadata/images.csv')
metadata.head(10)

In [None]:
id_path_dict = {}
for index, row in metadata.iterrows():
    id_path_dict[row['image_id']] = row['path']

In [None]:
([data[0]["main_image_id"]] + data[0]["other_image_id"])

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

images = [Image.open('/content/images/small/' + id_path_dict[id]) for id in ([data[0]["main_image_id"]] + data[0]["other_image_id"])]
for img in images:
    plt.imshow(img)
    plt.show()

In [None]:
import json

def QAParser(qa_text):
    # clean it to start from ```json and end with ```
    qa_text = qa_text.split("```json")[1]
    qa_text = qa_text.split("```")[0]
    # convert string to json
    qa_json = json.loads(qa_text)
    return qa_json

In [None]:
# # prompts:

# content_raw = [
#         images,
#         "Generate around 20 diverse questions about this item and its metadata, each followed by a single-word (valid words) answer \
#         and a difficulty category ('easy', 'medium', or 'hard').\
#         Return the result as a JSON array \
#         of dictionaries with the keys 'question', 'answer', and 'category'.\
#         "
#     ]

# content_filtered = [
#         images,
#         json.dumps(data[0]),
#         json.dumps(parsed_qas),
#         "Filter out the incorrect question answers based on the metadata and images provided and provide 5 confidently correct question and answers.\
#         Provide the output in a json format."
#     ]

#### Using GEMINI

In [None]:


GEMINI_API_KEY = [
    # insert your API Keys here
]

from google import genai

client_arr = [genai.Client(api_key=i) for i in GEMINI_API_KEY]
client = client_arr[0]

In [None]:
response = client.models.generate_content(
    model="gemini-2.0-flash", contents=[
        images,
        "You are shown a set of related images and associated metadata.\
        Your task is to generate question-answer pairs that are:\
        - Answerable from **only one or a few images**\
        - Variety in qas\
        Return the result as a JSON array \
        of dictionaries with the keys 'question', 'answer', and 'category {easy, medium or hard}'.\
        "
    ]
)
print(response.text)

In [None]:
parsed_qas = QAParser(response.text)

In [None]:
# . Refrain from making as many assumptions as possible

response = client.models.generate_content(
    model="gemini-2.0-flash", contents=[
        images,
        json.dumps(data[0]),
        json.dumps(parsed_qas),
        "Filter out the incorrect question answers based on the metadata and images provided and provide 5 confidently correct question and answers.\
        Provide the output in a json format."
    ]
)
print(response.text)

#### Using Ollama

In [None]:
!curl -fsSL https://ollama.ai/install.sh | sh

In [None]:
!nohup ollama serve &

In [None]:
!ollama pull llava:5b

In [None]:
!curl http://localhost:11434/api/tags

In [None]:
import requests
import json
import base64
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt

def query_ollama(model, prompt, images=None):
    """Queries the Ollama server with the given prompt and base64-encoded images."""
    if images is None:
        images = []

    encoded_images = []
    for img in images:
        buffered = BytesIO()
        img.save(buffered, format="JPEG")
        encoded = base64.b64encode(buffered.getvalue()).decode('utf-8')
        encoded_images.append(encoded)

    url = "http://localhost:11434/api/generate"
    payload = {
        "model": model,
        "prompt": prompt,
        "images": encoded_images,
        "stream": False
    }
    resp = requests.post(url, json=payload)
    resp.raise_for_status()
    return resp.json().get('response', '')

# Load and show images
image_ids = [data[0]["main_image_id"]] #+ data[0].get("other_image_id", [])
images = []
for img_id in image_ids:
    img_path = f"/content/images/small/{id_path_dict[img_id]}"
    try:
        img = Image.open(img_path)
        images.append(img)
    except Exception as e:
        print(f"Couldn't load image {img_id}: {e}")

for img in images:
    plt.imshow(img)
    plt.axis('off')
    plt.show()

# Use metadata to guide QA generation
metadata = data[0]
model_name = "llava:7b"

# for img_id, img in images:
prompt = f"""
You are given 1–3 product images and a short metadata description.
Your task is to generate **5 question–answer pairs** to train a visual question answering (VQA) model.
Constraints:
1. Each answer must be **visually inferable** — do not rely on metadata.
2. Each question should be **inspired by metadata**, but **must be answerable from the image(s) alone**.
3. Each answer must be a **single English word** (alphabetical only).
4. Cover **varied aspects**: color, material, shape, style, components, etc.
5. Label each question as either: "easy", "medium", or "hard" (w.r.t. model effort).
Metadata (not visible to the model at inference time): {metadata.get('item_keywords', item)}
Return only valid JSON as a list of dictionaries:
[
{{ "question": "...", "answer": "...", "category": "easy" }},
...
]
"""
try:
    result = query_ollama(model_name, prompt, images)
    print(result)
except Exception as e:
    print(f"Failed to query image {img_id}: {e}")


#### Driver code

In [None]:
# index of listings.json
i = "2"

In [None]:
import json

# Load the file content
with open(f'/content/listings/metadata/listings_{i}.json', 'r') as f:
    file_content = f.read()

# Attempt to decode JSON objects iteratively
data = []
decoder = json.JSONDecoder()
while file_content:
    try:
        obj, index = decoder.raw_decode(file_content)
        data.append(obj)
        file_content = file_content[index:].lstrip()  # Remove processed data and leading whitespace
    except json.JSONDecodeError as e:
        # Handle potential errors, e.g., log them or break the loop
        print(f"JSONDecodeError: {e}")
        break  # Or handle differently based on your needs


In [None]:
len(data)

In [None]:
import os
import json
import time
import pandas as pd
import re
from PIL import Image
from tqdm import tqdm
import requests
import base64
from io import BytesIO
import matplotlib.pyplot as plt

ci = 0
client = client_arr[ci]

file_name = f'/content/dataset_{i}.csv'
dataset = []
# load the existing one
flag = True
if os.path.exists(file_name) and flag:
    dataset = pd.read_csv(file_name).to_dict('records')
    print(f"Loaded {len(dataset)} items from {file_name}")


def QAParser(response_text):
    try:
        # Remove markdown code fences (```json ... ```)
        text = re.sub(r"^```(?:json)?\s*|\s*```$", "", response_text.strip(), flags=re.IGNORECASE)

        # Extract the first valid JSON array
        match = re.search(r'(\[\s*\{.*?\}\s*\])', text, re.DOTALL)
        if not match:
            raise ValueError("No valid JSON array found.")
        arr = match.group(1)

        # Fix missing commas if needed
        arr = re.sub(r'("answer":\s*"[^"]+")\s*("category":)', r'\1, \2', arr)

        qas = json.loads(arr)

        # Validate and sanitize
        valid = []
        for qa in qas:
            if all(k in qa for k in ("question", "answer", "category")):
                valid.append({
                    "question": qa["question"].strip(),
                    "answer": qa["answer"].strip().split()[0].capitalize(),
                    "category": qa["category"].strip().lower()
                })
        return valid
    except Exception as e:
        print(f"[Parsing Error] {e}")
        print("[Raw]", repr(response_text)[:200])
        return []


def query_ollama(model, prompt, images=None):
    if images is None:
        images = []
    encoded = []
    for img in images:
        buf = BytesIO()
        img.save(buf, format="JPEG")
        encoded.append(base64.b64encode(buf.getvalue()).decode())
    resp = requests.post(
        "http://localhost:11434/api/generate",
        json={"model": model, "prompt": prompt, "images": encoded, "stream": False}
    )
    resp.raise_for_status()
    return resp.json().get("response", "")

def tryPromptGemini(imgs, item):
    resp = None
    global client

    try:
        response = client.models.generate_content(
            model="gemini-2.0-flash",
            contents=[
                imgs,
                f"""
You are given 1–3 product images and a short metadata description.
Your task is to generate **5 question–answer pairs** to train a visual question answering (VQA) model.
Constraints:
1. Each answer must be **visually inferable** — do not rely on metadata.
2. Each question should be **inspired by metadata**, but **must be answerable from the image(s) alone**.
3. Each answer must be a **single English word** (alphabetical only).
4. Cover **varied aspects**: color, material, shape, style, components, etc.
5. Label each question as either: "easy", "medium", or "hard" (w.r.t. model effort).
Metadata (not visible to the model at inference time): {item.get('item_keywords', item)}
Return only valid JSON as a list of dictionaries:
[
{{ "question": "...", "answer": "...", "category": "easy" }},
...
]
"""
            ]
        )

        resp = response.text
    except Exception as e:
        global ci
        ci = (ci + 1) % len(client_arr)
        client = client_arr[ci]
        wait = 10
        # error_msg = str(e)
        # if "RESOURCE_EXHAUSTED" in error_msg:
        #     wait = 120  # longer wait for quota errors
        # elif "UNAVAILABLE" in error_msg:
        #     wait = 30
        # else:
        #     wait = 10

        print(f"[Error] {e}, redoing after {wait} seconds")
        time.sleep(wait)
        tryPromptGemini(imgs, item)
    return resp

# Process items
Skippings = []
df = pd.DataFrame()
for idx, item in enumerate(tqdm(data, desc="Processing items")):
    if idx < len(dataset):
        continue
    imgs = []
    for img_id in [item.get("main_image_id")]: # + item.get("other_image_id", []):
        if img_id is None: continue
        path = f'/content/images/small/{id_path_dict[img_id]}'
        try:
            imgs.append(Image.open(path))
        except:
            continue
    if not imgs:
        continue

    # # show the iamge
    # for img in imgs:
    #     plt.imshow(img)
    #     plt.axis('off')
    #     plt.show()

    resp = tryPromptGemini(imgs, item)

    if resp is None:
        Skippings.append(item)
        continue



    try:
        qas = QAParser(resp)
        if not qas:
            print(json.dumps(qas, indent=4))
            continue
        # print(json.dumps(qas, indent=4))
        dataset.append({
            "item_id": item["item_id"],
            "qas": qas,
            "image_id": item["main_image_id"]
        })
    except Exception as e:
        print(f"[Error] {e}")
        continue
    if len(dataset) % 50 == 0:
        pd.DataFrame(dataset).to_csv(file_name, index=False)

# Save
pd.DataFrame(dataset).to_csv(file_name, index=False)
print("Done")

In [None]:
file_name = f'/content/dataset_{i}.csv'
dataset = []
# load the existing one
if os.path.exists(file_name):
    dataset = pd.read_csv(file_name).to_dict('records')
    print(f"Loaded {len(dataset)} items from {file_name}")

In [None]:
# download the dataset to system
from google.colab import files
files.download(f"/content/dataset_{i}.csv")

In [None]:
# bert's code

# response = client.models.generate_content(
            #     model="gemini-2.0-flash",
                # contents=[
                #     images,
                #     json.dumps(item),
                #     json.dumps(parsed_qas),
                #     "You have:\n"
                #     "- Product images\n"
                #     "- Metadata JSON\n"
                #     "- 20 generated Q&A pairs\n\n"
                #     "Task: From those, select 5 that are clearly and confidently answerable using the image and metadata alone.\n"
                #     "Only include Q&A pairs with correct, unambiguous, single-word answers.\n"
                #     "Return a JSON array of 5 objects with keys 'question', 'answer', and 'category'."
                # ]
            # )

# # Prompt 2: QA Filtering (Only for Ollama)
        # #                     "The reselt of this will be used to train and evaluate vision language models solely based on the images.\n"

        # try:

        #     response = query_ollama(
        #         model="llava:7b",
        #         prompt=(
        #             "You have:\n"
        #             "- Product image\n"
        #             f"- Metadata JSON:\n{json.dumps(item, indent=2)}\n"
        #             f"- generated Q&A pairs:\n{json.dumps(parsed_qas, indent=2)}\n\n"
        #             "Task: From those, filter 5 that are clearly and confidently answerable using the image and metadata alone.\n"
        #             "The QAs must be general for all images and not specific to any one."
        #             "Only filter Q&A pairs with correct, unambiguous, single-word answers based on the metadata and images.\n"
        #             "Do not generate own QAs"
        #             "Return a JSON array of 5 objects with keys 'question', 'answer', and 'category' ie, ```json <message> ``` \n"
        #         ),
        #         images=images
        #     )


        #     filtered_qas = QAParser(response)
        #     # print(json.dumps(parsed_qas, indent=4))
        #     print(json.dumps(filtered_qas, indent=4))
        # except Exception as e:
        #     print(f"[Filter Error] Skipping filtering for item {item['item_id']}: {e}")
        #     continue


## Evaluating Pretrained Models

In [None]:
!pip install transformers accelerate pillow
!pip install bert_score

In [None]:
# load the dataset
i = '7'
import pandas as pd
eval_dataset = pd.read_csv(f'/content/dataset_{i}.csv').to_dict('records')

In [None]:
import pandas as pd
import ast
from transformers import pipeline
import torch
from tqdm import tqdm
from bert_score import score
from PIL import Image
import matplotlib.pyplot as plt

# Load the BLIP VQA model
# Using the smaller 'Salesforce/blip-vqa-base' for faster evaluation
vqa_pipeline = pipeline("visual-question-answering", model="Salesforce/blip-vqa-base", device=0 if torch.cuda.is_available() else -1)

# Prepare data for evaluation
eval_results = []
for item_data in tqdm(eval_dataset, desc="Evaluating with BLIP"):
    item_id = item_data['item_id']
    image_id = item_data['image_id']
    # Use ast.literal_eval to safely parse the string representation of the list
    try:
        qas = ast.literal_eval(item_data['qas'])
    except (ValueError, SyntaxError) as e:
        print(f"Error parsing QAs for item {item_id}: {e}")
        print(f"Problematic string: {item_data['qas']}")
        continue

    # Load the image
    img_path = f'/content/images/small/{id_path_dict[image_id]}'
    try:
        img = Image.open(img_path).convert("RGB")
    except Exception as e:
        print(f"Could not load image {image_id}: {e}")
        continue

    # # show the image
    # plt.imshow(img)
    # plt.axis('off')
    # plt.show()
    # Evaluate each Q&A pair
    for qa in qas:
        question = qa['question']
        ground_truth_answer = qa['answer']
        category = qa['category']

        try:
            # Get prediction from BLIP
            prediction = vqa_pipeline(image=img, question=question)
            # The output is a list of dicts, take the answer with the highest score
            predicted_answer = prediction[0]['answer']
            # print("Question:", question)
            # print("Ground Truth Answer:", ground_truth_answer)
            # print("Predicted Answer:", predicted_answer)
            # print()
            # Store results
            eval_results.append({
                "item_id": item_id,
                "image_id": image_id,
                "question": question,
                "ground_truth_answer": ground_truth_answer,
                "predicted_answer_blip": predicted_answer,
                "category": category
            })
        except Exception as e:
            print(f"Error evaluating question for item {item_id}: {e}")
            eval_results.append({
                "item_id": item_id,
                "image_id": image_id,
                "question": question,
                "ground_truth_answer": ground_truth_answer,
                "predicted_answer_blip": "ERROR",
                "category": category
            })


# Analyze results (basic accuracy)
eval_df = pd.DataFrame(eval_results)

# Use direct matching loss
eval_df['exact_match'] = eval_df['predicted_answer_blip'] == eval_df['ground_truth_answer']

# Combine question and answer to give context
preds = (eval_df['question'] + " " + eval_df['predicted_answer_blip']).astype(str).tolist()
refs = (eval_df['question'] + " " + eval_df['ground_truth_answer']).astype(str).tolist()

# Compute BERTScore
P, R, F1 = score(preds, refs, lang="en", verbose=True)

# Store results
eval_df['bertscore_f1_with_question'] = F1.tolist()

print(f"\nBLIP VQA Score:\nPrecision={P.mean():.4f}, Recall={R.mean():.4f}, F1={F1.mean():.4f}")

# Optionally, save the evaluation results
eval_output_filename = f'/content/blip_eval_results_{i}.csv'
eval_df.to_csv(eval_output_filename, index=False)
print(f"Evaluation results saved to {eval_output_filename}")

# Display some sample results
print("\nSample Evaluation Results:")
print(eval_df.head())

from sklearn.metrics import precision_score, recall_score, f1_score

# Normalize answers
eval_df['ground_truth_answer_clean'] = eval_df['ground_truth_answer'].str.lower().str.strip()
eval_df['predicted_answer_blip_clean'] = eval_df['predicted_answer_blip'].str.lower().str.strip()

# Exact string match
eval_df['exact_match'] = eval_df['predicted_answer_blip_clean'] == eval_df['ground_truth_answer_clean']

# Accuracy (proportion of correct answers)
accuracy = eval_df['exact_match'].mean()
print(f"\nExact Match Accuracy: {accuracy:.4f}")

# Binary values for precision, recall, F1
y_true = eval_df['ground_truth_answer_clean'] == eval_df['ground_truth_answer_clean']  # all True
y_pred = eval_df['exact_match']

# Precision, Recall, F1 based on exact match (1 if correct, 0 if incorrect)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1_direct = f1_score(y_true, y_pred)

print(f"Direct Match F1 Score:\nPrecision={precision:.4f}, Recall={recall:.4f}, F1={f1_direct:.4f}")

In [None]:
# Uninstall existing torch installation
!pip uninstall -y torch torchvision torchaudio

# Install a specific version of torch and torchvision that should be compatible
!pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118

# Reinstall the required libraries
!pip install transformers accelerate pillow bert_score


## Fine Tuning

In [None]:
!pip install peft accelerate transformers datasets bitsandbytes

In [None]:
import os
import pandas as pd
from PIL import Image
from transformers import BlipProcessor, BlipForQuestionAnswering, Trainer, TrainingArguments
from transformers import AutoProcessor, AutoModelForVision2Seq
import torch
from accelerate import Accelerator
#from sklearn.model_selection import train_test_split
from peft import LoraConfig, get_peft_model
from transformers.data.data_collator import default_data_collator

# === CONFIGURATION ===
BASE_IMAGE_PATH = 'images/small'  # Adjust this to match your images/small directory
CSV_PATH = 'merged.csv'  # Path to your CSV file
METADATA_PATH = 'images/metadata/images.csv'  # Path to the metadata CSV
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Blip
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base", use_fast=True)

# # Smol-256
# processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
# model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct").to(device)


# === LOAD YOUR CURATED CSV ===
df = pd.read_csv(CSV_PATH)
print(f"Loaded custom dataset with {len(df)} entries.")

# Load image metadata to map image_id to file paths
try:
    metadata_df = pd.read_csv(METADATA_PATH)
    print(f"Loaded metadata with {len(metadata_df)} images.")
    # Create a mapping from image_id to path
    image_id_to_path = {}
    for _, row in metadata_df.iterrows():
        if 'image_id' in row and 'path' in row:
            image_id_to_path[row['image_id']] = row['path']
except Exception as e:
    print(f"Error loading metadata: {e}")
    # Fallback: assume image_id directly maps to path
    image_id_to_path = {}


In [None]:
import pandas as pd
import ast

rows = []

for _, row in df.iterrows():
    image_id = row['image_id']
    try:
        qas_list = ast.literal_eval(row['qas'])  # Safer than eval, accepts Python-style lists
        for qa in qas_list:
            question = qa.get('question', '').strip()
            answer = qa.get('answer', '').strip()
            rows.append({
                'image_id': image_id,
                'question': question,
                'answer': answer
            })
    except Exception as e:
        print(f"Failed to parse qas for image_id {image_id}: {e}")

flattened_df = pd.DataFrame(rows)

# Only run this if there are actually rows
if not flattened_df.empty:
    flattened_df['answer'] = flattened_df['answer'].fillna('unknown').astype(str)
    flattened_df['image_id'] = flattened_df['image_id'].astype(str)
    print(flattened_df.head())
else:
    print("No valid QAs parsed. Please check the input format.")



In [None]:
# take subset of dataset
train_df = flattened_df.sample(frac=0.1, random_state=42)
eval_df = flattened_df.sample(frac=0.1, random_state=42)

In [None]:

# ============================ ACCELERATOR INIT ==================== BLIP ===================================
accelerator = Accelerator()

# Ensure proper types
train_df['answer'] = train_df['answer'].fillna('unknown').astype(str)
train_df['image_id'] = train_df['image_id'].astype(str)

# === TRAIN-TEST SPLIT ===
print(f"Train size: {len(train_df)}")# | Test size: {len(test_df)}")

# === DEFINE CUSTOM DATASET ===
class VQADataset(torch.utils.data.Dataset):
    def __init__(self, df, processor, image_base_path, image_id_to_path=None):
        self.df = df
        self.processor = processor
        self.image_base_path = image_base_path
        self.image_id_to_path = image_id_to_path or {}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = row['image_id']

        # Try to find image path using metadata mapping
        if image_id in self.image_id_to_path:
            # Use the path from metadata
            relative_path = self.image_id_to_path[image_id]
            full_image_path = os.path.join(self.image_base_path, relative_path)
        else:
            # Fallback: Determine path based on image_id first two characters
            # For example, if image_id is "81iZlv3bjpL", it would go in folder "8"
            # Adjust this logic based on your actual naming convention
            prefix = image_id[:2]
            full_image_path = os.path.join(self.image_base_path, prefix, f"{image_id}.jpg")

            # If not found, try alternative patterns
            if not os.path.exists(full_image_path):
                # Try looking in a folder matching the first two characters
                prefix = image_id[:2]
                full_image_path = os.path.join(self.image_base_path, prefix, f"{image_id}.jpg")

            if not os.path.exists(full_image_path):
                # Last resort: search for the image recursively (could be slow)
                for root, _, files in os.walk(self.image_base_path):
                    for file in files:
                        if image_id in file:
                            full_image_path = os.path.join(root, file)
                            break

        try:
            image = Image.open(full_image_path).convert("RGB")
        except Exception as e:
            print(f"Failed to load {full_image_path} for image_id {image_id}: {e}")
            image = Image.new("RGB", (224, 224), (0, 0, 0))  # Fallback image

        encoding = self.processor(
            images=image,
            text=row['question'],
            padding="max_length",
            max_length=128,
            truncation=True,
            return_tensors="pt",
            return_attention_mask=True
        )

        labels = self.processor.tokenizer(
            row['answer'],
            padding="max_length",
            truncation=True,
            max_length=32,
            return_tensors="pt"
        )["input_ids"]

        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        encoding["labels"] = labels.squeeze(0)
        return encoding

# Add verification functions here
def verify_dataset_images(dataset, num_samples=5):
    """Verify that images are being loaded correctly by checking a few samples"""
    print("\n=== DATASET VERIFICATION ===")
    print(f"Dataset contains {len(dataset)} samples")

    # Check a few random samples
    import random
    random.seed(42)  # For reproducibility
    sample_indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))

    for i, idx in enumerate(sample_indices):
        try:
            # Get the original data row
            row = dataset.df.iloc[idx]
            print(f"\nSample {i+1}/{len(sample_indices)}:")
            print(f"  Question: {row['question']}")
            print(f"  Answer: {row['answer']}")
            print(f"  Image ID: {row['image_id']}")

            # Try to get the processed item
            item = dataset[idx]
            if 'pixel_values' in item:
                pixel_shape = item['pixel_values'].shape
                print(f"  Image loaded successfully with shape: {pixel_shape}")
            else:
                print("  Warning: No pixel_values in processed item")

            if 'input_ids' in item:
                input_length = item['input_ids'].shape[0]
                print(f"  Question tokenized to {input_length} tokens")
            else:
                print("  Warning: No input_ids in processed item")

            if 'labels' in item:
                label_length = item['labels'].shape[0]
                print(f"  Answer tokenized to {label_length} tokens")
            else:
                print("  Warning: No labels in processed item")

            print("  Sample loaded successfully!")
        except Exception as e:
            print(f"  Error processing sample {idx}: {e}")

    print("\n=== VERIFICATION COMPLETE ===\n")
    return True

# === CREATE DATASET INSTANCE ===
train_dataset = VQADataset(train_df, processor, BASE_IMAGE_PATH, image_id_to_path)

# Verify that the dataset is working properly
verify_dataset_images(train_dataset)

# === APPLY LoRA TO MODEL ===
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none"
)
model = get_peft_model(model, lora_config)
print("LoRA applied.")

# === PREPARE MODEL FOR ACCELERATION ===
model = accelerator.prepare(model)

# === DEFINE TRAINING ARGUMENTS ===
training_args = TrainingArguments(
    output_dir="./results",
    run_name="blip_vqa_lora_finetune_curated",
    num_train_epochs=3,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=4,
    learning_rate=5e-4,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_strategy="epoch",
    fp16=True,
    remove_unused_columns=False,
    report_to="none"
)

# === DEFINE A VALIDATION CALLBACK ===
from transformers import TrainerCallback

class ValidationCallback(TrainerCallback):
    def __init__(self, processor, interval=500):
        self.processor = processor
        self.interval = interval

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.interval == 0 and state.global_step > 0:
            model = kwargs.get('model', None)
            if model is None:
                return

            model.eval()
            with torch.no_grad():
                # Generate a prediction for a simple example
                prompt = "What color is the object in the image?"
                inputs = processor(images=Image.new("RGB", (224, 224), (100, 150, 200)),
                                  text=prompt, return_tensors="pt")
                inputs = {k: v.to(model.device) for k, v in inputs.items()}

                # Generate output
                generated_ids = model.generate(**inputs, max_length=20)
                generated_text = processor.decode(generated_ids[0], skip_special_tokens=True)

                print(f"\n=== VALIDATION AT STEP {state.global_step} ===")
                print(f"Q: {prompt}")
                print(f"A: {generated_text}")
                print(f"Current training loss: {state.log_history[-1]['loss']:.4f}")
                print(f"=== END VALIDATION ===\n")

            model.train()

# === TRAINER SETUP ===
validation_callback = ValidationCallback(processor)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=default_data_collator,
    callbacks=[validation_callback]
)

# # === GPU INFO ===
# if torch.cuda.is_available():
#     print("GPU GPU Memory Usage Before Training:")
#     print(torch.cuda.memory_summary())

# # === START TRAINING ===
# trainer.train()

# # === SAVE MODEL ===
# trainer.save_model("./blip_vqa_lora_r_16")
# print("Model saved to './blip_vqa_lora_r_16'")


In [None]:
# # ========================= SMOL PART (IGNORE) ======================================

# import os
# import pandas as pd
# from PIL import Image
# from transformers import BlipProcessor, BlipForQuestionAnswering, Trainer, TrainingArguments
# from transformers import AutoProcessor, AutoModelForVision2Seq
# import torch
# from accelerate import Accelerator
# from peft import LoraConfig, get_peft_model
# from transformers.data.data_collator import default_data_collator
# import ast
# import re
# import time
# from tqdm import tqdm
# import requests
# import base64
# from io import BytesIO
# import matplotlib.pyplot as plt
# from torch.nn.utils.rnn import pad_sequence
# from transformers.data.data_collator import DataCollatorMixin
# from transformers import TrainerCallback


# # === ACCELERATOR INIT ===
# accelerator = Accelerator()

# # Ensure proper types
# train_df['answer'] = train_df['answer'].fillna('unknown').astype(str)
# train_df['image_id'] = train_df['image_id'].astype(str)

# # === TRAIN-TEST SPLIT ===
# print(f"Train size: {len(train_df)}")# | Test size: {len(test_df)}")

# # === DEFINE CUSTOM DATASET ===
# class VQADataset(torch.utils.data.Dataset):
#     def __init__(self, df, processor, image_base_path, image_id_to_path=None):
#         self.df = df
#         self.processor = processor
#         self.image_base_path = image_base_path
#         self.image_id_to_path = image_id_to_path or {}

#     def __len__(self):
#         return len(self.df)

#     def __getitem__(self, idx):
#         row = self.df.iloc[idx]
#         image_id = str(row['image_id'])
#         question = row['question'] or ""
#         answer   = row['answer'] or ""

#         # 1) Locate & load the image
#         if image_id in self.image_id_to_path:
#             rel = self.image_id_to_path[image_id]
#             img_path = os.path.join(self.image_base_path, rel)
#         else:
#             prefix = image_id[:2]
#             img_path = os.path.join(self.image_base_path, prefix, f"{image_id}.jpg")
#             if not os.path.exists(img_path):
#                 for root, _, files in os.walk(self.image_base_path):
#                     for fname in files:
#                         if image_id in fname:
#                             img_path = os.path.join(root, fname)
#                             break

#         try:
#             image = Image.open(img_path).convert("RGB")
#         except:
#             # fallback black image
#             # Use a reasonable default size, e.g., 224x224 or 512x512 based on typical VLM inputs
#             image = Image.new("RGB", (512, 512), (0, 0, 0))

#         # 2) Ensure prompt has one <image> token
#         prompt = f"<image> {question.strip()}"

#         # 3) Encode image + text together
#         # The processor handles resizing/patching. We return the resulting tensor.
#         encoding = self.processor(
#             images=image,
#             text=prompt,
#             padding="max_length",
#             truncation=True,
#             max_length=128,
#             return_tensors="pt",
#             return_attention_mask=True
#         )

#         # 4) Encode the answer separately as labels
#         labels = self.processor.tokenizer(
#             answer,
#             padding="max_length",
#             truncation=True,
#             max_length=32,
#             return_tensors="pt"
#         )["input_ids"]

#         # 5) Squeeze batch dimension
#         item = {k: v.squeeze(0) for k, v in encoding.items()}
#         item["labels"] = labels.squeeze(0)

#         return item


# # Add verification functions here
# def verify_dataset_images(dataset, num_samples=5):
#     """Verify that images are being loaded correctly by checking a few samples"""
#     print("\n=== DATASET VERIFICATION ===")
#     print(f"Dataset contains {len(dataset)} samples")

#     # Check a few random samples
#     import random
#     random.seed(42)  # For reproducibility
#     sample_indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))

#     for i, idx in enumerate(sample_indices):
#         try:
#             # Get the original data row
#             row = dataset.df.iloc[idx]
#             print(f"\nSample {i+1}/{len(sample_indices)}:")
#             print(f"  Question: {row['question']}")
#             print(f"  Answer: {row['answer']}")
#             print(f"  Image ID: {row['image_id']}")

#             # Try to get the processed item
#             item = dataset[idx]
#             if 'pixel_values' in item:
#                 pixel_shape = item['pixel_values'].shape
#                 print(f"  Image loaded successfully with shape: {pixel_shape}")
#             else:
#                 print("  Warning: No pixel_values in processed item")

#             if 'input_ids' in item:
#                 input_length = item['input_ids'].shape[0]
#                 print(f"  Question tokenized to {input_length} tokens")
#             else:
#                 print("  Warning: No input_ids in processed item")

#             if 'labels' in item:
#                 label_length = item['labels'].shape[0]
#                 print(f"  Answer tokenized to {label_length} tokens")
#             else:
#                 print("  Warning: No labels in processed item")

#             print("  Sample loaded successfully!")
#         except Exception as e:
#             print(f"  Error processing sample {idx}: {e}")

#     print("\n=== VERIFICATION COMPLETE ===\n")
#     return True

# # === CREATE DATASET INSTANCE ===
# train_dataset = VQADataset(train_df, processor, BASE_IMAGE_PATH, image_id_to_path)

# # Verify that the dataset is working properly
# verify_dataset_images(train_dataset)

# # === APPLY LoRA TO MODEL ===
# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     target_modules=["q_proj", "v_proj"], # for Smol
#     lora_dropout=0.1,
#     bias="none"
# )
# model = get_peft_model(model, lora_config)
# print("LoRA applied.")

# # === PREPARE MODEL FOR ACCELERATION ===
# model = accelerator.prepare(model)

# # === DEFINE TRAINING ARGUMENTS ===
# training_args = TrainingArguments(
#     output_dir="./results",
#     run_name="smol256_vqa_lora_finetune_curated", # Changed run name
#     num_train_epochs=3,
#     per_device_train_batch_size=2,
#     gradient_accumulation_steps=1,
#     learning_rate=5e-4,
#     weight_decay=0.01,
#     logging_dir='./logs',
#     logging_steps=10,
#     save_strategy="epoch",
#     fp16=True,
#     remove_unused_columns=False,
#     report_to="none"
# )


# # === DEFINE COLLATOR ===
# from torch.nn.utils.rnn import pad_sequence
# from transformers.data.data_collator import default_data_collator

# class VQACollator:
#     def __init__(self, processor):
#         self.processor = processor

#     def __call__(self, features):
#         # 1) Stack all pixel_values into (batch, C, H, W)
#         pixel_values = torch.stack([f.pop("pixel_values") for f in features])

#         # 2) Pad input_ids and attention_mask to the same length in the batch
#         input_ids      = pad_sequence(
#                              [f.pop("input_ids")      for f in features],
#                              batch_first=True,
#                              padding_value=self.processor.tokenizer.pad_token_id
#                          )
#         attention_mask = pad_sequence(
#                              [f.pop("attention_mask") for f in features],
#                              batch_first=True,
#                              padding_value=0
#                          )

#         # 3) Pad labels (and use -100 to ignore them in loss)
#         labels = pad_sequence(
#                      [f.pop("labels") for f in features],
#                      batch_first=True,
#                      padding_value=-100
#                  )

#         return {
#             "pixel_values":   pixel_values,
#             "input_ids":      input_ids,
#             "attention_mask": attention_mask,
#             "labels":         labels,
#         }


# # === DEFINE A VALIDATION CALLBACK ===
# class ValidationCallback(TrainerCallback):
#     def __init__(self, processor, interval=500):
#         self.processor = processor
#         self.interval = interval

#     def on_step_end(self, args, state, control, **kwargs):
#         if state.global_step % self.interval == 0 and state.global_step > 0:
#             model = kwargs.get('model', None)
#             if model is None:
#                 return

#             model.eval()
#             with torch.no_grad():
#                 # Generate a prediction for a simple example
#                 prompt = "What color is the object in the image?"
#                 # Use a consistent dummy image size that reflects what the model expects after processing
#                 dummy_image = Image.new("RGB", (512, 512), (100, 150, 200))
#                 inputs = self.processor(images=dummy_image,
#                                   text=prompt, return_tensors="pt")
#                 inputs = {k: v.to(model.device) for k, v in inputs.items()}

#                 # Generate output
#                 generated_ids = model.generate(**inputs, max_length=20)
#                 generated_text = self.processor.decode(generated_ids[0], skip_special_tokens=True)

#                 print(f"\n=== VALIDATION AT STEP {state.global_step} ===")
#                 print(f"Q: {prompt}")
#                 print(f"A: {generated_text}")
#                 print(f"Current training loss: {state.log_history[-1]['loss']:.4f}")
#                 print(f"=== END VALIDATION ===\n")

#             model.train()

# # === TRAINER SETUP ===
# validation_callback = ValidationCallback(processor)
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     data_collator= VQACollator(processor), # Use the custom collator instance
#     callbacks=[validation_callback]
# )

# # === GPU INFO ===
# if torch.cuda.is_available():
#     print("GPU GPU Memory Usage Before Training:")
#     print(torch.cuda.memory_summary())

# # === START TRAINING ===
# trainer.train()

# # === SAVE MODEL ===
# trainer.save_model("./smol256_vqa_lora_r_16")
# print("Model saved to './smol256_vqa_lora_r_16'")

In [None]:
from transformers import BlipProcessor, BlipForQuestionAnswering
from peft import PeftModel
from tqdm import tqdm
import pandas as pd
from torch.utils.data import DataLoader


# === Load processor and base model ===
base_model = "Salesforce/blip-vqa-base"
processor = BlipProcessor.from_pretrained(base_model)
base = BlipForQuestionAnswering.from_pretrained(base_model)

# === Load the LoRA-adapted model ===
model = PeftModel.from_pretrained(base, "./blip_vqa_lora_r_16")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# === LOAD EVALUATION DATA ===
# Ensure your eval_df has 'image_id', 'question', 'answer'
# You can create it from raw data the same way you built train_df
# Example:
# eval_df = pd.read_csv("path/to/eval.csv")

eval_df['answer'] = eval_df['answer'].fillna('unknown').astype(str)
eval_df['image_id'] = eval_df['image_id'].astype(str)


# === INIT DATASET & DATALOADER ===
eval_dataset = VQADataset(eval_df, processor, BASE_IMAGE_PATH, image_id_to_path)

from PIL import Image
from tqdm import tqdm
import os

predictions = []
ground_truths = []
image_ids = []

for idx, row in tqdm(eval_df.iterrows(), total=len(eval_df), desc="Evaluating"):
    image_id = row["image_id"]
    question = row["question"]
    answer = row["answer"]

    # Locate image path (same logic as in your dataset)
    if image_id in image_id_to_path:
        relative_path = image_id_to_path[image_id]
        image_path = os.path.join(BASE_IMAGE_PATH, relative_path)
    else:
        prefix = image_id[:2]
        image_path = os.path.join(BASE_IMAGE_PATH, prefix, f"{image_id}.jpg")

        if not os.path.exists(image_path):
            # Try recursive fallback
            for root, _, files in os.walk(BASE_IMAGE_PATH):
                for file in files:
                    if image_id in file:
                        image_path = os.path.join(root, file)
                        break

    # Load image
    try:
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        print(f"Failed to load image {image_id}: {e}")
        image = Image.new("RGB", (224, 224), (0, 0, 0))  # fallback

    # Preprocess and generate prediction
    inputs = processor(images=image, text=question, return_tensors="pt").to(device)

    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_length=20)
        pred_answer = processor.decode(generated_ids[0], skip_special_tokens=True).strip()

    predictions.append(pred_answer)
    ground_truths.append(answer.strip())
    image_ids.append(image_id)


In [None]:

from sklearn.metrics import precision_recall_fscore_support
import pandas as pd

# Normalize text
def normalize(text):
    return text.lower().strip()

# Exact match
exact_match = [int(normalize(p) == normalize(g)) for p, g in zip(predictions, ground_truths)]
accuracy = sum(exact_match) / len(exact_match)

# Precision, Recall, F1 Score (macro-averaged over samples)
# Convert to lowercase strings to ensure case-insensitive comparison
y_true = [normalize(g) for g in ground_truths]
y_pred = [normalize(p) for p in predictions]

# Create the DataFrame
results_df = pd.DataFrame({
    "image_id": image_ids,
    "question": eval_df["question"],
    "ground_truth": ground_truths,
    "prediction": predictions,
    "exact_match": exact_match
})

# Print summary metrics
print(f"\nExact Match Accuracy: {accuracy * 100:.2f}%")

# Print and save results
print(results_df.head())
results_df.to_csv("vqa_eval_results.csv", index=False)

