In [None]:
!pip install 'git+https://github.com/facebookresearch/detectron2.git'


In [None]:
#!/usr/bin/env python
# two_view_bmi_res101_res18_letterbox_no_proj_with_buckets_meanstd.py

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
import json
import time
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Dict
from PIL import Image, ImageOps
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.metrics import r2_score, mean_absolute_error
from scipy.stats import ks_2samp, wasserstein_distance
import matplotlib.pyplot as plt
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from transformers import pipeline

# ----------------------------- CONFIG -------------------------------
BASE_DIR         = "/content/drive/MyDrive"
SAMPLE_JSON_PATH = os.path.join(BASE_DIR, "sample_list.json")
RESNET101_PATH   = os.path.join(BASE_DIR, "resnet101-63fe2227.pth")
RESNET18_PATH    = os.path.join(BASE_DIR, "resnet18-5c106cde.pth")
HUMAN_NPZ_PATH   = os.path.join(BASE_DIR, "human_mask_cache_224.npz")
REF_NPZ_PATH     = os.path.join(BASE_DIR, "ref_mask_cache_224.npz")
REF_LABELS       = ["a dark paper", "a white ball"]
DEVICE           = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
OUT_W, OUT_H     = 224, 224

# ---------------------- LETTERBOX MASK -----------------------------
def letterbox_mask(mask: np.ndarray, new_size=(OUT_W, OUT_H)) -> np.ndarray:
    """
    Resize a binary mask to new_size with letterbox padding to preserve aspect ratio.
    """
    h, w = mask.shape
    nw, nh = new_size
    scale = min(nw / w, nh / h)
    rw, rh = int(w * scale), int(h * scale)
    resized = cv2.resize(mask, (rw, rh), interpolation=cv2.INTER_NEAREST)
    top = (nh - rh) // 2
    left = (nw - rw) // 2
    output = np.zeros((nh, nw), dtype=mask.dtype)
    output[top:top+rh, left:left+rw] = resized
    return output

# ---------------------- BUILD MASK CACHES --------------------------
def build_mask_caches(samples: List[Dict],
                      person_pred: DefaultPredictor,
                      ref_det,
                      ref_labels: List[str]) -> None:
    """
    Generate and cache letterboxed human and reference masks for the first two views.
    """
    if os.path.exists(HUMAN_NPZ_PATH) and os.path.exists(REF_NPZ_PATH):
        return
    human_cache, ref_cache = {}, {}
    for sample in tqdm(samples, desc="Building mask caches"):
        for path in sample["img_paths"][:2]:
            if path in human_cache or not os.path.exists(path):
                continue
            img = Image.open(path).convert("RGB")
            img = ImageOps.exif_transpose(img)
            arr = np.array(img)
            H, W = arr.shape[:2]

            # human mask via Detectron2
            inst = person_pred(cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))["instances"]
            hm = np.zeros((H, W), dtype=np.uint8)
            if len(inst):
                cls = inst.pred_classes.cpu().numpy()
                ids = np.where(cls == 0)[0]
                if len(ids):
                    idx = ids[inst.scores.cpu().numpy()[ids].argmax()]
                    if inst.has("pred_masks"):
                        hm = inst.pred_masks[idx].cpu().numpy().astype(np.uint8)
                    else:
                        y0,x0,y1,x1 = inst.pred_boxes[idx].tensor.cpu().numpy()[0]
                        hm[int(y0):int(y1), int(x0):int(x1)] = 1
            human_cache[path] = letterbox_mask(hm)

            # reference mask via Grounding‑DINO
            dets = ref_det(img, candidate_labels=ref_labels, threshold=0.7)
            det_map = {d["label"].lower(): d for d in dets}
            for lbl in ref_labels:
                key = f"{path}|{lbl}"
                rm = np.zeros((H, W), dtype=np.uint8)
                d  = det_map.get(lbl)
                if d:
                    x0,y0 = int(d["box"]["xmin"]), int(d["box"]["ymin"])
                    x1,y1 = int(d["box"]["xmax"]), int(d["box"]["ymax"])
                    rm[y0:y1, x0:x1] = 1
                ref_cache[key] = letterbox_mask(rm)

    np.savez_compressed(HUMAN_NPZ_PATH, **human_cache)
    np.savez_compressed(REF_NPZ_PATH,   **ref_cache)

