In [1]:
import nibabel as nib
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import scipy.ndimage
from monai.networks.nets import resnet18
from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight

In [2]:
import os

def find_files_with_substring(directory, substring):
    matching_files = [f for f in os.listdir(directory) if substring in f]
    return matching_files

def get_nib_image(adni_file_name):
    return nib.load(adni_file_name).get_fdata()

def visualize_image(nib_image):
    plt.imshow(nib_image[:,:,nib_image.shape[2]//2])
    plt.show()

In [3]:
def get_image_file_names_for_subject(subject_id, date=None):
    os.path.expanduser("~/adni_flat_dataset/adni_flat_dataset")
    dir_ = "/home/rittikar-s/adni_flat_dataset/adni_flat_dataset"
    files = find_files_with_substring(dir_, subject_id)
    if date:
        files = [file for file in files if date in file]
    file_paths = [f"{dir_}/{file}" for file in files]
    return file_paths

In [4]:
import pandas as pd

df = pd.read_csv("ADNI1_Complete_1Yr_1.5T_1_26_2025.csv")

In [5]:
from monai.networks.nets.vitautoenc import ViTAutoEnc

vit_model = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(128,128,128))

def get_vit_embedding(img):
    return vit_model(img)

In [6]:
class NiftiDataset(Dataset):
    def __init__(self, image_paths, labels, target_shape=(128, 128, 128)):
        self.image_paths = image_paths
        self.labels = labels
        self.target_shape = target_shape

    def __len__(self):
        return len(self.image_paths)

    def apply_multi_bit_plane_slicing(self, img_3d, bit_planes=[6, 7]):
        """
        Apply multi-bit-plane slicing to a 3D MRI image.
        
        Args:
            img_3d (torch.Tensor or np.ndarray): A 3D image (shape: [depth, height, width]).
            bit_planes (list of int): List of bit-planes to extract (0 = LSB, 7 = MSB).
        
        Returns:
            numpy.ndarray: The combined bit-plane image.
        """
        X, Y, Z = img_3d.shape
        processed_img = np.zeros_like(img_3d, dtype=np.uint8)
        
        if isinstance(img_3d, torch.Tensor):
            img_3d = img_3d.cpu().numpy()  # Convert to NumPy if it's a tensor
    
        for z in range(Z):  # Iterate through each slice dynamically
            slice_img = img_3d[:, :, z]  # Extract the 2D slice
            
            # Ensure the slice is in 8-bit format
            slice_img = (slice_img / np.max(slice_img) * 255).astype(np.uint8)  
            
            bit_sliced = np.zeros_like(slice_img, dtype=np.uint8)
    
            # Combine selected bit planes
            for bit in bit_planes:
                bit_sliced |= ((slice_img >> bit) & 1) << bit  
    
            processed_img[:, :, z] = bit_sliced  # Store processed slice back
    
        return processed_img
    
    def preprocess_nifti(self, nifti_path, threshold_value=80):
        # Load the NIfTI file
        img = nib.load(nifti_path).get_fdata()
    
        # Resize the image to the target shape
        img_resized = scipy.ndimage.zoom(img, np.array(self.target_shape) / np.array(img.shape), order=1)
    
        # Normalize intensity to [0, 1] for neural network
        img_normalized = (img_resized - np.min(img_resized)) / (np.max(img_resized) - np.min(img_resized) + 1e-8)
        # img_normalized = (img_normalized * 255).astype(np.uint8)  # You can adjust this based on model input needs
    
        # Create an empty array to store the processed result
        processed_img = self.apply_multi_bit_plane_slicing(img_normalized, [5, 6, 7])
    
        # Convert the processed image to a tensor and add a channel dimension (assuming 1 channel)
        return torch.tensor(processed_img, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    
    def __getitem__(self, idx):
        image = self.preprocess_nifti(self.image_paths[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        embedding = get_vit_embedding(image.reshape(1,1,128,128,128))
        return embedding, label

In [7]:
class_to_label = {
    "CN": 0,
    "MCI": 1,
    "AD": 2
}
image_paths = []
labels = []

for i in range(len(df)):
    row = df.iloc[i]
    subject = row["Subject"]
    date = row["Acq Date"]
    date = date.replace("/", "-")
    image_path = get_image_file_names_for_subject(subject, date)[0]
    image_paths.append(image_path)
    labels.append(class_to_label[row["Group"]])

In [8]:
len(image_paths)

2294

In [9]:
len(labels)

2294

In [10]:
from sklearn.model_selection import train_test_split
train_paths, test_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.3, random_state=42, stratify=labels)
val_paths, test_paths, val_labels, test_labels = train_test_split(test_paths, test_labels, test_size=0.5, random_state=42, stratify=test_labels)

In [11]:
# Create train & test datasets
train_dataset = NiftiDataset(train_paths, train_labels)
val_dataset = NiftiDataset(val_paths, val_labels)
test_dataset = NiftiDataset(test_paths, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=True)

print(f"Train Batches: {len(train_loader)}, Val Batches: {len(val_loader)}, Test Batches: {len(test_loader)}")

Train Batches: 402, Val Batches: 86, Test Batches: 87


In [12]:
import h5py

def save_embeddings_hdf5(dataloader, filename):
    """Save embeddings (from list format) and labels incrementally to an HDF5 file."""
    with h5py.File(filename, "w") as f:
        first_batch = True
        for i, (embedding_list, label) in enumerate(dataloader):
            # Extract the last layer embeddings
            embedding_tensor = embedding_list[1][-1]  # Extract final layer embeddings
            embedding_tensor = embedding_tensor.cpu().detach()  # Move to CPU
            
            embedding_numpy = embedding_tensor.numpy()  # Convert to NumPy
            label_numpy = label.cpu().numpy()

            # Reshape embeddings if needed
            embedding_numpy = embedding_numpy.reshape(embedding_numpy.shape[0], -1)  # (4, 1, 512, 768) → (4, 512 * 768)

            if first_batch:
                # Create expandable datasets with correct shape
                f.create_dataset("embeddings", data=embedding_numpy, 
                                 maxshape=(None, embedding_numpy.shape[1]))  # Now 2D
                f.create_dataset("labels", data=label_numpy, maxshape=(None,))
                first_batch = False
            else:
                # Resize and append new embeddings
                f["embeddings"].resize((f["embeddings"].shape[0] + embedding_numpy.shape[0]), axis=0)
                f["embeddings"][-embedding_numpy.shape[0]:] = embedding_numpy

                f["labels"].resize((f["labels"].shape[0] + label_numpy.shape[0]), axis=0)
                f["labels"][-label_numpy.shape[0]:] = label_numpy
            
            print(f"Saved embeddings for batch: {i+1}")

In [13]:
save_embeddings_hdf5(train_loader, "train_embeddings.h5")
save_embeddings_hdf5(val_loader, "val_embeddings.h5")
save_embeddings_hdf5(test_loader, "test_embeddings.h5")

invalid value encountered in divide
invalid value encountered in cast


Saved embeddings for batch: 1
Saved embeddings for batch: 2
Saved embeddings for batch: 3
Saved embeddings for batch: 4
Saved embeddings for batch: 5
Saved embeddings for batch: 6
Saved embeddings for batch: 7
Saved embeddings for batch: 8
Saved embeddings for batch: 9
Saved embeddings for batch: 10
Saved embeddings for batch: 11
Saved embeddings for batch: 12
Saved embeddings for batch: 13
Saved embeddings for batch: 14
Saved embeddings for batch: 15
Saved embeddings for batch: 16
Saved embeddings for batch: 17
Saved embeddings for batch: 18
Saved embeddings for batch: 19
Saved embeddings for batch: 20
Saved embeddings for batch: 21
Saved embeddings for batch: 22
Saved embeddings for batch: 23
Saved embeddings for batch: 24
Saved embeddings for batch: 25
Saved embeddings for batch: 26
Saved embeddings for batch: 27
Saved embeddings for batch: 28
Saved embeddings for batch: 29
Saved embeddings for batch: 30
Saved embeddings for batch: 31
Saved embeddings for batch: 32
Saved embeddings 

In [24]:
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [25]:
def load_embeddings_hdf5(filename, batch_size=4):
    with h5py.File(filename, "r") as f:
        num_samples = f["embeddings"].shape[0]  # Total samples
        for i in range(0, num_samples, batch_size):
            X_batch = torch.tensor(f["embeddings"][i : i + batch_size], dtype=torch.float32).to(DEVICE)
            y_batch = torch.tensor(f["labels"][i : i + batch_size], dtype=torch.long).to(DEVICE)
            yield X_batch, y_batch

In [26]:
import torch.nn as nn
import torch.optim as optim

class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.batch_norm1 = nn.BatchNorm1d(1024)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        
        self.fc2 = nn.Linear(1024, 512)
        self.batch_norm2 = nn.BatchNorm1d(512)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(512, 256)
        self.batch_norm3 = nn.BatchNorm1d(256)
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.5)
        
        self.fc4 = nn.Linear(256, num_classes)  # Final layer

    def forward(self, x):
        x = self.fc1(x)
        if x.shape[0] > 1:  # Apply BatchNorm only if batch size > 1
            x = self.batch_norm1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
    
        x = self.fc2(x)
        if x.shape[0] > 1:
            x = self.batch_norm2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
    
        x = self.fc3(x)
        if x.shape[0] > 1:
            x = self.batch_norm3(x)
        x = self.relu3(x)
        x = self.dropout3(x)
        
        x = self.fc4(x)
        return x

In [34]:
input_dim = next(load_embeddings_hdf5("train_embeddings.h5"))[0].shape[1]  # Get feature size
num_classes = 3  # Adjust based on labels
model = MLPClassifier(input_dim, num_classes).to(DEVICE)

In [35]:
model

MLPClassifier(
  (fc1): Linear(in_features=393216, out_features=1024, bias=True)
  (batch_norm1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (batch_norm2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (batch_norm3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (dropout3): Dropout(p=0.5, inplace=False)
  (fc4): Linear(in_features=256, out_features=3, bias=True)
)

In [29]:
# Compute class weights
# classes = np.unique(train_labels)
# class_weights = compute_class_weight(class_weight="balanced", classes=classes, y=train_labels)

# criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(DEVICE)).to(DEVICE)
# optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [30]:
# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs!")
#     model = torch.nn.DataParallel(model)

# model = model.to("cuda") 

In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from tqdm import tqdm  # Progress tracking

# Compute class weights
def compute_class_weights(y_train):
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
    return torch.tensor(class_weights, dtype=torch.float).to(DEVICE)

y_train = np.concatenate([y.cpu().numpy() for _, y in load_embeddings_hdf5("train_embeddings.h5", batch_size=32)])
class_weights = compute_class_weights(y_train)

# Define optimizer, scheduler, and scaler
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
scaler = torch.cuda.amp.GradScaler()

NUM_GPUS = 1
num_epochs = 100
patience = 15  # Early stopping patience
best_val_loss = float("inf")
epochs_without_improvement = 0

for epoch in range(num_epochs):
    model.train()
    train_loss, correct, total = 0.0, 0, 0

    # Training loop
    train_loader = load_embeddings_hdf5("train_embeddings.h5", batch_size=16 * NUM_GPUS)  # Adjust batch size
    train_bar = tqdm(train_loader, total=len(y_train) // (16 * NUM_GPUS), desc=f"Epoch {epoch+1}/{num_epochs} Train")
    for X_batch, y_batch in train_bar:
        X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():  # Mixed precision
            outputs = model(X_batch)
            loss = F.cross_entropy(outputs, y_batch)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(y_batch).sum().item()
        total += y_batch.size(0)
        
        train_bar.set_postfix(loss=loss.item(), acc=100 * correct / total)

    train_accuracy = 100 * correct / total
    avg_train_loss = train_loss / total

    # Validation loop
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    val_loader = load_embeddings_hdf5("val_embeddings.h5", batch_size=16 * NUM_GPUS)
    val_bar = tqdm(val_loader, total=len(y_train) // (16 * NUM_GPUS), desc="Validation")
    with torch.no_grad():
        for X_val, y_val in val_bar:
            X_val, y_val = X_val.to(DEVICE), y_val.to(DEVICE)
            outputs = model(X_val)
            loss = F.cross_entropy(outputs, y_val)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_correct += predicted.eq(y_val).sum().item()
            val_total += y_val.size(0)

            val_bar.set_postfix(loss=loss.item(), acc=100 * val_correct / val_total)
    
    val_accuracy = 100 * val_correct / val_total
    avg_val_loss = val_loss / val_total

    # Check for early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0  # Reset counter
        torch.save(model.state_dict(), "best_model.pth")  # Save best model
    else:
        epochs_without_improvement += 1

    # Stop if no improvement for 10 epochs
    if epochs_without_improvement >= patience:
        print(f"Early stopping at epoch {epoch+1}. No improvement in validation loss for {patience} epochs.")
        break

    # Learning rate scheduler step
    scheduler.step(avg_val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")

The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.                         | 0/100 [00:00<?, ?it/s]
Epoch 1/100 Train: 101it [00:11,  8.62it/s, acc=33.8, loss=1.36]                                                                                            
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 52.61it/s, acc=47.7, loss=1.02]


Epoch [1/100]
  Train Loss: 0.0752 | Train Acc: 33.83%
  Val Loss: 0.0664 | Val Acc: 47.67%


Epoch 2/100 Train: 101it [00:11,  8.66it/s, acc=44.2, loss=0.791]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.41it/s, acc=47.7, loss=0.952]


Epoch [2/100]
  Train Loss: 0.0678 | Train Acc: 44.17%
  Val Loss: 0.0656 | Val Acc: 47.67%


Epoch 3/100 Train: 101it [00:11,  8.61it/s, acc=48.2, loss=0.761]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.09it/s, acc=51.2, loss=0.965]


Epoch [3/100]
  Train Loss: 0.0642 | Train Acc: 48.22%
  Val Loss: 0.0645 | Val Acc: 51.16%


Epoch 4/100 Train: 101it [00:11,  8.60it/s, acc=55.8, loss=0.698]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.26it/s, acc=57.6, loss=0.859]


Epoch [4/100]
  Train Loss: 0.0583 | Train Acc: 55.83%
  Val Loss: 0.0622 | Val Acc: 57.56%


Epoch 5/100 Train: 101it [00:11,  8.60it/s, acc=59.5, loss=0.75]                                                                                            
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 53.43it/s, acc=61, loss=0.825]


Epoch [5/100]
  Train Loss: 0.0567 | Train Acc: 59.50%
  Val Loss: 0.0607 | Val Acc: 61.05%


Epoch 6/100 Train: 101it [00:11,  8.62it/s, acc=64.9, loss=0.827]                                                                                           
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 53.40it/s, acc=59.3, loss=0.78]


Epoch [6/100]
  Train Loss: 0.0530 | Train Acc: 64.92%
  Val Loss: 0.0596 | Val Acc: 59.30%


Epoch 7/100 Train: 101it [00:11,  8.63it/s, acc=68.2, loss=0.32]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.97it/s, acc=60.8, loss=0.735]


Epoch [7/100]
  Train Loss: 0.0498 | Train Acc: 68.16%
  Val Loss: 0.0584 | Val Acc: 60.76%


Epoch 8/100 Train: 101it [00:11,  8.64it/s, acc=72.8, loss=0.556]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.58it/s, acc=62.2, loss=0.703]


