# Step 3: PyTorch Dataset & DataLoader

In [1]:

import os                                 
import pandas as pd                      
from PIL import Image                     
import torch                             
from torch.utils.data import Dataset, DataLoader  
from torchvision import transforms        
from ImageDataset import CYPImageDataset

In [2]:
# Define Transforms 
# We resize, center‑crop to 224×224, convert to tensor, and normalize.

# Use CLIP normalization
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_STD  = (0.26862954, 0.26130258, 0.27577711)

no_aug_transforms = transforms.Compose([
    transforms.Resize(256),                            # ensure min side=256
    transforms.CenterCrop(224),                        # crop to 224×224
    transforms.ToTensor(),                             # PIL→FloatTensor, scales to [0,1]
    transforms.Normalize(CLIP_MEAN, CLIP_STD)          # Use CLIP normalization
])

In [3]:
# Instantiate DataLoaders  (for CYP1A2)

from ImageDataset import CYPImageDataset

# Common loader settings
batch_size  = 32
num_workers = 4

# Training set
train_ds = CYPImageDataset(
    csv_file="../data/processed/1A2_train.csv",
    image_dir="../images/1A2/train/clean",
    transform=no_aug_transforms
)
train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,        # shuffle for training
    num_workers=num_workers,
    pin_memory=True      # speeds up host→GPU transfer
)

# Validation set
val_ds = CYPImageDataset(
    csv_file="../data/processed/1A2_val.csv",
    image_dir="../images/1A2/val/clean",
    transform=no_aug_transforms
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,       
    num_workers=num_workers,
    pin_memory=True
)

# Test set
test_ds = CYPImageDataset(
    csv_file="../data/processed/1A2_test.csv",
    image_dir="../images/1A2/test/clean",
    transform=no_aug_transforms
)
test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,      
    num_workers=num_workers,
    pin_memory=True
)

# Quick sanity check
print(f"Train batches: {len(train_loader)},  Val batches: {len(val_loader)},  Test batches: {len(test_loader)}")


Train batches: 315,  Val batches: 40,  Test batches: 40


# Step 4: Model Definition & Transfer‑Learning Setup

1. first load the standard CLIP ViT via the clip library (e.g. clip.load("ViT‑B/16")).

2. wrap and rename its visual submodule as backbone.

3. then load the MoleCLIP checkpoint (the .pth file) and filter for all keys beginning with visual.

4. call backbone.load_state_dict(moleclip_state, strict=False), which replaces the vanilla CLIP weights with the MoleCLIP‑fine‑tuned weights.

5. Finally, freeze those backbone parameters and add  new CYPHead on top.

In [4]:
# === Robust MoleCLIP backbone loading + fine-tune-friendly setup ===
import os
import torch
import clip
from torch import nn

In [5]:
# 0) Settings you can tweak
MOLECLIP_CKPT = "checkpoints/MoleCLIP - Primary.pth"   # path to MoleCLIP checkpoint
CLIP_MODEL_NAME = "ViT-B/16"                          # CLIP variant to match
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FINETUNE_LAST_N = 3    # how many last transformer blocks to un-freeze