# -------------------------- DATASET -------------------------------
class TwoViewMaskDataset(Dataset):
    """
    Dataset that returns two RGB images plus their human/ref masks and BMI.
    """
    def __init__(self, samples, resize, tf, ref_labels):
        self.samples    = samples
        self.resize     = resize
        self.tf         = tf
        self.ref_labels = ref_labels
        self.h_masks    = dict(np.load(HUMAN_NPZ_PATH, allow_pickle=True))
        self.r_masks    = dict(np.load(REF_NPZ_PATH,   allow_pickle=True))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        p1, p2 = s["img_paths"][:2]
        img1 = Image.open(p1).convert("RGB")
        img2 = Image.open(p2).convert("RGB")
        rgb1 = self.tf(self.resize(img1))
        rgb2 = self.tf(self.resize(img2))
        m1 = torch.from_numpy(self.h_masks[p1]).unsqueeze(0).float()
        m2 = torch.from_numpy(self.h_masks[p2]).unsqueeze(0).float()
        rm1 = np.stack([self.r_masks[f"{p1}|{lbl}"] for lbl in self.ref_labels], 0)
        rm2 = np.stack([self.r_masks[f"{p2}|{lbl}"] for lbl in self.ref_labels], 0)
        rm1 = torch.from_numpy(rm1).float()
        rm2 = torch.from_numpy(rm2).float()
        bmi = torch.tensor(float(s["bmi"]), dtype=torch.float32).unsqueeze(0)
        return rgb1, rgb2, m1, m2, rm1, rm2, bmi

# --------------------------- MODEL -------------------------------
class MultiBranchNoProj(nn.Module):
    """
    Multi-branch model:
      - RGB branch: ResNet-101
      - Human-mask branch: ResNet-18
      - Ref-mask branch: ResNet-18
      Outputs regression and classification heads.
    """
    def __init__(self, num_ref=2):
        super().__init__()
        # RGB backbone
        r101 = models.resnet101(weights=None)
        sd = torch.load(RESNET101_PATH, map_location="cpu", weights_only=False)
        r101.load_state_dict(sd, strict=False)
        self.rgb_backbone = nn.Sequential(*list(r101.children())[:-1])
        # Human-mask backbone
        h18 = models.resnet18(weights=None)
        sd = torch.load(RESNET18_PATH, map_location="cpu", weights_only=False)
        h18.load_state_dict(sd, strict=False)
        w0 = h18.conv1.weight.data
        h18.conv1 = nn.Conv2d(1,64,7,2,3,bias=False)
        h18.conv1.weight.data = w0.mean(dim=1,keepdim=True)
        self.h_backbone = nn.Sequential(*list(h18.children())[:-1])
        # Ref-mask backbone
        r18 = models.resnet18(weights=None)
        sd = torch.load(RESNET18_PATH, map_location="cpu", weights_only=False)
        r18.load_state_dict(sd, strict=False)
        w1 = r18.conv1.weight.data.mean(dim=1,keepdim=True)
        r18.conv1 = nn.Conv2d(num_ref,64,7,2,3,bias=False)
        r18.conv1.weight.data = w1.repeat(1,num_ref,1,1)
        self.r_backbone = nn.Sequential(*list(r18.children())[:-1])
        # Heads
        in_dim = (2048 + 512 + 512) * 2
        self.fc_reg = nn.Linear(in_dim, 1)
        self.fc_cls = nn.Linear(in_dim, 3)

    def forward(self, rgb1, rgb2, hm1, hm2, rm1, rm2):
        def feat(backbone, x): return backbone(x).flatten(1)
        r1 = feat(self.rgb_backbone, rgb1)
        r2 = feat(self.rgb_backbone, rgb2)
        h1 = feat(self.h_backbone,   hm1)
        h2 = feat(self.h_backbone,   hm2)
        f1 = feat(self.r_backbone,    rm1)
        f2 = feat(self.r_backbone,    rm2)
        v1 = torch.cat([r1,h1,f1], dim=1)
        v2 = torch.cat([r2,h2,f2], dim=1)
        fused = torch.cat([v1,v2], dim=1)
        return self.fc_reg(fused), self.fc_cls(fused)

# ---------------------- UTILITIES -------------------------------
def get_bmi_class_new(bmi: float) -> int:
    if bmi < 17.28: return 0
    if bmi < 25.57: return 1
    return 2

def compute_metrics(gt, pred):
    gt, pred = np.asarray(gt), np.asarray(pred)
    err = pred - gt
    ks, _ = ks_2samp(gt, pred)
    return {
        "r2":          r2_score(gt, pred),
        "mae":         mean_absolute_error(gt, pred),
        "tol_rate":    np.mean(np.abs(err) <= 1.0),
        "mean_bias":   err.mean(),
        "error_std":   err.std(),
        "ks":          ks,
        "wasserstein": wasserstein_distance(gt, pred),
    }

