<a href="https://colab.research.google.com/github/taupork/DSA/blob/main/SimSiamSSL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd

CSV_PATH = "/content/drive/MyDrive/DSA_Comp/dataset.csv"
df = pd.read_csv(CSV_PATH)
df

Unnamed: 0,Filename,Hb,Ethnicity,IndividualID,Image No,Extension
0,/content/drive/MyDrive/DSA_Comp/Dataset/Random...,7.0,,100,0,jpeg
1,/content/drive/MyDrive/DSA_Comp/Dataset/HgB_10...,10.7,,7,0,heic
2,/content/drive/MyDrive/DSA_Comp/Dataset/HgB_17...,17.3,,5,2,jpeg
3,/content/drive/MyDrive/DSA_Comp/Dataset/HgB_12...,12.0,,2,4,jpg
4,/content/drive/MyDrive/DSA_Comp/Dataset/HgB_10...,10.7,,9,0,heic
5,/content/drive/MyDrive/DSA_Comp/Dataset/Random...,13.7,MiddleEasternOrigin,101,0,jpg
6,/content/drive/MyDrive/DSA_Comp/Dataset/HgB_8....,8.0,,3,1,jpg
7,/content/drive/MyDrive/DSA_Comp/Dataset/HgB_10...,10.7,,2,0,heic
8,/content/drive/MyDrive/DSA_Comp/Dataset/HgB_12...,12.0,,2,1,jpg
9,/content/drive/MyDrive/DSA_Comp/Dataset/HgB_17...,17.3,,5,1,jpeg


In [None]:
# ---------------------------
# modular_pipeline.py
# ---------------------------
import os
import copy
from datetime import datetime
import joblib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import LeaveOneOut, KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import xgboost as xgb

# ---------------------------
# ssl_models.py
# ---------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

class BaseSSL:
    """Abstract SSL backbone interface"""
    def __init__(self, device='cpu'):
        self.device = device
        self.backbone = None

    def pretrain(self, dataset, epochs=10, batch_size=4, lr=1e-3):
        """Pretrain on dataset"""
        raise NotImplementedError

    def extract_embeddings(self, dataset, batch_size=8):
        """Extract embeddings from dataset"""
        self.backbone.eval()
        feats, targets = [], []
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
        with torch.no_grad():
            for imgs, labels in loader:
                imgs = imgs.to(self.device)
                emb = self.backbone(imgs).view(imgs.size(0), -1)
                feats.append(emb.cpu().numpy())
                targets.extend(labels.numpy())
        X = np.vstack(feats) if feats else np.zeros((0,512))
        return X, np.array(targets)


# ---------------------------
# SimSiam
# ---------------------------
class SimSiamSSL(BaseSSL):
    def __init__(self, device='cpu', pretrained=True):
        super().__init__(device)
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
        self.replace_bn_with_gn(resnet)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1]).to(device)
        self.projector = ProjectionHead().to(device)
        self.predictor = Predictor().to(device)

    def replace_bn_with_gn(self, module):
        for name, child in list(module.named_children()):
            if isinstance(child, nn.BatchNorm2d):
                setattr(module, name, nn.GroupNorm(1, child.num_features))
            else:
                self.replace_bn_with_gn(child)

    def pretrain(self, dataset, epochs=10, batch_size=4, lr=1e-3, freeze_early=True, drop_last=True):
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last)
        if freeze_early:
            for name, p in self.backbone.named_parameters():
                if "layer4" not in name: p.requires_grad = False
        optimizer = optim.Adam(
            [p for p in list(self.backbone.parameters()) + list(self.projector.parameters()) + list(self.predictor.parameters()) if p.requires_grad],
            lr=lr
        )
        for epoch in range(epochs):
            for (v1, v2), _ in loader:
                v1, v2 = v1.to(self.device), v2.to(self.device)
                f1 = self.backbone(v1).view(v1.size(0), -1)
                f2 = self.backbone(v2).view(v2.size(0), -1)
                z1, z2 = self.projector(f1), self.projector(f2)
                p1, p2 = self.predictor(z1), self.predictor(z2)
                loss = 0.5 * negative_cosine_similarity(p1, z2) + 0.5 * negative_cosine_similarity(p2, z1)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        return self.backbone