Epoch [8/100]
  Train Loss: 0.0459 | Train Acc: 72.77%
  Val Loss: 0.0570 | Val Acc: 62.21%


Epoch 9/100 Train: 101it [00:11,  8.65it/s, acc=76.1, loss=0.452]                                                                                           
Validation:  22%|███████████████████▏                                                                   | 22/100 [00:00<00:01, 52.17it/s, acc=61, loss=0.69]


Epoch [9/100]
  Train Loss: 0.0429 | Train Acc: 76.07%
  Val Loss: 0.0560 | Val Acc: 61.05%


Epoch 10/100 Train: 101it [00:11,  8.62it/s, acc=79.8, loss=0.47]                                                                                           
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 52.66it/s, acc=62.2, loss=0.66]


Epoch [10/100]
  Train Loss: 0.0401 | Train Acc: 79.75%
  Val Loss: 0.0547 | Val Acc: 62.21%


Epoch 11/100 Train: 101it [00:11,  8.66it/s, acc=83.8, loss=0.375]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.82it/s, acc=63.1, loss=0.673]


Epoch [11/100]
  Train Loss: 0.0364 | Train Acc: 83.80%
  Val Loss: 0.0544 | Val Acc: 63.08%


Epoch 12/100 Train: 101it [00:11,  8.63it/s, acc=87, loss=0.45]                                                                                             
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.70it/s, acc=62.2, loss=0.645]