# -------------------------- MAIN -------------------------------
if __name__ == "__main__":
    with open(SAMPLE_JSON_PATH) as f:
        samples = json.load(f)

    # build caches on first run
    BUILD_CACHE = False
    if BUILD_CACHE:
        cfg = get_cfg()
        cfg.merge_from_file(model_zoo.get_config_file(
            "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
        cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
            "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        human_pred = DefaultPredictor(cfg)
        ref_det    = pipeline("zero-shot-object-detection",
                              model="IDEA-Research/grounding-dino-tiny",
                              device=0 if torch.cuda.is_available() else -1)
        build_mask_caches(samples, human_pred, ref_det, REF_LABELS)
        exit(0)

    # split data
    train = [s for s in samples if s["split"] == "Training"]
    val   = [s for s in samples if s["split"] == "Validation"]

    # data loaders
    resize = transforms.Resize((OUT_H, OUT_W))
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
    train_loader = DataLoader(
        TwoViewMaskDataset(train, resize, tf, REF_LABELS),
        batch_size=8, shuffle=True,  num_workers=2, pin_memory=True)
    val_loader = DataLoader(
        TwoViewMaskDataset(val,   resize, tf, REF_LABELS),
        batch_size=8, shuffle=False, num_workers=2, pin_memory=True)

    # training settings
    CHECKPOINTS = [20, 25, 30, 35]
    NUM_RUNS    = 3
    EPOCHS      = 35

    # prepare accumulators for metrics and bucket MAE
    overall_results = {ep: {k: [] for k in compute_metrics([0],[0])} for ep in CHECKPOINTS}
    bucket_results  = {ep: {0:[], 1:[], 2:[]} for ep in CHECKPOINTS}

    for run in range(1, NUM_RUNS+1):
        print(f"\n===== RUN {run} =====")
        model     = MultiBranchNoProj(len(REF_LABELS)).to(DEVICE)
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        mse_loss  = nn.MSELoss()

        # train for EPOCHS
        for epoch in range(1, EPOCHS+1):
            model.train()
            for rgb1,rgb2,hm1,hm2,rm1,rm2,bmi in train_loader:
                rgb1,rgb2,hm1,hm2,rm1,rm2,bmi = [
                    x.to(DEVICE) for x in (rgb1,rgb2,hm1,hm2,rm1,rm2,bmi)
                ]
                labels = torch.tensor(
                    [get_bmi_class_new(b.item()) for b in bmi.flatten()],
                    dtype=torch.long, device=DEVICE
                )
                optimizer.zero_grad()
                out_reg, out_cls = model(rgb1,rgb2,hm1,hm2,rm1,rm2)
                loss = mse_loss(out_reg, bmi)
                loss.backward()
                optimizer.step()

            # evaluate at checkpoints
            if epoch in CHECKPOINTS:
                model.eval()
                all_preds, all_gts, all_buckets = [], [], []
                with torch.no_grad():
                    for rgb1,rgb2,hm1,hm2,rm1,rm2,bmi in val_loader:
                        rgb1,rgb2,hm1,hm2,rm1,rm2 = [
                            x.to(DEVICE) for x in (rgb1,rgb2,hm1,hm2,rm1,rm2)
                        ]
                        out_reg, _ = model(rgb1,rgb2,hm1,hm2,rm1,rm2)
                        preds = out_reg.cpu().flatten().tolist()
                        gts   = bmi.flatten().tolist()
                        buckets = [get_bmi_class_new(float(x)) for x in gts]
                        all_preds   += preds
                        all_gts     += gts
                        all_buckets += buckets

                # compute overall metrics
                metrics = compute_metrics(all_gts, all_preds)
                print(f"\n--- Epoch {epoch} metrics ---")
                for k,v in metrics.items():
                    print(f"{k}: {v:.4f}", end="  ")
                print()

                # append to overall results
                for k,v in metrics.items():
                    overall_results[epoch][k].append(v)

                # compute and append bucket MAEs
                errors = np.abs(np.array(all_preds) - np.array(all_gts))
                arr_buckets = np.array(all_buckets)
                for b in (0,1,2):
                    mask = (arr_buckets == b)
                    mae = errors[mask].mean() if mask.sum()>0 else float('nan')
                    bucket_results[epoch][b].append(mae)

                # visualize for this run/checkpoint
                names = ["Underweight","Normal","Overweight"]
                bucket_mae = [bucket_results[epoch][b][-1] for b in (0,1,2)]
                plt.figure(figsize=(6,4))
                plt.bar(names, bucket_mae)
                plt.ylabel("MAE")
                plt.title(f"MAE by BMI Bucket @ epoch {epoch}")
                plt.show()

    # at end, print mean±std for both overall and buckets
    print("\n===== OVERALL METRIC SUMMARY =====")
    for ep in CHECKPOINTS:
        print(f"\nCheckpoint {ep}:")
        for k, vals in overall_results[ep].items():
            arr = np.array(vals)
            print(f"  {k.upper():12s} {arr.mean():.4f} ± {arr.std():.4f}")
        print("  BUCKET MAE:")
        for b,name in zip((0,1,2), ["Underweight","Normal","Overweight"]):
            arr = np.array(bucket_results[ep][b])
            print(f"    {name:12s}: {arr.mean():.4f} ± {arr.std():.4f}")
