In [1]:
import nibabel as nib
import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import scipy.ndimage
from monai.networks.nets import resnet18
from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight

In [2]:
import os

def find_files_with_substring(directory, substring):
    matching_files = [f for f in os.listdir(directory) if substring in f]
    return matching_files

def get_nib_image(adni_file_name):
    return nib.load(adni_file_name).get_fdata()

def visualize_image(nib_image):
    plt.imshow(nib_image[:,:,nib_image.shape[2]//2])
    plt.show()

In [3]:
def get_image_file_names_for_subject(subject_id, date=None):
    os.path.expanduser("~/adni_flat_dataset/adni_flat_dataset")
    dir_ = "/home/rittikar-s/adni_flat_dataset/adni_flat_dataset"
    files = find_files_with_substring(dir_, subject_id)
    if date:
        files = [file for file in files if date in file]
    file_paths = [f"{dir_}/{file}" for file in files]
    return file_paths

In [4]:
import pandas as pd

df = pd.read_csv("ADNI1_Complete_1Yr_1.5T_1_26_2025.csv")

In [5]:
from monai.networks.nets.vitautoenc import ViTAutoEnc

vit_model = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(128,128,128))

def get_vit_embedding(img):
    return vit_model(img)

In [6]:
class NiftiDataset(Dataset):
    def __init__(self, image_paths, labels, target_shape=(128, 128, 128)):
        self.image_paths = image_paths
        self.labels = labels
        self.target_shape = target_shape

    def __len__(self):
        return len(self.image_paths)

    def preprocess_nifti(self, nifti_path, threshold_value=80):
        # Load the NIfTI file
        img = nib.load(nifti_path).get_fdata()
    
        # Resize the image to the target shape
        img_resized = scipy.ndimage.zoom(img, np.array(self.target_shape) / np.array(img.shape), order=1)
    
        # Normalize intensity to [0, 1] for neural network
        img_normalized = (img_resized - np.min(img_resized)) / (np.max(img_resized) - np.min(img_resized) + 1e-8)
        img_normalized = (img_normalized * 255).astype(np.uint8)  # You can adjust this based on model input needs
    
        # Create an empty array to store the processed result
        processed_img = np.zeros_like(img_normalized)
    
        # Iterate over each slice in the z-axis (assuming this is a 3D image)
        for z in range(img_normalized.shape[2]):
            slice_img = img_normalized[:, :, z]
            
            # Apply Sharpening to the slice using a kernel
            sharpening_kernel = np.array([[0, -1, 0], [-1, 7, -1], [0, -1, 0]])
            sharpened_slice = cv2.filter2D(slice_img, -1, sharpening_kernel)
    
            # Threshold the low-intensity regions (set them to black)
            thresholded_slice = np.where(sharpened_slice < threshold_value, 0, sharpened_slice)
            
            # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to the thresholded image
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(16, 16))
            clahe_slice = clahe.apply(thresholded_slice)
    
            # Store the processed slice back into the 3D image
            processed_img[:, :, z] = clahe_slice
    
        # Convert the processed image to a tensor and add a channel dimension (assuming 1 channel)
        return torch.tensor(processed_img, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    
    def __getitem__(self, idx):
        image = self.preprocess_nifti(self.image_paths[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        embedding = get_vit_embedding(image.reshape(1,1,128,128,128))
        return embedding, label

In [7]:
class_to_label = {
    "CN": 0,
    "MCI": 1,
    "AD": 2
}
image_paths = []
labels = []

for i in range(len(df)):
    row = df.iloc[i]
    subject = row["Subject"]
    date = row["Acq Date"]
    date = date.replace("/", "-")
    image_path = get_image_file_names_for_subject(subject, date)[0]
    image_paths.append(image_path)
    labels.append(class_to_label[row["Group"]])

In [8]:
len(image_paths)

2294

In [9]:
len(labels)

2294

In [10]:
from sklearn.model_selection import train_test_split
train_paths, test_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.3, random_state=42, stratify=labels)
val_paths, test_paths, val_labels, test_labels = train_test_split(test_paths, test_labels, test_size=0.5, random_state=42, stratify=test_labels)

In [11]:
# Create train & test datasets
train_dataset = NiftiDataset(train_paths, train_labels)
val_dataset = NiftiDataset(val_paths, val_labels)
test_dataset = NiftiDataset(test_paths, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=True)

print(f"Train Batches: {len(train_loader)}, Val Batches: {len(val_loader)}, Test Batches: {len(test_loader)}")

Train Batches: 402, Val Batches: 86, Test Batches: 87


In [12]:
import h5py

def save_embeddings_hdf5(dataloader, filename):
    """Save embeddings (from list format) and labels incrementally to an HDF5 file."""
    with h5py.File(filename, "w") as f:
        first_batch = True
        for i, (embedding_list, label) in enumerate(dataloader):
            # Extract the last layer embeddings
            embedding_tensor = embedding_list[1][-1]  # Extract final layer embeddings
            embedding_tensor = embedding_tensor.cpu().detach()  # Move to CPU
            
            embedding_numpy = embedding_tensor.numpy()  # Convert to NumPy
            label_numpy = label.cpu().numpy()

            # Reshape embeddings if needed
            embedding_numpy = embedding_numpy.reshape(embedding_numpy.shape[0], -1)  # (4, 1, 512, 768) → (4, 512 * 768)

            if first_batch:
                # Create expandable datasets with correct shape
                f.create_dataset("embeddings", data=embedding_numpy, 
                                 maxshape=(None, embedding_numpy.shape[1]))  # Now 2D
                f.create_dataset("labels", data=label_numpy, maxshape=(None,))
                first_batch = False
            else:
                # Resize and append new embeddings
                f["embeddings"].resize((f["embeddings"].shape[0] + embedding_numpy.shape[0]), axis=0)
                f["embeddings"][-embedding_numpy.shape[0]:] = embedding_numpy

                f["labels"].resize((f["labels"].shape[0] + label_numpy.shape[0]), axis=0)
                f["labels"][-label_numpy.shape[0]:] = label_numpy
            
            print(f"Saved embeddings for batch: {i+1}")

In [13]:
save_embeddings_hdf5(train_loader, "train_embeddings.h5")
save_embeddings_hdf5(val_loader, "val_embeddings.h5")
save_embeddings_hdf5(test_loader, "test_embeddings.h5")

Saved embeddings for batch: 1
Saved embeddings for batch: 2
Saved embeddings for batch: 3
Saved embeddings for batch: 4
Saved embeddings for batch: 5
Saved embeddings for batch: 6
Saved embeddings for batch: 7
Saved embeddings for batch: 8
Saved embeddings for batch: 9
Saved embeddings for batch: 10
Saved embeddings for batch: 11
Saved embeddings for batch: 12
Saved embeddings for batch: 13
Saved embeddings for batch: 14
Saved embeddings for batch: 15
Saved embeddings for batch: 16
Saved embeddings for batch: 17
Saved embeddings for batch: 18
Saved embeddings for batch: 19
Saved embeddings for batch: 20
Saved embeddings for batch: 21
Saved embeddings for batch: 22
Saved embeddings for batch: 23
Saved embeddings for batch: 24
Saved embeddings for batch: 25
Saved embeddings for batch: 26
Saved embeddings for batch: 27
Saved embeddings for batch: 28
Saved embeddings for batch: 29
Saved embeddings for batch: 30
Saved embeddings for batch: 31
Saved embeddings for batch: 32
Saved embeddings 

In [14]:
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [15]:
def load_embeddings_hdf5(filename, batch_size=4):
    with h5py.File(filename, "r") as f:
        num_samples = f["embeddings"].shape[0]  # Total samples
        for i in range(0, num_samples, batch_size):
            X_batch = torch.tensor(f["embeddings"][i : i + batch_size], dtype=torch.float32).to(DEVICE)
            y_batch = torch.tensor(f["labels"][i : i + batch_size], dtype=torch.long).to(DEVICE)
            yield X_batch, y_batch

In [16]:
import torch.nn as nn
import torch.optim as optim

class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.batch_norm1 = nn.BatchNorm1d(1024)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        
        self.fc2 = nn.Linear(1024, 512)
        self.batch_norm2 = nn.BatchNorm1d(512)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(512, 256)
        self.batch_norm3 = nn.BatchNorm1d(256)
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.5)
        
        self.fc4 = nn.Linear(256, num_classes)  # Final layer

    def forward(self, x):
        x = self.fc1(x)
        if x.shape[0] > 1:  # Apply BatchNorm only if batch size > 1
            x = self.batch_norm1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
    
        x = self.fc2(x)
        if x.shape[0] > 1:
            x = self.batch_norm2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
    
        x = self.fc3(x)
        if x.shape[0] > 1:
            x = self.batch_norm3(x)
        x = self.relu3(x)
        x = self.dropout3(x)
        
        x = self.fc4(x)
        return x

In [17]:
input_dim = next(load_embeddings_hdf5("train_embeddings.h5"))[0].shape[1]  # Get feature size
num_classes = 3  # Adjust based on labels
model = MLPClassifier(input_dim, num_classes).to(DEVICE)

In [18]:
model

MLPClassifier(
  (fc1): Linear(in_features=393216, out_features=1024, bias=True)
  (batch_norm1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (batch_norm2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (batch_norm3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (dropout3): Dropout(p=0.5, inplace=False)
  (fc4): Linear(in_features=256, out_features=3, bias=True)
)

In [42]:
# Compute class weights
# classes = np.unique(train_labels)
# class_weights = compute_class_weight(class_weight="balanced", classes=classes, y=train_labels)

# criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(DEVICE)).to(DEVICE)
# optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [43]:
# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs!")
#     model = torch.nn.DataParallel(model)

# model = model.to("cuda") 

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from tqdm import tqdm  # Progress tracking

NUM_GPUS = 1

# Compute class weights
def compute_class_weights(y_train):
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
    return torch.tensor(class_weights, dtype=torch.float).to(DEVICE)

y_train = np.concatenate([y.cpu().numpy() for _, y in load_embeddings_hdf5("train_embeddings.h5", batch_size=4)])
class_weights = compute_class_weights(y_train)

# Define optimizer, scheduler, and scaler
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
scaler = torch.cuda.amp.GradScaler()

num_epochs = 100
patience = 10  # Early stopping patience
best_val_acc = 0.0
epochs_without_improvement = 0

for epoch in range(num_epochs):
    model.train()
    train_loss, correct, total = 0.0, 0, 0

    # Training loop
    train_loader = load_embeddings_hdf5("train_embeddings.h5", batch_size=4 * NUM_GPUS)  # Adjust batch size
    train_bar = tqdm(train_loader, total=len(y_train) // (4 * NUM_GPUS), desc=f"Epoch {epoch+1}/{num_epochs} Train")
    for X_batch, y_batch in train_bar:
        X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():  # Mixed precision
            outputs = model(X_batch)
            loss = F.cross_entropy(outputs, y_batch, weight=class_weights)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(y_batch).sum().item()
        total += y_batch.size(0)
        
        train_bar.set_postfix(loss=loss.item(), acc=100 * correct / total)

    train_accuracy = 100 * correct / total
    avg_train_loss = train_loss / total

    # Validation loop
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    val_loader = load_embeddings_hdf5("val_embeddings.h5", batch_size=4 * NUM_GPUS)
    val_bar = tqdm(val_loader, total=len(y_train) // (4 * NUM_GPUS), desc="Validation")
    with torch.no_grad():
        for X_val, y_val in val_bar:
            X_val, y_val = X_val.to(DEVICE), y_val.to(DEVICE)
            outputs = model(X_val)
            loss = F.cross_entropy(outputs, y_val, weight=class_weights)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_correct += predicted.eq(y_val).sum().item()
            val_total += y_val.size(0)

            val_bar.set_postfix(loss=loss.item(), acc=100 * val_correct / val_total)
    
    val_accuracy = 100 * val_correct / val_total
    avg_val_loss = val_loss / val_total

    # Check for early stopping based on accuracy
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        epochs_without_improvement = 0  # Reset counter
        torch.save(model.state_dict(), "best_model.pth")  # Save best model
    else:
        epochs_without_improvement += 1

    # Stop if no improvement for 10 epochs
    if epochs_without_improvement >= patience:
        print(f"Early stopping at epoch {epoch+1}. No improvement in validation accuracy for {patience} epochs.")
        break

    # Learning rate scheduler step based on validation accuracy
    scheduler.step(val_accuracy)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")

The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.                         | 0/401 [00:00<?, ?it/s]
Epoch 1/100 Train: 402it [00:41,  9.60it/s, acc=35.5, loss=86.4]                                                                                            
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 113.00it/s, acc=48.3, loss=0.992]


Epoch [1/100]
  Train Loss: 0.3447 | Train Acc: 35.45%
  Val Loss: 0.2701 | Val Acc: 48.26%


Epoch 2/100 Train: 402it [00:42,  9.54it/s, acc=38.4, loss=44.3]                                                                                            
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 122.26it/s, acc=47.7, loss=0.994]


Epoch [2/100]
  Train Loss: 0.3108 | Train Acc: 38.44%
  Val Loss: 0.2680 | Val Acc: 47.67%


Epoch 3/100 Train: 402it [00:42,  9.54it/s, acc=40, loss=148]                                                                                               
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 117.28it/s, acc=50.9, loss=0.929]


Epoch [3/100]
  Train Loss: 0.3725 | Train Acc: 40.00%
  Val Loss: 0.2659 | Val Acc: 50.87%


Epoch 4/100 Train: 402it [00:42,  9.54it/s, acc=41.4, loss=168]                                                                                             
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 119.32it/s, acc=51.2, loss=0.922]


Epoch [4/100]
  Train Loss: 0.3794 | Train Acc: 41.37%
  Val Loss: 0.2660 | Val Acc: 51.16%


Epoch 5/100 Train: 402it [00:42,  9.52it/s, acc=44.1, loss=0]                                                                                               
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 122.96it/s, acc=52.3, loss=0.892]


Epoch [5/100]
  Train Loss: 0.2695 | Train Acc: 44.11%
  Val Loss: 0.2650 | Val Acc: 52.33%


Epoch 6/100 Train: 402it [00:42,  9.52it/s, acc=44.7, loss=29.8]                                                                                            
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 115.52it/s, acc=53.5, loss=0.885]


