In [1]:
# ==========================================
# CELL 1: ENVIRONMENT SETUP
# ==========================================
import sys
import os
import shutil
import glob
import zipfile
import torch
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm # Use notebook version for better UI
from google.colab import drive

# Install dependencies (only if missing)
try:
    import qwen_vl_utils
except ImportError:
    print("üì¶ Installing Libraries...")
    !pip install -q git+https://github.com/huggingface/transformers accelerate bitsandbytes qwen-vl-utils pandas pillow

# Mount Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

project_path = '/content/drive/MyDrive/AdMIRe_Project'
print("‚úÖ Environment Ready.")

üì¶ Installing Libraries...
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.1/59.1 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m521.0/521.0 kB[0m [31m42.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m40.5/40.5 MB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behavi

In [2]:
# ==========================================
# CELL 2: DATA PREPARATION & IMPORT (CORRECTED)
# ==========================================
import pandas as pd
import os
import glob
import zipfile
import shutil

# 1. Helper to unzip data
def prepare_data(drive_folder, extract_to):
    if not os.path.exists(extract_to):
        os.makedirs(extract_to)

    # Check if data exists
    if glob.glob(f"{extract_to}/**/*.tsv", recursive=True):
        print(f"‚úÖ Data found in {extract_to}")
    else:
        zips = glob.glob(f"{drive_folder}/*.zip")
        if not zips:
            raise FileNotFoundError(f"‚ö†Ô∏è No .zip files in {drive_folder}")

        print(f"üì¶ Extracting {len(zips)} zip files...")
        for z in zips:
            with zipfile.ZipFile(z, 'r') as zip_ref:
                zip_ref.extractall(extract_to)

    return extract_to

# 2. Data Root Finder (The Fix)
def find_correct_root(base_path):
    """
    Finds the folder that contains the TSV files.
    This is what the Reader expects as 'data_root_path'.
    """
    print(f"üîç Searching for TSV root in {base_path}...")

    # Find any TSV file
    tsvs = glob.glob(f"{base_path}/**/*.tsv", recursive=True)
    tsvs = [t for t in tsvs if "result" not in t.lower()]

    if not tsvs:
        raise FileNotFoundError("‚ùå CRITICAL: No TSV files found after extraction.")

    # The Reader wants the folder where it can start searching recursively.
    # Usually, the extraction root (/content/admire_data) is safest.
    # But if the zip created a subfolder (e.g. /content/admire_data/AdMIRe_Task1/...),
    # we might want to point there.

    # Heuristic: Use the common parent of all TSVs, or just the base path.
    # For AdMIReReader, passing the base extraction path is usually best
    # because it uses recursive glob (**) to find files.

    print(f"‚úÖ Found {len(tsvs)} TSV files. Using base path: {base_path}")
    return base_path

# 3. Execution
raw_data_root = prepare_data(project_path, '/content/admire_data')
data_root = find_correct_root(raw_data_root) # Use this variable!

# 4. Import Custom Reader
if os.path.exists(f"{project_path}/admire_dataset.py"):
    shutil.copy(f"{project_path}/admire_dataset.py", ".")
    from admire_dataset import AdMIReReader
    print("‚úÖ AdMIReReader imported successfully.")
else:
    raise FileNotFoundError("‚ùå admire_dataset.py not found in Drive!")

üì¶ Extracting 1 zip files...
üîç Searching for TSV root in /content/admire_data...
‚úÖ Found 1 TSV files. Using base path: /content/admire_data
‚úÖ AdMIReReader imported successfully.


In [6]:
# ==========================================
# CELL 3: MODEL DEFINITION (ADJUSTED FOR LENGTH)
# ==========================================
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
import torch

class QwenVLM:
    def __init__(self, model_name="qwen3"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.hf_id = "Qwen/Qwen3-VL-4B-Instruct" if "qwen3" in model_name else "Qwen/Qwen2.5-VL-3B-Instruct"

        print(f"ü§ñ Loading {self.hf_id}...")
        bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)

        self.model = AutoModelForVision2Seq.from_pretrained(
            self.hf_id, quantization_config=bnb_config, device_map="auto", trust_remote_code=True
        )
        self.processor = AutoProcessor.from_pretrained(self.hf_id, trust_remote_code=True)
        self.processor.image_processor.max_pixels = 336 * 336
        self.processor.image_processor.min_pixels = 224 * 224

    def predict_winner(self, sentence, images):
        messages = []
        for i, img in enumerate(images):
            messages.append({"type": "image", "image": img})
            messages.append({"type": "text", "text": f"[Image {i+1}] "})

        # --- CONCISE PROMPT ---
        prompt = (
            f"\nIdiom: \"{sentence}\"\n"
            "Task: Select the image that best represents the METAPHORICAL meaning.\n"
            "Constraints: 1. Be concise (1-2 sentences per step). 2. Avoid literal traps.\n"
            "Steps:\n"
            "1. Define the abstract meaning.\n"
            "2. Pick the best image.\n"
            "3. Final Answer: MUST end with [[number]], e.g., [[3]]."
        )
        messages.append({"type": "text", "text": prompt})

        text = self.processor.apply_chat_template([{"role": "user", "content": messages}], tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info([{"role": "user", "content": messages}])

        inputs = self.processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            # INCREASED TO 300 TO PREVENT CUT-OFF
            generated_ids = self.model.generate(**inputs, max_new_tokens=300)

        output = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return output.split("assistant\n")[-1].strip() if "assistant\n" in output else output

model = QwenVLM(model_name="qwen3")

ü§ñ Loading Qwen/Qwen3-VL-4B-Instruct...




Loading weights:   0%|          | 0/713 [00:00<?, ?it/s]

In [7]:
# ==========================================
# CELL 4: EXECUTION (UNIVERSAL PARSER)
# ==========================================
import time
import re
import pandas as pd
from tqdm.notebook import tqdm

# Configuration
LIMIT_ITEMS = None

# Initialize Reader
try:
    reader = AdMIReReader(data_root_path=data_root, split="Train", mode="qwen")
except:
    reader = AdMIReReader(data_root_path='/content/admire_data', split="Train", mode="qwen")

print(f"üöÄ Processing {len(reader)} items...")

results = []
report_data = []
save_path = f"{project_path}/results_qwen_train.tsv"

count = 0
for item in tqdm(reader):
    if LIMIT_ITEMS and count >= LIMIT_ITEMS: break

    start_time = time.time()

    image_paths = item['image_paths']
    pil_images = []
    missing_flag = False
    for path in image_paths:
        if path != "MISSING" and os.path.exists(path):
            try: pil_images.append(Image.open(path).convert("RGB"))
            except:
                pil_images.append(Image.new('RGB', (224, 224), color='black'))
                missing_flag = True
        else:
            pil_images.append(Image.new('RGB', (224, 224), color='black'))
            missing_flag = True

    # Predict
    prediction_raw = "ERROR"
    top_pred_idx = -1

    try:
        if missing_flag:
            top_pred_idx = -1
        else:
            prediction_raw = model.predict_winner(item['text'], pil_images)

            # === UNIVERSAL PARSER ===
            # Priority 1: Check for [[x]] format
            match = re.search(r"\[\[(\d+)\]\]", prediction_raw)
            if match:
                top_pred_idx = int(match.group(1)) - 1
            else:
                # Priority 2: Look for explicit "Image X"
                # We hunt for "Image 4", "Option 2", etc.
                explicit_matches = re.findall(r"(?:image|option|choice)\s*(\d)", prediction_raw.lower())
                if explicit_matches:
                    # Take the LAST one mentioned (usually the conclusion)
                    top_pred_idx = int(explicit_matches[-1]) - 1
                else:
                    # Priority 3: Last resort, take the very last number found
                    nums = re.findall(r'\d+', prediction_raw)
                    if nums:
                        top_pred_idx = int(nums[-1]) - 1
                    else:
                        top_pred_idx = 0

    except Exception as e:
        print(f"‚ùå Error: {e}")

    # Grade
    is_correct = (top_pred_idx == item['label'])
    elapsed = time.time() - start_time

    report_data.append({
        "True_Answer": item['label'] + 1,
        "Model_Prediction": top_pred_idx + 1 if top_pred_idx != -1 else -1,
        "Result": "TRUE" if is_correct else "FALSE",
        "Time_Sec": round(elapsed, 2)
    })

    results.append({
        "text": item['text'],
        "prediction_raw": prediction_raw,
        "is_correct": is_correct
    })

    # Save checkpoint
    if count % 10 == 0:
        pd.DataFrame(results).to_csv(save_path, sep='\t', index=False)

    count += 1

# Report
df_report = pd.DataFrame(report_data)
print("\n" + "="*40)
print(f"üìä REPORT ({len(df_report)} Items)")
print("="*40)
print(df_report.to_string(index=True))
print("="*40)

accuracy = (df_report["Result"] == "TRUE").mean() * 100
print(f"üèÜ Final Accuracy: {accuracy:.2f}%")
pd.DataFrame(results).to_csv(save_path, sep='\t', index=False)

üïµÔ∏è Scanning /content/admire_data for Train TSV...
‚úÖ Loaded TSV: /content/admire_data/train/subtask_a_train.tsv
‚úÖ Images located at: /content/admire_data/train/
üöÄ Processing 70 items...


  0%|          | 0/70 [00:00<?, ?it/s]


üìä REPORT (70 Items)
    True_Answer  Model_Prediction Result  Time_Sec
0             1                 2  FALSE     28.94
1             2                 2   TRUE     25.96
2             3                 4  FALSE     26.93
3             4                 4   TRUE     28.92
4             4                 4   TRUE     33.27
5             2                 1  FALSE     33.41
6             5                 4  FALSE     25.04
7             1                 4  FALSE     23.83
8             2                 1  FALSE     29.67
9             4                 3  FALSE     27.68
10            1                 2  FALSE     33.45
11            3                 2  FALSE     32.20
12            5                 3  FALSE     25.83
13            5                 4  FALSE     25.17
14            3                 2  FALSE     25.14
15            1                 3  FALSE     26.69
16            2                 2   TRUE     25.24
17            5                 2  FALSE     26.64
18     