In [None]:
import os
import re
import json
from datetime import datetime
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")

import torch
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from scipy import ndimage

from transformers import AutoProcessor, Qwen3VLForConditionalGeneration

import sys
sys.path.append('/home/jjs2403/2026_bootcamp_02/models/sam2')
from sam2.sam2_image_predictor import SAM2ImagePredictor
from omegaconf import OmegaConf
from hydra.utils import instantiate


# =============================================================================
# 0) Seed / Device
# =============================================================================
def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


# =============================================================================
# 1) Paths (요청대로 경로 수정하지 않음)
# =============================================================================
BASE_DIR = Path('/home/jjs2403/2026_bootcamp_02')
DATASET_DIR = BASE_DIR / 'dataset/challenge_datasets/challenge2'
TEST_IMG_DIR = DATASET_DIR / 'input_images'

OUTPUT_DIR = Path.home() / 'challenge2_outputs'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

RUN_TAG = datetime.now().strftime("%Y%m%d_%H%M%S")
VIZ_FORCE_OVERWRITE = True


# =============================================================================
# 2) Load Models (Qwen3 VL / SAM2.1)
# =============================================================================
MODEL_PATH = BASE_DIR / 'models/Qwen3-VL-8B-Instruct'
vl_model = Qwen3VLForConditionalGeneration.from_pretrained(
    str(MODEL_PATH),
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    device_map='auto',
    attn_implementation='flash_attention_2'
)
vl_processor = AutoProcessor.from_pretrained(str(MODEL_PATH))

SAM2_1_MODEL_PATH = BASE_DIR / 'models/SAM2.1/weights/sam2.1_hiera_large.pt'
SAM2_1_MODEL_CFG_PATH = BASE_DIR / 'models/SAM2.1/weights/sam2.1_hiera_l.yaml'

checkpoint = torch.load(SAM2_1_MODEL_PATH, map_location=DEVICE)
cfg = OmegaConf.load(SAM2_1_MODEL_CFG_PATH)
sam_model = instantiate(cfg.model, _recursive_=True)
sam_model.load_state_dict(checkpoint['model'])
sam_model = sam_model.to(DEVICE).eval()
sam_predictor = SAM2ImagePredictor(sam_model)


# =============================================================================
# 3) Dataset / Loader
# =============================================================================
class MarineTestDataset(Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.img_files = sorted([
            f for f in os.listdir(img_dir)
            if f.lower().endswith(('.jpg', '.png', '.jpeg'))
        ])

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img_name = self.img_files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        img_id = os.path.splitext(img_name)[0]
        return image, img_id

assert TEST_IMG_DIR.exists(), f"TEST_IMG_DIR not found: {TEST_IMG_DIR}"

test_dataset = MarineTestDataset(img_dir=str(TEST_IMG_DIR))
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=lambda x: x[0]
)


# =============================================================================
# 4) Utils: clamp / robust JSON extraction
# =============================================================================
def clamp_int(v, lo, hi):
    return int(max(lo, min(hi, int(round(v)))))

def clamp_bbox_xyxy(bbox, w, h):
    x1, y1, x2, y2 = bbox
    x1 = clamp_int(x1, 0, w - 1)
    y1 = clamp_int(y1, 0, h - 1)
    x2 = clamp_int(x2, 0, w - 1)
    y2 = clamp_int(y2, 0, h - 1)
    if x2 < x1: x1, x2 = x2, x1
    if y2 < y1: y1, y2 = y2, y1
    return [x1, y1, x2, y2]

def _strip_role_prefix(text: str) -> str:
    t = text.strip()
    if "assistant" in t:
        t = t.split("assistant", 1)[-1].strip()
    t = re.sub(r"^\s*(user|system)\s*:?\s*", "", t, flags=re.IGNORECASE).strip()
    return t

def _extract_first_json_like(text: str) -> str:
    t = text.strip()
    t = re.sub(r"^```(?:json)?", "", t, flags=re.IGNORECASE).strip()
    t = re.sub(r"```$", "", t).strip()

    m_obj = re.search(r"\{.*\}", t, flags=re.DOTALL)
    m_arr = re.search(r"\[.*\]", t, flags=re.DOTALL)

    if m_obj and m_arr:
        cand = m_obj.group(0) if m_obj.start() < m_arr.start() else m_arr.group(0)
    elif m_obj:
        cand = m_obj.group(0)
    elif m_arr:
        cand = m_arr.group(0)
    else:
        return ""

    cand = re.sub(r",\s*([}\]])", r"\1", cand)
    return cand.strip()


