In [None]:
!pip install 'git+https://github.com/facebookresearch/detectron2.git'


In [None]:
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

import os
import random
import json
import time
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from dataclasses import dataclass
from typing import List, Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import mean_absolute_error, r2_score
from scipy.stats import ks_2samp, wasserstein_distance

# Mount Google Drive (ensure JSON files and images are stored in Drive)
from google.colab import drive
drive.mount('/content/drive')

# -------------------------------
# Global device configuration
# -------------------------------
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Current device:", DEVICE)

# -------------------------------
# 1. Load the Detectron2 model (for human detection)
# -------------------------------
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # Detection threshold
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
person_predictor = DefaultPredictor(cfg)

# -------------------------------
# 2. Load the Grounding DINO model (for reference object detection without using SAM)
# -------------------------------
from transformers import pipeline
# Initialize globally to avoid repeated instantiation
GLOBAL_GROUNDING_DETECTOR = pipeline(
    model="IDEA-Research/grounding-dino-tiny",
    task="zero-shot-object-detection",
    device=0 if torch.cuda.is_available() else -1
)
print("Grounding Dino global initialization complete.")

# -------------------------------
# 3. Define data structures
# -------------------------------
@dataclass
class BoundingBox:
    xmin: int
    ymin: int
    xmax: int
    ymax: int

    @property
    def xyxy(self) -> List[float]:
        return [self.xmin, self.ymin, self.xmax, self.ymax]

@dataclass
class DetectionResult:
    score: float
    label: str
    box: BoundingBox
    mask: Optional[np.ndarray] = None   # SAM mask is not used here

    @classmethod
    def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
        return cls(
            score=detection_dict['score'],
            label=detection_dict['label'],
            box=BoundingBox(
                xmin=int(detection_dict['box']['xmin']),
                ymin=int(detection_dict['box']['ymin']),
                xmax=int(detection_dict['box']['xmax']),
                ymax=int(detection_dict['box']['ymax'])
            )
        )

# -------------------------------
# Reference object detection function (using Grounding DINO only, without SAM)
# -------------------------------
def grounded_segmentation(image: Union[Image.Image, str],
                          labels: List[str],
                          threshold: float = 0.7,
                          polygon_refinement: bool = False,
                          detector_id: Optional[str] = None) -> Tuple[np.ndarray, List[DetectionResult]]:
    """
    Use Grounding DINO to detect specified reference objects in the image (without calling SAM),
    returning the EXIF-corrected image (as a numpy array) and a list of detection results.
    The detection results contain only bounding boxes (mask is always None).
    """
    if isinstance(image, str):
        image = Image.open(image).convert("RGB")
    image = ImageOps.exif_transpose(image)
    labels = [lb if lb.endswith('.') else lb + '.' for lb in labels]
    results = GLOBAL_GROUNDING_DETECTOR(image, candidate_labels=labels, threshold=threshold)
    detections = [DetectionResult.from_dict(r) for r in results]
    return np.array(image), detections

# -------------------------------
# 4. "View Feature Extraction" Function (2-view, reference feature output dimension 5*2=10)
# -------------------------------
def compute_view_features(img_obj: Image.Image, ref_labels: List[str]) -> np.ndarray:
    """
    For an image:
      - Use Grounding DINO to detect reference objects (without calling SAM); if detected, set flag=1,
        and area ratio = (bounding box area)/(image area); otherwise, both are 0.
      - Use Detectron2 to detect humans; if a mask is available, compute area ratio using the mask,
        otherwise use the bounding box area.
    Returns a 5-dimensional feature vector.
    """
    W, H = img_obj.width, img_obj.height
    # Reference object detection (using DINO only, returning bounding boxes)
    _, detections = grounded_segmentation(img_obj, ref_labels, threshold=0.7, detector_id="IDEA-Research/grounding-dino-tiny")
    feats_ref = []
    for label in ref_labels:
        matched = [d for d in detections if d.label.strip('.').lower() == label.strip('.').lower()]
        if matched:
            best = max(matched, key=lambda d: d.score)
            box = best.box.xyxy
            bbox_area = (box[2] - box[0]) * (box[3] - box[1])
            ratio = bbox_area / (W * H)
            flag = 1.0
        else:
            flag = 0.0
            ratio = 0.0
        feats_ref.extend([flag, ratio])
    # Human detection using Detectron2
    img_np = np.array(img_obj)
    img_cv = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    outputs = person_predictor(img_cv)
    instances = outputs["instances"]
    human_ratio = 0.0
    if len(instances) > 0:
        person_idxs = (instances.pred_classes == 0).nonzero().squeeze()
        if person_idxs.numel() > 0:
            if person_idxs.ndim == 0:
                idx = person_idxs.item()
            else:
                person_scores = instances.scores[person_idxs]
                idx = person_idxs[person_scores.argmax()]
            if instances.has("pred_masks") and len(instances.pred_masks) > idx:
                person_mask = instances.pred_masks[idx].cpu().numpy().astype(np.uint8)
                human_ratio = float(np.sum(person_mask)) / (W * H)
            else:
                bbox = instances.pred_boxes[idx].tensor.cpu().numpy()[0]
                bbox_area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
                human_ratio = bbox_area / (W * H)
    feats = np.array(feats_ref + [human_ratio], dtype=np.float32)
    return feats  # Output is a 5-dimensional feature vector

