In [None]:
!pip install 'git+https://github.com/facebookresearch/detectron2.git'


In [None]:
#!/usr/bin/env python
# two_view_human_ref_res18_letterbox_with_cls_and_buckets.py
# Extends two_view_human_ref_res18_letterbox.py:
# adds classification loss weight loop, bucket‐MAE, cls_acc, visualizations, multi‐run and checkpoint summaries.

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os, json, cv2, numpy as np, torch
from typing import List, Dict
from PIL import Image, ImageOps
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from sklearn.metrics import r2_score, mean_absolute_error
from scipy.stats import ks_2samp, wasserstein_distance
import matplotlib.pyplot as plt
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from transformers import pipeline

# CONFIGURATION
BASE_DIR        = "/content/drive/MyDrive"
SAMPLE_JSON     = os.path.join(BASE_DIR, "sample_list.json")
RESNET18_PATH   = os.path.join(BASE_DIR, "resnet18-5c106cde.pth")
HUMAN_NPZ_PATH  = os.path.join(BASE_DIR, "human_mask_cache_224.npz")
REF_NPZ_PATH    = os.path.join(BASE_DIR, "ref_mask_cache_224.npz")
REF_LABELS      = ["a white ball", "a dark paper"]
DEVICE          = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
OUT_W, OUT_H    = 224, 224

def letterbox_mask(mask: np.ndarray, new_size=(OUT_W, OUT_H)) -> np.ndarray:
    h, w = mask.shape
    nw, nh = new_size
    scale = min(nw/w, nh/h)
    rw, rh = int(w*scale), int(h*scale)
    resized = cv2.resize(mask, (rw, rh), interpolation=cv2.INTER_NEAREST)
    top, left = (nh-rh)//2, (nw-rw)//2
    out = np.zeros((nh, nw), dtype=mask.dtype)
    out[top:top+rh, left:left+rw] = resized
    return out

def build_mask_caches(samples: List[Dict], human_pred: DefaultPredictor, ref_det, ref_labels: List[str]):
    if os.path.exists(HUMAN_NPZ_PATH) and os.path.exists(REF_NPZ_PATH):
        return
    human_cache, ref_cache = {}, {}
    for s in tqdm(samples, desc="Building mask cache"):
        for p in s["img_paths"][:2]:
            if p in human_cache or not os.path.exists(p):
                continue
            img = Image.open(p).convert("RGB")
            img = ImageOps.exif_transpose(img)
            arr = np.array(img); H, W = arr.shape[:2]
            inst = human_pred(cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))["instances"]
            hm = np.zeros((H, W), dtype=np.uint8)
            if len(inst):
                cls = inst.pred_classes.cpu().numpy()
                ids = np.where(cls==0)[0]
                if ids.size:
                    i = ids[inst.scores[ids].argmax()]
                    if inst.has("pred_masks"):
                        hm = inst.pred_masks[i].cpu().numpy().astype(np.uint8)
                    else:
                        y0,x0,y1,x1 = inst.pred_boxes[i].tensor.cpu().numpy()[0]
                        hm[int(y0):int(y1),int(x0):int(x1)] = 1
            human_cache[p] = letterbox_mask(hm)
            dets = ref_det(img, candidate_labels=ref_labels, threshold=0.7)
            mapping = {d["label"].rstrip('.').lower(): d for d in dets}
            for lbl in ref_labels:
                key = f"{p}|{lbl}"
                rm = np.zeros((H, W), dtype=np.uint8)
                d = mapping.get(lbl)
                if d:
                    x0,y0 = int(d["box"]["xmin"]), int(d["box"]["ymin"])
                    x1,y1 = int(d["box"]["xmax"]), int(d["box"]["ymax"])
                    rm[y0:y1, x0:x1] = 1
                ref_cache[key] = letterbox_mask(rm)
    np.savez_compressed(HUMAN_NPZ_PATH, **human_cache)
    np.savez_compressed(REF_NPZ_PATH, **ref_cache)