Epoch [12/100]
  Train Loss: 0.0328 | Train Acc: 87.04%
  Val Loss: 0.0535 | Val Acc: 62.21%


Epoch 13/100 Train: 101it [00:11,  8.63it/s, acc=88.1, loss=0.432]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.77it/s, acc=64.5, loss=0.627]


Epoch [13/100]
  Train Loss: 0.0310 | Train Acc: 88.10%
  Val Loss: 0.0527 | Val Acc: 64.53%


Epoch 14/100 Train: 101it [00:11,  8.65it/s, acc=90.2, loss=0.2]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.80it/s, acc=65.1, loss=0.562]


Epoch [14/100]
  Train Loss: 0.0278 | Train Acc: 90.16%
  Val Loss: 0.0519 | Val Acc: 65.12%


Epoch 15/100 Train: 101it [00:11,  8.61it/s, acc=92.3, loss=0.404]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.89it/s, acc=64.2, loss=0.605]


Epoch [15/100]
  Train Loss: 0.0250 | Train Acc: 92.27%
  Val Loss: 0.0509 | Val Acc: 64.24%


Epoch 16/100 Train: 101it [00:11,  8.60it/s, acc=93.4, loss=0.295]                                                                                          
Validation:  22%|███████████████████▏                                                                   | 22/100 [00:00<00:01, 50.16it/s, acc=64, loss=0.57]


