<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/pimcl_mil_contrastive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Preprocess

In [1]:
!pip install SimpleITK torchio nilearn



In [2]:
%%writefile preprocess.py
import os
import zipfile
import numpy as np
import subprocess
import sys
import argparse
from glob import glob
import SimpleITK as sitk
import nibabel as nib
from nilearn import datasets

# Fallback dependency installation
for pkg in ['simpleitk', 'torchio', 'nilearn']:
    try:
        __import__(pkg)
    except ImportError:
        print(f"Installing {pkg}...")
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', pkg])

# Hyperparameter Configuration
parser = argparse.ArgumentParser(description='PMICL Training')
parser.add_argument('--emb_dim', type=int, default=256)
parser.add_argument('--protos_per_class', type=int, default=3)
parser.add_argument('--lambda_proto', type=float, default=0.5)
parser.add_argument('--patch_size', type=int, default=32)
parser.add_argument('--stride', type=int, default=32)
parser.add_argument('--N', type=int, default=64)
parser.add_argument('--max_files', type=int, default=None)
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--batch_size', type=int, default=2)
args, _ = parser.parse_known_args()

# Paths
ZIP_PATH = '/content/oaisis.zip'
RAW_DIR = '/content/oasis_raw'
PREPROC_DIR = '/content/oasis_preproc'
MNI_TEMPLATE = '/content/icbm152_2009/mni_icbm152_nlin_sym_09a/mni_icbm152_t1_tal_nlin_sym_09a.nii'

# Verify zip file
if not os.path.exists(ZIP_PATH):
    print(f"Error: Zip file {ZIP_PATH} not found. Please upload it to /content/")
    sys.exit(1)

# List zip contents
print("Zip file contents:")
with zipfile.ZipFile(ZIP_PATH, 'r') as z:
    zip_files = z.namelist()
    for fname in zip_files:
        print(f" - {fname}")
    data_files = [f for f in zip_files if f.startswith('Data/') and f.endswith(('.nii', '.nii.gz'))]
    print(f"Found {len(data_files)} .nii/.nii.gz files in zip under Data/")

# Unzip
os.makedirs(RAW_DIR, exist_ok=True)
with zipfile.ZipFile(ZIP_PATH, 'r') as z:
    z.extractall(RAW_DIR)
print(f"Unzipped to {RAW_DIR}")

# Verify unzipped files
print("Unzipped directory contents:")
for cls in ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']:
    cls_dir = os.path.join(RAW_DIR, 'Data', cls)
    if os.path.exists(cls_dir):
        nii_files = glob(os.path.join(cls_dir, '*.nii*'))
        print(f" - {cls_dir}: {len(nii_files)} .nii/.nii.gz files")
        for f in nii_files:
            print(f"   - {f} (size: {os.path.getsize(f)} bytes)")
    else:
        print(f" - {cls_dir}: Directory not found")

# Fetch MNI template if not present
if not os.path.exists(MNI_TEMPLATE):
    icbm = datasets.fetch_icbm152_2009(data_dir='/content')
    MNI_TEMPLATE = icbm.t1
    print(f"Fetched MNI template to {MNI_TEMPLATE}")

# Preprocess
os.makedirs(PREPROC_DIR, exist_ok=True)
classes = ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
for cls in classes:
    src = os.path.join(RAW_DIR, 'Data', cls)
    dst = os.path.join(PREPROC_DIR, cls)
    os.makedirs(dst, exist_ok=True)
    files = sorted(glob(os.path.join(src, '*.nii*')))
    if args.max_files:
        files = files[:args.max_files]
    if not files:
        print(f"No .nii files found in {src}")
        continue
    print(f"Processing {len(files)} files in {src}")
    for fpath in files:
        try:
            print(f"Processing {fpath}")
            img = sitk.ReadImage(fpath)
            print(f" - Loaded image: {fpath}, size={img.GetSize()}")
            img = sitk.N4BiasFieldCorrectionImageFilter().Execute(img)
            mask = sitk.OtsuThreshold(img, 0, 1)
            img = sitk.Mask(img, mask)
            template_img = sitk.ReadImage(MNI_TEMPLATE)
            reg = sitk.ImageRegistrationMethod()
            reg.SetMetricAsMeanSquares()
            reg.SetOptimizerAsGradientDescent(1.0, 100)
            init_t = sitk.CenteredTransformInitializer(template_img, img, sitk.Euler3DTransform())
            reg.SetInitialTransform(init_t)
            trans = reg.Execute(template_img, img)
            img = sitk.Resample(img, template_img, trans, sitk.sitkLinear, 0.0, img.GetPixelID())
            arr = sitk.GetArrayFromImage(img).astype(np.float32)
            arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-6)
            out_path = os.path.join(dst, os.path.basename(fpath).replace('.nii', '.npy'))
            np.save(out_path, arr)
            print(f"Saved {out_path} (size: {os.path.getsize(out_path)} bytes)")
        except Exception as e:
            print(f"Error processing {fpath}: {str(e)}")