# =============================================================================
# 5) Qwen Turn-1: ROI BBox
# =============================================================================
BBOX_PROMPT = """You are given an image.
Task: Find ONE large ROI bounding box that covers the main floating debris region on the water surface.
This ROI should include the straw/debris mat area broadly (even if straw is not sharply visible).
Include plastics / styrofoam / branches that are embedded in or near the straw mat.
Exclude: sky, upper land/buildings, distant background.

Important:
- Make the box LARGE enough to cover the entire debris/straw region, not just a small object.
- Prefer over-coverage rather than missing straw mat.

Output ONLY one JSON object:
{"bbox_2d":[x1,y1,x2,y2]}
Coordinates must be in PIXELS (top-left origin).
No extra text.
"""

def parse_bbox_from_text(text: str, w: int, h: int):
    t = _strip_role_prefix(text)
    js = _extract_first_json_like(t)
    if not js:
        return None

    try:
        obj = json.loads(js)
        if (
            isinstance(obj, dict)
            and "bbox_2d" in obj
            and isinstance(obj["bbox_2d"], (list, tuple))
            and len(obj["bbox_2d"]) >= 4
        ):
            bbox = [float(x) for x in obj["bbox_2d"][:4]]
            return clamp_bbox_xyxy(bbox, w, h)
    except Exception:
        pass

    m = re.search(r"\[\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*\]", t)
    if m:
        bbox = [float(m.group(i)) for i in range(1, 5)]
        return clamp_bbox_xyxy(bbox, w, h)

    return None

@torch.no_grad()
def qwen_predict_roi_bbox(image: Image.Image):
    messages = [{
        'role': 'user',
        'content': [
            {'type': 'image', 'image': image},
            {'type': 'text', 'text': BBOX_PROMPT}
        ]
    }]

    text = vl_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = vl_processor(text=[text], images=[image], return_tensors='pt')

    if torch.cuda.is_available():
        inputs = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

    output_ids = vl_model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        num_beams=1
    )
    out = vl_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
    out = _strip_role_prefix(out)

    w, h = image.size
    bbox = parse_bbox_from_text(out, w, h)
    if bbox is None:
        bbox = [0, int(h * 0.40), w - 1, h - 1]
    return bbox, out


# =============================================================================
# 6) Qwen Turn-2: Points (pos/neg) inside ROI
# =============================================================================
LOW_COVERAGE_EXTRA = """
ADDITIONAL CONSTRAINT (SECOND PASS ONLY):
- The final mask must cover LESS THAN 30% of the entire image.
- Avoid placing POS points that could make the mask overly large.
- Use NEG points to block broad water areas and large non-target regions.
- Focus POS on compact, dense straw/debris clusters only.
"""

POINTS_PROMPT_TEMPLATE = """You are given an image and an ROI bounding box in pixels: ROI=[{x1},{y1},{x2},{y2}].

Inside this ROI, choose points for segmentation guidance for a segmentation model (SAM).

TARGET (POS): the straw/debris mat floating on the water surface (broad region), including embedded small debris.
- If pure-white rigid debris (styrofoam blocks / white plastic chunks) is present, you MUST place at least 2 POS points on it (required). These white objects are NOT glare and must be included in the final mask. If none are visible, do not invent them.
- Provide POS points on the thick interior of the largest continuous straw/debris mat region.
- Prefer dense straw-like texture (not smooth water reflection).
- If possible, keep POS points ~50-80 pixels away from the straw-mat boundary.
- Avoid mixed boundary zones (water + straw edges) and avoid isolated tiny debris pieces.

NON-TARGET (NEG): areas that should NOT be segmented as straw/debris mat.
PRIORITY (most important): SUN-GLARE / SPECULAR REFLECTION WATER
- Place most NEG points on shiny glare patches on water (bright reflection streaks, mirror-like highlights).
- These glare regions are smooth reflective water, NOT straw texture.

SECONDARY NEG (only if needed and clearly non-target):
- Sea foam / bubbly water patches (foam texture)
- Dark shadow water near walls/rocks
- Wall/tires/rocks texture (if inside ROI)

IMPORTANT EXCLUSION for NEG:
- Do NOT place NEG points on pure-white rigid debris objects such as styrofoam blocks or white plastic chunks.
  (They are debris and should remain segmentable; if visible, POS on them is REQUIRED and the final mask must include them.)

COUNTS:
- Provide exactly 4 POS points.
- Provide exactly 6 NEG points.
- List neg_points in priority order (most important first).
- All points must be inside the ROI.
- Use full-image pixel coordinates (not cropped coordinates).

OUTPUT (STRICT JSON ONLY, no extra text):
{{"pos_points":[[x,y],[x,y],[x,y],[x,y]], "neg_points":[[x,y],[x,y],[x,y],[x,y],[x,y],[x,y]]}}
"""