In [6]:
# 1) Load CLIP ViT model (only image encoder used)
#    Load on CPU first — we will move model to DEVICE after weight loading.
clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
class Backbone(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        # Use only the image encoder (named "visual" in the CLIP wrapper)
        self.visual = clip_model.visual
    def forward(self, x):
        return self.visual(x)

backbone = Backbone(clip_model)  # currently on CPU

In [7]:
# 2) Load MoleCLIP checkpoint (on CPU) and inspect keys
ck = torch.load(MOLECLIP_CKPT, map_location="cpu")
ck_model = ck["model"] if isinstance(ck, dict) and "model" in ck else ck

  ck = torch.load(MOLECLIP_CKPT, map_location="cpu")


In [8]:
# 3) Auto-detect best prefix that contains the image-encoder keys
#    (we search for a prefix such that many keys start with it)
possible_prefixes = [
    "model_image.module.model.visual.",
    "model_image.model.visual.",
    "model.visual.", "visual.",
    "module.visual.", "backbone.visual.",
    "clip.visual.", "visual"
]

# fallback: also try to extract any prefix up to 'visual.' occurrence
extra_prefixes = set()
for k in ck_model.keys():
    if "visual." in k:
        # take substring upto and including 'visual.'
        idx = k.find("visual.") + len("visual.")
        # find prefix that ends right before 'visual.' occurrence
        prefix = k[:k.find("visual.")+len("visual.")]
        extra_prefixes.add(prefix)
possible_prefixes = list(possible_prefixes) + list(sorted(extra_prefixes))

# function to count how many keys under a given prefix match backbone.visual keys
backbone_sd = backbone.visual.state_dict()
backbone_keys = set(backbone_sd.keys())

best_prefix = None
best_count = -1
best_mapped = {}

for prefix in possible_prefixes:
    prefixed = {k: v for k, v in ck_model.items() if k.startswith(prefix)}
    # mapped = stripped prefix names
    mapped = {k[len(prefix):]: v for k, v in prefixed.items()}
    common = set(mapped.keys()).intersection(backbone_keys)
    if len(common) > best_count:
        best_count = len(common)
        best_prefix = prefix
        best_mapped = mapped

if best_prefix is None or best_count == 0:
    raise RuntimeError("Could not find matching visual/encoder block in MoleCLIP checkpoint. "
                       "Inspect checkpoint keys manually.")

print(f"Auto-detected checkpoint prefix for visual weights: '{best_prefix}'")
print(f"Checkpoint provides {len([k for k in ck_model.keys() if k.startswith(best_prefix)])} keys under that prefix.")
print(f"backbone.visual expects {len(backbone_keys)} keys; {best_count} keys match by name under prefix.")

Auto-detected checkpoint prefix for visual weights: 'model_image.module.model.visual.'
Checkpoint provides 152 keys under that prefix.
backbone.visual expects 152 keys; 152 keys match by name under prefix.


In [9]:
# 4) Select only keys that both exist in backbone.visual and have matching shapes
to_load = {}
bad_shapes = []
for k, v in best_mapped.items():
    if k in backbone_sd:
        if tuple(v.shape) == tuple(backbone_sd[k].shape):
            to_load[k] = v
        else:
            bad_shapes.append((k, tuple(v.shape), tuple(backbone_sd[k].shape)))

print(f"Matching tensors to load: {len(to_load)}")
print(f"Mismatched-shape tensors (will NOT be loaded): {len(bad_shapes)}")
if bad_shapes:
    print("Some mismatches (first 20):")
    for i, item in enumerate(bad_shapes[:20]):
        print(" ", i, item)

Matching tensors to load: 152
Mismatched-shape tensors (will NOT be loaded): 0


In [10]:
# 5) Load matching tensors into backbone.visual state dict (safe update)
new_sd = backbone_sd.copy()
new_sd.update(to_load)
backbone.visual.load_state_dict(new_sd)   # we provide a filled state-dict (no unexpected keys)
print("Loaded MoleCLIP visual weights into backbone.visual (matching keys updated).")

Loaded MoleCLIP visual weights into backbone.visual (matching keys updated).


In [11]:
# 6) Move backbone to DEVICE and inspect dtype/device
backbone = backbone.to(DEVICE)
first_param = next(backbone.visual.parameters())
print("backbone.visual device:", first_param.device, "dtype:", first_param.dtype)

backbone.visual device: cuda:0 dtype: torch.float32


In [12]:
# 7) Instantiate VisCYPNet using this backbone
from model_viscypnet import VisCYPNet
model = VisCYPNet(
    backbone=backbone,
    head_hidden_dims=[256,64],
    head_dropout=0.2,
    device=DEVICE
).to(DEVICE)

In [13]:
# 8) By default freeze the entire backbone (common pattern when fine-tuning)
for p in model.backbone.visual.parameters():
    p.requires_grad = False

In [14]:
# 9) Un-freeze the last FINETUNE_LAST_N transformer blocks (if available)
#    This accesses: model.backbone.visual.transformer.resblocks which is a nn.Sequential
resblocks = model.backbone.visual.transformer.resblocks
num_blocks = len(list(resblocks))
N = min(FINETUNE_LAST_N, num_blocks)
print(f"Total transformer resblocks: {num_blocks}; unfreezing last {N} blocks.")

for block in list(resblocks)[-N:]:
    for p in block.parameters():
        p.requires_grad = True

Total transformer resblocks: 12; unfreezing last 3 blocks.


In [15]:
# 10) Check which parameters are trainable (sanity)
trainable = [name for name, p in model.named_parameters() if p.requires_grad]
print(f"Number of trainable parameter tensors: {len(trainable)}")
# optional: print first few names for inspection
for nm in trainable[:40]:
    print("  -", nm)

Number of trainable parameter tensors: 44
  - backbone.visual.transformer.resblocks.9.attn.in_proj_weight
  - backbone.visual.transformer.resblocks.9.attn.in_proj_bias
  - backbone.visual.transformer.resblocks.9.attn.out_proj.weight
  - backbone.visual.transformer.resblocks.9.attn.out_proj.bias
  - backbone.visual.transformer.resblocks.9.ln_1.weight
  - backbone.visual.transformer.resblocks.9.ln_1.bias
  - backbone.visual.transformer.resblocks.9.mlp.c_fc.weight
  - backbone.visual.transformer.resblocks.9.mlp.c_fc.bias
  - backbone.visual.transformer.resblocks.9.mlp.c_proj.weight
  - backbone.visual.transformer.resblocks.9.mlp.c_proj.bias
  - backbone.visual.transformer.resblocks.9.ln_2.weight
  - backbone.visual.transformer.resblocks.9.ln_2.bias
  - backbone.visual.transformer.resblocks.10.attn.in_proj_weight
  - backbone.visual.transformer.resblocks.10.attn.in_proj_bias
  - backbone.visual.transformer.resblocks.10.attn.out_proj.weight
  - backbone.visual.transformer.resblocks.10.attn.

In [16]:
# 11) Build optimizer with two parameter groups:
#     - backbone params (those that require grad) at a small LR
#     - head params (usually model.head) at a larger LR
backbone_params = [p for n, p in model.backbone.named_parameters() if p.requires_grad]
# attempt to locate the head parameters robustly
if hasattr(model, "head"):
    head_params = [p for p in model.head.parameters()]
else:
    # fallback: treat any parameter not in backbone as 'head'
    backbone_names = {n for n, p in model.backbone.named_parameters()}
    head_params = [p for n, p in model.named_parameters() if n.split(".")[0] not in backbone_names]

print(f"Backbone trainable tensors: {len(backbone_params)}; head params approx: {len(head_params)}")

optimizer = torch.optim.AdamW([
    {"params": backbone_params, "lr": 1e-6, "weight_decay": 1e-6},
    {"params": head_params, "lr": 1e-4,     "weight_decay": 1e-4},
])

print("Optimizer created with 2 parameter groups (backbone small LR, head larger LR).")

Backbone trainable tensors: 36; head params approx: 8
Optimizer created with 2 parameter groups (backbone small LR, head larger LR).


In [17]:
# 12) Small verification: dummy forward (optional, but recommended)
try:
    dummy = torch.randn(1, 3, 224, 224, device=first_param.device, dtype=first_param.dtype)
    with torch.no_grad():
        out = model.backbone.visual(dummy)
    print("Dummy forward OK. Visual output shape:", getattr(out, "shape", None))
except Exception as e:
    print("Warning: dummy forward failed:", e)
    # If it fails, inspect dtype/device mismatch between dummy and model parameters.


Dummy forward OK. Visual output shape: torch.Size([1, 512])


In [18]:
model

VisCYPNet(
  (backbone): Backbone(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuant

# Step 5: Training Loop & Hyperparameter Random Search

In [19]:
import copy, random, os
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from itertools import product

# ensure you have "preprocess" from clip.load earlier in the notebook:
# clip_model, preprocess = clip.load("ViT-B/16", device="cpu")  # done earlier

# Use CPU/GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

In [20]:

# ===== Loader factory (use CLIP preprocess as transform) =====
from ImageDataset import CYPImageDataset

def make_loaders(batch_size):
    # pass CLIP preprocess (returns tensor normalized for the backbone)
    train_ds = CYPImageDataset(
        csv_file="../data/processed/1A2_train.csv",
        image_dir="../images/1A2/train/clean",
        transform=preprocess
    )
    val_ds = CYPImageDataset(
        csv_file="../data/processed/1A2_val.csv",
        image_dir="../images/1A2/val/clean",
        transform=preprocess
    )
    test_ds = CYPImageDataset(
        csv_file="../data/processed/1A2_test.csv",
        image_dir="../images/1A2/test/clean",
        transform=preprocess
    )

    # num_workers tuned to your machine; pin_memory True if using CUDA
    pin = True if device.type == "cuda" else False
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=pin)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=2, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=2, pin_memory=pin)

    return train_loader, val_loader, test_loader

In [21]:


# ===== Training helpers (handle label shapes & optional AMP) =====
use_amp = True if device.type == "cuda" else False
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

def train_one_epoch(model, loader, optimizer, loss_fn, device, amp=scaler):
    model.train()
    total_loss = 0.0
    n = 0
    for imgs, lbls in loader:
        # imgs expected as tensor (C,H,W) after preprocess
        imgs = imgs.to(device, non_blocking=True)
        # ensure labels are float and same shape as logits later
        lbls = lbls.float().to(device, non_blocking=True)

        optimizer.zero_grad()
        if use_amp:
            with torch.cuda.amp.autocast():
                logits = model(imgs)                              # logits shape (B,1) or (B,)
                # If logits are (B,1), make labels (B,1)
                if logits.dim() == 2 and logits.shape[1] == 1:
                    lbl = lbls.view(-1,1)
                else:
                    lbl = lbls.view(-1)
                loss = loss_fn(logits, lbl)
            amp.scale(loss).backward()
            amp.step(optimizer)
            amp.update()
        else:
            logits = model(imgs)
            if logits.dim() == 2 and logits.shape[1] == 1:
                lbl = lbls.view(-1,1)
            else:
                lbl = lbls.view(-1)
            loss = loss_fn(logits, lbl)
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        n += imgs.size(0)
    return total_loss / max(1, n)

def compute_auc(model, loader, device):
    model.eval()
    all_probs, all_lbls = [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device, non_blocking=True)
            lbls = lbls.cpu().numpy().flatten()
            # forward: handle mixed precision if enabled
            if use_amp:
                with torch.cuda.amp.autocast():
                    logits = model(imgs)
            else:
                logits = model(imgs)
            probs = torch.sigmoid(logits).cpu().numpy().flatten()
            all_probs.extend(probs.tolist())
            all_lbls.extend(lbls.tolist())
    try:
        return roc_auc_score(all_lbls, all_probs)
    except Exception:
        return 0.5

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


In [22]:
# ===== Hyperparameter search setup  =====
search_space = {
    'head_lr':     [1e-3, 1e-4, 1e-5],
    'batch_size': [16, 32, 64],
    'dropout':    [0.0, 0.2, 0.5],
    'weight_decay':[1e-4, 1e-3, 1e-2],
}
param_grid = list(product(
    search_space['head_lr'],
    search_space['batch_size'],
    search_space['dropout'],
    search_space['weight_decay'],
))
random.shuffle(param_grid)
max_trials = min(10, len(param_grid))

# fix seeds
random.seed(42); np.random.seed(42); torch.manual_seed(42)
if device.type == "cuda":
    torch.cuda.manual_seed_all(42)

best_auc    = 0.0
best_params = None
best_state  = None   # store state_dict of best model to save memory

In [23]:
# ===== Run Random Search =====
from model_viscypnet import VisCYPNet

for i, (lr, bs, do, wd) in enumerate(param_grid[:max_trials], 1):
    print(f"\nTrial {i}/{max_trials}: lr={lr}, bs={bs}, dropout={do}, wd={wd}")

    # loaders
    train_loader, val_loader, test_loader = make_loaders(bs)

    # instantiate fresh model; ensure backbone is used (backbone already loaded earlier)
    model = VisCYPNet(
        backbone=backbone,               # MoleCLIP backbone loaded earlier
        head_hidden_dims=[256,64],
        head_dropout=do,
        device=device
    ).to(device)

    # Force model dtype to match backbone's first param dtype (avoid float16/float32 mismatch)
    first_param = next(model.backbone.visual.parameters())
    model = model.to(device=device, dtype=first_param.dtype)

    # Freeze backbone then unfreeze last N blocks
    for p in model.backbone.visual.parameters():
        p.requires_grad = False

    N = 3
    resblocks = model.backbone.visual.transformer.resblocks
    for block in list(resblocks)[-N:]:
        for p in block.parameters():
            p.requires_grad = True

    # Build optimizer (must collect parameters after updating requires_grad)
    backbone_params = [p for p in model.backbone.parameters() if p.requires_grad]
    head_params = [p for p in model.head.parameters()]

    optimizer = torch.optim.AdamW([
        {'params': backbone_params, 'lr': 5e-7, 'weight_decay': 1e-6},
        {'params': head_params,     'lr': lr,   'weight_decay': wd}
    ])

    loss_fn = nn.BCEWithLogitsLoss()

    # quick training (5 epochs)
    for epoch in range(1, 6):
        tr_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
        val_auc = compute_auc(model, val_loader, device)
        print(f"  Epoch {epoch}: train_loss={tr_loss:.4f}, val_auc={val_auc:.4f}")

    # Track best: save state_dict (less memory than deepcopy)
    if val_auc > best_auc:
        best_auc = val_auc
        best_params = {'lr': lr, 'bs': bs, 'dropout': do, 'wd': wd}
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}  # keep on CPU

