In [1]:
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import scipy.ndimage
from monai.networks.nets import resnet18
from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight

In [2]:
def preprocess_nifti(nifti_path, target_shape=(128, 128, 128)):
    # Normalize intensity to [0,1]
    img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
    # Resize to target shape: 
    img_resized = scipy.ndimage.zoom(img, np.array(target_shape) / np.array(img.shape), order=1)
    return img_resized

In [3]:
import os

def find_files_with_substring(directory, substring):
    matching_files = [f for f in os.listdir(directory) if substring in f]
    return matching_files

def get_nib_image(adni_file_name):
    return nib.load(adni_file_name).get_fdata()

def visualize_image(nib_image):
    plt.imshow(nib_image[:,:,nib_image.shape[2]//2])
    plt.show()

In [4]:
def get_image_file_names_for_subject(subject_id, date=None):
    os.path.expanduser("~/adni_flat_dataset")
    dir_ = "/home/rittikar-s/adni_flat_dataset"
    files = find_files_with_substring(dir_, subject_id)
    if date:
        files = [file for file in files if date in file]
    file_paths = [f"{dir_}/{file}" for file in files]
    return file_paths

In [5]:
import pandas as pd

df = pd.read_csv("ADNI1_Complete_1Yr_1.5T_1_26_2025.csv")

In [6]:
from monai.networks.nets.vitautoenc import ViTAutoEnc

vit_model = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(128,128,128))

def get_vit_embedding(img):
    return vit_model(img)

In [7]:
class NiftiDataset(Dataset):
    def __init__(self, image_paths, labels, target_shape=(128, 128, 128)):
        self.image_paths = image_paths
        self.labels = labels
        self.target_shape = target_shape

    def __len__(self):
        return len(self.image_paths)

    def preprocess_nifti(self, nifti_path):
        img = nib.load(nifti_path).get_fdata()
        
        # Normalize intensity to [0,1]
        img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
        
        # Resize to target shape
        img_resized = scipy.ndimage.zoom(img, np.array(self.target_shape) / np.array(img.shape), order=1)
        
        return torch.tensor(img_resized, dtype=torch.float32).unsqueeze(0)  # Add channel dim

    def __getitem__(self, idx):
        image = self.preprocess_nifti(self.image_paths[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        embedding = get_vit_embedding(image.reshape(1,1,128,128,128))
        return embedding, label

In [8]:
class_to_label = {
    "CN": 0,
    "MCI": 1,
    "AD": 2
}
image_paths = []
labels = []

for i in range(len(df)):
    row = df.iloc[i]
    subject = row["Subject"]
    date = row["Acq Date"]
    date = date.replace("/", "-")
    image_path = get_image_file_names_for_subject(subject, date)[0]
    image_paths.append(image_path)
    labels.append(class_to_label[row["Group"]])

In [9]:
len(image_paths)

2294

In [10]:
len(labels)

2294

In [11]:
from sklearn.model_selection import train_test_split
train_paths, test_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.3, random_state=42)

In [12]:
# Create train & test datasets
train_dataset = NiftiDataset(train_paths, train_labels)
test_dataset = NiftiDataset(test_paths, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=True)

print(f"Train Batches: {len(train_loader)}, Test Batches: {len(test_loader)}")

Train Batches: 402, Test Batches: 173


In [20]:
import h5py

def save_embeddings_hdf5(dataloader, filename):
    """Save embeddings (from list format) and labels incrementally to an HDF5 file."""
    with h5py.File(filename, "w") as f:
        first_batch = True
        for i, (embedding_list, label) in enumerate(dataloader):
            # Extract the last layer embeddings
            embedding_tensor = embedding_list[1][-1]  # Extract final layer embeddings
            embedding_tensor = embedding_tensor.cpu().detach()  # Move to CPU
            
            embedding_numpy = embedding_tensor.numpy()  # Convert to NumPy
            label_numpy = label.cpu().numpy()

            # Reshape embeddings if needed
            embedding_numpy = embedding_numpy.reshape(embedding_numpy.shape[0], -1)  # (4, 1, 512, 768) → (4, 512 * 768)

            if first_batch:
                # Create expandable datasets with correct shape
                f.create_dataset("embeddings", data=embedding_numpy, 
                                 maxshape=(None, embedding_numpy.shape[1]))  # Now 2D
                f.create_dataset("labels", data=label_numpy, maxshape=(None,))
                first_batch = False
            else:
                # Resize and append new embeddings
                f["embeddings"].resize((f["embeddings"].shape[0] + embedding_numpy.shape[0]), axis=0)
                f["embeddings"][-embedding_numpy.shape[0]:] = embedding_numpy

                f["labels"].resize((f["labels"].shape[0] + label_numpy.shape[0]), axis=0)
                f["labels"][-label_numpy.shape[0]:] = label_numpy
            
            print(f"Saved embeddings for batch: {i+1}")

In [21]:
save_embeddings_hdf5(train_loader, "train_embeddings.h5")

Saved embeddings for batch: 1
Saved embeddings for batch: 2
Saved embeddings for batch: 3
Saved embeddings for batch: 4
Saved embeddings for batch: 5
Saved embeddings for batch: 6
Saved embeddings for batch: 7
Saved embeddings for batch: 8
Saved embeddings for batch: 9
Saved embeddings for batch: 10
Saved embeddings for batch: 11
Saved embeddings for batch: 12
Saved embeddings for batch: 13
Saved embeddings for batch: 14
Saved embeddings for batch: 15
Saved embeddings for batch: 16
Saved embeddings for batch: 17
Saved embeddings for batch: 18
Saved embeddings for batch: 19
Saved embeddings for batch: 20
Saved embeddings for batch: 21
Saved embeddings for batch: 22
Saved embeddings for batch: 23
Saved embeddings for batch: 24
Saved embeddings for batch: 25
Saved embeddings for batch: 26
Saved embeddings for batch: 27
Saved embeddings for batch: 28
Saved embeddings for batch: 29
Saved embeddings for batch: 30
Saved embeddings for batch: 31
Saved embeddings for batch: 32
Saved embeddings 

In [46]:
save_embeddings_hdf5(test_loader, "test_embeddings.h5")

Saved embeddings for batch: 1
Saved embeddings for batch: 2
Saved embeddings for batch: 3
Saved embeddings for batch: 4
Saved embeddings for batch: 5
Saved embeddings for batch: 6
Saved embeddings for batch: 7


KeyboardInterrupt: 

In [13]:
!python3 -V

Python 3.10.16


In [41]:
def load_embeddings_hdf5(filename, batch_size=32):
    with h5py.File(filename, "r") as f:
        num_samples = f["embeddings"].shape[0]  # Total samples
        for i in range(0, num_samples, batch_size):
            X_batch = torch.tensor(f["embeddings"][i : i + batch_size], dtype=torch.float32)
            y_batch = torch.tensor(f["labels"][i : i + batch_size], dtype=torch.long)
            yield X_batch, y_batch

In [25]:
from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(n_estimators=100, random_state=42)

In [26]:
for X_batch, y_batch in load_embeddings_hdf5("train_embeddings.h5"):
    clf.fit(X_batch, y_batch)

In [27]:
y_true, y_pred = [], []
for X_batch, y_batch in load_embeddings_hdf5("test_embeddings.h5"):
    y_pred.extend(clf.predict(X_batch))
    y_true.extend(y_batch)

In [29]:
from sklearn.metrics import classification_report

print(classification_report(y_true, y_pred, digits=4))

              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000       213
           1     0.4822    0.9790    0.6462       333
           2     0.3077    0.0280    0.0513       143

    accuracy                         0.4790       689
   macro avg     0.2633    0.3357    0.2325       689
weighted avg     0.2969    0.4790    0.3230       689



Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.


In [34]:
import torch.nn as nn
import torch.optim as optim

class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [42]:
input_dim = next(load_embeddings_hdf5("train_embeddings.h5"))[0].shape[1]  # Get feature size
num_classes = 3  # Adjust based on labels
model = MLPClassifier(input_dim, num_classes)

In [43]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [45]:
for epoch in range(10):
    for X_batch, y_batch in load_embeddings_hdf5("train_embeddings.h5", batch_size=4):
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/{10}], Loss: {loss.item():.4f}")

Epoch [1/10], Loss: 0.7554
Epoch [2/10], Loss: 0.7504
Epoch [3/10], Loss: 0.7506
Epoch [4/10], Loss: 0.7511
Epoch [5/10], Loss: 0.7516
Epoch [6/10], Loss: 0.7522
Epoch [7/10], Loss: 0.7527
Epoch [8/10], Loss: 0.7532
Epoch [9/10], Loss: 0.7537
Epoch [10/10], Loss: 0.7542


In [47]:
# Inference
y_true, y_pred = [], []

with torch.no_grad():
    for X_batch, y_batch in load_embeddings_hdf5("test_embeddings.h5", batch_size=32):
        outputs = model(X_batch)
        predicted_labels = torch.argmax(outputs, dim=1)

        y_true.extend(y_batch.numpy())
        y_pred.extend(predicted_labels.numpy())

# Compute accuracy
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_true, y_pred)
print(f"Test Accuracy: {accuracy:.4f}")

Test Accuracy: 0.4286
