In [None]:
# Cell 1: setup — imports, mount, classes, dataloaders

import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

import os
import json
import time
import numpy as np
from PIL import Image, ImageOps

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

# mount Drive
from google.colab import drive
drive.mount('/content/drive')

# --- transforms ---
class LetterboxResize:
    def __init__(self, size=(224, 224), fill=(0.485, 0.456, 0.406)):
        self.size, self.fill = size, fill
    def __call__(self, img):
        img = ImageOps.exif_transpose(img)
        iw, ih = img.size
        w, h = self.size
        scale = min(w/iw, h/ih)
        nw, nh = int(iw*scale), int(ih*scale)
        img = img.resize((nw, nh), Image.BICUBIC)
        bg = Image.new('RGB', self.size, tuple(int(c*255) for c in self.fill))
        bg.paste(img, ((w-nw)//2, (h-nh)//2))
        return bg

# --- dataset ---
class MultiViewBMIDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples, self.transform = samples, transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        record = self.samples[idx]
        imgs = []
        for p in record['img_paths'][:2]:
            if os.path.exists(p):
                try:
                    img = Image.open(p).convert('RGB')
                except:
                    img = Image.new('RGB', (224,224), (0,0,0))
            else:
                img = Image.new('RGB', (224,224), (0,0,0))
            if self.transform: img = self.transform(img)
            imgs.append(img)
        return imgs[0], imgs[1], float(record['bmi'])

# --- model ---
class MultiViewResNet101Baseline(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        ckpt = '/content/drive/MyDrive/resnet101-63fe2227.pth'
        base = models.resnet101(weights=None)
        base.load_state_dict(torch.load(ckpt, map_location='cpu'))
        self.fe = nn.Sequential(*list(base.children())[:-1])
        self.fc_reg = nn.Linear(2048*2, 1)
        self.fc_cls = nn.Linear(2048*2, num_classes)
    def forward(self, x1, x2):
        def ext(x):
            y = self.fe(x)
            return y.view(y.size(0), -1)
        f1, f2 = ext(x1), ext(x2)
        cat = torch.cat([f1, f2], dim=1)
        return self.fc_reg(cat), self.fc_cls(cat)

# --- data loaders ---
with open('/content/drive/MyDrive/sample_list.json','r') as f:
    all_samples = json.load(f)
train_samples = [s for s in all_samples if s['split']=='Training']
val_samples   = [s for s in all_samples if s['split']=='Validation']

transform = transforms.Compose([
    LetterboxResize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_loader = DataLoader(
    MultiViewBMIDataset(train_samples, transform),
    batch_size=8, shuffle=True, num_workers=0, pin_memory=True
)
val_loader = DataLoader(
    MultiViewBMIDataset(val_samples, transform),
    batch_size=8, shuffle=False, num_workers=0, pin_memory=True
)

print(f"train: {len(train_loader.dataset)}, val: {len(val_loader.dataset)}")


In [None]:
# Cell 2: train once & save checkpoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model  = MultiViewResNet101Baseline(num_classes=3).to(device)
opt    = optim.Adam(model.parameters(), lr=1e-4)
loss_fn= nn.MSELoss()
epochs = 20

for e in range(epochs):
    model.train()
    total = 0.0
    t0 = time.time()
    for x1, x2, y in train_loader:
        x1, x2, y = x1.to(device), x2.to(device), y.to(device).unsqueeze(1)
        opt.zero_grad()
        pred, _ = model(x1, x2)
        loss    = loss_fn(pred, y)
        loss.backward()
        opt.step()
        total += loss.item()*y.size(0)
    print(f"Epoch {e+1}/{epochs} loss={total/len(train_loader.dataset):.4f} time={time.time()-t0:.1f}s")

# save
ckpt_path = '/content/drive/MyDrive/multi2view_resnet101.pth'
torch.save(model.state_dict(), ckpt_path)
print("checkpoint saved to", ckpt_path)


In [None]:
# Cell 3: downstream tasks (model already in memory)

# Load checkpoint
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# Downstream Task 1: Evaluation
from sklearn.metrics import mean_absolute_error, r2_score
all_pred, all_true = [], []
with torch.no_grad():
    for x1, x2, y in val_loader:
        x1, x2 = x1.to(device), x2.to(device)
        p, _   = model(x1, x2)
        all_pred += p.cpu().flatten().tolist()
        all_true += y.flatten().tolist()
mae = mean_absolute_error(all_true, all_pred)
r2  = r2_score(all_true, all_pred)
print(f"Eval MAE={mae:.4f}, R2={r2:.4f}")

# Downstream Task 2: add more here, like determine whether 3-view is better or not