# After search, save best model weights
if best_state is not None:
    torch.save(best_state, "best_viscypnet_moleclip.pth")
    print("Saved best model -> best_viscypnet_moleclip.pth with params:", best_params, "AUC:", best_auc)
else:
    print("No model improved on initial best_auc.")



Trial 1/10: lr=0.0001, bs=16, dropout=0.2, wd=0.01


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.4729, val_auc=0.9000


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.3843, val_auc=0.9116


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.3440, val_auc=0.9200


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.3086, val_auc=0.9198


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.2717, val_auc=0.9185

Trial 2/10: lr=0.001, bs=32, dropout=0.2, wd=0.0001


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.3314, val_auc=0.9209


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.2665, val_auc=0.9193


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.2267, val_auc=0.9119


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.1926, val_auc=0.9140


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.1527, val_auc=0.9039

Trial 3/10: lr=0.0001, bs=16, dropout=0.0, wd=0.01


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.3293, val_auc=0.9227


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.2031, val_auc=0.9234


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.1545, val_auc=0.9152


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.1223, val_auc=0.9165


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.0948, val_auc=0.9145

Trial 4/10: lr=1e-05, bs=32, dropout=0.2, wd=0.001


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.5831, val_auc=0.8951


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.4003, val_auc=0.9190


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.3157, val_auc=0.9284


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.2588, val_auc=0.9314


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.2148, val_auc=0.9313