def parse_points_from_text(text: str, bbox, w: int, h: int):
    x1, y1, x2, y2 = bbox
    t = _strip_role_prefix(text)
    js = _extract_first_json_like(t)
    if not js:
        return None, None

    try:
        obj = json.loads(js)
        pos = obj.get("pos_points", [])
        neg = obj.get("neg_points", [])
        if not (isinstance(pos, list) and isinstance(neg, list)):
            return None, None

        def _clamp_pt(pt):
            px, py = pt
            px = clamp_int(px, x1, x2)
            py = clamp_int(py, y1, y2)
            return [px, py]

        pos2 = []
        for p in pos:
            if isinstance(p, (list, tuple)) and len(p) == 2:
                pos2.append(_clamp_pt(p))

        neg2 = []
        for p in neg:
            if isinstance(p, (list, tuple)) and len(p) == 2:
                neg2.append(_clamp_pt(p))

        rng = np.random.default_rng(42)

        def _random_pt():
            return [int(rng.integers(x1, x2 + 1)), int(rng.integers(y1, y2 + 1))]

        while len(pos2) < 4:
            pos2.append(_random_pt())
        pos2 = pos2[:4]

        while len(neg2) < 6:
            neg2.append(_random_pt())
        neg2 = neg2[:6]

        return pos2, neg2
    except Exception:
        return None, None

@torch.no_grad()
def qwen_predict_points(image: Image.Image, bbox, extra_instruction: str = None):
    w, h = image.size
    x1, y1, x2, y2 = bbox

    prompt = POINTS_PROMPT_TEMPLATE.format(x1=x1, y1=y1, x2=x2, y2=y2)
    if extra_instruction:
        prompt = prompt + "\n\n" + extra_instruction

    messages = [{
        'role': 'user',
        'content': [
            {'type': 'image', 'image': image},
            {'type': 'text', 'text': prompt}
        ]
    }]

    text = vl_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = vl_processor(text=[text], images=[image], return_tensors='pt')

    if torch.cuda.is_available():
        inputs = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

    output_ids = vl_model.generate(
        **inputs,
        max_new_tokens=384,
        do_sample=False,
        num_beams=1
    )
    out = vl_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
    out_clean = _strip_role_prefix(out)

    pos, neg = parse_points_from_text(out_clean, bbox, w, h)
    if pos is None or neg is None:
        rng = np.random.default_rng(42)
        x1, y1, x2, y2 = bbox
        pos = [[int(rng.integers(x1, x2 + 1)), int(rng.integers(y1, y2 + 1))] for _ in range(4)]
        neg = [[int(rng.integers(x1, x2 + 1)), int(rng.integers(y1, y2 + 1))] for _ in range(6)]

    return pos, neg, out_clean


# =============================================================================
# 7) SAM: box + points -> multimask, select best (area gate + point consistency)
# =============================================================================
def mask_point_hits(mask: np.ndarray, pts):
    if len(pts) == 0:
        return 0.0
    h, w = mask.shape
    hits = 0
    for x, y in pts:
        x = clamp_int(x, 0, w - 1)
        y = clamp_int(y, 0, h - 1)
        hits += 1 if mask[y, x] else 0
    return hits / max(len(pts), 1)