Epoch [16/100]
  Train Loss: 0.0230 | Train Acc: 93.40%
  Val Loss: 0.0498 | Val Acc: 63.95%


Epoch 17/100 Train: 101it [00:11,  8.63it/s, acc=94.6, loss=0.235]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.01it/s, acc=63.7, loss=0.596]


Epoch [17/100]
  Train Loss: 0.0208 | Train Acc: 94.58%
  Val Loss: 0.0501 | Val Acc: 63.66%


Epoch 18/100 Train: 101it [00:11,  8.62it/s, acc=96.6, loss=0.233]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.32it/s, acc=64.5, loss=0.566]


Epoch [18/100]
  Train Loss: 0.0185 | Train Acc: 96.64%
  Val Loss: 0.0498 | Val Acc: 64.53%


Epoch 19/100 Train: 101it [00:11,  8.61it/s, acc=97.3, loss=0.15]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.79it/s, acc=66.3, loss=0.588]


Epoch [19/100]
  Train Loss: 0.0168 | Train Acc: 97.32%
  Val Loss: 0.0496 | Val Acc: 66.28%


Epoch 20/100 Train: 101it [00:11,  8.65it/s, acc=97.6, loss=0.18]                                                                                           
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 52.40it/s, acc=66, loss=0.549]


Epoch [20/100]
  Train Loss: 0.0154 | Train Acc: 97.63%
  Val Loss: 0.0484 | Val Acc: 65.99%