Trial 5/10: lr=0.001, bs=32, dropout=0.5, wd=0.01


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.1819, val_auc=0.9205


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.1171, val_auc=0.9152


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.0868, val_auc=0.9136


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.0745, val_auc=0.9108


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.0638, val_auc=0.9076

Trial 6/10: lr=0.001, bs=64, dropout=0.5, wd=0.01


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.1267, val_auc=0.9182


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.0555, val_auc=0.9128


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.0426, val_auc=0.9042


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.0383, val_auc=0.9022


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.0338, val_auc=0.9053

Trial 7/10: lr=1e-05, bs=64, dropout=0.5, wd=0.001


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.6085, val_auc=0.9155


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.4512, val_auc=0.9238


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.3543, val_auc=0.9313


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.2933, val_auc=0.9346


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.2425, val_auc=0.9346

Trial 8/10: lr=0.0001, bs=64, dropout=0.2, wd=0.01


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.2466, val_auc=0.9312


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.0850, val_auc=0.9259


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.0498, val_auc=0.9221


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.0339, val_auc=0.9222


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.0252, val_auc=0.9170

Trial 9/10: lr=0.0001, bs=16, dropout=0.2, wd=0.001


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.1148, val_auc=0.9230


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.0303, val_auc=0.9201


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.0217, val_auc=0.9125


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.0173, val_auc=0.9030


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.0166, val_auc=0.9054