# ---------------------------
# SimCLR
# ---------------------------
class SimCLRSSL(BaseSSL):
    def __init__(self, device='cpu', pretrained=True, out_dim=512):
        super().__init__(device)
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
        self.replace_bn_with_gn(resnet)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1]).to(device)
        self.projection = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, out_dim)
        ).to(device)

    def replace_bn_with_gn(self, module):
        for name, child in list(module.named_children()):
            if isinstance(child, nn.BatchNorm2d):
                setattr(module, name, nn.GroupNorm(1, child.num_features))
            else:
                self.replace_bn_with_gn(child)

    def pretrain(self, dataset, epochs=10, batch_size=4, lr=1e-3):
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        optimizer = optim.Adam(list(self.backbone.parameters()) + list(self.projection.parameters()), lr=lr)
        for epoch in range(epochs):
            for (v1, v2), _ in loader:
                v1, v2 = v1.to(self.device), v2.to(self.device)
                h1, h2 = self.projection(self.backbone(v1).view(v1.size(0), -1)), \
                         self.projection(self.backbone(v2).view(v2.size(0), -1))
                loss = nt_xent_loss(h1, h2)  # Implement NT-Xent loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        return self.backbone

# ---------------------------
# SSL Factory
# ---------------------------
def get_ssl_model(name="simsiam", device='cpu'):
    if name.lower() == "simsiam": return SimSiamSSL(device)
    if name.lower() == "simclr": return SimCLRSSL(device)
    raise ValueError(f"Unknown SSL model: {name}")


RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ---------------------------
# PCA + Scaler
# ---------------------------
def pca_and_scale(X_train, X_val, n_components=64):
    scaler = StandardScaler()
    X_train_s = scaler.fit_transform(X_train)
    X_val_s = scaler.transform(X_val)
    pca = PCA(n_components=min(n_components, X_train_s.shape[1]))
    X_train_p = pca.fit_transform(X_train_s)
    X_val_p = pca.transform(X_val_s)
    return X_train_p, X_val_p, scaler, pca


# ---------------------------
# LOOCV evaluation with XGBoost
# ---------------------------
def evaluate_loocv_with_xgb(X_all, y_all, df_all, backbone,
                            run_dir, fine_tune=False, ft_epochs=5, ft_lr=1e-5,
                            pca_n=32, optuna_trials=10):
    loo = LeaveOneOut()
    maes, rmses = [], []
    preds_all, true_all = [], []
    fold = 0
    for train_idx, val_idx in loo.split(X_all):
        fold += 1
        X_tr_raw, X_val_raw = X_all[train_idx], X_all[val_idx]
        y_tr, y_val = y_all[train_idx], y_all[val_idx]

        # Optional fine-tuning
        if fine_tune:
            df_train_fold = df_all.iloc[train_idx].reset_index(drop=True)
            df_val_fold = df_all.iloc[val_idx].reset_index(drop=True)
            backbone_ft = fine_tune_backbone_for_regression(backbone, df_train_fold, df_val_fold,
                                                           epochs=ft_epochs)
            X_tr_raw, _ = extract_embeddings(backbone_ft, df_train_fold, val_transform)
            X_val_raw, _ = extract_embeddings(backbone_ft, df_val_fold, val_transform)

        X_tr, X_val, scaler, pca = pca_and_scale(X_tr_raw, X_val_raw, n_components=pca_n)

        # Train XGBoost
        model = xgb.XGBRegressor(random_state=RANDOM_SEED, n_jobs=1, tree_method="hist")
        model.fit(X_tr, y_tr)
        y_pred = model.predict(X_val)

        mae = mean_absolute_error(y_val, y_pred)
        rmse = np.sqrt(mean_squared_error(y_val, y_pred))
        maes.append(mae)
        rmses.append(rmse)
        preds_all.append(float(y_pred[0]))
        true_all.append(float(y_val[0]))

        # Save fold artifacts
        fold_dir = os.path.join(run_dir, f"fold_{fold}")
        os.makedirs(fold_dir, exist_ok=True)
        try:
            model.save_model(os.path.join(fold_dir, "xgb_model.json"))
        except Exception:
            pass

    maes = np.array(maes)
    rmses = np.array(rmses)
    preds_all = np.array(preds_all)
    true_all = np.array(true_all)
    corr = pearsonr(true_all, preds_all)[0] if len(true_all) > 1 else float("nan")

    summary = {
        "MAE_mean": float(maes.mean()),
        "MAE_std": float(maes.std(ddof=1) if len(maes)>1 else 0.0),
        "RMSE_mean": float(rmses.mean()),
        "RMSE_std": float(rmses.std(ddof=1) if len(rmses)>1 else 0.0),
        "pearson_r": float(corr)
    }
    return summary, true_all, preds_all