Epoch 21/100 Train: 101it [00:11,  8.63it/s, acc=98.3, loss=0.13]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.77it/s, acc=65.1, loss=0.501]


Epoch [21/100]
  Train Loss: 0.0138 | Train Acc: 98.26%
  Val Loss: 0.0486 | Val Acc: 65.12%


Epoch 22/100 Train: 101it [00:11,  8.62it/s, acc=98.8, loss=0.13]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.63it/s, acc=66.3, loss=0.515]


Epoch [22/100]
  Train Loss: 0.0129 | Train Acc: 98.82%
  Val Loss: 0.0482 | Val Acc: 66.28%


Epoch 23/100 Train: 101it [00:11,  8.63it/s, acc=98.8, loss=0.125]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.62it/s, acc=63.1, loss=0.517]


Epoch [23/100]
  Train Loss: 0.0116 | Train Acc: 98.82%
  Val Loss: 0.0480 | Val Acc: 63.08%


Epoch 24/100 Train: 101it [00:11,  8.63it/s, acc=98.9, loss=0.113]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.75it/s, acc=64.5, loss=0.556]


Epoch [24/100]
  Train Loss: 0.0109 | Train Acc: 98.94%
  Val Loss: 0.0488 | Val Acc: 64.53%


Epoch 25/100 Train: 101it [00:11,  8.63it/s, acc=99.2, loss=0.0903]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.75it/s, acc=65.4, loss=0.463]