Trial 10/10: lr=1e-05, bs=16, dropout=0.2, wd=0.001


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 1: train_loss=0.3670, val_auc=0.9359


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 2: train_loss=0.1203, val_auc=0.9319


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 3: train_loss=0.0539, val_auc=0.9315


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 4: train_loss=0.0279, val_auc=0.9244


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


  Epoch 5: train_loss=0.0188, val_auc=0.9293
Saved best model -> best_viscypnet_moleclip.pth with params: {'lr': 1e-05, 'bs': 64, 'dropout': 0.5, 'wd': 0.001} AUC: 0.934586302424399


# Step 6 & Step 7: Final Training, Validation Check & Test Evaluation

We rebuild the loaders using the chosen best_bs.

We train for 20 epochs, tracking validation ROC AUC and saving the best model state.

After training, we save the best weights to models/Without_Augmentation_CYP1A2.pth.

We define compute_metrics to calculate ROC AUC, BA, MCC, PRE, REC, and F1 at a 0.5 threshold.

Finally, we load the best weights back into model and print all metrics for train, val, and test.

In [24]:
import os, copy
import torch, numpy as np
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import (balanced_accuracy_score, matthews_corrcoef,
                             precision_score, recall_score, f1_score, roc_auc_score)
from ImageDataset import CYPImageDataset
from model_viscypnet import VisCYPNet

