In [None]:
!pip install 'git+https://github.com/facebookresearch/detectron2.git'


In [None]:
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

import os
import random
import json
import time
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from dataclasses import dataclass
from typing import List, Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import mean_absolute_error, r2_score
from scipy.stats import ks_2samp, wasserstein_distance
from transformers import pipeline
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Device setup
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("DEVICE:", DEVICE)

# Configure Detectron2 for person detection
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(
    "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
    "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
person_predictor = DefaultPredictor(cfg)

# Initialize Grounding DINO pipeline
GLOBAL_GROUNDING_DETECTOR = pipeline(
    model="IDEA-Research/grounding-dino-tiny",
    task="zero-shot-object-detection",
    device=0 if torch.cuda.is_available() else -1
)

@dataclass
class BoundingBox:
    xmin: int
    ymin: int
    xmax: int
    ymax: int

    @property
    def xyxy(self) -> List[float]:
        return [self.xmin, self.ymin, self.xmax, self.ymax]

@dataclass
class DetectionResult:
    score: float
    label: str
    box: BoundingBox
    mask: Optional[np.ndarray] = None

    @classmethod
    def from_dict(cls, d: Dict) -> 'DetectionResult':
        return cls(
            score=d['score'],
            label=d['label'],
            box=BoundingBox(
                xmin=int(d['box']['xmin']),
                ymin=int(d['box']['ymin']),
                xmax=int(d['box']['xmax']),
                ymax=int(d['box']['ymax'])
            )
        )

def grounded_segmentation(image: Union[Image.Image, str],
                          labels: List[str],
                          threshold: float = 0.7) -> Tuple[np.ndarray, List[DetectionResult]]:
    if isinstance(image, str):
        image = Image.open(image).convert("RGB")
    image = ImageOps.exif_transpose(image)
    labels = [lbl if lbl.endswith('.') else lbl + '.' for lbl in labels]
    results = GLOBAL_GROUNDING_DETECTOR(image, candidate_labels=labels, threshold=threshold)
    detections = [DetectionResult.from_dict(r) for r in results]
    return np.array(image), detections

def compute_view_features(img: Image.Image, ref_labels: List[str]) -> np.ndarray:
    W, H = img.width, img.height
    _, detections = grounded_segmentation(img, ref_labels, threshold=0.7)
    feats_ref = []
    for label in ref_labels:
        matched = [d for d in detections if d.label.strip('.').lower() == label.strip('.').lower()]
        if matched:
            best = max(matched, key=lambda d: d.score)
            bx, by, xx, yy = best.box.xyxy
            area = (xx - bx) * (yy - by)
            feats_ref += [1.0, area / (W * H)]
        else:
            feats_ref += [0.0, 0.0]
    img_np = np.array(img)
    outputs = person_predictor(cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR))
    inst = outputs["instances"]
    human_ratio = 0.0
    if len(inst) > 0:
        idxs = (inst.pred_classes == 0).nonzero().squeeze()
        if idxs.numel() > 0:
            if idxs.ndim == 0:
                i = idxs.item()
            else:
                scores = inst.scores[idxs]
                i = idxs[scores.argmax()]
            if inst.has("pred_masks") and len(inst.pred_masks) > i:
                m = inst.pred_masks[i].cpu().numpy().astype(np.uint8)
                human_ratio = m.sum() / (W * H)
            else:
                bx0, by0, bx1, by1 = inst.pred_boxes[i].tensor.cpu().numpy()[0]
                human_ratio = ((bx1 - bx0) * (by1 - by0)) / (W * H)
    return np.array(feats_ref + [human_ratio], dtype=np.float32)