class MaskOnlyDataset(Dataset):
    def __init__(self, samples: List[Dict], ref_labels: List[str]):
        self.samples, self.ref_labels = samples, ref_labels
        self.h_masks = dict(np.load(HUMAN_NPZ_PATH, allow_pickle=True))
        self.r_masks = dict(np.load(REF_NPZ_PATH, allow_pickle=True))
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        s = self.samples[idx]
        p1,p2 = s["img_paths"][:2]
        hm1 = torch.from_numpy(self.h_masks[p1]).unsqueeze(0).float()
        hm2 = torch.from_numpy(self.h_masks[p2]).unsqueeze(0).float()
        rm1 = np.stack([self.r_masks[f"{p1}|{lbl}"] for lbl in self.ref_labels],0)
        rm2 = np.stack([self.r_masks[f"{p2}|{lbl}"] for lbl in self.ref_labels],0)
        rm1,rm2 = torch.from_numpy(rm1).float(), torch.from_numpy(rm2).float()
        bmi = torch.tensor(float(s["bmi"]), dtype=torch.float32).unsqueeze(0)
        return hm1,hm2,rm1,rm2,bmi

def get_bmi_class_new(bmi: float) -> int:
    if bmi<17.28: return 0
    if bmi<25.57: return 1
    return 2

class TwoViewHumanRefRes18(nn.Module):
    def __init__(self, num_ref: int, num_classes: int=3):
        super().__init__()
        h18 = models.resnet18(weights=None)
        sd = torch.load(RESNET18_PATH, map_location="cpu", weights_only=False)
        h18.load_state_dict(sd, strict=False)
        w0 = h18.conv1.weight.data
        h18.conv1 = nn.Conv2d(1,64,7,2,3,bias=False)
        h18.conv1.weight.data = w0.mean(dim=1,keepdim=True)
        self.h_backbone = nn.Sequential(*list(h18.children())[:-1])
        r18 = models.resnet18(weights=None)
        r18.load_state_dict(sd, strict=False)
        w1 = r18.conv1.weight.data.mean(dim=1,keepdim=True)
        r18.conv1 = nn.Conv2d(num_ref,64,7,2,3,bias=False)
        r18.conv1.weight.data = w1.repeat(1,num_ref,1,1)
        self.r_backbone = nn.Sequential(*list(r18.children())[:-1])
        in_dim = (512+512)*2
        self.fc_reg = nn.Linear(in_dim,1)
        self.fc_cls = nn.Linear(in_dim,num_classes)
    def forward(self, hm1,hm2,rm1,rm2):
        def f(b,x): return b(x).flatten(1)
        h1,h2 = f(self.h_backbone,hm1), f(self.h_backbone,hm2)
        r1,r2 = f(self.r_backbone,rm1), f(self.r_backbone,rm2)
        v1,v2 = torch.cat([h1,r1],1), torch.cat([h2,r2],1)
        x = torch.cat([v1,v2],1)
        return self.fc_reg(x), self.fc_cls(x)

def compute_metrics(gt,pred):
    gt,pred = np.asarray(gt),np.asarray(pred)
    err = pred-gt
    ks,_ = ks_2samp(gt,pred)
    return {
        "r2": r2_score(gt,pred),
        "mae": mean_absolute_error(gt,pred),
        "tol_rate": np.mean(np.abs(err)<=1.0),
        "mean_bias": err.mean(),
        "error_std": err.std(),
        "ks": ks,
        "wasserstein": wasserstein_distance(gt,pred),
    }