Epoch [6/100]
  Train Loss: 0.2857 | Train Acc: 44.67%
  Val Loss: 0.2631 | Val Acc: 53.49%


Epoch 7/100 Train: 402it [00:42,  9.49it/s, acc=46.4, loss=9.54e-7]                                                                                         
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 117.70it/s, acc=52.3, loss=0.833]


Epoch [7/100]
  Train Loss: 0.2601 | Train Acc: 46.42%
  Val Loss: 0.2636 | Val Acc: 52.33%


Epoch 8/100 Train: 402it [00:42,  9.51it/s, acc=51.1, loss=0]                                                                                               
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 116.57it/s, acc=52.6, loss=0.816]


Epoch [8/100]
  Train Loss: 0.2542 | Train Acc: 51.09%
  Val Loss: 0.2617 | Val Acc: 52.62%


Epoch 9/100 Train: 402it [00:42,  9.53it/s, acc=49.8, loss=57.9]                                                                                            
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 119.78it/s, acc=51.2, loss=0.748]


Epoch [9/100]
  Train Loss: 0.2895 | Train Acc: 49.78%
  Val Loss: 0.2619 | Val Acc: 51.16%


Epoch 10/100 Train: 402it [00:42,  9.53it/s, acc=52.9, loss=419]                                                                                            
Validation:  21%|██████████████████▏                                                                  | 86/401 [00:00<00:02, 118.70it/s, acc=52, loss=0.758]