In [25]:
# -------------------------
# 1) Best hyperparameters (from Step 5)
# -------------------------
#  {'lr': 1e-05, 'bs': 64, 'dropout': 0.5, 'wd': 0.001}
best_lr = 1e-05
best_bs = 64
best_do = 0.5
best_wd = 0.01

In [26]:
# -------------------------
# 2) Transform & DataLoaders
# -------------------------
# prefer the CLIP 'preprocess' used earlier; if not available fall back to no_aug_transforms
try:
    preprocess  # noqa: F821
    transform_for_dataset = preprocess
except NameError:
    # fallback: define minimal no-aug transforms (assumes torchvision.transforms imported earlier)
    import torchvision.transforms as transforms
    CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
    CLIP_STD  = (0.26862954, 0.26130258, 0.27577711)
    no_aug_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(CLIP_MEAN, CLIP_STD)
    ])
    transform_for_dataset = no_aug_transforms

def make_loaders(batch_size):
    ds_train = CYPImageDataset(csv_file="../data/processed/1A2_train.csv",
                               image_dir="../images/1A2/train/clean",
                               transform=transform_for_dataset)
    ds_val   = CYPImageDataset(csv_file="../data/processed/1A2_val.csv",
                               image_dir="../images/1A2/val/clean",
                               transform=transform_for_dataset)
    ds_test  = CYPImageDataset(csv_file="../data/processed/1A2_test.csv",
                               image_dir="../images/1A2/test/clean",
                               transform=transform_for_dataset)
    pin = True if torch.cuda.is_available() else False
    train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=pin)
    val_loader   = DataLoader(ds_val,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=pin)
    test_loader  = DataLoader(ds_test,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=pin)
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = make_loaders(best_bs)

