In [None]:
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

import os
import json
import time
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from dataclasses import dataclass
from typing import List, Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import mean_absolute_error, r2_score
from scipy.stats import ks_2samp, wasserstein_distance

# Mount Google Drive (ensure JSON and images are on Drive)
from google.colab import drive
drive.mount('/content/drive')

# Global device setting
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Current device:", DEVICE)

# Data Preprocessing: LetterboxResize (consistent with experimental group)
class LetterboxResize:
    def __init__(self, size=(224, 224), fill=(0.485, 0.456, 0.406)):
        self.size = size
        self.fill = fill
    def __call__(self, img):
        # Correct image orientation
        img = ImageOps.exif_transpose(img)
        iw, ih = img.size
        w, h = self.size
        scale = min(w / iw, h / ih)
        nw, nh = int(iw * scale), int(ih * scale)
        img = img.resize((nw, nh), Image.BICUBIC)
        new_img = Image.new("RGB", self.size, tuple([int(c * 255) for c in self.fill]))
        new_img.paste(img, ((w - nw) // 2, (h - nh) // 2))
        return new_img

# Dataset definition: MultiViewBMIDataset (2-view version)
# Loads the first 2 images and returns BMI as the target
class MultiViewBMIDataset(Dataset):
    def __init__(self, sample_list, transform=None):
        self.samples = sample_list
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        sample = self.samples[idx]
        imgs = []
        for path in sample["img_paths"][:2]:
            if not os.path.exists(path):
                img = Image.new("RGB", (224, 224), (0, 0, 0))
            else:
                try:
                    img = Image.open(path).convert("RGB")
                except Exception as e:
                    print(f"Error opening image: {path} | {e}")
                    img = Image.new("RGB", (224, 224), (0, 0, 0))
            if self.transform is not None:
                img = self.transform(img)
            imgs.append(img)
        bmi = float(sample["bmi"])
        return imgs[0], imgs[1], bmi

# BMI classification function (3 classes)
def get_bmi_class_new(bmi):
    if bmi < 17.28:
        return 0
    elif bmi < 25.57:
        return 1
    else:
        return 2

# Model definition: FusedMultiViewBMIModel_2view
# Uses ResNet101 to extract 2048-dim features per image (2-view => 4096-dim), then predicts BMI (regression) and BMI class
class FusedMultiViewBMIModel_2view(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        resnet_path = "/content/drive/MyDrive/resnet101-63fe2227.pth"
        print("Loading ResNet101 weights...")
        base = models.resnet101(weights=None)
        state_dict = torch.load(resnet_path, map_location=DEVICE)
        base.load_state_dict(state_dict)
        print("Weights loaded successfully!")
        self.backbone = nn.Sequential(*list(base.children())[:-1])
        self.fc_reg = nn.Linear(4096, 1)
        self.fc_cls = nn.Linear(4096, num_classes)
    def forward(self, img1, img2):
        def extract(x):
            x = self.backbone(x)
            return torch.flatten(x, 1)
        f1 = extract(img1)  # [B, 2048]
        f2 = extract(img2)  # [B, 2048]
        fused = torch.cat([f1, f2], dim=1)  # [B, 4096]
        out_reg = self.fc_reg(fused)
        out_cls = self.fc_cls(fused)
        return out_reg, out_cls

# Training and validation process (includes classification loss)
if __name__ == "__main__":
    # Load sample list JSON
    sample_json = "/content/drive/MyDrive/sample_list.json"
    with open(sample_json, "r") as f:
        samples = json.load(f)
    # Split samples based on "split" field (case-sensitive)
    train_samples = [s for s in samples if s["split"] == "Training"]
    val_samples   = [s for s in samples if s["split"] == "Validation"]

    transform = transforms.Compose([
        LetterboxResize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = MultiViewBMIDataset(train_samples, transform=transform)
    val_dataset = MultiViewBMIDataset(val_samples, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0, pin_memory=True)

    print(f"Train Dataset Size: {len(train_dataset)}")
    print(f"Val Dataset Size: {len(val_dataset)}")

    # Experiment settings: different class loss weight factors (vector: [wf, 1, wf])
    weight_factors = [1, 2, 3, 5, 8]
    num_epochs = 20

    # Run experiments for different weight factors
    for wf in weight_factors:
        print(f"\n====== Classification Loss Weight Factor: {wf} (multiply head/tail by {wf}, middle is 1) ======")
        class_weight = torch.tensor([wf, 1.0, wf], dtype=torch.float32, device=DEVICE)
        criterion_reg = nn.MSELoss()
        criterion_cls = nn.CrossEntropyLoss(weight=class_weight)
        model = FusedMultiViewBMIModel_2view(num_classes=3).to(DEVICE)
        optimizer = optim.Adam(model.parameters(), lr=1e-4)

        # Training phase
        for epoch in range(1, num_epochs+1):
            model.train()
            running_loss = 0.0
            start_time = time.time()
            for img1, img2, bmi in train_loader:
                img1, img2 = img1.to(DEVICE), img2.to(DEVICE)
                targets = bmi.to(DEVICE).float().unsqueeze(1)
                cls_targets = torch.tensor([get_bmi_class_new(b.item()) for b in bmi],
                                             dtype=torch.long, device=DEVICE)

                optimizer.zero_grad()
                out_reg, out_cls = model(img1, img2)
                loss_reg = criterion_reg(out_reg, targets)
                loss_cls = criterion_cls(out_cls, cls_targets)
                loss = loss_reg + loss_cls
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * targets.size(0)
            epoch_loss = running_loss / len(train_dataset)
            print(f"Epoch {epoch}/{num_epochs}, Loss: {epoch_loss:.4f}, Time: {time.time()-start_time:.2f}s")

        # Validation phase
        model.eval()
        all_preds, all_targets, all_cls_preds, all_cls_targets = [], [], [], []
        with torch.no_grad():
            for img1, img2, bmi in val_loader:
                img1, img2 = img1.to(DEVICE), img2.to(DEVICE)
                targets = bmi.to(DEVICE).float().unsqueeze(1)
                out_reg, out_cls = model(img1, img2)
                all_preds.extend(out_reg.cpu().numpy().flatten().tolist())
                all_targets.extend(targets.cpu().numpy().flatten().tolist())
                cls_pred = torch.argmax(out_cls, dim=1)
                cls_target = torch.tensor([get_bmi_class_new(b.item()) for b in bmi],
                                            dtype=torch.long, device=DEVICE)
                all_cls_preds.extend(cls_pred.cpu().numpy().tolist())
                all_cls_targets.extend(cls_target.cpu().numpy().tolist())
        gt_arr = np.array(all_targets)
        pred_arr = np.array(all_preds)
        try:
            r2_val = r2_score(gt_arr, pred_arr)
        except Exception as e:
            r2_val = 0.0
        mae_val = mean_absolute_error(gt_arr, pred_arr)
        tol_rate = np.mean(np.abs(gt_arr - pred_arr) <= 1.0)
        error = pred_arr - gt_arr
        mean_bias = np.mean(error)
        error_std = np.std(error)
        try:
            ks_val, _ = ks_2samp(gt_arr, pred_arr)
            wass_val = wasserstein_distance(gt_arr, pred_arr)
        except Exception as e:
            ks_val = wass_val = 0.0
        cls_acc = np.mean(np.array(all_cls_preds) == np.array(all_cls_targets))

        print("\nValidation Metrics:")
        print(f"R²: {r2_val:.4f}")
        print(f"MAE: {mae_val:.4f}")
        print(f"Tolerance Rate (±1.0): {tol_rate*100:.2f}%")
        print(f"Mean Bias: {mean_bias:.4f}")
        print(f"Error Std: {error_std:.4f}")
        print(f"KS: {ks_val:.4f}")
        print(f"Wasserstein: {wass_val:.4f}")
        print(f"Classification Accuracy: {cls_acc*100:.2f}%")