# -------------------------------
# 5. Offline Precomputation of Features (each sample has 2 views → total dimension 10)
# -------------------------------
def offline_precompute_features_chunked(samples, cache_path, ref_labels=["a black paper", "a white ball"], chunk_size=30):
    """
    For each sample (requiring at least 2 images), compute the 5-dimensional feature for each image,
    then for each sample, take the first 2 view features and concatenate them into a 10-dimensional feature.
    The result is saved to cache_path.
    """
    print(f"\n[Offline Precomputation - Chunked] Total samples: {len(samples)}")
    print(f"Cache file: {cache_npy}")
    print(f"chunk_size={chunk_size}, ref_labels={ref_labels}\n")
    feat_dim_per_img = 5
    total_feat_dim = feat_dim_per_img * 2  # 2 views → 10 dimensions
    feature_cache = np.zeros((len(samples), total_feat_dim), dtype=np.float32)
    num_samples = len(samples)
    start_idx = 0
    while start_idx < num_samples:
        end_idx = min(start_idx + chunk_size, num_samples)
        print(f"=== Processing sample indices [{start_idx}, {end_idx}) ===")
        chunk = samples[start_idx:end_idx]
        chunk_img_paths = []
        index_map = []  # (local index within the chunk, which image) -- only use the first 2 views
        for local_i, s in enumerate(chunk):
            paths = s["img_paths"]
            if len(paths) < 2:
                raise ValueError(f"Each sample must have at least 2 images, but found {len(paths)} in {paths}")
            for j in range(2):
                chunk_img_paths.append(paths[j])
                index_map.append((local_i, j))
        chunk_images = []
        for p in chunk_img_paths:
            if not os.path.exists(p):
                print(f"❌ Image not found: {p}")
                chunk_images.append(None)
            else:
                try:
                    img = Image.open(p).convert("RGB")
                    chunk_images.append(img)
                except Exception as e:
                    print(f"❌ Unable to open image: {p} | Error: {e}")
                    chunk_images.append(None)
        chunk_sample_feats = [[None, None] for _ in range(len(chunk))]
        for global_img_idx, (local_i, which_img) in enumerate(index_map):
            img_obj = chunk_images[global_img_idx]
            if img_obj is None:
                chunk_sample_feats[local_i][which_img] = np.zeros(feat_dim_per_img, dtype=np.float32)
            else:
                feats = compute_view_features(img_obj, ref_labels)
                chunk_sample_feats[local_i][which_img] = feats
        for local_i in range(len(chunk)):
            feat_img0 = chunk_sample_feats[local_i][0]
            feat_img1 = chunk_sample_feats[local_i][1]
            final_feat = np.concatenate([feat_img0, feat_img1], axis=0)  # 5+5=10 dimensions
            real_sample_idx = start_idx + local_i
            feature_cache[real_sample_idx] = final_feat
        del chunk_images
        start_idx = end_idx
    np.save(cache_path, feature_cache)
    print(f"\n[Chunk Extraction Complete] Features saved to: {cache_path}")