In [27]:
# -------------------------
# 3) Instantiate model & align device + dtype with backbone
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# IMPORTANT: 'backbone' should already be loaded with MoleCLIP visual weights
# (from your earlier cell). We'll use the dtype/device of backbone.visual's first param.
first_backbone_param = next(backbone.visual.parameters())
backbone_device = first_backbone_param.device
backbone_dtype  = first_backbone_param.dtype

model = VisCYPNet(backbone=backbone, head_hidden_dims=[256,64], head_dropout=best_do, device=device)

# Move entire model to device and to the same dtype as backbone visual params to prevent mismatch
model = model.to(device=backbone_device, dtype=backbone_dtype)

In [28]:
# -------------------------
# 4) Freeze backbone and unfreeze last N transformer blocks, BEFORE building optimizer
# -------------------------
# Freeze all backbone visual params
for p in model.backbone.visual.parameters():
    p.requires_grad = False

# Un-freeze last N transformer blocks
N = 3
resblocks = model.backbone.visual.transformer.resblocks
num_blocks = len(list(resblocks))
N = min(N, num_blocks)
for block in list(resblocks)[-N:]:
    for p in block.parameters():
        p.requires_grad = True

# Also make sure the head parameters require grad
for p in model.head.parameters():
    p.requires_grad = True