Epoch [10/100]
  Train Loss: 0.5048 | Train Acc: 52.90%
  Val Loss: 0.2626 | Val Acc: 52.03%


Epoch 11/100 Train: 402it [00:42,  9.53it/s, acc=51.2, loss=171]                                                                                            
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 120.42it/s, acc=52.3, loss=0.797]


Epoch [11/100]
  Train Loss: 0.3544 | Train Acc: 51.21%
  Val Loss: 0.2613 | Val Acc: 52.33%


Epoch 12/100 Train: 402it [00:42,  9.53it/s, acc=51.6, loss=578]                                                                                            
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 117.21it/s, acc=53.8, loss=0.737]


Epoch [12/100]
  Train Loss: 0.6055 | Train Acc: 51.59%
  Val Loss: 0.2594 | Val Acc: 53.78%


Epoch 13/100 Train: 402it [00:42,  9.54it/s, acc=55, loss=0]                                                                                                
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 119.88it/s, acc=54.4, loss=0.757]


Epoch [13/100]
  Train Loss: 0.2420 | Train Acc: 55.02%
  Val Loss: 0.2584 | Val Acc: 54.36%


Epoch 14/100 Train: 402it [00:42,  9.52it/s, acc=56.7, loss=159]                                                                                            
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 115.83it/s, acc=54.1, loss=0.726]


