<a href="https://colab.research.google.com/github/taupork/DSA/blob/main/FINAL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd

CSV_PATH = "/content/drive/MyDrive/DSA_Comp/dataset.csv"
df = pd.read_csv(CSV_PATH)
df

### Training SSL Model


In [None]:
import os
import random
import time
from datetime import datetime
import joblib

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import xgboost as xgb
import optuna

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import pillow_heif
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt

# Register HEIF opener (if using HEIC/HEIF images)
pillow_heif.register_heif_opener()

# ---------------------------
# Configuration / Reproducibility
# ---------------------------
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ---------------------------
# Utilities
# ---------------------------
def make_run_dir(base="models", prefix="run"):
    ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    run_dir = os.path.join(base, f"{prefix}_{ts}")
    os.makedirs(run_dir, exist_ok=True)
    return run_dir

def safe_save_json(df, path):
    try:
        df.to_json(path, orient="records", lines=True)
    except Exception:
        df.to_csv(path.replace(".json", ".csv"), index=False)

# ---------------------------
# Dataset Classes
# ---------------------------
from torch.utils.data import Dataset
from PIL import Image
import torch
import os

class HbImageDataset(Dataset):
    """Labeled dataset for supervised + SSL training"""
    def __init__(self, df, transform=None, path_col="Filename", target_col="Hb", n_views=2):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.path_col = path_col
        self.target_col = target_col
        self.n_views = n_views

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row[self.path_col]).convert("RGB")
        if self.n_views == 1:
            return self.transform(img).unsqueeze(0), torch.tensor(row[self.target_col], dtype=torch.float32)
        else:
            views = [self.transform(img) for _ in range(self.n_views)]
            return torch.stack(views), torch.tensor(row[self.target_col], dtype=torch.float32)

class UnlabelledImageDataset(Dataset):
    """Unlabeled dataset for SSL pretraining, supports subfolders"""
    def __init__(self, root_dir, transform=None, n_views=2):
        self.image_paths = []
        for dirpath, _, filenames in os.walk(root_dir):
            for fname in filenames:
                if fname.lower().endswith((".jpg", ".jpeg", ".png", ".heic", ".heif")):
                    self.image_paths.append(os.path.join(dirpath, fname))
        self.transform = transform
        self.n_views = n_views

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.n_views == 1:
            return self.transform(img).unsqueeze(0)
        else:
            views = [self.transform(img) for _ in range(self.n_views)]
            return torch.stack(views)

class CombinedImageDataset(Dataset):
    """
    Combines a labeled DataFrame dataset and an unlabelled image directory for SSL.
    Supports multi-view transformations for contrastive learning.
    Returns (views, target, is_labeled).
    Logs the number of labeled and unlabeled images loaded.
    """
    def __init__(self, labelled_df=None, unlabelled_dir=None, transform=None, path_col="Filename", target_col="Hb", n_views=2):
        self.labelled_dataset = None
        self.unlabelled_dataset = None

        if labelled_df is not None:
            self.labelled_dataset = HbImageDataset(
                labelled_df, transform=transform, path_col=path_col, target_col=target_col, n_views=n_views
            )
            print(f"[INFO] Loaded {len(self.labelled_dataset)} labeled images for SSL pretraining.")

        if unlabelled_dir is not None:
            self.unlabelled_dataset = UnlabelledImageDataset(
                unlabelled_dir, transform=transform, n_views=n_views
            )
            print(f"[INFO] Loaded {len(self.unlabelled_dataset)} unlabeled images for SSL pretraining.")

        # Compute total length
        self.labelled_len = len(self.labelled_dataset) if self.labelled_dataset else 0
        self.unlabelled_len = len(self.unlabelled_dataset) if self.unlabelled_dataset else 0
        self.total_len = self.labelled_len + self.unlabelled_len
        print(f"[INFO] Total images in combined dataset: {self.total_len}")

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        if self.labelled_dataset and idx < self.labelled_len:
            views, target = self.labelled_dataset[idx]
            return views, target, True  # True indicates labeled
        else:
            unlabelled_idx = idx - self.labelled_len
            views = self.unlabelled_dataset[unlabelled_idx]
            dummy_target = torch.tensor(-1.0)  # Dummy target for unlabeled
            return views, dummy_target, False  # False indicates unlabelled