# -------------------------------
# 6. Datasets and Model (2-view version)
# -------------------------------
class LetterboxResize:
    def __init__(self, size=(224, 224), fill=(0.485, 0.456, 0.406)):
        self.size = size
        self.fill = fill
    def __call__(self, img):
        iw, ih = img.size
        w, h = self.size
        scale = min(w / iw, h / ih)
        nw, nh = int(iw * scale), int(ih * scale)
        img = img.resize((nw, nh), Image.BICUBIC)
        new_img = Image.new("RGB", self.size, tuple([int(c * 255) for c in self.fill]))
        new_img.paste(img, ((w - nw) // 2, (h - nh) // 2))
        return new_img

class MultiViewBMIWithRefDataset(Dataset):
    def __init__(self, samples, feature_cache, transform=None):
        self.samples = samples
        self.feature_cache = feature_cache  # shape: (number of samples, 10)
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        img_paths = sample["img_paths"]
        bmi = float(sample["bmi"])
        imgs = []
        # Only use the first 2 views
        for j in range(2):
            p = img_paths[j]
            if not os.path.exists(p):
                img = Image.new("RGB", (224, 224), (0, 0, 0))
            else:
                try:
                    img = Image.open(p).convert("RGB")
                except Exception as e:
                    print(f"❌ [Dataset] Unable to open image: {p} | Error: {e}")
                    img = Image.new("RGB", (224, 224), (0, 0, 0))
            if self.transform is not None:
                img = self.transform(img)
            imgs.append(img)
        ref_feats_np = self.feature_cache[idx]  # 10 dimensions
        ref_feats_tensor = torch.tensor(ref_feats_np, dtype=torch.float32)
        bmi_tensor = torch.tensor(bmi, dtype=torch.float32)
        return imgs[0], imgs[1], ref_feats_tensor, bmi_tensor

class FusedMultiViewBMIModel(nn.Module):
    def __init__(self, ref_dim=256, num_classes=4):
        super().__init__()
        # Load ResNet101 (excluding the fully connected layers)
        base = models.resnet101(weights=None)
        resnet_path = "/content/drive/MyDrive/resnet101-63fe2227.pth"
        state_dict = torch.load(resnet_path, map_location=DEVICE)
        base.load_state_dict(state_dict)
        self.backbone = nn.Sequential(*list(base.children())[:-1])
        # Reference feature projection: input is 10 dimensions, output is ref_dim
        self.ref_projection = nn.Sequential(
            nn.Linear(10, ref_dim),
            nn.ReLU()
        )
        # 2 views: each with 2048 dimensions; concatenated yields 4096, then add ref_dim
        self.fc_reg = nn.Linear(4096 + ref_dim, 1)
        self.fc_cls = nn.Linear(4096 + ref_dim, num_classes)
    def forward(self, img0, img1, ref_features):
        x0 = self.backbone(img0).flatten(1)
        x1 = self.backbone(img1).flatten(1)
        fused_cnn = torch.cat([x0, x1], dim=1)  # 4096 dimensions
        ref_proj = self.ref_projection(ref_features)
        fused_all = torch.cat([fused_cnn, ref_proj], dim=1)
        out_reg = self.fc_reg(fused_all)
        out_cls = self.fc_cls(fused_all)
        return out_reg, out_cls

# -------------------------------
# Main Process: Offline feature precomputation → Data loading → Model training and validation
# -------------------------------
if __name__ == "__main__":
    sample_json = "/content/drive/MyDrive/sample_list.json"
    cache_npy = "/content/drive/MyDrive/features_cache_2view.npy"  # Offline features: 10 dimensions
    ref_labels = ["a black paper", "a white ball"]

    with open(sample_json, "r") as f:
        samples = json.load(f)

    # Check if the npy cache file exists; if so, skip recomputation
    if not os.path.exists(cache_npy):
        print(f"Cache file {cache_npy} does not exist, starting offline feature computation...")
        offline_precompute_features_chunked(
            samples=samples,
            cache_path=cache_npy,
            ref_labels=ref_labels,
            chunk_size=30
        )
    else:
        print(f"Cache file {cache_npy} found, skipping offline feature computation.")
    feature_cache = np.load(cache_npy)  # shape: (number of samples, 10)

    train_samples = [s for s in samples if s["split"] == "Training"]
    val_samples = [s for s in samples if s["split"] == "Validation"]
    train_indices = [i for i, s in enumerate(samples) if s["split"] == "Training"]
    val_indices = [i for i, s in enumerate(samples) if s["split"] == "Validation"]

    train_features = feature_cache[train_indices]
    val_features = feature_cache[val_indices]

    transform = transforms.Compose([
        LetterboxResize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = MultiViewBMIWithRefDataset(train_samples, train_features, transform=transform)
    val_dataset = MultiViewBMIWithRefDataset(val_samples, val_features, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)

    print(f"Train Dataset Size: {len(train_dataset)}")
    print(f"Val Dataset Size: {len(val_dataset)}")
    print(f"Train Loader Batch Count: {len(train_loader)}")
    print(f"Val Loader Batch Count: {len(val_loader)}")

    # New training strategy: train continuously for each experiment until max_epoch (max checkpoint epoch),
    # recording validation metrics at checkpoints (20, 25, 30, 35)
    checkpoint_epochs = [20, 25, 30, 35]
    max_epoch = max(checkpoint_epochs)
    num_runs = 5

    # 'results' stores metrics from multiple experiments for each checkpoint
    results = {ep: {'r2': [], 'mae': [], 'tol_rate': [], 'mean_bias': [], 'error_std': [], 'ks': [], 'wasserstein': []}
               for ep in checkpoint_epochs}

    for run in range(num_runs):
        print(f"\n==== Running experiment {run+1} ====")
        model = FusedMultiViewBMIModel(ref_dim=256, num_classes=4).to(DEVICE)
        criterion_reg = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        # Train continuously for max_epoch epochs in each experiment
        for epoch in range(1, max_epoch + 1):
            start_time = time.time()
            model.train()
            running_loss = 0.0
            for (img0, img1, ref_feats, bmi) in train_loader:
                img0, img1 = img0.to(DEVICE), img1.to(DEVICE)
                ref_feats = ref_feats.to(DEVICE)
                bmi = bmi.to(DEVICE).unsqueeze(1)
                optimizer.zero_grad()
                out_reg, out_cls = model(img0, img1, ref_feats)
                loss = criterion_reg(out_reg, bmi)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * bmi.size(0)
            epoch_loss = running_loss / len(train_dataset)
            print(f"[Run {run+1}] Epoch {epoch}/{max_epoch}, Loss: {epoch_loss:.4f}, Time: {time.time()-start_time:.2f}s")

            # If the current epoch is a checkpoint, perform validation and record metrics
            if epoch in checkpoint_epochs:
                model.eval()
                all_preds, all_targets = [], []
                with torch.no_grad():
                    for (img0, img1, ref_feats, bmi) in val_loader:
                        img0, img1 = img0.to(DEVICE), img1.to(DEVICE)
                        ref_feats = ref_feats.to(DEVICE)
                        bmi = bmi.to(DEVICE).unsqueeze(1)
                        out_reg, _ = model(img0, img1, ref_feats)
                        all_preds.extend(out_reg.cpu().numpy().flatten().tolist())
                        all_targets.extend(bmi.cpu().numpy().flatten().tolist())
                gt_arr = np.array(all_targets)
                pred_arr = np.array(all_preds)
                try:
                    r2_val = r2_score(gt_arr, pred_arr)
                except Exception as e:
                    r2_val = 0.0
                mae_val = mean_absolute_error(gt_arr, pred_arr)
                tolerance = 1.0
                tol_rate = np.mean(np.abs(gt_arr - pred_arr) <= tolerance)
                error = pred_arr - gt_arr
                mean_bias = np.mean(error)
                error_std = np.std(error)
                try:
                    ks_val, _ = ks_2samp(gt_arr, pred_arr)
                    wass_val = wasserstein_distance(gt_arr, pred_arr)
                except Exception as e:
                    ks_val = wass_val = 0.0

                print(f"\n[Run {run+1}] Epoch {epoch} Validation Metrics:")
                print(f"R²: {r2_val:.4f}")
                print(f"MAE: {mae_val:.4f}")
                print(f"Tolerance Rate (±{tolerance}): {tol_rate*100:.2f}%")
                print(f"Mean Bias: {mean_bias:.4f}")
                print(f"Error Std: {error_std:.4f}")
                print(f"KS: {ks_val:.4f}")
                print(f"Wasserstein: {wass_val:.4f}\n")

                # Save the metrics for this checkpoint
                results[epoch]['r2'].append(r2_val)
                results[epoch]['mae'].append(mae_val)
                results[epoch]['tol_rate'].append(tol_rate)
                results[epoch]['mean_bias'].append(mean_bias)
                results[epoch]['error_std'].append(error_std)
                results[epoch]['ks'].append(ks_val)
                results[epoch]['wasserstein'].append(wass_val)

                # Optional: Plot validation result figures
                plt.figure(figsize=(8, 6))
                plt.scatter(gt_arr, pred_arr, alpha=0.6, label="Predictions")
                cp_min = min(np.min(gt_arr), np.min(pred_arr))
                cp_max = max(np.max(gt_arr), np.max(pred_arr))
                plt.plot([cp_min, cp_max], [cp_min, cp_max], "r--", label="Ideal")
                plt.xlabel("Ground Truth BMI")
                plt.ylabel("Predicted BMI")
                plt.title(f"Run {run+1} Epoch {epoch}: Scatter Plot")
                plt.legend()
                plt.show()

                bins = np.linspace(cp_min, cp_max, 20)
                plt.figure(figsize=(8, 6))
                plt.hist(gt_arr, bins=bins, alpha=0.5, label="Ground Truth")
                plt.hist(pred_arr, bins=bins, alpha=0.5, label="Predictions")
                plt.xlabel("BMI")
                plt.ylabel("Count")
                plt.title(f"Run {run+1} Epoch {epoch}: Distribution Histogram")
                plt.legend()
                plt.show()

    # Output the aggregated statistics over all runs for each checkpoint
    print("\n========== Overall Results Summary ==========")
    for ep in sorted(results.keys()):
        print(f"\nCheckpoint Epoch: {ep}")
        metrics = results[ep]
        for key in metrics:
            m = np.mean(metrics[key])
            s = np.std(metrics[key])
            print(f"{key.upper()}: Mean = {m:.4f}, Std = {s:.4f}")