# ---------------------------
# Full pipeline
# ---------------------------
def run_pipeline(df, ssl_model_name="simsiam",
                 path_col="Filename",
                 target_col="Hb",
                 run_base_dir="models",
                 simsiam_epochs=10,
                 simsiam_batch=4,
                 freeze_early=True,
                 pca_n=30,
                 fine_tune=False,
                 ft_epochs=5,
                 ft_lr=1e-5,
                 loocv_optuna_trials=10,
                 drop_last=True):
    run_dir = make_run_dir(base=run_base_dir)
    print("Run directory:", run_dir)
    df.reset_index(drop=True).to_csv(os.path.join(run_dir, "df.csv"), index=False)

    # ---------------------------
    # Pretrain SSL
    # ---------------------------
    dataset = HbImageDataset(df, transform=val_transform, n_views=2)
    ssl_model: BaseSSL = get_ssl_model(ssl_model_name, device=device)
    print(f">>> Pretraining SSL model: {ssl_model_name}")
    ssl_model.pretrain(dataset, epochs=simsiam_epochs, batch_size=simsiam_batch, drop_last=drop_last)

    # ---------------------------
    # Extract embeddings
    # ---------------------------
    print(">>> Extract embeddings for full dataset")
    X_all, y_all = ssl_model.extract_embeddings(HbImageDataset(df, transform=val_transform, n_views=1))
    print("Embeddings shape:", X_all.shape)
    np.save(os.path.join(run_dir, "X_all_raw.npy"), X_all)
    np.save(os.path.join(run_dir, "y_all.npy"), y_all)

    # ---------------------------
    # LOOCV evaluation
    # ---------------------------
    print(">>> LOOCV downstream evaluation")
    summary, true_all, preds_all = evaluate_loocv_with_xgb(X_all, y_all, df, ssl_model.backbone,
                                                          run_dir, fine_tune=fine_tune,
                                                          ft_epochs=ft_epochs, ft_lr=ft_lr,
                                                          pca_n=pca_n,
                                                          optuna_trials=loocv_optuna_trials)

    # Summary & scatter plot
    print("=== LOOCV Summary ===")
    print(f"MAE: {summary['MAE_mean']:.4f} ± {summary['MAE_std']:.4f}")
    print(f"RMSE: {summary['RMSE_mean']:.4f} ± {summary['RMSE_std']:.4f}")
    print(f"Pearson r: {summary['pearson_r']:.4f}")
    pd.Series(summary).to_json(os.path.join(run_dir, "loocv_summary.json"))

    plt.figure(figsize=(6,6))
    plt.scatter(true_all, preds_all, alpha=0.7)
    mn, mx = min(true_all.min(), preds_all.min()), max(true_all.max(), preds_all.max())
    plt.plot([mn, mx], [mn, mx], 'r--')
    plt.xlabel("True Hb"); plt.ylabel("Predicted Hb")
    plt.title(f"LOOCV preds (MAE {summary['MAE_mean']:.3f} ± {summary['MAE_std']:.3f})")
    plt.tight_layout()
    plt.savefig(os.path.join(run_dir, "loocv_scatter.png"), dpi=300)
    plt.close()

    return {
        "run_dir": run_dir,
        "loocv_summary": summary,
        "true": true_all,
        "preds": preds_all
    }


# Resnet 34 has weaker performance and is much slower


In [None]:
run_pipeline(df)

Run directory: models/run_2025-10-06_15-42-11
>>> Pretraining SSL model: simsiam
>>> Extract embeddings for full dataset
Embeddings shape: (31, 512)
>>> LOOCV downstream evaluation
=== LOOCV Summary ===
MAE: 1.8427 ± 2.2773
RMSE: 1.8427 ± 2.2773
Pearson r: 0.5055


{'run_dir': 'models/run_2025-10-06_15-42-11',
 'loocv_summary': {'MAE_mean': 1.842696374462497,
  'MAE_std': 2.2773128005820498,
  'RMSE_mean': 1.8426963602229682,
  'RMSE_std': 2.277312772222579,
  'pearson_r': 0.505483082578548},
 'true': array([ 7.        , 10.69999981, 17.29999924, 12.        , 10.69999981,
        13.69999981,  8.        , 10.69999981, 12.        , 17.29999924,
        17.29999924,  8.89999962, 10.69999981, 11.60000038, 10.69999981,
        17.29999924,  8.89999962, 10.69999981,  7.80000019, 10.69999981,
        12.        , 10.69999981, 10.69999981, 10.69999981, 10.69999981,
        10.69999981, 17.29999924, 11.89999962,  4.0999999 , 16.        ,
        12.        ]),
 'preds': array([10.99869633, 10.686306  , 12.37405682, 11.69604588, 10.87406349,
        11.5959549 , 10.85046959,  8.86496544, 13.014328  , 17.22431755,
        17.30682564, 13.49284935, 10.70213318, 12.51898098, 10.7281332 ,
        13.79051971, 10.79702473, 10.40252781, 17.27773094, 10.70181847