def offline_precompute_features_chunked(samples: List[dict],
                                        cache_path: str,
                                        ref_labels: List[str] = ["a black paper", "a white ball"],
                                        chunk_size: int = 30):
    feat_dim = 5 * 3
    cache = np.zeros((len(samples), feat_dim), dtype=np.float32)
    start = 0
    while start < len(samples):
        end = min(start + chunk_size, len(samples))
        chunk = samples[start:end]
        paths = []
        index_map = []
        for i, s in enumerate(chunk):
            for j in range(3):
                paths.append(s["img_paths"][j])
                index_map.append((i, j))
        images = []
        for p in paths:
            if os.path.exists(p):
                try:
                    images.append(Image.open(p).convert("RGB"))
                except:
                    images.append(None)
            else:
                images.append(None)
        feats = [[None]*3 for _ in chunk]
        for img, (i, j) in zip(images, index_map):
            if img is None:
                feats[i][j] = np.zeros(5, dtype=np.float32)
            else:
                feats[i][j] = compute_view_features(img, ref_labels)
        for i in range(len(chunk)):
            cache[start + i] = np.concatenate(feats[i], axis=0)
        start = end
    np.save(cache_path, cache)
    print("Features cached to", cache_path)

class LetterboxResize:
    def __init__(self, size=(224,224), fill=(0.485,0.456,0.406)):
        self.size, self.fill = size, fill
    def __call__(self, img):
        img = ImageOps.exif_transpose(img)
        iw, ih = img.size
        w, h = self.size
        scale = min(w/iw, h/ih)
        nw, nh = int(iw*scale), int(ih*scale)
        img = img.resize((nw,nh), Image.BICUBIC)
        bg = Image.new("RGB", self.size, tuple(int(c*255) for c in self.fill))
        bg.paste(img, ((w-nw)//2, (h-nh)//2))
        return bg

class MultiViewBMIWithRefDataset(Dataset):
    def __init__(self, samples, feats, transform=None):
        self.samples, self.feats, self.transform = samples, feats, transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        s = self.samples[idx]
        imgs = []
        for j in range(3):
            p = s["img_paths"][j]
            if os.path.exists(p):
                try:
                    img = Image.open(p).convert("RGB")
                except:
                    img = Image.new("RGB", (224,224), (0,0,0))
            else:
                img = Image.new("RGB", (224,224), (0,0,0))
            if self.transform:
                img = self.transform(img)
            imgs.append(img)
        ref = torch.tensor(self.feats[idx], dtype=torch.float32)
        bmi = torch.tensor(float(s["bmi"]), dtype=torch.float32)
        return imgs[0], imgs[1], imgs[2], ref, bmi

class FusedMultiViewBMIModel(nn.Module):
    def __init__(self, ref_dim=256, num_classes=4):
        super().__init__()
        base = models.resnet101(weights=None)
        base.load_state_dict(torch.load(
            "/content/drive/MyDrive/resnet101-63fe2227.pth",
            map_location=DEVICE))
        self.backbone = nn.Sequential(*list(base.children())[:-1])
        self.ref_proj = nn.Sequential(nn.Linear(15, ref_dim), nn.ReLU())
        self.fc_reg = nn.Linear(2048*3 + ref_dim, 1)
        self.fc_cls = nn.Linear(2048*3 + ref_dim, num_classes)
    def forward(self, x0, x1, x2, ref):
        def ext(x):
            y = self.backbone(x)
            return y.view(y.size(0), -1)
        f0, f1, f2 = ext(x0), ext(x1), ext(x2)
        cnn = torch.cat([f0, f1, f2], dim=1)
        r = self.ref_proj(ref)
        allf = torch.cat([cnn, r], dim=1)
        return self.fc_reg(allf), self.fc_cls(allf)

if __name__ == "__main__":
    sample_json = "/content/drive/MyDrive/sample_list.json"
    cache_npy   = "/content/drive/MyDrive/features_cache_3view.npy"
    ref_labels  = ["a black paper", "a white ball"]

    with open(sample_json) as f:
        samples = json.load(f)

    if not os.path.exists(cache_npy):
        offline_precompute_features_chunked(samples, cache_npy, ref_labels, chunk_size=30)
    feats = np.load(cache_npy)

    train_s = [s for s in samples if s["split"]=="Training"]
    val_s   = [s for s in samples if s["split"]=="Validation"]
    train_idx = [i for i, s in enumerate(samples) if s["split"]=="Training"]
    val_idx   = [i for i, s in enumerate(samples) if s["split"]=="Validation"]
    train_f = feats[train_idx]
    val_f   = feats[val_idx]

    transform = transforms.Compose([
        LetterboxResize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    train_loader = DataLoader(
        MultiViewBMIWithRefDataset(train_s, train_f, transform),
        batch_size=8, shuffle=True, num_workers=0, pin_memory=True
    )
    val_loader = DataLoader(
        MultiViewBMIWithRefDataset(val_s, val_f, transform),
        batch_size=8, shuffle=False, num_workers=0, pin_memory=True
    )

    print(f"Train: {len(train_loader.dataset)}, Val: {len(val_loader.dataset)}")

    checkpoint_epochs = [20,25,30,35]
    max_epoch = max(checkpoint_epochs)
    num_runs  = 5

    results = {ep: {k:[] for k in ['r2','mae','tol_rate','mean_bias','error_std','ks','wasserstein']}
               for ep in checkpoint_epochs}

    for run in range(1, num_runs+1):
        print(f"\n=== Run {run}/{num_runs} ===")
        model = FusedMultiViewBMIModel(ref_dim=256, num_classes=4).to(DEVICE)
        opt   = optim.Adam(model.parameters(), lr=1e-4)
        loss_fn = nn.MSELoss()

        for epoch in range(1, max_epoch+1):
            t0, tloss = time.time(), 0.0
            model.train()
            for x0, x1, x2, ref, bmi in train_loader:
                x0, x1, x2 = x0.to(DEVICE), x1.to(DEVICE), x2.to(DEVICE)
                ref, bmi   = ref.to(DEVICE), bmi.to(DEVICE).unsqueeze(1)
                opt.zero_grad()
                pred, _ = model(x0, x1, x2, ref)
                loss = loss_fn(pred, bmi)
                loss.backward()
                opt.step()
                tloss += loss.item() * bmi.size(0)
            print(f"Epoch {epoch}/{max_epoch} loss={tloss/len(train_loader.dataset):.4f} time={time.time()-t0:.1f}s")

            if epoch in checkpoint_epochs:
                model.eval()
                preds, trues = [], []
                with torch.no_grad():
                    for x0, x1, x2, ref, bmi in val_loader:
                        x0, x1, x2 = x0.to(DEVICE), x1.to(DEVICE), x2.to(DEVICE)
                        ref, bmi   = ref.to(DEVICE), bmi.to(DEVICE).unsqueeze(1)
                        p, _ = model(x0, x1, x2, ref)
                        preds += p.cpu().flatten().tolist()
                        trues += bmi.cpu().flatten().tolist()
                gt = np.array(trues)
                pr = np.array(preds)
                r2v  = r2_score(gt, pr)
                maev = mean_absolute_error(gt, pr)
                tol  = np.mean(np.abs(gt-pr) <= 1.0)
                bias = np.mean(pr-gt)
                std  = np.std(pr-gt)
                ks_v = ks_2samp(gt, pr)[0]
                ws   = wasserstein_distance(gt, pr)
                print(f"[cp {epoch}] R2={r2v:.4f}, MAE={maev:.4f}, tol_rate={tol:.2%}")
                res = results[epoch]
                res['r2'].append(r2v)
                res['mae'].append(maev)
                res['tol_rate'].append(tol)
                res['mean_bias'].append(bias)
                res['error_std'].append(std)
                res['ks'].append(ks_v)
                res['wasserstein'].append(ws)

    print("\n=== Summary ===")
    for ep in checkpoint_epochs:
        print(f"\nEpoch {ep}:")
        for k, v in results[ep].items():
            print(f"  {k}: mean={np.mean(v):.4f}, std={np.std(v):.4f}")
