In [1]:
import nibabel as nib
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import scipy.ndimage
from monai.networks.nets import resnet18
from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight

# Utility fns

In [2]:
import os

def find_files_with_substring(directory, substring):
    matching_files = [f for f in os.listdir(directory) if substring in f]
    return matching_files

def get_nib_image(adni_file_name):
    return nib.load(adni_file_name).get_fdata()

def visualize_image(nib_image):
    plt.imshow(nib_image[:,:,nib_image.shape[2]//2])
    plt.show()

In [3]:
def get_image_file_names_for_subject(subject_id, date=None):
    os.path.expanduser("~/adni_flat_dataset/adni_flat_dataset")
    dir_ = "/home/rittikar-s/adni_flat_dataset/adni_flat_dataset"
    files = find_files_with_substring(dir_, subject_id)
    if date:
        files = [file for file in files if date in file]
    file_paths = [f"{dir_}/{file}" for file in files]
    return file_paths

# Define 3D ViT Model

In [4]:
from monai.networks.nets.vitautoenc import ViTAutoEnc

vit_model = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(128,128,128))

def get_vit_embedding(img):
    return vit_model(img)

# Define Nifti Dataset for ViT

In [5]:
class NiftiDataset(Dataset):
    def __init__(self, image_paths, labels, image_ids, target_shape=(128, 128, 128)):
        self.image_paths = image_paths
        self.labels = labels
        self.target_shape = target_shape
        self.image_ids = image_ids

    def __len__(self):
        return len(self.image_paths)

    def apply_multi_bit_plane_slicing(self, img_3d, bit_planes=[6, 7]):
        """
        Apply multi-bit-plane slicing to a 3D MRI image.
        
        Args:
            img_3d (torch.Tensor or np.ndarray): A 3D image (shape: [depth, height, width]).
            bit_planes (list of int): List of bit-planes to extract (0 = LSB, 7 = MSB).
        
        Returns:
            numpy.ndarray: The combined bit-plane image.
        """
        X, Y, Z = img_3d.shape
        processed_img = np.zeros_like(img_3d, dtype=np.uint8)
        
        if isinstance(img_3d, torch.Tensor):
            img_3d = img_3d.cpu().numpy()  # Convert to NumPy if it's a tensor
    
        for z in range(Z):  # Iterate through each slice dynamically
            slice_img = img_3d[:, :, z]  # Extract the 2D slice
            
            # Ensure the slice is in 8-bit format
            slice_img = (slice_img / np.max(slice_img) * 255).astype(np.uint8)  
            
            bit_sliced = np.zeros_like(slice_img, dtype=np.uint8)
    
            # Combine selected bit planes
            for bit in bit_planes:
                bit_sliced |= ((slice_img >> bit) & 1) << bit  
    
            processed_img[:, :, z] = bit_sliced  # Store processed slice back
    
        return processed_img
    
    def preprocess_nifti(self, nifti_path, threshold_value=80):
        # Load the NIfTI file
        img = nib.load(nifti_path).get_fdata()
    
        # Resize the image to the target shape
        img_resized = scipy.ndimage.zoom(img, np.array(self.target_shape) / np.array(img.shape), order=1)
    
        # Normalize intensity to [0, 1] for neural network
        img_normalized = (img_resized - np.min(img_resized)) / (np.max(img_resized) - np.min(img_resized) + 1e-8)
        # img_normalized = (img_normalized * 255).astype(np.uint8)  # You can adjust this based on model input needs
    
        # Create an empty array to store the processed result
        processed_img = self.apply_multi_bit_plane_slicing(img_normalized, [5, 6, 7])
    
        # Convert the processed image to a tensor and add a channel dimension (assuming 1 channel)
        return torch.tensor(processed_img, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    
    def __getitem__(self, idx):
        image = self.preprocess_nifti(self.image_paths[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        embedding = get_vit_embedding(image.reshape(1,1,128,128,128))
        return embedding, label, self.image_ids

In [6]:
# get_image_file_for_image_id

In [7]:
import pandas as pd

df = pd.read_csv("~/dip_project/ADNI1_Final_With_Biomarkers.csv")

In [8]:
df.head()

Unnamed: 0,Image Data ID,Subject,Group,Sex,Age,Visit,Modality,Description,Type,Acq Date,...,VISCODE_y.1,HMSCORE,VISCODE_x.2,NPISCORE,VISCODE_y.2,GDTOTAL,VISCODE2,ABETA42,TAU,PTAU
0,I97327,941_S_1311,MCI,M,69,sc,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,3/02/2007,...,sc,1.0,,,sc,1.0,,,,
1,I112538,941_S_1311,MCI,M,70,m12,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,6/01/2008,...,,,m12,4.0,m12,3.0,,,,
2,I97341,941_S_1311,MCI,M,70,m06,MRI,MPR-R; GradWarp; B1 Correction; N3; Scaled,Processed,9/27/2007,...,,,m06,3.0,,,,,,
3,I63874,941_S_1202,CN,M,78,sc,MRI,MPR-R; GradWarp; B1 Correction; N3; Scaled,Processed,1/30/2007,...,sc,0.0,,,sc,0.0,,,,
4,I75150,941_S_1202,CN,M,78,m06,MRI,MPR; GradWarp; B1 Correction; N3; Scaled,Processed,8/24/2007,...,,,m06,2.0,,,,,,


In [9]:
class_to_label = {
    "CN": 0,
    "MCI": 1,
    "AD": 2
}
image_paths = []
labels = []
image_ids = []

for i in range(len(df)):
    row = df.iloc[i]
    subject = row["Subject"]
    date = row["Acq Date"]
    date = date.replace("/", "-")
    image_id = row["Image Data ID"]
    image_path = get_image_file_names_for_subject(subject, date)[0]
    image_paths.append(image_path)
    image_ids.append(image_id)
    labels.append(class_to_label[row["Group"]])

In [10]:
len(image_paths)

2294

In [11]:
len(labels)

2294

In [12]:
len(image_ids)

2294

In [13]:
from sklearn.model_selection import train_test_split
train_paths, test_paths, train_labels, test_labels, train_image_ids, test_image_ids = train_test_split(image_paths, labels, image_ids, test_size=0.3, random_state=42, stratify=labels)
val_paths, test_paths, val_labels, test_labels, val_image_ids, test_image_ids = train_test_split(test_paths, test_labels, test_image_ids, test_size=0.5, random_state=42, stratify=test_labels)

In [14]:
# Create train & test datasets
train_dataset = NiftiDataset(train_paths, train_labels, train_image_ids)
val_dataset = NiftiDataset(val_paths, val_labels, val_image_ids)
test_dataset = NiftiDataset(test_paths, test_labels, test_image_ids)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=True)

print(f"Train Batches: {len(train_loader)}, Val Batches: {len(val_loader)}, Test Batches: {len(test_loader)}")

Train Batches: 402, Val Batches: 86, Test Batches: 87


## ViT-BPS 3D

In [15]:
import h5py

def save_embeddings_hdf5(dataloader, filename):
    """Save embeddings (from list format) and labels incrementally to an HDF5 file."""
    with h5py.File(filename, "w") as f:
        first_batch = True
        for i, (embedding_list, label, image_id) in enumerate(dataloader):
            # Extract the last layer embeddings
            embedding_tensor = embedding_list[1][-1]  # Extract final layer embeddings
            embedding_tensor = embedding_tensor.cpu().detach()  # Move to CPU
            
            embedding_numpy = embedding_tensor.numpy()  # Convert to NumPy
            label_numpy = label.cpu().numpy()

            # Reshape embeddings if needed
            embedding_numpy = embedding_numpy.reshape(embedding_numpy.shape[0], -1)  # (4, 1, 512, 768) → (4, 512 * 768)

            if first_batch:
                # Create expandable datasets with correct shape
                f.create_dataset("embeddings", data=embedding_numpy, 
                                 maxshape=(None, embedding_numpy.shape[1]))  # Now 2D
                f.create_dataset("labels", data=label_numpy, maxshape=(None,))
                first_batch = False
            else:
                # Resize and append new embeddings
                f["embeddings"].resize((f["embeddings"].shape[0] + embedding_numpy.shape[0]), axis=0)
                f["embeddings"][-embedding_numpy.shape[0]:] = embedding_numpy

                f["labels"].resize((f["labels"].shape[0] + label_numpy.shape[0]), axis=0)
                f["labels"][-label_numpy.shape[0]:] = label_numpy
            
            print(f"Saved embeddings for batch: {i+1}")

In [16]:
save_embeddings_hdf5(train_loader, "train_embeddings.h5")
save_embeddings_hdf5(val_loader, "val_embeddings.h5")
save_embeddings_hdf5(test_loader, "test_embeddings.h5")

invalid value encountered in divide
invalid value encountered in cast


Saved embeddings for batch: 1
Saved embeddings for batch: 2
Saved embeddings for batch: 3
Saved embeddings for batch: 4
Saved embeddings for batch: 5
Saved embeddings for batch: 6
Saved embeddings for batch: 7
Saved embeddings for batch: 8
Saved embeddings for batch: 9
Saved embeddings for batch: 10
Saved embeddings for batch: 11
Saved embeddings for batch: 12
Saved embeddings for batch: 13
Saved embeddings for batch: 14
Saved embeddings for batch: 15
Saved embeddings for batch: 16
Saved embeddings for batch: 17
Saved embeddings for batch: 18
Saved embeddings for batch: 19
Saved embeddings for batch: 20
Saved embeddings for batch: 21
Saved embeddings for batch: 22
Saved embeddings for batch: 23
Saved embeddings for batch: 24
Saved embeddings for batch: 25
Saved embeddings for batch: 26
Saved embeddings for batch: 27
Saved embeddings for batch: 28
Saved embeddings for batch: 29
Saved embeddings for batch: 30
Saved embeddings for batch: 31
Saved embeddings for batch: 32
Saved embeddings 

In [19]:
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [20]:
def load_embeddings_hdf5(filename, batch_size=4):
    with h5py.File(filename, "r") as f:
        num_samples = f["embeddings"].shape[0]  # Total samples
        for i in range(0, num_samples, batch_size):
            X_batch = torch.tensor(f["embeddings"][i : i + batch_size], dtype=torch.float32).to(DEVICE)
            y_batch = torch.tensor(f["labels"][i : i + batch_size], dtype=torch.long).to(DEVICE)
            yield X_batch, y_batch

In [21]:
import torch.nn as nn
import torch.optim as optim

class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.batch_norm1 = nn.BatchNorm1d(1024)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        
        self.fc2 = nn.Linear(1024, 512)
        self.batch_norm2 = nn.BatchNorm1d(512)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(512, 256)
        self.batch_norm3 = nn.BatchNorm1d(256)
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.5)
        
        self.fc4 = nn.Linear(256, num_classes)  # Final layer

    def forward(self, x, return_embeddings=False):
        x = self.fc1(x)
        if x.shape[0] > 1:  # Apply BatchNorm only if batch size > 1
            x = self.batch_norm1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
    
        x = self.fc2(x)
        if x.shape[0] > 1:
            x = self.batch_norm2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
    
        x = self.fc3(x)
        if x.shape[0] > 1:
            x = self.batch_norm3(x)
        x = self.relu3(x)
        x = self.dropout3(x)

        if return_embeddings:
            return x
        
        x = self.fc4(x)
        return x

In [22]:
input_dim = next(load_embeddings_hdf5("train_embeddings.h5"))[0].shape[1]  # Get feature size
num_classes = 3  # Adjust based on labels
model = MLPClassifier(input_dim, num_classes).to(DEVICE)

In [23]:
model

MLPClassifier(
  (fc1): Linear(in_features=393216, out_features=1024, bias=True)
  (batch_norm1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (batch_norm2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (batch_norm3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (dropout3): Dropout(p=0.5, inplace=False)
  (fc4): Linear(in_features=256, out_features=3, bias=True)
)

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from tqdm import tqdm  # Progress tracking

# Compute class weights
def compute_class_weights(y_train):
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
    return torch.tensor(class_weights, dtype=torch.float).to(DEVICE)

y_train = np.concatenate([y.cpu().numpy() for _, y in load_embeddings_hdf5("train_embeddings.h5", batch_size=32)])
class_weights = compute_class_weights(y_train)

# Define optimizer, scheduler, and scaler
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
scaler = torch.cuda.amp.GradScaler()

NUM_GPUS = 1
num_epochs = 100
patience = 15  # Early stopping patience
best_val_loss = float("inf")
epochs_without_improvement = 0

for epoch in range(num_epochs):
    model.train()
    train_loss, correct, total = 0.0, 0, 0

    # Training loop
    train_loader = load_embeddings_hdf5("train_embeddings.h5", batch_size=16 * NUM_GPUS)  # Adjust batch size
    train_bar = tqdm(train_loader, total=len(y_train) // (16 * NUM_GPUS), desc=f"Epoch {epoch+1}/{num_epochs} Train")
    for X_batch, y_batch in train_bar:
        X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():  # Mixed precision
            outputs = model(X_batch)
            loss = F.cross_entropy(outputs, y_batch)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(y_batch).sum().item()
        total += y_batch.size(0)
        
        train_bar.set_postfix(loss=loss.item(), acc=100 * correct / total)

    train_accuracy = 100 * correct / total
    avg_train_loss = train_loss / total

    # Validation loop
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    val_loader = load_embeddings_hdf5("val_embeddings.h5", batch_size=16 * NUM_GPUS)
    val_bar = tqdm(val_loader, total=len(y_train) // (16 * NUM_GPUS), desc="Validation")
    with torch.no_grad():
        for X_val, y_val in val_bar:
            X_val, y_val = X_val.to(DEVICE), y_val.to(DEVICE)
            outputs = model(X_val)
            loss = F.cross_entropy(outputs, y_val)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_correct += predicted.eq(y_val).sum().item()
            val_total += y_val.size(0)

            val_bar.set_postfix(loss=loss.item(), acc=100 * val_correct / val_total)
    
    val_accuracy = 100 * val_correct / val_total
    avg_val_loss = val_loss / val_total

    # Check for early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0  # Reset counter
        torch.save(model.state_dict(), "best_model.pth")  # Save best model
    else:
        epochs_without_improvement += 1

    # Stop if no improvement for 10 epochs
    if epochs_without_improvement >= patience:
        print(f"Early stopping at epoch {epoch+1}. No improvement in validation loss for {patience} epochs.")
        break

    # Learning rate scheduler step
    scheduler.step(avg_val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")

`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.                         | 0/100 [00:00<?, ?it/s]
Epoch 1/100 Train: 101it [00:11,  8.58it/s, acc=35, loss=1.18]                                                                                              
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.01it/s, acc=49.1, loss=0.995]


Epoch [1/100]
  Train Loss: 0.0729 | Train Acc: 35.02%
  Val Loss: 0.0662 | Val Acc: 49.13%


Epoch 2/100 Train: 101it [00:11,  8.65it/s, acc=43.7, loss=1.03]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.59it/s, acc=51.5, loss=0.962]


Epoch [2/100]
  Train Loss: 0.0670 | Train Acc: 43.68%
  Val Loss: 0.0652 | Val Acc: 51.45%


Epoch 3/100 Train: 101it [00:11,  8.66it/s, acc=50.2, loss=0.693]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.63it/s, acc=56.7, loss=0.928]


Epoch [3/100]
  Train Loss: 0.0628 | Train Acc: 50.22%
  Val Loss: 0.0637 | Val Acc: 56.69%


Epoch 4/100 Train: 101it [00:11,  8.64it/s, acc=55.1, loss=0.923]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.87it/s, acc=58.7, loss=0.848]


Epoch [4/100]
  Train Loss: 0.0596 | Train Acc: 55.14%
  Val Loss: 0.0623 | Val Acc: 58.72%


Epoch 5/100 Train: 101it [00:11,  8.65it/s, acc=58.8, loss=0.923]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.61it/s, acc=60.2, loss=0.836]


Epoch [5/100]
  Train Loss: 0.0575 | Train Acc: 58.75%
  Val Loss: 0.0613 | Val Acc: 60.17%


Epoch 6/100 Train: 101it [00:11,  8.65it/s, acc=64.4, loss=0.666]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.37it/s, acc=63.1, loss=0.805]


Epoch [6/100]
  Train Loss: 0.0528 | Train Acc: 64.36%
  Val Loss: 0.0602 | Val Acc: 63.08%


Epoch 7/100 Train: 101it [00:11,  8.66it/s, acc=67.8, loss=0.634]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.22it/s, acc=63.4, loss=0.788]


Epoch [7/100]
  Train Loss: 0.0504 | Train Acc: 67.79%
  Val Loss: 0.0590 | Val Acc: 63.37%


Epoch 8/100 Train: 101it [00:11,  8.64it/s, acc=72.5, loss=0.574]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.53it/s, acc=63.1, loss=0.729]


Epoch [8/100]
  Train Loss: 0.0470 | Train Acc: 72.46%
  Val Loss: 0.0582 | Val Acc: 63.08%


Epoch 9/100 Train: 101it [00:11,  8.65it/s, acc=74.8, loss=0.579]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.71it/s, acc=62.5, loss=0.742]


Epoch [9/100]
  Train Loss: 0.0441 | Train Acc: 74.83%
  Val Loss: 0.0577 | Val Acc: 62.50%


Epoch 10/100 Train: 101it [00:11,  8.68it/s, acc=79.6, loss=0.49]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.04it/s, acc=61.9, loss=0.718]


Epoch [10/100]
  Train Loss: 0.0399 | Train Acc: 79.63%
  Val Loss: 0.0569 | Val Acc: 61.92%


Epoch 11/100 Train: 101it [00:11,  8.64it/s, acc=82.9, loss=0.463]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.88it/s, acc=62.8, loss=0.713]


Epoch [11/100]
  Train Loss: 0.0369 | Train Acc: 82.93%
  Val Loss: 0.0562 | Val Acc: 62.79%


Epoch 12/100 Train: 101it [00:11,  8.68it/s, acc=85.6, loss=0.305]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.67it/s, acc=64, loss=0.653]


Epoch [12/100]
  Train Loss: 0.0343 | Train Acc: 85.61%
  Val Loss: 0.0548 | Val Acc: 63.95%


Epoch 13/100 Train: 101it [00:11,  8.67it/s, acc=89.9, loss=0.505]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.91it/s, acc=63.1, loss=0.646]


Epoch [13/100]
  Train Loss: 0.0300 | Train Acc: 89.91%
  Val Loss: 0.0540 | Val Acc: 63.08%


Epoch 14/100 Train: 101it [00:11,  8.65it/s, acc=89.9, loss=0.284]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.01it/s, acc=63.1, loss=0.604]


Epoch [14/100]
  Train Loss: 0.0285 | Train Acc: 89.91%
  Val Loss: 0.0531 | Val Acc: 63.08%


Epoch 15/100 Train: 101it [00:11,  8.66it/s, acc=93.1, loss=0.212]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.25it/s, acc=64.2, loss=0.548]


Epoch [15/100]
  Train Loss: 0.0255 | Train Acc: 93.15%
  Val Loss: 0.0521 | Val Acc: 64.24%


Epoch 16/100 Train: 101it [00:11,  8.66it/s, acc=93.2, loss=0.188]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.15it/s, acc=64, loss=0.503]


Epoch [16/100]
  Train Loss: 0.0235 | Train Acc: 93.21%
  Val Loss: 0.0518 | Val Acc: 63.95%


Epoch 17/100 Train: 101it [00:11,  8.63it/s, acc=94.4, loss=0.296]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.59it/s, acc=65.4, loss=0.509]


Epoch [17/100]
  Train Loss: 0.0216 | Train Acc: 94.39%
  Val Loss: 0.0510 | Val Acc: 65.41%


Epoch 18/100 Train: 101it [00:11,  8.65it/s, acc=96.2, loss=0.198]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 51.92it/s, acc=66, loss=0.455]


Epoch [18/100]
  Train Loss: 0.0185 | Train Acc: 96.20%
  Val Loss: 0.0509 | Val Acc: 65.99%


Epoch 19/100 Train: 101it [00:11,  8.65it/s, acc=96.6, loss=0.229]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.26it/s, acc=66.3, loss=0.452]


Epoch [19/100]
  Train Loss: 0.0168 | Train Acc: 96.57%
  Val Loss: 0.0503 | Val Acc: 66.28%


Epoch 20/100 Train: 101it [00:11,  8.64it/s, acc=98.1, loss=0.155]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.68it/s, acc=65.7, loss=0.446]


Epoch [20/100]
  Train Loss: 0.0152 | Train Acc: 98.07%
  Val Loss: 0.0499 | Val Acc: 65.70%


Epoch 21/100 Train: 101it [00:11,  8.65it/s, acc=97.6, loss=0.184]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.94it/s, acc=64.8, loss=0.463]


Epoch [21/100]
  Train Loss: 0.0142 | Train Acc: 97.63%
  Val Loss: 0.0499 | Val Acc: 64.83%


Epoch 22/100 Train: 101it [00:11,  8.65it/s, acc=98.6, loss=0.091]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.55it/s, acc=63.7, loss=0.468]


Epoch [22/100]
  Train Loss: 0.0128 | Train Acc: 98.63%
  Val Loss: 0.0496 | Val Acc: 63.66%


Epoch 23/100 Train: 101it [00:11,  8.66it/s, acc=99.1, loss=0.0808]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.48it/s, acc=64.8, loss=0.429]


Epoch [23/100]
  Train Loss: 0.0113 | Train Acc: 99.07%
  Val Loss: 0.0502 | Val Acc: 64.83%


Epoch 24/100 Train: 101it [00:11,  8.65it/s, acc=99.1, loss=0.128]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.94it/s, acc=63.7, loss=0.409]


Epoch [24/100]
  Train Loss: 0.0109 | Train Acc: 99.13%
  Val Loss: 0.0501 | Val Acc: 63.66%


Epoch 25/100 Train: 101it [00:11,  8.67it/s, acc=99, loss=0.0783]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.18it/s, acc=66.3, loss=0.435]


Epoch [25/100]
  Train Loss: 0.0101 | Train Acc: 99.00%
  Val Loss: 0.0491 | Val Acc: 66.28%


Epoch 26/100 Train: 101it [00:11,  8.64it/s, acc=99.6, loss=0.0647]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.91it/s, acc=67.2, loss=0.436]


Epoch [26/100]
  Train Loss: 0.0090 | Train Acc: 99.56%
  Val Loss: 0.0493 | Val Acc: 67.15%


Epoch 27/100 Train: 101it [00:11,  8.64it/s, acc=99.4, loss=0.1]                                                                                            
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 52.16it/s, acc=66, loss=0.445]


Epoch [27/100]
  Train Loss: 0.0084 | Train Acc: 99.44%
  Val Loss: 0.0496 | Val Acc: 65.99%


Epoch 28/100 Train: 101it [00:11,  8.66it/s, acc=99.6, loss=0.121]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.12it/s, acc=65.7, loss=0.472]


Epoch [28/100]
  Train Loss: 0.0078 | Train Acc: 99.56%
  Val Loss: 0.0493 | Val Acc: 65.70%


Epoch 29/100 Train: 101it [00:11,  8.61it/s, acc=99.6, loss=0.103]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 51.19it/s, acc=66, loss=0.445]


Epoch [29/100]
  Train Loss: 0.0071 | Train Acc: 99.63%
  Val Loss: 0.0494 | Val Acc: 65.99%


Epoch 30/100 Train: 101it [00:11,  8.66it/s, acc=99.6, loss=0.0539]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.38it/s, acc=65.7, loss=0.488]


Epoch [30/100]
  Train Loss: 0.0071 | Train Acc: 99.56%
  Val Loss: 0.0500 | Val Acc: 65.70%


Epoch 31/100 Train: 101it [00:11,  8.63it/s, acc=99.7, loss=0.11]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.06it/s, acc=66.6, loss=0.458]


Epoch [31/100]
  Train Loss: 0.0066 | Train Acc: 99.69%
  Val Loss: 0.0502 | Val Acc: 66.57%


Epoch 32/100 Train: 101it [00:11,  8.64it/s, acc=99.6, loss=0.0603]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.40it/s, acc=65.4, loss=0.466]


Epoch [32/100]
  Train Loss: 0.0061 | Train Acc: 99.63%
  Val Loss: 0.0497 | Val Acc: 65.41%


Epoch 33/100 Train: 101it [00:11,  8.65it/s, acc=99.8, loss=0.0372]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.64it/s, acc=65.7, loss=0.425]


Epoch [33/100]
  Train Loss: 0.0058 | Train Acc: 99.81%
  Val Loss: 0.0498 | Val Acc: 65.70%


Epoch 34/100 Train: 101it [00:11,  8.64it/s, acc=99.8, loss=0.0414]                                                                                         
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.41it/s, acc=65.7, loss=0.46]


Epoch [34/100]
  Train Loss: 0.0055 | Train Acc: 99.81%
  Val Loss: 0.0501 | Val Acc: 65.70%


Epoch 35/100 Train: 101it [00:11,  8.64it/s, acc=99.8, loss=0.0199]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.48it/s, acc=66.6, loss=0.449]


Epoch [35/100]
  Train Loss: 0.0053 | Train Acc: 99.75%
  Val Loss: 0.0501 | Val Acc: 66.57%


Epoch 36/100 Train: 101it [00:11,  8.66it/s, acc=99.9, loss=0.0709]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.26it/s, acc=66.3, loss=0.457]


Epoch [36/100]
  Train Loss: 0.0051 | Train Acc: 99.88%
  Val Loss: 0.0501 | Val Acc: 66.28%


Epoch 37/100 Train: 101it [00:11,  8.69it/s, acc=99.7, loss=0.0576]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.67it/s, acc=66.3, loss=0.441]


Epoch [37/100]
  Train Loss: 0.0050 | Train Acc: 99.69%
  Val Loss: 0.0502 | Val Acc: 66.28%


Epoch 38/100 Train: 101it [00:11,  8.63it/s, acc=99.9, loss=0.0293]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.87it/s, acc=66.6, loss=0.441]


Epoch [38/100]
  Train Loss: 0.0048 | Train Acc: 99.88%
  Val Loss: 0.0502 | Val Acc: 66.57%


Epoch 39/100 Train: 101it [00:11,  8.66it/s, acc=99.9, loss=0.0639]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.48it/s, acc=66.3, loss=0.428]


Epoch [39/100]
  Train Loss: 0.0047 | Train Acc: 99.88%
  Val Loss: 0.0502 | Val Acc: 66.28%


Epoch 40/100 Train: 101it [00:11,  8.72it/s, acc=99.9, loss=0.0489]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.77it/s, acc=66.9, loss=0.438]

Early stopping at epoch 40. No improvement in validation loss for 15 epochs.





In [25]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Inference
y_true, y_pred = [], []

with torch.no_grad():
    for X_batch, y_batch in load_embeddings_hdf5("test_embeddings.h5", batch_size=4):
        outputs = model(X_batch)
        predicted_labels = torch.argmax(outputs, dim=1)

        y_true.extend(y_batch.cpu().numpy())
        y_pred.extend(predicted_labels.cpu().numpy())

# Compute metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average="weighted")  # "weighted" accounts for class imbalance
recall = recall_score(y_true, y_pred, average="weighted")
f1 = f1_score(y_true, y_pred, average="weighted")

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

Test Accuracy: 0.6261
Precision: 0.6283
Recall: 0.6261
F1 Score: 0.6091


In [26]:
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.65      0.52      0.58       106
           1       0.62      0.82      0.70       167
           2       0.62      0.33      0.43        72

    accuracy                           0.63       345
   macro avg       0.63      0.56      0.57       345
weighted avg       0.63      0.63      0.61       345



# Ft-Transformer

In [27]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Import the FTTransformer model from rtdl_revisiting_models.
from rtdl_revisiting_models import FTTransformer

In [28]:
numerical_features = ["Age", "CDGLOBAL", "CDRSB", "MMSCORE", "HMSCORE", "NPISCORE", "GDTOTAL"]
categorical_features = ["GENOTYPE"]
label = "Group"

In [29]:
# Subset dataframes to desired columns.
cols = numerical_features + categorical_features + [label]
train_data = df[df["Image Data ID"].isin(train_image_ids)][cols]
val_data   = df[df["Image Data ID"].isin(val_image_ids)][cols]
test_data  = df[df["Image Data ID"].isin(test_image_ids)][cols]

In [30]:
train_data.to_csv("train_data.csv")
val_data.to_csv("val_data.csv")
test_data.to_csv("test_data.csv")

In [31]:
# Handle missingness for numerical features.
cols_with_missing = ["CDRSB", "MMSCORE", "HMSCORE", "NPISCORE", "GDTOTAL"]
for col in cols_with_missing:
    for df_ in [train_data, val_data, test_data]:
        df_[col + "_is_missing"] = df_[col].isnull().astype(int)
        df_[col] = df_[col].fillna(-999)

In [32]:
# Extend continuous features to include missing indicators.
numerical_features_extended = numerical_features + [col + "_is_missing" for col in cols_with_missing]

In [33]:
# Encode categorical features using LabelEncoder.
cat_encoders = {}
for col in categorical_features:
    le = LabelEncoder()
    train_data[col] = le.fit_transform(train_data[col].astype(str))
    val_data[col]   = le.transform(val_data[col].astype(str))
    test_data[col]  = le.transform(test_data[col].astype(str))
    cat_encoders[col] = le

In [34]:
# Encode the target.
label_encoder = LabelEncoder()
train_data[label] = label_encoder.fit_transform(train_data[label])
val_data[label]   = label_encoder.transform(val_data[label])
test_data[label]  = label_encoder.transform(test_data[label])
num_classes = len(label_encoder.classes_)  # e.g., 3 for classification

In [35]:
##########################################
# 2. Prepare NumPy Arrays and Create Dataset
##########################################
# Continuous features (including missing indicators).
X_train_cont = train_data[numerical_features_extended].values.astype(np.float32)
X_val_cont   = val_data[numerical_features_extended].values.astype(np.float32)
X_test_cont  = test_data[numerical_features_extended].values.astype(np.float32)

# Categorical features.
X_train_cat = train_data[categorical_features].values.astype(np.int64)
X_val_cat   = val_data[categorical_features].values.astype(np.int64)
X_test_cat  = test_data[categorical_features].values.astype(np.int64)

# Labels.
y_train = train_data[label].values.astype(np.int64)
y_val   = val_data[label].values.astype(np.int64)
y_test  = test_data[label].values.astype(np.int64)

In [36]:
# Create a simple PyTorch Dataset.
class TabularDataset(Dataset):
    def __init__(self, cont, cat, labels):
        self.cont = cont
        self.cat = cat
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            "cont": torch.tensor(self.cont[idx], dtype=torch.float32),
            "cat": torch.tensor(self.cat[idx], dtype=torch.long),
            "target": torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [37]:
train_dataset = TabularDataset(X_train_cont, X_train_cat, y_train)
val_dataset   = TabularDataset(X_val_cont, X_val_cat, y_val)
test_dataset  = TabularDataset(X_test_cont, X_test_cat, y_test)

# Create DataLoaders.
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [38]:
##########################################
# 3. Initialize and Train the FTTransformer Classifier
##########################################
# Get the number of continuous features.
n_cont_features = X_train_cont.shape[1]
# Determine the cardinalities for each categorical feature.
cat_cardinalities = [int(train_data[col].nunique()) for col in categorical_features]

# For classification, set d_out = number of classes.
d_out = num_classes

# Instantiate the FTTransformer.
model = FTTransformer(
    n_cont_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    d_out=d_out,
    n_blocks=3,
    d_block=192,                # Backbone (hidden) dimension
    attention_n_heads=8,
    attention_dropout=0.2,
    ffn_d_hidden=None,          # Defaults internally if None.
    ffn_d_hidden_multiplier=4/3,
    ffn_dropout=0.1,
    residual_dropout=0.0
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

FTTransformer(
  (cls_embedding): _CLSEmbedding()
  (cont_embeddings): LinearEmbeddings()
  (cat_embeddings): CategoricalEmbeddings(
    (embeddings): ModuleList(
      (0): Embedding(6, 192)
    )
  )
  (backbone): FTTransformerBackbone(
    (blocks): ModuleList(
      (0): ModuleDict(
        (attention): MultiheadAttention(
          (W_q): Linear(in_features=192, out_features=192, bias=True)
          (W_k): Linear(in_features=192, out_features=192, bias=True)
          (W_v): Linear(in_features=192, out_features=192, bias=True)
          (W_out): Linear(in_features=192, out_features=192, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (attention_residual_dropout): Dropout(p=0.0, inplace=False)
        (ffn_normalization): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (ffn): Sequential(
          (linear1): Linear(in_features=192, out_features=512, bias=True)
          (activation): _ReGLU()
          (dropout): Dropout(p=0.1, inplace

In [39]:
# Set up optimizer, loss function, and scheduler.
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

# Training loop with early stopping.
max_epochs = 100
patience = 5
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(max_epochs):
    model.train()
    train_loss = 0.0
    for batch in train_loader:
        cont = batch["cont"].to(device)
        cat = batch["cat"].to(device)
        targets = batch["target"].to(device)
        
        optimizer.zero_grad()
        logits = model(cont, cat)  # Forward pass returns logits (shape: [batch_size, d_out])
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * cont.size(0)
    train_loss /= len(train_dataset)
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            cont = batch["cont"].to(device)
            cat = batch["cat"].to(device)
            targets = batch["target"].to(device)
            
            logits = model(cont, cat)
            loss = criterion(logits, targets)
            val_loss += loss.item() * cont.size(0)
    val_loss /= len(val_dataset)
    
    scheduler.step(val_loss)
    print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict()
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

Epoch 1: Train Loss: 1.0597 | Val Loss: 1.0601
Epoch 2: Train Loss: 1.0354 | Val Loss: 0.9883
Epoch 3: Train Loss: 0.8518 | Val Loss: 0.6626
Epoch 4: Train Loss: 0.5633 | Val Loss: 0.3959
Epoch 5: Train Loss: 0.4657 | Val Loss: 0.3652
Epoch 6: Train Loss: 0.4325 | Val Loss: 0.3221
Epoch 7: Train Loss: 0.3945 | Val Loss: 0.3015
Epoch 8: Train Loss: 0.3642 | Val Loss: 0.3225
Epoch 9: Train Loss: 0.3462 | Val Loss: 0.3173
Epoch 10: Train Loss: 0.3730 | Val Loss: 0.3152
Epoch 11: Train Loss: 0.3429 | Val Loss: 0.3154
Epoch 12: Train Loss: 0.3237 | Val Loss: 0.2893
Epoch 13: Train Loss: 0.3101 | Val Loss: 0.2754
Epoch 14: Train Loss: 0.3139 | Val Loss: 0.2796
Epoch 15: Train Loss: 0.3050 | Val Loss: 0.2958
Epoch 16: Train Loss: 0.3178 | Val Loss: 0.3067
Epoch 17: Train Loss: 0.3099 | Val Loss: 0.2720
Epoch 18: Train Loss: 0.3074 | Val Loss: 0.2764
Epoch 19: Train Loss: 0.2993 | Val Loss: 0.2811
Epoch 20: Train Loss: 0.3157 | Val Loss: 0.2752
Epoch 21: Train Loss: 0.3014 | Val Loss: 0.2794
E

In [40]:
# Save the best model.
save_model_path = "best_ft_transformer_classification.pt"
torch.save(best_model_state, save_model_path)
print("Trained model saved to", save_model_path)

Trained model saved to best_ft_transformer_classification.pt


In [41]:
# Load the best model (optional).
model.load_state_dict(torch.load(save_model_path))

<All keys matched successfully>

In [42]:
# Evaluate classification performance on the test set.
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for batch in test_loader:
        cont = batch["cont"].to(device)
        cat = batch["cat"].to(device)
        targets = batch["target"].to(device)
        
        logits = model(cont, cat)
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds.cpu().numpy())
        all_targets.append(targets.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
test_acc = accuracy_score(all_targets, all_preds)
print("Test Accuracy:", test_acc)
print("Classification Report (Test):")
print(classification_report(all_targets, all_preds, target_names=[str(c) for c in label_encoder.classes_]))

Test Accuracy: 0.8405797101449275
Classification Report (Test):
              precision    recall  f1-score   support

          AD       0.70      0.60      0.65        72
          CN       0.96      0.95      0.96       106
         MCI       0.82      0.87      0.84       167

    accuracy                           0.84       345
   macro avg       0.83      0.81      0.82       345
weighted avg       0.84      0.84      0.84       345



In [43]:
##########################################
# 4. Extract 192-Dimensional Embeddings (Before the Final Classification)
##########################################
# The classifier was trained with d_out = num_classes.
# To obtain 192-dim embeddings (the backbone outputs), we replace the final linear layer with an identity.
# Here we assume the final projection is stored in the attribute 'fc'.
model.fc = nn.Identity()  # Now model(cont, cat) returns the backbone features of shape (batch_size, 192).

# Function to extract features from a DataLoader.
def extract_features(loader, model, device):
    model.eval()
    features_list = []
    with torch.no_grad():
        for batch in loader:
            cont = batch["cont"].to(device)
            cat = batch["cat"].to(device)
            feats = model(cont, cat)  # Should now return features of shape (batch_size, 192)
            features_list.append(feats.cpu().numpy())
    return np.concatenate(features_list)

# Extract features from each set.
features_train = extract_features(train_loader, model, device)
features_val   = extract_features(val_loader, model, device)
features_test  = extract_features(test_loader, model, device)

print("Extracted feature shapes:")
print("Train features:", features_train.shape)
print("Validation features:", features_val.shape)
print("Test features:", features_test.shape)

# Save the extracted features.
np.save("ft_train_features.npy", features_train)
np.save("ft_val_features.npy", features_val)
np.save("ft_test_features.npy", features_test)
print("Extracted features saved as 'ft_train_features.npy', 'ft_val_features.npy', and 'ft_test_features.npy'.")

Extracted feature shapes:
Train features: (1605, 3)
Validation features: (344, 3)
Test features: (345, 3)
Extracted features saved as 'ft_train_features.npy', 'ft_val_features.npy', and 'ft_test_features.npy'.


# Early Fusion

In [44]:
import h5py
def load_embeddings_hdf5_np(filename):
    with h5py.File(filename, "r") as f:
        embeddings = f["embeddings"][:].astype(np.float32)
    return embeddings

In [45]:
train_features = load_embeddings_hdf5_np("train_embeddings.h5")
val_features   = load_embeddings_hdf5_np("val_embeddings.h5")
test_features  = load_embeddings_hdf5_np("test_embeddings.h5")

In [46]:
train_features.shape

(1605, 393216)

In [47]:
ft_train = np.load("ft_train_features.npy")  # shape: (num_train_samples, ft_feature_dim)
ft_val   = np.load("ft_val_features.npy")      # shape: (num_val_samples,   ft_feature_dim)
ft_test  = np.load("ft_test_features.npy")     # shape: (num_test_samples,  ft_feature_dim)
# -------------------------------------------------------------------
# 3. Concatenate DeiT features with TabNet embeddings
# -------------------------------------------------------------------
train_concat = np.concatenate([ft_train, train_features], axis=1)
val_concat   = np.concatenate([ft_val,   val_features],   axis=1)
test_concat  = np.concatenate([ft_test,  test_features],  axis=1)

In [48]:
print("\nConcatenated Train Shape:", train_concat.shape)
print("Concatenated Val Shape:  ", val_concat.shape)
print("Concatenated Test Shape: ", test_concat.shape)


Concatenated Train Shape: (1605, 393219)
Concatenated Val Shape:   (344, 393219)
Concatenated Test Shape:  (345, 393219)


In [49]:
from sklearn.preprocessing import LabelEncoder

# We assume the label column is named "Group"
label_col = "Group"

# Encode labels using LabelEncoder.
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(train_data[label_col])
val_labels   = label_encoder.transform(val_data[label_col])
test_labels  = label_encoder.transform(test_data[label_col])

num_classes = len(label_encoder.classes_)

In [50]:
from torch.utils.data import TensorDataset, DataLoader
##########################################
# 3. Create PyTorch Datasets and DataLoaders
##########################################
# Convert features and labels to tensors.
train_concat = torch.tensor(train_concat, dtype=torch.float32)
val_concat   = torch.tensor(val_concat, dtype=torch.float32)
test_concat  = torch.tensor(test_concat, dtype=torch.float32)

train_labels = torch.tensor(train_labels, dtype=torch.long)
val_labels   = torch.tensor(val_labels, dtype=torch.long)
test_labels  = torch.tensor(test_labels, dtype=torch.long)

# Create TensorDatasets.
train_dataset = TensorDataset(train_concat, train_labels)
val_dataset   = TensorDataset(val_concat, val_labels)
test_dataset  = TensorDataset(test_concat, test_labels)

# Create DataLoaders.
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [51]:
##########################################
# 4. Define an MLP Classifier for Fused Features
##########################################
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPClassifier, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 2048),  # Increased first layer size
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),

            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),  # BatchNorm starts here
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),

            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),

            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        return self.net(x)

# The input dimension is the sum of the two feature dimensions.
input_dim = train_concat.shape[1]
model_mlp = MLPClassifier(input_dim, num_classes)
model_mlp.to(device)

MLPClassifier(
  (net): Sequential(
    (0): Linear(in_features=393219, out_features=2048, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=2048, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Dropout(p=0.1, inplace=False)
    (7): Linear(in_features=1024, out_features=512, bias=True)
    (8): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): LeakyReLU(negative_slope=0.01)
    (10): Dropout(p=0.1, inplace=False)
    (11): Linear(in_features=512, out_features=128, bias=True)
    (12): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): LeakyReLU(negative_slope=0.01)
    (14): Dropout(p=0.1, inplace=False)
    (15): Linear(in_features=128, out_features=3, bias=True)
  )
)

In [52]:
##########################################
# 5. Train the MLP Classifier
##########################################
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_mlp.parameters(), lr=1e-3)
num_epochs = 50

for epoch in range(num_epochs):
    model_mlp.train()
    train_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model_mlp(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
    train_loss /= len(train_dataset)
    
    # Evaluate on the validation set.
    model_mlp.eval()
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model_mlp(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu().numpy())
            all_targets.append(y.cpu().numpy())
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    val_acc = accuracy_score(all_targets, all_preds)
    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {train_loss:.4f}, Val Acc = {val_acc:.4f}")

Epoch 1/50: Train Loss = 1.0930, Val Acc = 0.4797
Epoch 2/50: Train Loss = 1.0628, Val Acc = 0.4244
Epoch 3/50: Train Loss = 1.0507, Val Acc = 0.4360
Epoch 4/50: Train Loss = 1.0129, Val Acc = 0.4419
Epoch 5/50: Train Loss = 0.9632, Val Acc = 0.3779
Epoch 6/50: Train Loss = 0.8950, Val Acc = 0.3779
Epoch 7/50: Train Loss = 0.8005, Val Acc = 0.3779
Epoch 8/50: Train Loss = 0.7174, Val Acc = 0.3983
Epoch 9/50: Train Loss = 0.6405, Val Acc = 0.3198
Epoch 10/50: Train Loss = 0.5782, Val Acc = 0.3779
Epoch 11/50: Train Loss = 0.5126, Val Acc = 0.3605
Epoch 12/50: Train Loss = 0.4780, Val Acc = 0.2645
Epoch 13/50: Train Loss = 0.4443, Val Acc = 0.3750
Epoch 14/50: Train Loss = 0.4350, Val Acc = 0.3488
Epoch 15/50: Train Loss = 0.3975, Val Acc = 0.4244
Epoch 16/50: Train Loss = 0.3879, Val Acc = 0.3808
Epoch 17/50: Train Loss = 0.3754, Val Acc = 0.3663
Epoch 18/50: Train Loss = 0.3884, Val Acc = 0.3517
Epoch 19/50: Train Loss = 0.3151, Val Acc = 0.3169


KeyboardInterrupt: 