@torch.no_grad()
def sam_segment_box_points(
    image: Image.Image, bbox, pos_pts, neg_pts,
    min_area_ratio=0.003, max_area_ratio=0.95,
    pos_miss_weight=2.0, neg_hit_weight=6.0,
    min_pos_hit=1.0, force_pos=True
):
    img = np.array(image, dtype=np.uint8)
    H, W = img.shape[:2]
    sam_predictor.set_image(img)

    x1, y1, x2, y2 = bbox
    roi_area = float(max(1, (x2 - x1 + 1) * (y2 - y1 + 1)))

    all_pts = np.array(pos_pts + neg_pts, dtype=np.float32)
    all_labels = np.array([1] * len(pos_pts) + [0] * len(neg_pts), dtype=np.int32)
    box = np.array(bbox, dtype=np.float32)[None, :]

    masks, scores, _ = sam_predictor.predict(
        box=box,
        point_coords=all_pts[None, :, :],
        point_labels=all_labels[None, :],
        multimask_output=True
    )

    best_mask = None
    best_val = -1e18
    best_pos_mask = None
    best_pos_hit = -1.0
    best_pos_score = -1e18

    for i in range(masks.shape[0]):
        m = masks[i].astype(bool)
        s = float(scores[i])

        roi_m = m[y1:y2 + 1, x1:x2 + 1]
        area_ratio = float(roi_m.sum()) / roi_area

        if area_ratio < min_area_ratio or area_ratio > max_area_ratio:
            continue

        pos_hit = mask_point_hits(m, pos_pts)
        neg_hit = mask_point_hits(m, neg_pts)
        pos_miss = 1.0 - pos_hit

        if pos_hit > best_pos_hit or (pos_hit == best_pos_hit and s > best_pos_score):
            best_pos_hit = pos_hit
            best_pos_score = s
            best_pos_mask = m

        if force_pos and pos_hit < min_pos_hit:
            continue

        val = s - pos_miss_weight * pos_miss - neg_hit_weight * neg_hit
        if val > best_val:
            best_val = val
            best_mask = m

    if best_mask is None and best_pos_mask is not None:
        best_mask = best_pos_mask

    if best_mask is None:
        idx = int(np.argmax(scores))
        best_mask = masks[idx].astype(bool)

    return best_mask


# =============================================================================
# 8) Post-processing
# =============================================================================
def clean_mask_components(mask: np.ndarray, min_area=40):
    labeled, num = ndimage.label(mask)
    out = np.zeros_like(mask, dtype=bool)
    for label_id in range(1, num + 1):
        comp = (labeled == label_id)
        if int(comp.sum()) >= int(min_area):
            out |= comp
    return out


