In [None]:
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

import os
import json
import time
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from dataclasses import dataclass
from typing import List, Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import mean_absolute_error, r2_score
from scipy.stats import ks_2samp, wasserstein_distance

from google.colab import drive
drive.mount('/content/drive')

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class LetterboxResize:
    def __init__(self, size=(224, 224), fill=(0.485, 0.456, 0.406)):
        self.size = size
        self.fill = fill
    def __call__(self, img):
        img = ImageOps.exif_transpose(img)
        iw, ih = img.size
        w, h = self.size
        scale = min(w / iw, h / ih)
        nw, nh = int(iw * scale), int(ih * scale)
        img = img.resize((nw, nh), Image.BICUBIC)
        new_img = Image.new("RGB", self.size, tuple([int(c * 255) for c in self.fill]))
        new_img.paste(img, ((w - nw) // 2, (h - nh) // 2))
        return new_img

class MultiViewBMIDataset(Dataset):
    def __init__(self, sample_list, transform=None):
        self.samples = sample_list
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        imgs = []
        for path in sample["img_paths"]:
            if not os.path.exists(path):
                img = Image.new("RGB", (224, 224), (0, 0, 0))
            else:
                try:
                    img = Image.open(path).convert("RGB")
                except Exception as e:
                    print(f"Error opening image: {path} | {e}")
                    img = Image.new("RGB", (224, 224), (0, 0, 0))
            if self.transform is not None:
                img = self.transform(img)
            imgs.append(img)
        bmi = float(sample["bmi"])
        return imgs[0], imgs[1], imgs[2], bmi

class MultiViewResNet101Baseline(nn.Module):
    def __init__(self, num_classes=3):
        """
        num_classes: number of classes (e.g., 3 BMI bins); regression loss is used.
        """
        super(MultiViewResNet101Baseline, self).__init__()
        resnet_path = "/content/drive/MyDrive/resnet101-63fe2227.pth"
        print("Loading ResNet101 weights...")
        base = models.resnet101(weights=None)
        state_dict = torch.load(resnet_path, map_location=DEVICE)
        base.load_state_dict(state_dict)
        print("Weights loaded successfully!")
        self.backbone = nn.Sequential(*list(base.children())[:-1])
        self.fc_reg = nn.Linear(2048 * 3, 1)
        self.fc_cls = nn.Linear(2048 * 3, num_classes)
    def forward(self, img1, img2, img3):
        def extract(x):
            x = self.backbone(x)
            return torch.flatten(x, 1)
        f1 = extract(img1)
        f2 = extract(img2)
        f3 = extract(img3)
        fused = torch.cat([f1, f2, f3], dim=1)
        out_reg = self.fc_reg(fused)
        out_cls = self.fc_cls(fused)
        return out_reg, out_cls

if __name__ == "__main__":
    sample_json = "/content/drive/MyDrive/sample_list.json"
    with open(sample_json, "r") as f:
        samples = json.load(f)
    train_samples = [s for s in samples if s["split"] == "Training"]
    val_samples = [s for s in samples if s["split"] == "Validation"]

    transform = transforms.Compose([
        LetterboxResize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = MultiViewBMIDataset(train_samples, transform=transform)
    val_dataset = MultiViewBMIDataset(val_samples, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)

    print(f"Train Dataset Size: {len(train_dataset)}")
    print(f"Val Dataset Size: {len(val_dataset)}")

    num_epochs = 20
    num_runs = 10
    results = []

    for run in range(num_runs):
        print(f"\n========== Run {run+1} ==========")
        model = MultiViewResNet101Baseline(num_classes=3).to(DEVICE)
        criterion_reg = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        for epoch in range(1, num_epochs + 1):
            model.train()
            running_loss = 0.0
            start_time = time.time()
            for img1, img2, img3, bmi in train_loader:
                img1, img2, img3 = img1.to(DEVICE), img2.to(DEVICE), img3.to(DEVICE)
                targets = bmi.to(DEVICE).float().unsqueeze(1)
                optimizer.zero_grad()
                out_reg, _ = model(img1, img2, img3)
                loss = criterion_reg(out_reg, targets)
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * targets.size(0)
            epoch_loss = running_loss / len(train_dataset)
            print(f"[Run {run+1}] Epoch {epoch}/{num_epochs}, Loss: {epoch_loss:.4f}, Time: {time.time()-start_time:.2f}s")

        model.eval()
        all_preds, all_targets = [], []
        with torch.no_grad():
            for img1, img2, img3, bmi in val_loader:
                img1, img2, img3 = img1.to(DEVICE), img2.to(DEVICE), img3.to(DEVICE)
                targets = bmi.to(DEVICE).float().unsqueeze(1)
                out_reg, _ = model(img1, img2, img3)
                all_preds.extend(out_reg.cpu().numpy().flatten().tolist())
                all_targets.extend(targets.cpu().numpy().flatten().tolist())
        gt_arr = np.array(all_targets)
        pred_arr = np.array(all_preds)
        try:
            r2_val = r2_score(gt_arr, pred_arr)
        except Exception as e:
            r2_val = 0.0
        mae_val = mean_absolute_error(gt_arr, pred_arr)
        tolerance = 1.0
        tol_rate = np.mean(np.abs(gt_arr - pred_arr) <= tolerance)
        error = pred_arr - gt_arr
        mean_bias = np.mean(error)
        error_std = np.std(error)
        try:
            ks_val, _ = ks_2samp(gt_arr, pred_arr)
            wass_val = wasserstein_distance(gt_arr, pred_arr)
        except Exception as e:
            ks_val = wass_val = 0.0

        print(f"\n[Run {run+1}] Validation Metrics:")
        print(f"R²: {r2_val:.4f}")
        print(f"MAE: {mae_val:.4f}")
        print(f"Tolerance Rate (±{tolerance}): {tol_rate*100:.2f}%")
        print(f"Mean Bias: {mean_bias:.4f}")
        print(f"Error Std: {error_std:.4f}")
        print(f"KS: {ks_val:.4f}")
        print(f"Wasserstein: {wass_val:.4f}\n")

        results.append({
            'r2': r2_val,
            'mae': mae_val,
            'tol_rate': tol_rate,
            'mean_bias': mean_bias,
            'error_std': error_std,
            'ks': ks_val,
            'wasserstein': wass_val
        })

    r2_vals = [res['r2'] for res in results]
    mae_vals = [res['mae'] for res in results]
    tol_vals = [res['tol_rate'] for res in results]
    mean_bias_vals = [res['mean_bias'] for res in results]
    err_std_vals = [res['error_std'] for res in results]
    ks_vals = [res['ks'] for res in results]
    wass_vals = [res['wasserstein'] for res in results]

    print("\n========== Overall Results (10 runs, 20 epochs each) ==========")
    print(f"R²: Mean = {np.mean(r2_vals):.4f}, Std = {np.std(r2_vals):.4f}")
    print(f"MAE: Mean = {np.mean(mae_vals):.4f}, Std = {np.std(mae_vals):.4f}")
    print(f"Tolerance Rate: Mean = {np.mean(tol_vals)*100:.2f}%, Std = {np.std(tol_vals)*100:.2f}%")
    print(f"Mean Bias: Mean = {np.mean(mean_bias_vals):.4f}, Std = {np.std(mean_bias_vals):.4f}")
    print(f"Error Std: Mean = {np.mean(err_std_vals):.4f}, Std = {np.std(err_std_vals):.4f}")
    print(f"KS: Mean = {np.mean(ks_vals):.4f}, Std = {np.std(ks_vals):.4f}")
    print(f"Wasserstein: Mean = {np.mean(wass_vals):.4f}, Std = {np.std(wass_vals):.4f}")