# Verify output
print("Preprocessing output:")
for cls in classes:
    dst = os.path.join(PREPROC_DIR, cls)
    npy_files = glob(os.path.join(dst, '*.npy'))
    print(f" - {dst}: {len(npy_files)} .npy files")
    for f in npy_files:
        print(f"   - {f} (size: {os.path.getsize(f)} bytes)")

# Load atlas
atlas_img = '/content/gray_matter_mask.nii.gz'
try:
    atlas_data = nib.load(atlas_img).get_fdata()
    print(f"Loaded gray matter mask from {atlas_img}")
except (FileNotFoundError, nib.filebasedimages.ImageFileError):
    from nilearn.datasets import fetch_atlas_harvard_oxford
    atlas_res = fetch_atlas_harvard_oxford('cort-maxprob-thr25-1mm')
    maps = atlas_res.maps
    if isinstance(maps, str):
        atlas_data = nib.load(maps).get_fdata() > 0
    else:
        atlas_data = maps.get_fdata() > 0
    print(f"Fetched Harvard-Oxford atlas mask from: {maps}")
atlas = atlas_data.astype(float).flatten()

# Save atlas
np.save('/content/atlas.npy', atlas)
print("Saved atlas to /content/atlas.npy")

Overwriting preprocess.py


#-Dataset -

In [3]:
%%writefile dataset.py
import os
import numpy as np
import torch
import torchio as tio
import SimpleITK as sitk
from torch.utils.data import Dataset

class OASIS3DDataset(Dataset):
    LABELS = {
        'Mild Dementia': 0,
        'Moderate Dementia': 1,
        'Non Demented': 2,
        'Very mild Dementia': 3
    }

    def __init__(self, preproc_dir, atlas, patch_size, stride, N, augment=True, max_files=None):
        self.ps = patch_size
        self.stride = stride
        self.N = N
        self.samples = []
        self.preproc_dir = preproc_dir

        for cls, lab in self.LABELS.items():
            folder = os.path.join(preproc_dir, cls)
            if not os.path.isdir(folder):
                print(f"Warning: missing folder '{cls}' -> {folder}")
                continue
            npy_files = sorted([f for f in os.listdir(folder) if f.endswith('.npy')])
            nii_files = sorted([f for f in os.listdir(folder) if f.endswith(('.nii', '.nii.gz'))])
            if npy_files:
                use_npy = True
                files = npy_files
            else:
                use_npy = False
                files = nii_files
            if max_files:
                files = files[:max_files]
            if not files:
                print(f"Warning: no .npy or .nii in {folder}")
                continue
            print(f"Found {len(files)} files for class '{cls}' using {'npy' if use_npy else 'nii'}")
            for fname in files:
                self.samples.append((os.path.join(folder, fname), lab, use_npy))

        if not self.samples:
            raise ValueError("No samples found—did you run preprocessing?")

        self.atlas = atlas.astype(float) / (atlas.sum() + 1e-6)
        self.aug = tio.Compose([
            tio.RandomAffine(scales=(0.9, 1.1), degrees=10),
            tio.RandomNoise()
        ]) if augment else None

    def __len__(self):
        return len(self.samples)

    def sample_centers(self, vol_shape):
        D, H, W = vol_shape
        idxs = np.random.choice(D * H * W, self.N, p=self.atlas)
        coords = []
        for idx in idxs:
            z = idx // (H * W)
            rem = idx % (H * W)
            y = rem // W
            x = rem % W
            coords.append((z, y, x))
        return coords

    def __getitem__(self, idx):
        fpath, label, use_npy = self.samples[idx]
        vol = np.load(fpath) if use_npy else sitk.GetArrayFromImage(sitk.ReadImage(fpath))
        if self.aug:
            vol = self.aug(tio.ScalarImage(tensor=torch.tensor(vol)[None])['data']).numpy()[0]
        centers = self.sample_centers(vol.shape)
        patches = []
        for z, y, x in centers:
            patch = vol[z:z+self.ps, y:y+self.ps, x:x+self.ps]
            if patch.shape != (self.ps, self.ps, self.ps):
                patch = np.pad(patch, [(0, max(0, self.ps - patch.shape[i])) for i in range(3)], mode='constant')
            patches.append(patch)
        patches = torch.tensor(np.stack(patches), dtype=torch.float32).unsqueeze(1)
        return patches, torch.tensor(label, dtype=torch.long)

