In [None]:
import os
import re
import json
import torch
import pandas as pd
from pathlib import Path
from PIL import Image
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
import tqdm
# =========================
# CONFIG
# =========================
model_id = "Qwen/Qwen3-VL-8B-Instruct"

panels_dir = Path("/home/rishabh.mondal/Brick-Kilns-project/ijcai_2025_kilns/temporal-analysis/data/panels_with_polygon")
out_csv = "data/vlm_kiln_appearance_results.csv"

years = [2014, 2016, 2018, 2020, 2022, 2024, 2025]
max_new_tokens = 1024

# if you want to limit for testing, set e.g. 20, else None
limit = None

# =========================
# MODEL LOAD
# =========================
model = Qwen3VLForConditionalGeneration.from_pretrained(
    model_id, dtype="auto", device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_id)

# =========================
# PROMPT TEMPLATE
# =========================
PROMPT = f"""
You are analyzing a temporal panel of satellite images of the SAME location.
The panel contains {len(years)} sub-images in ONE ROW (left to right) for years:
{years}.

Focus ONLY on the kiln area inside/along the red boundary (treat it as the ROI).
Ignore surroundings outside the red boundary.

Task:
1) Compare each year with the previous year within the ROI.
2) Identify the earliest year where a kiln structure (oval/rectangular; red/brown/gray kiln-like morphology) first appears in the ROI.
3) If no kiln appears in any year in the ROI, return "no kiln present".

Output STRICTLY as JSON:
{{
  "appearance_year": <year or "no kiln present">,
  "roi_state_by_year": {{
     "2014": "<present|absent|unclear>",
     "2016": "<present|absent|unclear>",
     "2018": "<present|absent|unclear>",
     "2020": "<present|absent|unclear>",
     "2022": "<present|absent|unclear>",
     "2024": "<present|absent|unclear>",
     "2025": "<present|absent|unclear>"
  }},
  "confidence": "<high|medium|low>"
}}

Rules:
Use only visual evidence from the ROI.
If uncertain, use "unclear" and lower confidence.
"""

# =========================
# JSON EXTRACTION (robust)
# =========================
def extract_json(text: str):
    text = text.strip()

    # try direct
    try:
        return json.loads(text)
    except Exception:
        pass

    # try first {...} block
    m = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if m:
        cand = m.group(0)
        try:
            return json.loads(cand)
        except Exception:
            # minor cleanup attempts
            cand2 = cand.replace("\n", " ").strip()
            try:
                return json.loads(cand2)
            except Exception:
                return None
    return None

# =========================
# RUN ONE IMAGE
# =========================
def run_one(panel_path: Path):
    img = Image.open(panel_path).convert("RGB")

    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": img},
            {"type": "text", "text": PROMPT},
        ],
    }]

    inputs = processor.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_dict=True,
        return_tensors="pt",
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.inference_mode():
        generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)

    trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
    out_text = processor.batch_decode(
        trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]

    data = extract_json(out_text)
    return out_text, data

# =========================
# MAIN LOOP
# =========================
panel_files = sorted(panels_dir.glob("*.png"))

if limit is not None:
    panel_files = panel_files[:limit]

rows = []
for p in tqdm.tqdm(panel_files):
    lat_lon = p.stem  # filename without .png

    try:
        raw, data = run_one(p)

        if data is None:
            rows.append({
                "lat_lon": lat_lon,
                "panel_path": str(p),
                "appearance_year": "",
                "confidence": "",
                "status": "parse_fail",
                "raw_output": raw
            })
            continue

        appearance_year = data.get("appearance_year", "")
        confidence = data.get("confidence", "")

        # flatten roi_state_by_year
        roi = data.get("roi_state_by_year", {}) or {}
        row = {
            "lat_lon": lat_lon,
            "panel_path": str(p),
            "appearance_year": appearance_year,
            "confidence": confidence,
            "status": "ok",
            "raw_output": raw
        }
        for y in years:
            row[f"roi_{y}"] = roi.get(str(y), "")

        # keep change notes as json string
        row["change_notes"] = json.dumps(data.get("change_notes", []), ensure_ascii=False)

        rows.append(row)

    except Exception as e:
        rows.append({
            "lat_lon": lat_lon,
            "panel_path": str(p),
            "appearance_year": "",
            "confidence": "",
            "status": "exception",
            "error": repr(e),
            "raw_output": ""
        })

df = pd.DataFrame(rows)
df.to_csv(out_csv, index=False)
print("Saved:", out_csv)
print(df["status"].value_counts(dropna=False))