In [None]:
#!/usr/bin/env python
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import r2_score, mean_absolute_error
from scipy.stats import ks_2samp, wasserstein_distance
import matplotlib.pyplot as plt

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LetterboxResize:
    def __init__(self, size=(224,224), fill=(0,0,0)):
        self.size, self.fill = size, fill
    def __call__(self, img):
        iw, ih = img.size
        w, h = self.size
        scale = min(w/iw, h/ih)
        nw, nh = int(iw*scale), int(ih*scale)
        img = img.resize((nw, nh), Image.BICUBIC)
        new = Image.new("RGB", self.size, self.fill)
        new.paste(img, ((w-nw)//2, (h-nh)//2))
        return new

class MultiViewBMIDataset3(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        s = self.samples[idx]
        imgs = []
        for p in s["img_paths"]:
            if os.path.exists(p):
                try:
                    img = Image.open(p).convert("RGB")
                except:
                    img = Image.new("RGB",(224,224),(0,0,0))
            else:
                img = Image.new("RGB",(224,224),(0,0,0))
            if self.transform:
                img = self.transform(img)
            imgs.append(img)
        bmi = float(s["bmi"])
        return imgs[0], imgs[1], imgs[2], torch.tensor(bmi, dtype=torch.float32)

def get_bmi_class(b):
    if b < 17.28: return 0
    if b < 25.57: return 1
    return 2

class FusedMultiViewBMIModel3(nn.Module):
    def __init__(self, view_indices):
        super().__init__()
        base = models.resnet101(weights=None)
        sd = torch.load("/content/drive/MyDrive/resnet101-63fe2227.pth", map_location=DEVICE)
        base.load_state_dict(sd)
        self.backbone = nn.Sequential(*list(base.children())[:-1])
        self.view_indices = view_indices
        dim = 2048 * len(view_indices)
        self.fc = nn.Linear(dim, 1)
    def forward(self, img1, img2, img3):
        imgs = [img1, img2, img3]
        feats = []
        for i in self.view_indices:
            x = self.backbone(imgs[i]).flatten(1)
            feats.append(x)
        fused = torch.cat(feats, dim=1)
        return self.fc(fused)

def compute_metrics(gt, pred):
    gt, pred = np.asarray(gt), np.asarray(pred)
    err = pred - gt
    return {
        "r2":          r2_score(gt,pred),
        "mae":         mean_absolute_error(gt,pred),
        "tol_rate":    np.mean(np.abs(err)<=1.0),
        "mean_bias":   err.mean(),
        "error_std":   err.std(),
        "ks":          ks_2samp(gt,pred)[0],
        "wasserstein": wasserstein_distance(gt,pred),
    }

if __name__ == "__main__":
    sample_json = "/content/drive/MyDrive/sample_list.json"
    with open(sample_json) as f:
        samples = json.load(f)
    train = [s for s in samples if s["split"]=="Training"]
    val   = [s for s in samples if s["split"]=="Validation"]

    transform = transforms.Compose([
        LetterboxResize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    train_ds = MultiViewBMIDataset3(train, transform=transform)
    val_ds   = MultiViewBMIDataset3(val,   transform=transform)
    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  num_workers=0, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False, num_workers=0, pin_memory=True)

    experiments = [
        ("Single [0]", [0]),
        ("Single [1]", [1]),
        ("Single [2]", [2]),
        ("Dual [0,1]", [0,1]),
        ("Dual [0,2]", [0,2]),
        ("Dual [1,2]", [1,2]),
        ("Triple [0,1,2]", [0,1,2]),
    ]

    EPOCHS = 20
    LR = 1e-4
    NUM_RUNS = 4

    for name, views in experiments:
        print(f"\n=== Experiment: {name} ===")
        summary = {k: [] for k in ["r2","mae","tol_rate","mean_bias","error_std","ks","wasserstein"]}
        bucket_maes = {0:[],1:[],2:[]}

        for run in range(1, NUM_RUNS+1):
            model = FusedMultiViewBMIModel3(views).to(DEVICE)
            opt   = optim.Adam(model.parameters(), lr=LR)
            mse   = nn.MSELoss()
            # train
            for _ in range(EPOCHS):
                model.train()
                for i1,i2,i3,bmi in train_loader:
                    i1,i2,i3,bmi = [x.to(DEVICE) for x in (i1,i2,i3,bmi)]
                    opt.zero_grad()
                    out = model(i1,i2,i3)
                    loss = mse(out, bmi.unsqueeze(1))
                    loss.backward()
                    opt.step()
            # eval
            model.eval()
            preds, gts = [],[]
            with torch.no_grad():
                for i1,i2,i3,bmi in val_loader:
                    i1,i2,i3 = [x.to(DEVICE) for x in (i1,i2,i3)]
                    out = model(i1,i2,i3)
                    preds += out.cpu().flatten().tolist()
                    gts   += bmi.flatten().tolist()
            m = compute_metrics(gts, preds)
            for k,v in m.items(): summary[k].append(v)
            # bucket MAE
            errs = np.abs(np.array(preds)-np.array(gts))
            classes = np.array([get_bmi_class(x) for x in gts])
            for b in (0,1,2):
                sel = errs[classes==b]
                bucket_maes[b].append(sel.mean() if sel.size>0 else np.nan)
            print(f" Run {run}/{NUM_RUNS} — " +
                  ", ".join(f"{k}={m[k]:.4f}" for k in summary.keys()))

        # print summary
        print(f"\n--- Summary for {name} ---")
        for k, vals in summary.items():
            a = np.array(vals)
            print(f"  {k.upper():10s}: {a.mean():.4f} ± {a.std():.4f}")
        # print & plot bucket MAE
        names = ["Under","Normal","Over"]
        mean_mae = [np.nanmean(bucket_maes[b]) for b in (0,1,2)]
        std_mae  = [np.nanstd(bucket_maes[b])  for b in (0,1,2)]
        for b,nm in enumerate(names):
            print(f"  Bucket {nm:6s} MAE: {mean_mae[b]:.4f} ± {std_mae[b]:.4f}")
        plt.figure()
        plt.bar(names, mean_mae, yerr=std_mae, capsize=5)
        plt.ylabel("MAE"); plt.title(f"Bucket MAE: {name}")
        plt.show()