if __name__=="__main__":
    with open(SAMPLE_JSON,"r") as f: samples=json.load(f)
    BUILD_CACHE=False
    if BUILD_CACHE:
        cfg=get_cfg()
        cfg.merge_from_file(model_zoo.get_config_file(
          "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST=0.5
        cfg.MODEL.WEIGHTS=model_zoo.get_checkpoint_url(
          "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
        human_pred=DefaultPredictor(cfg)
        ref_det=pipeline("zero-shot-object-detection",
            model="IDEA-Research/grounding-dino-tiny",
            device=0 if torch.cuda.is_available() else -1)
        build_mask_caches(samples,human_pred,ref_det,REF_LABELS)
        exit(0)

    train=[s for s in samples if s["split"]=="Training"]
    val  =[s for s in samples if s["split"]=="Validation"]
    train_loader=DataLoader(MaskOnlyDataset(train, REF_LABELS),
                            batch_size=8,shuffle=True,num_workers=2,pin_memory=True)
    val_loader  =DataLoader(MaskOnlyDataset(val,   REF_LABELS),
                            batch_size=8,shuffle=False,num_workers=2,pin_memory=True)

    CHECKPOINTS   =[20,25,30,35]
    NUM_RUNS      =3
    EPOCHS        =35
    weight_factors=[1,2,3,5,8]

    for wf in weight_factors:
        print(f"\n=== CLS weight factor = {wf} ===")
        metric_keys=list(compute_metrics([0],[0]).keys())+["cls_acc",
                     "bucket_0_mae","bucket_1_mae","bucket_2_mae"]
        results={ep:{k:[] for k in metric_keys} for ep in CHECKPOINTS}

        for run in range(1,NUM_RUNS+1):
            print(f"\n--- Run {run} ---")
            model=TwoViewHumanRefRes18(num_ref=len(REF_LABELS),num_classes=3).to(DEVICE)
            opt=optim.Adam(model.parameters(),lr=1e-4)
            mse=nn.MSELoss()
            cw=torch.tensor([wf,1.0,wf],device=DEVICE)
            ce=nn.CrossEntropyLoss(weight=cw)

            for ep in range(1,EPOCHS+1):
                model.train()
                for hm1,hm2,rm1,rm2,bmi in train_loader:
                    hm1,hm2,rm1,rm2,bmi=[x.to(DEVICE) for x in (hm1,hm2,rm1,rm2,bmi)]
                    labels=torch.tensor([get_bmi_class_new(b.item()) for b in bmi.flatten()],
                                        device=DEVICE,dtype=torch.long)
                    opt.zero_grad()
                    out_reg,out_cls=model(hm1,hm2,rm1,rm2)
                    loss=mse(out_reg,bmi)+ce(out_cls,labels)
                    loss.backward(); opt.step()

                if ep in CHECKPOINTS:
                    model.eval()
                    all_p,all_t,all_c=[],[],[]
                    with torch.no_grad():
                        for hm1,hm2,rm1,rm2,bmi in val_loader:
                            hm1,hm2,rm1,rm2=[x.to(DEVICE) for x in (hm1,hm2,rm1,rm2)]
                            p_reg,p_cls=model(hm1,hm2,rm1,rm2)
                            all_p+=p_reg.cpu().flatten().tolist()
                            all_t+=bmi.flatten().tolist()
                            all_c+=torch.argmax(p_cls,dim=1).cpu().tolist()

                    m=compute_metrics(all_t,all_p)
                    acc=np.mean(np.array(all_c)==np.array([get_bmi_class_new(x) for x in all_t]))
                    print(f"Epoch {ep}:",", ".join(f"{k}={v:.4f}" for k,v in m.items()),
                          f"cls_acc={acc*100:.2f}%")
                    for k,v in m.items(): results[ep][k].append(v)
                    results[ep]["cls_acc"].append(acc)

                    errors=np.abs(np.array(all_p)-np.array(all_t))
                    buckets=np.array([get_bmi_class_new(x) for x in all_t])
                    bucket_mae=[(errors[buckets==b].mean() if (buckets==b).sum()>0 else np.nan)
                                for b in (0,1,2)]
                    results[ep]["bucket_0_mae"].append(bucket_mae[0])
                    results[ep]["bucket_1_mae"].append(bucket_mae[1])
                    results[ep]["bucket_2_mae"].append(bucket_mae[2])

                    names=["Underweight","Normal","Overweight"]
                    plt.figure(figsize=(6,4))
                    plt.bar(names,bucket_mae)
                    plt.ylabel("MAE"); plt.title(f"Bucket MAE @ wf={wf}, ep={ep}")
                    plt.show()

                    data=[errors[buckets==b] for b in (0,1,2)]
                    plt.figure(figsize=(6,4))
                    plt.boxplot(data,labels=names)
                    plt.ylabel("Absolute Error"); plt.title(f"Error Dist by Bucket @ wf={wf}, ep={ep}")
                    plt.show()

                    plt.figure(figsize=(7,6))
                    for b,c in zip((0,1,2),("C0","C1","C2")):
                        mask=buckets==b
                        plt.scatter(np.array(all_t)[mask],np.array(all_p)[mask],
                                    alpha=0.6,label=names[b],c=c)
                    mval,Mval=min(all_t+all_p),max(all_t+all_p)
                    plt.plot([mval,Mval],[mval,Mval],"k--")
                    plt.xlabel("True BMI"); plt.ylabel("Pred BMI")
                    plt.legend(); plt.title(f"Pred vs True @ wf={wf}, ep={ep}")
                    plt.show()

        print(f"\n--- Summary for wf={wf} ---")
        for ep in CHECKPOINTS:
            print(f"\nCheckpoint {ep}:")
            for k,vals in results[ep].items():
                arr=np.array(vals)
                if k=="cls_acc":
                    print(f"  {k:12s}{arr.mean()*100:.2f}% ± {arr.std()*100:.2f}%")
                else:
                    print(f"  {k.upper():12s}{arr.mean():.4f} ± {arr.std():.4f}")