# ---------------------------
# Transforms
# ---------------------------
ssl_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    # transforms.ColorJitter(0.4,0.4,0.4,0.1), => May be too strong for hb predictions
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.RandomResizedCrop(224, scale=(0.8,1.0)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# ---------------------------
# ResNet Backbone + Projection Head
# ---------------------------
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256, output_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

def get_ssl_backbone(pretrained=True):
    resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
    backbone = nn.Sequential(*list(resnet.children())[:-1])
    return backbone.to(device)

# ---------------------------
# NT-Xent Loss
# ---------------------------
def nt_xent_loss(z_i, z_j, temperature=0.5):
    z_i = nn.functional.normalize(z_i, dim=1)
    z_j = nn.functional.normalize(z_j, dim=1)
    batch_size = z_i.size(0)
    representations = torch.cat([z_i, z_j], dim=0)
    similarity_matrix = torch.matmul(representations, representations.T) / temperature
    mask = torch.eye(2*batch_size, device=z_i.device).bool()
    similarity_matrix = similarity_matrix.masked_fill(mask, -1e9)
    labels = torch.arange(batch_size, device=z_i.device)
    labels = torch.cat([labels + batch_size, labels], dim=0)
    return nn.CrossEntropyLoss()(similarity_matrix, labels)

# ---------------------------
# SSL Pretraining (supports unlabeled or labeled)
# ---------------------------
def pretrain_ssl(labelled_df=None,
                          unlabelled_dir=None,
                          transform=ssl_transform,
                          epochs=20,
                          batch_size=8,
                          lr=1e-3,
                          num_workers=2,
                          run_dir=None):

    dataset = CombinedImageDataset(labelled_df=labelled_df,
                                   unlabelled_dir=unlabelled_dir,
                                   transform=transform,
                                   n_views=2)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

    backbone = get_ssl_backbone(pretrained=True)
    proj_head = ProjectionHead().to(device)
    optimizer = torch.optim.Adam(list(backbone.parameters()) + list(proj_head.parameters()), lr=lr)

    backbone.train()
    proj_head.train()
    ssl_losses = []

    for epoch in range(epochs):
        total_loss = 0.0
        for batch in loader:
            views, target, is_labeled = batch
            v1, v2 = views[:,0].to(device), views[:,1].to(device)

            feats1 = backbone(v1).view(v1.size(0), -1)
            feats2 = backbone(v2).view(v2.size(0), -1)
            z1 = proj_head(feats1)
            z2 = proj_head(feats2)
            loss = nt_xent_loss(z1, z2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(loader) if len(loader) > 0 else float("nan")
        ssl_losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{epochs} - SSL Loss: {avg_loss:.4f}")

    if run_dir:
        torch.save(backbone.state_dict(), os.path.join(run_dir, "ssl_backbone_state_dict.pth"))
        torch.save(proj_head.state_dict(), os.path.join(run_dir, "ssl_projection_head_state_dict.pth"))
        pd.DataFrame({"ssl_loss": ssl_losses}).to_csv(os.path.join(run_dir, "ssl_loss_history.csv"), index=False)

    return backbone, proj_head, ssl_losses


Using device: cuda


#### Train SSL Step

In [None]:
run_dir = make_run_dir(base="models", prefix="ssl_combined")

backbone, proj_head, ssl_losses = pretrain_ssl(
    labelled_df=df,       # your labelled DataFrame
    unlabelled_dir="/content/drive/MyDrive/DSA_Comp/Lip Images",  # path to unlabelled images
    transform=ssl_transform,
    epochs=20,
    batch_size=16,
    lr=1e-3,
    num_workers=4,
    run_dir=run_dir
)

### Training + Validating Regressor Step

In [None]:
# ---------------------------
# Feature Extraction
# ---------------------------
def extract_embeddings(df, backbone, transform, path_col="Filename", target_col="Hb", batch_size=8, num_workers=2):
    dataset = HbImageDataset(df, transform, n_views=1, path_col=path_col, target_col=target_col)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    backbone.eval()

    feats, targets = [], []
    with torch.no_grad():
        for views, hb in loader:
            img = views[:,0].to(device)
            emb = backbone(img).view(img.size(0), -1)
            feats.append(emb.cpu().numpy())
            targets.extend(hb.cpu().numpy())
    if len(feats) == 0:
        return np.zeros((0,512)), np.array([])
    return np.vstack(feats), np.array(targets)


# ---------------------------
# Combine Metadata
# ---------------------------
def combine_metadata(features, df, cols_to_include=None):
    df_reset = df.reset_index(drop=True)
    if cols_to_include is None:
        cols_to_include = [c for c in ["IndividualID","ImageNo"] if c in df_reset.columns]
    if cols_to_include:
        metadata = df_reset[cols_to_include].values
        if len(metadata) != len(features):
            print(f"Warning: metadata rows ({len(metadata)}) != feature rows ({len(features)}) -> ignoring metadata")
            return features
        try:
            metadata = metadata.astype(np.float32)
        except Exception:
            print("Metadata not numeric; skipping metadata concatenation.")
            return features
        return np.hstack([features, metadata])
    return features

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from torch.utils.data import DataLoader

def run_ssl_pipeline_kfold_with_r2(labeled_df,
                                   unlabelled_dir=None,
                                   ssl_epochs=20,
                                   ssl_batch=8,
                                   fine_tune_backbone=True,
                                   fine_tune_epochs=10,
                                   use_metadata=False,
                                   optuna_trials=20,
                                   run_base_dir="models",
                                   load_run_dir=None,
                                   n_splits=5):

    run_dir = load_run_dir if load_run_dir else make_run_dir(run_base_dir)
    os.makedirs(run_dir, exist_ok=True)
    print(f"Run directory: {run_dir}")

    # -------------------
    # Load or pretrain SSL backbone
    # -------------------
    backbone = get_ssl_backbone(pretrained=False)
    proj_head = ProjectionHead().to(device)

    if load_run_dir and os.path.exists(os.path.join(run_dir, "ssl_backbone_state_dict.pth")):
        backbone.load_state_dict(torch.load(os.path.join(run_dir, "ssl_backbone_state_dict.pth"), map_location=device))
        proj_head.load_state_dict(torch.load(os.path.join(run_dir, "ssl_projection_head_state_dict.pth"), map_location=device))
        backbone.to(device).eval()
        proj_head.to(device).eval()
        print("Loaded pretrained backbone and projection head from saved run.")
    else:
        print("=== SSL Pretraining ===")
        if unlabelled_dir:
            backbone, proj_head, ssl_losses = pretrain_ssl(unlabelled_dir, ssl_transform,
                                                           epochs=ssl_epochs, batch_size=ssl_batch,
                                                           run_dir=run_dir, unlabelled=True)
        else:
            backbone, proj_head, ssl_losses = pretrain_ssl(labeled_df, ssl_transform,
                                                           epochs=ssl_epochs, batch_size=ssl_batch,
                                                           run_dir=run_dir, unlabelled=False)

    # -------------------
    # K-Fold CV
    # -------------------
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_SEED)
    fold_metrics = []

    for fold_idx, (train_idx, test_idx) in enumerate(kf.split(labeled_df)):
        print(f"\n--- K-Fold {fold_idx+1}/{n_splits} ---")
        train_df, test_df = labeled_df.iloc[train_idx], labeled_df.iloc[test_idx]

        # Fine-tune backbone if required
        if fine_tune_backbone:
            dataset = HbImageDataset(train_df, transform=ssl_transform, n_views=2)
            loader = DataLoader(dataset, batch_size=ssl_batch, shuffle=True, num_workers=2, pin_memory=True)
            backbone.train()
            proj_head.train()
            optimizer = torch.optim.Adam(list(backbone.parameters()) + list(proj_head.parameters()), lr=1e-4)
            for epoch in range(fine_tune_epochs):
                total_loss = 0.0
                for views, _ in loader:
                    v1, v2 = views[:,0].to(device), views[:,1].to(device)
                    feats1, feats2 = backbone(v1).view(v1.size(0), -1), backbone(v2).view(v2.size(0), -1)
                    z1, z2 = proj_head(feats1), proj_head(feats2)
                    loss = nt_xent_loss(z1, z2)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                print(f"Fold {fold_idx+1} Epoch {epoch+1} Loss: {total_loss/len(loader):.4f}")

        # Extract embeddings
        X_train, y_train = extract_embeddings(train_df, backbone, val_transform)
        X_test, y_test = extract_embeddings(test_df, backbone, val_transform)

        # Optuna hyperparameter tuning
        def objective(trial):
            params = {
                "n_estimators": trial.suggest_int("n_estimators",50,300),
                "max_depth": trial.suggest_int("max_depth",2,10),
                "learning_rate": trial.suggest_float("learning_rate",0.01,0.3,log=True),
                "subsample": trial.suggest_float("subsample",0.5,1.0),
                "colsample_bytree": trial.suggest_float("colsample_bytree",0.5,1.0),
                "random_state": RANDOM_SEED,
                "verbosity": 0,
                "n_jobs": 1,
                "tree_method": "hist"
            }
            kf_inner = KFold(n_splits=min(5, len(train_df)), shuffle=True, random_state=RANDOM_SEED)
            maes = []
            for tr_idx, val_idx in kf_inner.split(X_train):
                model = xgb.XGBRegressor(**params)
                model.fit(X_train[tr_idx], y_train[tr_idx])
                y_pred = model.predict(X_train[val_idx])
                maes.append(mean_absolute_error(y_train[val_idx], y_pred))
            return np.mean(maes)

        study = optuna.create_study(direction="minimize", sampler=optuna.samplers.TPESampler(seed=RANDOM_SEED))
        study.optimize(objective, n_trials=optuna_trials, show_progress_bar=False)
        best_params = study.best_params
        print(f"Best params fold {fold_idx+1}: {best_params}")

        # Train final model
        final_model = xgb.XGBRegressor(**best_params, random_state=RANDOM_SEED, n_jobs=-1, tree_method="hist")
        final_model.fit(X_train, y_train)
        y_pred = final_model.predict(X_test)

        # Compute metrics including R²
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        mae = mean_absolute_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        print(f"[Fold {fold_idx+1}] RMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")

        fold_metrics.append({"fold": fold_idx+1, "RMSE": rmse, "MAE": mae, "R2": r2})

    # -------------------
    # Summary metrics
    # -------------------
    metrics_df = pd.DataFrame(fold_metrics)
    metrics_df.loc["Mean"] = metrics_df.mean()
    print("\n=== Fold Metrics Summary ===")
    print(metrics_df)

    # Save metrics and models
    metrics_df.to_csv(os.path.join(run_dir, "kfold_metrics.csv"), index=False)
    torch.save(backbone.state_dict(), os.path.join(run_dir, "ssl_backbone_state_dict.pth"))
    torch.save(proj_head.state_dict(), os.path.join(run_dir, "ssl_projection_head_state_dict.pth"))

    # Visualization
    plt.figure(figsize=(10,6))
    plt.plot(metrics_df['fold'][:-1], metrics_df['RMSE'][:-1], marker='o', label='RMSE')
    plt.plot(metrics_df['fold'][:-1], metrics_df['MAE'][:-1], marker='s', label='MAE')
    plt.plot(metrics_df['fold'][:-1], metrics_df['R2'][:-1], marker='^', label='R²')
    plt.axhline(metrics_df.loc["Mean", 'RMSE'], color='blue', linestyle='--', alpha=0.5)
    plt.axhline(metrics_df.loc["Mean", 'MAE'], color='orange', linestyle='--', alpha=0.5)
    plt.axhline(metrics_df.loc["Mean", 'R2'], color='green', linestyle='--', alpha=0.5)
    plt.xlabel("Fold")
    plt.ylabel("Metric Value")
    plt.title("K-Fold Regression Metrics (RMSE, MAE, R²)")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(run_dir, "kfold_metrics_plot.png"))
    plt.show()

    return {
        "backbone": backbone,
        "proj_head": proj_head,
        "metrics_df": metrics_df,
        "run_dir": run_dir
    }



### R2 Kfold

#### K Folder Validation

#### Current Best (SSL + Finetuning) High R2 > 0.5

In [None]:
loaded_results = run_ssl_pipeline_kfold_with_r2(
    labeled_df=df,
    load_run_dir="/content/drive/MyDrive/DSA_Comp/ssl_models/ssl_740",
    fine_tune_epochs = 25
)

#### SSL MultiTask Only

In [None]:
loaded_results = run_ssl_pipeline_kfold_with_r2(
    labeled_df=df,
    load_run_dir="/content/drive/MyDrive/DSA_Comp/ssl_models/ssl_multi",
    fine_tune_epochs = 0
)

#### SSL Multi Task + Finetuning == Horrible R2 <0

In [None]:
loaded_results = run_ssl_pipeline_kfold_with_r2(
    labeled_df=df,
    load_run_dir="/content/drive/MyDrive/DSA_Comp/ssl_models/ssl_multi",
    fine_tune_epochs = 25
)