Epoch [14/100]
  Train Loss: 0.3335 | Train Acc: 56.70%
  Val Loss: 0.2576 | Val Acc: 54.07%


Epoch 15/100 Train: 402it [00:42,  9.52it/s, acc=56.1, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 120.69it/s, acc=54.4, loss=0.726]


Epoch [15/100]
  Train Loss: 0.2335 | Train Acc: 56.14%
  Val Loss: 0.2554 | Val Acc: 54.36%


Epoch 16/100 Train: 402it [00:42,  9.49it/s, acc=56.7, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 117.56it/s, acc=55.2, loss=0.759]


Epoch [16/100]
  Train Loss: 0.2307 | Train Acc: 56.70%
  Val Loss: 0.2529 | Val Acc: 55.23%


Epoch 17/100 Train: 402it [00:42,  9.50it/s, acc=59.7, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 118.32it/s, acc=55.2, loss=0.688]


Epoch [17/100]
  Train Loss: 0.2227 | Train Acc: 59.69%
  Val Loss: 0.2538 | Val Acc: 55.23%


Epoch 18/100 Train: 402it [00:42,  9.49it/s, acc=59.9, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 120.49it/s, acc=55.5, loss=0.732]


Epoch [18/100]
  Train Loss: 0.2245 | Train Acc: 59.94%
  Val Loss: 0.2530 | Val Acc: 55.52%