# =============================================================================
# 9) RLE Encoding
# =============================================================================
def rle_encode(mask):
    if torch.is_tensor(mask):
        mask = mask.detach().cpu().numpy()
    pixels = mask.flatten()
    pixels = np.concatenate([[False], pixels, [False]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


# =============================================================================
# 10) Visualization (No OpenCV)
# =============================================================================
def save_viz(image: Image.Image, img_id: str, bbox, pos_pts, neg_pts, mask: np.ndarray):
    img = np.array(image)
    H, W = img.shape[:2]
    x1, y1, x2, y2 = bbox
    coverage = float(mask.sum()) / float(mask.size) * 100.0

    fig, axes = plt.subplots(1, 3, figsize=(24, 8))

    axes[0].imshow(img)
    title = f'{img_id} - ROI bbox + points'
    if RUN_TAG:
        title += f'\nrun {RUN_TAG}'
    axes[0].set_title(title, fontsize=14, fontweight='bold')
    axes[0].axis('off')

    import matplotlib.patches as patches
    rect = patches.Rectangle((x1, y1), (x2 - x1 + 1), (y2 - y1 + 1),
                             linewidth=2.5, edgecolor='lime', facecolor='none')
    axes[0].add_patch(rect)

    if len(pos_pts) > 0:
        px = [p[0] for p in pos_pts]
        py = [p[1] for p in pos_pts]
        axes[0].scatter(px, py, s=60, c='dodgerblue', marker='o', label='POS')
    if len(neg_pts) > 0:
        nx = [p[0] for p in neg_pts]
        ny = [p[1] for p in neg_pts]
        axes[0].scatter(nx, ny, s=60, c='orange', marker='x', label='NEG')

    axes[0].legend(loc='lower right')

    axes[1].imshow(img)
    axes[1].imshow(mask, alpha=0.55)
    axes[1].set_title('SAM Segmentation (overlay)', fontsize=14)
    axes[1].axis('off')

    axes[2].imshow(mask, cmap='gray')
    axes[2].set_title(f'Mask ({coverage:.2f}%)', fontsize=14, fontweight='bold')
    axes[2].axis('off')

    plt.tight_layout()
    out_path = OUTPUT_DIR / f'viz_{img_id}.png'
    if VIZ_FORCE_OVERWRITE and out_path.exists():
        out_path.unlink()
    plt.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.close()
    return out_path


# =============================================================================
# 11) Main loop: ROI BBox -> Points -> SAM -> Postproc -> Viz
# =============================================================================
mask_results = []
debug_rows = []
viz_paths = []

for image, img_id in tqdm(test_loader, desc='Segmentation'):
    w, h = image.size
    try:
        # 1) ROI bbox
        bbox, bbox_raw = qwen_predict_roi_bbox(image)

        # 2) Points inside ROI
        pos_pts, neg_pts, points_raw = qwen_predict_points(image, bbox)

        # 3) SAM segment with box+points
        mask = sam_segment_box_points(
            image=image,
            bbox=bbox,
            pos_pts=pos_pts,
            neg_pts=neg_pts,
            # ---- 튜닝 포인트 ----
            min_area_ratio=0.003,
            max_area_ratio=0.80,
            pos_miss_weight=2.0,
            neg_hit_weight=2.5,
            min_pos_hit=1.0,
            force_pos=True
        )
        mask = clean_mask_components(mask, min_area=40)

        coverage = float(mask.sum()) / float(mask.size) * 100.0

        # Retry once if coverage is too high
        retry_used = False
        coverage_retry = None
        points_raw_retry = None

        if coverage >= 30.0:
            retry_used = True
            pos_pts, neg_pts, points_raw_retry = qwen_predict_points(
                image, bbox, extra_instruction=LOW_COVERAGE_EXTRA
            )
            mask = sam_segment_box_points(
                image=image,
                bbox=bbox,
                pos_pts=pos_pts,
                neg_pts=neg_pts,
                # ---- 튜닝 포인트 ----
                min_area_ratio=0.003,
                max_area_ratio=0.97,
                pos_miss_weight=2.0,
                neg_hit_weight=6.0,
                min_pos_hit=1.0,
                force_pos=True
            )
            mask = clean_mask_components(mask, min_area=40)
            coverage = float(mask.sum()) / float(mask.size) * 100.0
            coverage_retry = coverage

        mask_results.append({'ID': img_id, 'mask': mask})

        debug_rows.append({
            'ID': img_id,
            'bbox': bbox,
            'pos_pts': pos_pts,
            'neg_pts': neg_pts,
            'coverage_pct': coverage,
            'retry_used': retry_used,
            'coverage_retry': coverage_retry,
            'bbox_raw_head': bbox_raw[:180],
            'points_raw_head': points_raw[:180],
            'points_raw_head_retry': (points_raw_retry[:180] if points_raw_retry else None),
        })

        vp = save_viz(image, img_id, bbox, pos_pts, neg_pts, mask)
        viz_paths.append(str(vp))

    except Exception as e:
        import traceback
        traceback.print_exc()
        mask_results.append({'ID': img_id, 'mask': np.zeros((h, w), dtype=bool)})
        debug_rows.append({'ID': img_id, 'bbox': None, 'pos_pts': None, 'neg_pts': None, 'coverage_pct': 0.0})

mask_dict = {r['ID']: r['mask'] for r in mask_results}


# =============================================================================
# 12) Submission + Debug dump + Stats
# =============================================================================
submission_data = []
for r in mask_results:
    img_id = r['ID']
    mask = r['mask']
    submission_data.append({'ID': img_id, 'Label': rle_encode(mask)})

submission_df = pd.DataFrame(submission_data, columns=['ID', 'Label'])
submission_path = OUTPUT_DIR / 'submission_qwen_roi_points_sam.csv'
submission_df.to_csv(submission_path, index=False)

debug_path = OUTPUT_DIR / 'debug_qwen_roi_points.json'
with open(debug_path, 'w', encoding='utf-8') as f:
    json.dump(debug_rows, f, ensure_ascii=False, indent=2)

coverages = [d.get('coverage_pct', 0.0) for d in debug_rows if isinstance(d.get('coverage_pct', None), (int, float))]
if coverages:
    _avg = float(np.mean(coverages))
    _min = float(np.min(coverages))
    _max = float(np.max(coverages))

# (필요하면 아래만 출력해서 최소 로그로 확인)
print(f"Saved: {submission_path}")
print(f"Saved: {debug_path}")