Epoch [25/100]
  Train Loss: 0.0101 | Train Acc: 99.19%
  Val Loss: 0.0480 | Val Acc: 65.41%


Epoch 26/100 Train: 101it [00:11,  8.61it/s, acc=99.1, loss=0.0918]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.24it/s, acc=65.7, loss=0.538]


Epoch [26/100]
  Train Loss: 0.0092 | Train Acc: 99.13%
  Val Loss: 0.0478 | Val Acc: 65.70%


Epoch 27/100 Train: 101it [00:11,  8.66it/s, acc=99.3, loss=0.053]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.01it/s, acc=65.4, loss=0.514]


Epoch [27/100]
  Train Loss: 0.0085 | Train Acc: 99.31%
  Val Loss: 0.0474 | Val Acc: 65.41%


Epoch 28/100 Train: 101it [00:11,  8.61it/s, acc=99.4, loss=0.088]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.68it/s, acc=66.3, loss=0.554]


Epoch [28/100]
  Train Loss: 0.0081 | Train Acc: 99.44%
  Val Loss: 0.0480 | Val Acc: 66.28%


Epoch 29/100 Train: 101it [00:11,  8.62it/s, acc=99.4, loss=0.0968]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.79it/s, acc=64.8, loss=0.514]


Epoch [29/100]
  Train Loss: 0.0077 | Train Acc: 99.44%
  Val Loss: 0.0484 | Val Acc: 64.83%


Epoch 30/100 Train: 101it [00:11,  8.63it/s, acc=99.4, loss=0.087]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.89it/s, acc=65.4, loss=0.572]


Epoch [30/100]
  Train Loss: 0.0072 | Train Acc: 99.44%
  Val Loss: 0.0485 | Val Acc: 65.41%


Epoch 31/100 Train: 101it [00:11,  8.63it/s, acc=99.7, loss=0.0604]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.98it/s, acc=64.2, loss=0.545]


Epoch [31/100]
  Train Loss: 0.0066 | Train Acc: 99.69%
  Val Loss: 0.0479 | Val Acc: 64.24%


Epoch 32/100 Train: 101it [00:11,  8.61it/s, acc=99.5, loss=0.0548]                                                                                         
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.76it/s, acc=66, loss=0.544]


Epoch [32/100]
  Train Loss: 0.0065 | Train Acc: 99.50%
  Val Loss: 0.0478 | Val Acc: 65.99%


Epoch 33/100 Train: 101it [00:11,  8.63it/s, acc=99.7, loss=0.0771]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.46it/s, acc=65.4, loss=0.535]


Epoch [33/100]
  Train Loss: 0.0060 | Train Acc: 99.69%
  Val Loss: 0.0474 | Val Acc: 65.41%


Epoch 34/100 Train: 101it [00:11,  8.61it/s, acc=99.8, loss=0.0316]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.68it/s, acc=65.7, loss=0.517]


Epoch [34/100]
  Train Loss: 0.0057 | Train Acc: 99.75%
  Val Loss: 0.0478 | Val Acc: 65.70%


Epoch 35/100 Train: 101it [00:11,  8.64it/s, acc=99.6, loss=0.0542]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.07it/s, acc=66.6, loss=0.507]


Epoch [35/100]
  Train Loss: 0.0057 | Train Acc: 99.56%
  Val Loss: 0.0478 | Val Acc: 66.57%


Epoch 36/100 Train: 101it [00:11,  8.61it/s, acc=99.8, loss=0.0226]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.77it/s, acc=65.1, loss=0.574]


Epoch [36/100]
  Train Loss: 0.0052 | Train Acc: 99.81%
  Val Loss: 0.0481 | Val Acc: 65.12%


Epoch 37/100 Train: 101it [00:11,  8.64it/s, acc=99.6, loss=0.0306]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.51it/s, acc=65.7, loss=0.397]


Epoch [37/100]
  Train Loss: 0.0050 | Train Acc: 99.63%
  Val Loss: 0.0481 | Val Acc: 65.70%


Epoch 38/100 Train: 101it [00:11,  8.63it/s, acc=99.6, loss=0.0732]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.13it/s, acc=66.9, loss=0.528]