Epoch 19/100 Train: 402it [00:42,  9.54it/s, acc=62.5, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 120.17it/s, acc=56.1, loss=0.679]


Epoch [19/100]
  Train Loss: 0.2178 | Train Acc: 62.49%
  Val Loss: 0.2514 | Val Acc: 56.10%


Epoch 20/100 Train: 402it [00:42,  9.55it/s, acc=64.8, loss=238]                                                                                            
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 116.33it/s, acc=55.8, loss=0.713]


Epoch [20/100]
  Train Loss: 0.3581 | Train Acc: 64.80%
  Val Loss: 0.2516 | Val Acc: 55.81%


Epoch 21/100 Train: 402it [00:42,  9.52it/s, acc=65, loss=0]                                                                                                
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 118.51it/s, acc=58.4, loss=0.686]


Epoch [21/100]
  Train Loss: 0.2068 | Train Acc: 65.05%
  Val Loss: 0.2485 | Val Acc: 58.43%


Epoch 22/100 Train: 402it [00:42,  9.52it/s, acc=65.2, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 117.96it/s, acc=58.4, loss=0.697]


Epoch [22/100]
  Train Loss: 0.2070 | Train Acc: 65.17%
  Val Loss: 0.2485 | Val Acc: 58.43%


Epoch 23/100 Train: 402it [00:42,  9.51it/s, acc=66.2, loss=0]                                                                                              
Validation:  21%|██████████████████▏                                                                  | 86/401 [00:00<00:02, 118.37it/s, acc=57, loss=0.598]