Overwriting dataset.py


#Model

In [4]:
%%writefile model.py
import torch
import torch.nn as nn
import numpy as np
import os
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader
from dataset import OASIS3DDataset

# Verify atlas exists
if not os.path.exists('/content/atlas.npy'):
    raise FileNotFoundError("Atlas file /content/atlas.npy not found. Run preprocess.py first.")

# Load atlas
atlas = np.load('/content/atlas.npy')

# Define dataset
train_dataset = OASIS3DDataset(
    preproc_dir='/content/oasis_preproc',
    atlas=atlas,
    patch_size=32,
    stride=32,
    N=64,
    augment=True,
    max_files=None
)

# Define kmeans_loader
kmeans_loader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True
)

class PatchEncoder3D(nn.Module):
    def __init__(self, emb_dim):
        super(PatchEncoder3D, self).__init__()
        self.net = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(2),
            nn.Conv3d(128, emb_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d(1)
        )

    def forward(self, x):
        out = self.net(x)
        return out.view(x.size(0), -1)

class AttentionMIL(nn.Module):
    def __init__(self, emb_dim):
        super(AttentionMIL, self).__init__()
        self.attn = nn.Sequential(
            nn.Linear(emb_dim, emb_dim // 2),
            nn.Tanh(),
            nn.Linear(emb_dim // 2, 1)
        )

    def forward(self, H):
        w = torch.softmax(self.attn(H), dim=1)
        bag = (w * H).sum(dim=1)
        return bag, w.squeeze(-1)

class PMICL3D(nn.Module):
    def __init__(self, emb_dim, num_classes, proto_tau, protos_per_class, init_kmeans=False, kmeans_loader=None):
        super(PMICL3D, self).__init__()
        self.encoder = PatchEncoder3D(emb_dim)
        self.pool = AttentionMIL(emb_dim)
        self.classifier = nn.Linear(emb_dim, num_classes)
        self.tau = proto_tau
        self.num_classes = num_classes
        self.K = protos_per_class
        self.prototypes = nn.Parameter(torch.randn(num_classes * protos_per_class, emb_dim))
        if init_kmeans and kmeans_loader is not None:
            self._init_prototypes(kmeans_loader)

    def _init_prototypes(self, loader):
        feats_list, labs_list = [], []
        with torch.no_grad():
            for bags, y in loader:
                bags = bags.to(self.encoder.net[0].weight.device)
                B, N, C, D, H, W = bags.shape
                bags_flat = bags.view(B * N, C, D, H, W)
                H_flat = self.encoder(bags_flat)
                feats_list.append(H_flat.cpu().numpy())
                y_repeated = y.repeat_interleave(N).cpu().numpy()
                labs_list.append(y_repeated)
        feats = np.concatenate(feats_list, axis=0)
        labs = np.concatenate(labs_list, axis=0)
        centers = []
        for c in range(self.num_classes):
            feats_c = feats[labs == c]
            if len(feats_c) >= self.K:
                km = KMeans(n_clusters=self.K).fit(feats_c)
                centers.append(km.cluster_centers_)
            else:
                reps = np.tile(feats_c, (int(np.ceil(self.K / len(feats_c))), 1))
                centers.append(reps[:self.K])
        all_centers = np.vstack(centers)
        self.prototypes.data.copy_(torch.tensor(all_centers, dtype=torch.float32).to(self.prototypes.device))

    def forward(self, X):
        B = X.size(0) // self.K
        H_flat = self.encoder(X)
        H = H_flat.view(B, self.K, -1)
        bag, attn = self.pool(H)
        logits = self.classifier(bag)
        return logits, H, attn

    def prototype_loss(self, H, y):
        B, K, D = H.shape
        Hf = H.view(B * K, D)
        sims = torch.matmul(Hf, self.prototypes.t()) / self.tau
        targets = []
        for b in range(B):
            cls = y[b].item()
            idx_start = cls * self.K
            idxs = list(range(idx_start, idx_start + self.K))
            sims_b = torch.matmul(H[b], self.prototypes[idxs].t()) / self.tau
            nearest = sims_b.argmax(dim=1)
            targets += [idxs[n.item()] for n in nearest]
        targets = torch.tensor(targets, device=sims.device)
        return nn.CrossEntropyLoss()(sims, targets)

# Instantiate model
model = PMICL3D(
    emb_dim=256,
    num_classes=4,
    proto_tau=0.5,
    protos_per_class=3,
    init_kmeans=True,
    kmeans_loader=kmeans_loader
).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

Overwriting model.py


#Training Loop

In [5]:
%%writefile train.py
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from dataset import OASIS3DDataset
from model import PMICL3D

# Define hyperparameters (matching preprocess.py)
class Args:
    emb_dim = 256
    protos_per_class = 3
    lambda_proto = 0.5
    patch_size = 32
    stride = 32
    N = 64
    max_files = None
    epochs = 20
    batch_size = 2

args = Args()

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load atlas
if not os.path.exists('/content/atlas.npy'):
    raise FileNotFoundError("Atlas file /content/atlas.npy not found. Run preprocess.py first.")
atlas = np.load('/content/atlas.npy')

# Define dataset
train_dataset = OASIS3DDataset(
    preproc_dir='/content/oasis_preproc',
    atlas=atlas,
    patch_size=args.patch_size,
    stride=args.stride,
    N=args.N,
    augment=True,
    max_files=args.max_files
)

# Define train_loader
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True
)

# Instantiate model
model = PMICL3D(
    emb_dim=args.emb_dim,
    num_classes=4,
    proto_tau=args.lambda_proto,
    protos_per_class=args.protos_per_class,
    init_kmeans=True,
    kmeans_loader=train_loader  # Reuse train_loader for KMeans initialization
).to(device)

# Training loop
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for ep in range(1, args.epochs + 1):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        logits, H, attn = model(X)
        ce = nn.CrossEntropyLoss()(logits, y)
        proto = model.prototype_loss(H, y)
        loss = ce + args.lambda_proto * proto
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        all_preds += logits.argmax(1).cpu().tolist()
        all_labels += y.cpu().tolist()
    avg_loss = total_loss / len(train_loader)
    acc = accuracy_score(all_labels, all_preds)
    print(f"Epoch {ep}: Loss={avg_loss:.3f}, Acc={acc:.3f}")
    # Save checkpoint
    torch.save(model.state_dict(), f'pmicl_epoch{ep}.pth')

Overwriting train.py


In [6]:
!python preprocess.py

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_115.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_116.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_117.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_118.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_119.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_120.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_121.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_122.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_123.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_124.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_125.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_126.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_127.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_128.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_129.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_130.jpg
 - Data/Very mild Dementia/OAS1_0263_MR1_mpr-1_131.jpg


#Evaluation

In [7]:
model.eval(); preds,probs,labs=[],[],[]
with torch.no_grad():
    for X,y in val_loader:
        X,y=X.to(device),y.to(device)
        logits,H,attn=model(X)
        p=torch.softmax(logits,1).cpu().numpy()
        preds+=logits.argmax(1).cpu().tolist(); probs+=p.tolist(); labs+=y.cpu().tolist()
print("Val Acc",accuracy_score(labs,preds))
print("ROC AUC",roc_auc_score(np.eye(model.num_classes)[labs],probs,multi_class='ovr'))

NameError: name 'model' is not defined

#Interpretability

In [None]:
import matplotlib.pyplot as plt

def visualize_attention_3d(volume,attn,centers):
    heatmap=np.zeros_like(volume)
    for (z,y,x),a in zip(centers,attn): heatmap[z:z+args.patch_size,y:y+args.patch_size,x:x+args.patch_size]+=a
    # visualize central slice
    plt.imshow(heatmap[heatmap.shape[0]//2], cmap='hot'); plt.title('3D Attention Heatmap'); plt.show()