In [29]:
# -------------------------
# 5) Build optimizer with two parameter groups (after requires_grad set)
# -------------------------
backbone_trainable_params = [p for p in model.backbone.parameters() if p.requires_grad]
head_params = [p for p in model.head.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW([
    {"params": backbone_trainable_params, "lr": 5e-7, "weight_decay": 1e-6},
    {"params": head_params,               "lr": best_lr, "weight_decay": best_wd}
])

loss_fn = nn.BCEWithLogitsLoss()

In [30]:
# -------------------------
# 6) Train with AMP if CUDA (safe dtype handling)
# -------------------------
use_amp = True if (torch.cuda.is_available()) else False
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

num_epochs = 10
best_val_auc = -1.0
best_state = None

for epoch in range(1, num_epochs+1):
    # Train
    model.train()
    running_loss = 0.0
    n_samples = 0
    for imgs, lbls in train_loader:
        # imgs should already be tensors (preprocess) with correct normalization; move to backbone device/dtype
        imgs = imgs.to(device=backbone_device, dtype=backbone_dtype, non_blocking=True)
        lbls = lbls.float().to(device=backbone_device, non_blocking=True)

        optimizer.zero_grad()
        if use_amp:
            with torch.cuda.amp.autocast():
                logits = model(imgs)  # shape: (B,1) or (B,)
                if logits.dim()==2 and logits.shape[1]==1:
                    lbl = lbls.view(-1,1)
                else:
                    lbl = lbls.view(-1)
                loss = loss_fn(logits, lbl)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(imgs)
            if logits.dim()==2 and logits.shape[1]==1:
                lbl = lbls.view(-1,1)
            else:
                lbl = lbls.view(-1)
            loss = loss_fn(logits, lbl)
            loss.backward()
            optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        n_samples += imgs.size(0)

    train_loss = running_loss / max(1, n_samples)

    # Validate
    model.eval()
    all_logits = []
    all_lbls = []
    with torch.no_grad():
        for imgs, lbls in val_loader:
            imgs = imgs.to(device=backbone_device, dtype=backbone_dtype, non_blocking=True)
            lbls_np = lbls.numpy().flatten()
            if use_amp:
                with torch.cuda.amp.autocast():
                    logits = model(imgs)
            else:
                logits = model(imgs)

            logits_np = logits.detach().cpu().numpy().flatten()
            all_logits.extend(logits_np.tolist())
            all_lbls.extend(lbls_np.tolist())

    probs = 1.0 / (1.0 + np.exp(-np.array(all_logits)))
    try:
        val_auc = roc_auc_score(all_lbls, probs)
    except Exception:
        val_auc = 0.5

    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Val ROC AUC: {val_auc:.4f}")

    # Checkpoint best model (save state_dict to CPU)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        # copy state_dict to CPU tensors to avoid GPU memory pinning
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

# Save best model
os.makedirs("models", exist_ok=True)
path = "models/VisCYPNet_CYP1A2.pth"
if best_state is not None:
    torch.save(best_state, path)
    print(f"Saved best model state_dict (Val AUC={best_val_auc:.4f}) -> {path}")
else:
    print("No model improved during training; nothing saved.")


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 01 | Train Loss: 0.5694 | Val ROC AUC: 0.9324


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 02 | Train Loss: 0.3525 | Val ROC AUC: 0.9373


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 03 | Train Loss: 0.2243 | Val ROC AUC: 0.9377


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 04 | Train Loss: 0.1447 | Val ROC AUC: 0.9358


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 05 | Train Loss: 0.0936 | Val ROC AUC: 0.9343


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 06 | Train Loss: 0.0644 | Val ROC AUC: 0.9336


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 07 | Train Loss: 0.0442 | Val ROC AUC: 0.9323


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 08 | Train Loss: 0.0341 | Val ROC AUC: 0.9322


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 09 | Train Loss: 0.0260 | Val ROC AUC: 0.9324


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Epoch 10 | Train Loss: 0.0213 | Val ROC AUC: 0.9317
Saved best model state_dict (Val AUC=0.9377) -> models/VisCYPNet_CYP1A2.pth


In [31]:
# -------------------------
# 7) Load best_state and evaluate on Train/Val/Test with final metrics
# -------------------------
if best_state is not None:
    model.load_state_dict(best_state)   # ensure model has best weights
model = model.to(device=backbone_device, dtype=backbone_dtype)
model.eval()

def compute_metrics(model, loader, device_dtype):
    model.eval()
    logits_list, labels_list = [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device=backbone_device, dtype=backbone_dtype, non_blocking=True)
            if use_amp:
                with torch.cuda.amp.autocast():
                    logits = model(imgs).detach().cpu().numpy().flatten()
            else:
                logits = model(imgs).detach().cpu().numpy().flatten()
            logits_list.extend(logits.tolist())
            labels_list.extend(lbls.numpy().flatten().tolist())

    probs = 1.0 / (1.0 + np.exp(-np.array(logits_list)))
    preds = (probs >= 0.5).astype(int)
    # metrics (safe wrt single-class / degenerate cases)
    metrics = {}
    try:
        metrics["ROC_AUC"] = roc_auc_score(labels_list, probs)
    except Exception:
        metrics["ROC_AUC"] = float("nan")
    try:
        metrics["BA"] = balanced_accuracy_score(labels_list, preds)
    except Exception:
        metrics["BA"] = float("nan")
    try:
        metrics["MCC"] = matthews_corrcoef(labels_list, preds)
    except Exception:
        metrics["MCC"] = float("nan")
    metrics["PRE"] = precision_score(labels_list, preds, zero_division=0)
    metrics["REC"] = recall_score(labels_list, preds, zero_division=0)
    metrics["F1"]  = f1_score(labels_list, preds, zero_division=0)
    metrics["ACC"] = np.mean(preds == np.array(labels_list))
    return metrics

print("\n=== Final Metrics ===")
for split, loader in [("Train", train_loader), ("Val", val_loader), ("Test", test_loader)]:
    m = compute_metrics(model, loader, backbone_dtype)
    line = ", ".join(f"{k}={v:.4f}" if isinstance(v, (float, np.floating)) else f"{k}={v}" for k,v in m.items())
    print(f"{split}: {line}")



=== Final Metrics ===


  with torch.cuda.amp.autocast():


Train: ROC_AUC=0.9977, BA=0.9841, MCC=0.9674, PRE=0.9772, REC=0.9873, F1=0.9822, ACC=0.9838


  with torch.cuda.amp.autocast():


Val: ROC_AUC=0.9377, BA=0.8509, MCC=0.7027, PRE=0.8414, REC=0.8717, F1=0.8563, ACC=0.8512


  with torch.cuda.amp.autocast():


Test: ROC_AUC=0.9378, BA=0.8689, MCC=0.7380, PRE=0.8748, REC=0.8594, F1=0.8670, ACC=0.8689