Epoch [23/100]
  Train Loss: 0.1984 | Train Acc: 66.17%
  Val Loss: 0.2478 | Val Acc: 56.98%


Epoch 24/100 Train: 402it [00:42,  9.52it/s, acc=67.5, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 120.44it/s, acc=57.8, loss=0.657]


Epoch [24/100]
  Train Loss: 0.1926 | Train Acc: 67.48%
  Val Loss: 0.2461 | Val Acc: 57.85%


Epoch 25/100 Train: 402it [00:42,  9.52it/s, acc=68.8, loss=525]                                                                                            
Validation:  21%|██████████████████▋                                                                    | 86/401 [00:00<00:02, 119.30it/s, acc=57, loss=0.6]


Epoch [25/100]
  Train Loss: 0.5180 | Train Acc: 68.85%
  Val Loss: 0.2453 | Val Acc: 56.98%


Epoch 26/100 Train: 402it [00:42,  9.51it/s, acc=71.4, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 120.79it/s, acc=57.3, loss=0.594]


Epoch [26/100]
  Train Loss: 0.1834 | Train Acc: 71.40%
  Val Loss: 0.2482 | Val Acc: 57.27%


Epoch 27/100 Train: 402it [00:42,  9.51it/s, acc=70.3, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 115.30it/s, acc=57.8, loss=0.624]


Epoch [27/100]
  Train Loss: 0.1822 | Train Acc: 70.28%
  Val Loss: 0.2456 | Val Acc: 57.85%


Epoch 28/100 Train: 402it [00:42,  9.52it/s, acc=71.6, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 122.21it/s, acc=57.6, loss=0.629]


Epoch [28/100]
  Train Loss: 0.1817 | Train Acc: 71.59%
  Val Loss: 0.2444 | Val Acc: 57.56%


Epoch 29/100 Train: 402it [00:42,  9.52it/s, acc=72.1, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 122.01it/s, acc=57.3, loss=0.587]


Epoch [29/100]
  Train Loss: 0.1765 | Train Acc: 72.09%
  Val Loss: 0.2442 | Val Acc: 57.27%


Epoch 30/100 Train: 402it [00:42,  9.52it/s, acc=73.8, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 121.41it/s, acc=57.6, loss=0.589]


Epoch [30/100]
  Train Loss: 0.1743 | Train Acc: 73.77%
  Val Loss: 0.2458 | Val Acc: 57.56%


Epoch 31/100 Train: 402it [00:42,  9.51it/s, acc=73.2, loss=0]                                                                                              
Validation:  21%|█████████████████▊                                                                 | 86/401 [00:00<00:02, 121.66it/s, acc=58.1, loss=0.577]

Early stopping at epoch 31. No improvement in validation accuracy for 10 epochs.





In [20]:
# Inference
y_true, y_pred = [], []

with torch.no_grad():
    for X_batch, y_batch in load_embeddings_hdf5("test_embeddings.h5", batch_size=4):
        outputs = model(X_batch)
        predicted_labels = torch.argmax(outputs, dim=1)

        y_true.extend(y_batch.cpu().numpy())
        y_pred.extend(predicted_labels.cpu().numpy())

# Compute accuracy
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_true, y_pred)
print(f"Test Accuracy: {accuracy:.4f}")

Test Accuracy: 0.5507