Epoch [38/100]
  Train Loss: 0.0051 | Train Acc: 99.56%
  Val Loss: 0.0494 | Val Acc: 66.86%


Epoch 39/100 Train: 101it [00:11,  8.67it/s, acc=99.6, loss=0.026]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.99it/s, acc=63.7, loss=0.491]


Epoch [39/100]
  Train Loss: 0.0047 | Train Acc: 99.63%
  Val Loss: 0.0500 | Val Acc: 63.66%


Epoch 40/100 Train: 101it [00:11,  8.62it/s, acc=99.8, loss=0.0145]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.64it/s, acc=64.5, loss=0.506]


Epoch [40/100]
  Train Loss: 0.0043 | Train Acc: 99.81%
  Val Loss: 0.0491 | Val Acc: 64.53%


Epoch 41/100 Train: 101it [00:11,  8.69it/s, acc=99.8, loss=0.0366]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.07it/s, acc=64.5, loss=0.509]


Epoch [41/100]
  Train Loss: 0.0042 | Train Acc: 99.75%
  Val Loss: 0.0493 | Val Acc: 64.53%


Epoch 42/100 Train: 101it [00:11,  8.58it/s, acc=100, loss=0.027]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.94it/s, acc=63.7, loss=0.535]


Epoch [42/100]
  Train Loss: 0.0038 | Train Acc: 100.00%
  Val Loss: 0.0493 | Val Acc: 63.66%


Epoch 43/100 Train: 101it [00:11,  8.62it/s, acc=99.8, loss=0.0522]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.79it/s, acc=64.2, loss=0.552]


Epoch [43/100]
  Train Loss: 0.0038 | Train Acc: 99.81%
  Val Loss: 0.0491 | Val Acc: 64.24%


Epoch 44/100 Train: 101it [00:11,  8.63it/s, acc=99.9, loss=0.0259]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.62it/s, acc=64.5, loss=0.507]


Epoch [44/100]
  Train Loss: 0.0037 | Train Acc: 99.94%
  Val Loss: 0.0495 | Val Acc: 64.53%


Epoch 45/100 Train: 101it [00:11,  8.62it/s, acc=99.9, loss=0.0312]                                                                                         
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 51.97it/s, acc=64, loss=0.518]


Epoch [45/100]
  Train Loss: 0.0038 | Train Acc: 99.88%
  Val Loss: 0.0495 | Val Acc: 63.95%


Epoch 46/100 Train: 101it [00:11,  8.61it/s, acc=99.8, loss=0.0291]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.06it/s, acc=64.5, loss=0.525]


Epoch [46/100]
  Train Loss: 0.0035 | Train Acc: 99.81%
  Val Loss: 0.0494 | Val Acc: 64.53%


Epoch 47/100 Train: 101it [00:11,  8.65it/s, acc=99.9, loss=0.0202]                                                                                         
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 52.84it/s, acc=64, loss=0.533]


Epoch [47/100]
  Train Loss: 0.0035 | Train Acc: 99.88%
  Val Loss: 0.0490 | Val Acc: 63.95%


Epoch 48/100 Train: 101it [00:11,  8.63it/s, acc=99.9, loss=0.0196]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 53.13it/s, acc=64.5, loss=0.511]

Early stopping at epoch 48. No improvement in validation loss for 15 epochs.





In [39]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Inference
y_true, y_pred = [], []

with torch.no_grad():
    for X_batch, y_batch in load_embeddings_hdf5("test_embeddings.h5", batch_size=4):
        outputs = model(X_batch)
        predicted_labels = torch.argmax(outputs, dim=1)

        y_true.extend(y_batch.cpu().numpy())
        y_pred.extend(predicted_labels.cpu().numpy())

# Compute metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average="weighted")  # "weighted" accounts for class imbalance
recall = recall_score(y_true, y_pred, average="weighted")
f1 = f1_score(y_true, y_pred, average="weighted")

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

Test Accuracy: 0.6493
Precision: 0.6486
Recall: 0.6493
F1 Score: 0.6437


In [40]:
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.61      0.58      0.60       106
           1       0.67      0.77      0.71       167
           2       0.67      0.47      0.55        72

    accuracy                           0.65       345
   macro avg       0.65      0.61      0.62       345
weighted avg       0.65      0.65      0.64       345

