In [1]:
import nibabel as nib
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import scipy.ndimage
from monai.networks.nets import resnet18
from torch.utils.data import Dataset, DataLoader
from sklearn.utils.class_weight import compute_class_weight

In [2]:
def preprocess_nifti(nifti_path, target_shape=(128, 128, 128)):
    # Normalize intensity to [0,1]
    img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
    # Resize to target shape: 
    img_resized = scipy.ndimage.zoom(img, np.array(target_shape) / np.array(img.shape), order=1)
    return img_resized

In [3]:
import os

def find_files_with_substring(directory, substring):
    matching_files = [f for f in os.listdir(directory) if substring in f]
    return matching_files

def get_nib_image(adni_file_name):
    return nib.load(adni_file_name).get_fdata()

def visualize_image(nib_image):
    plt.imshow(nib_image[:,:,nib_image.shape[2]//2])
    plt.show()

In [11]:
def get_image_file_names_for_subject(subject_id, date=None):
    os.path.expanduser("~/adni_flat_dataset/adni_flat_dataset")
    dir_ = "/home/rittikar-s/adni_flat_dataset/adni_flat_dataset"
    files = find_files_with_substring(dir_, subject_id)
    if date:
        files = [file for file in files if date in file]
    file_paths = [f"{dir_}/{file}" for file in files]
    return file_paths

In [12]:
import pandas as pd

df = pd.read_csv("ADNI1_Complete_1Yr_1.5T_1_26_2025.csv")

In [13]:
from monai.networks.nets.vitautoenc import ViTAutoEnc

vit_model = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(128,128,128))

def get_vit_embedding(img):
    return vit_model(img)

In [14]:
class NiftiDataset(Dataset):
    def __init__(self, image_paths, labels, target_shape=(128, 128, 128)):
        self.image_paths = image_paths
        self.labels = labels
        self.target_shape = target_shape

    def __len__(self):
        return len(self.image_paths)

    def preprocess_nifti(self, nifti_path):
        img = nib.load(nifti_path).get_fdata()
        
        # Normalize intensity to [0,1]
        img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
        
        # Resize to target shape
        img_resized = scipy.ndimage.zoom(img, np.array(self.target_shape) / np.array(img.shape), order=1)
        
        return torch.tensor(img_resized, dtype=torch.float32).unsqueeze(0)  # Add channel dim

    def __getitem__(self, idx):
        image = self.preprocess_nifti(self.image_paths[idx])
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        embedding = get_vit_embedding(image.reshape(1,1,128,128,128))
        return embedding, label

In [16]:
class_to_label = {
    "CN": 0,
    "MCI": 1,
    "AD": 2
}
image_paths = []
labels = []

for i in range(len(df)):
    row = df.iloc[i]
    subject = row["Subject"]
    date = row["Acq Date"]
    date = date.replace("/", "-")
    image_path = get_image_file_names_for_subject(subject, date)[0]
    image_paths.append(image_path)
    labels.append(class_to_label[row["Group"]])

In [17]:
len(image_paths)

2294

In [18]:
len(labels)

2294

In [19]:
from sklearn.model_selection import train_test_split
train_paths, test_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.3, random_state=42, stratify=labels)
val_paths, test_paths, val_labels, test_labels = train_test_split(test_paths, test_labels, test_size=0.5, random_state=42, stratify=test_labels)

In [20]:
# Create train & test datasets
train_dataset = NiftiDataset(train_paths, train_labels)
val_dataset = NiftiDataset(val_paths, val_labels)
test_dataset = NiftiDataset(test_paths, test_labels)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=True)

print(f"Train Batches: {len(train_loader)}, Val Batches: {len(val_loader)}, Test Batches: {len(test_loader)}")

Train Batches: 402, Val Batches: 86, Test Batches: 87


In [21]:
import h5py

def save_embeddings_hdf5(dataloader, filename):
    """Save embeddings (from list format) and labels incrementally to an HDF5 file."""
    with h5py.File(filename, "w") as f:
        first_batch = True
        for i, (embedding_list, label) in enumerate(dataloader):
            # Extract the last layer embeddings
            embedding_tensor = embedding_list[1][-1]  # Extract final layer embeddings
            embedding_tensor = embedding_tensor.cpu().detach()  # Move to CPU
            
            embedding_numpy = embedding_tensor.numpy()  # Convert to NumPy
            label_numpy = label.cpu().numpy()

            # Reshape embeddings if needed
            embedding_numpy = embedding_numpy.reshape(embedding_numpy.shape[0], -1)  # (4, 1, 512, 768) → (4, 512 * 768)

            if first_batch:
                # Create expandable datasets with correct shape
                f.create_dataset("embeddings", data=embedding_numpy, 
                                 maxshape=(None, embedding_numpy.shape[1]))  # Now 2D
                f.create_dataset("labels", data=label_numpy, maxshape=(None,))
                first_batch = False
            else:
                # Resize and append new embeddings
                f["embeddings"].resize((f["embeddings"].shape[0] + embedding_numpy.shape[0]), axis=0)
                f["embeddings"][-embedding_numpy.shape[0]:] = embedding_numpy

                f["labels"].resize((f["labels"].shape[0] + label_numpy.shape[0]), axis=0)
                f["labels"][-label_numpy.shape[0]:] = label_numpy
            
            print(f"Saved embeddings for batch: {i+1}")

In [22]:
save_embeddings_hdf5(train_loader, "train_embeddings.h5")
save_embeddings_hdf5(val_loader, "val_embeddings.h5")
save_embeddings_hdf5(test_loader, "test_embeddings.h5")

Saved embeddings for batch: 1
Saved embeddings for batch: 2
Saved embeddings for batch: 3
Saved embeddings for batch: 4
Saved embeddings for batch: 5
Saved embeddings for batch: 6
Saved embeddings for batch: 7
Saved embeddings for batch: 8
Saved embeddings for batch: 9
Saved embeddings for batch: 10
Saved embeddings for batch: 11
Saved embeddings for batch: 12
Saved embeddings for batch: 13
Saved embeddings for batch: 14
Saved embeddings for batch: 15
Saved embeddings for batch: 16
Saved embeddings for batch: 17
Saved embeddings for batch: 18
Saved embeddings for batch: 19
Saved embeddings for batch: 20
Saved embeddings for batch: 21
Saved embeddings for batch: 22
Saved embeddings for batch: 23
Saved embeddings for batch: 24
Saved embeddings for batch: 25
Saved embeddings for batch: 26
Saved embeddings for batch: 27
Saved embeddings for batch: 28
Saved embeddings for batch: 29
Saved embeddings for batch: 30
Saved embeddings for batch: 31
Saved embeddings for batch: 32
Saved embeddings 

In [23]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [24]:
def load_embeddings_hdf5(filename, batch_size=32):
    with h5py.File(filename, "r") as f:
        num_samples = f["embeddings"].shape[0]  # Total samples
        for i in range(0, num_samples, batch_size):
            X_batch = torch.tensor(f["embeddings"][i : i + batch_size], dtype=torch.float32).to(DEVICE)
            y_batch = torch.tensor(f["labels"][i : i + batch_size], dtype=torch.long).to(DEVICE)
            yield X_batch, y_batch

In [25]:
import torch.nn as nn
import torch.optim as optim

class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, 1024)
        self.batch_norm1 = nn.BatchNorm1d(1024)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        
        self.fc2 = nn.Linear(1024, 512)
        self.batch_norm2 = nn.BatchNorm1d(512)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)

        self.fc3 = nn.Linear(512, 256)
        self.batch_norm3 = nn.BatchNorm1d(256)
        self.relu3 = nn.ReLU()
        self.dropout3 = nn.Dropout(0.5)
        
        self.fc4 = nn.Linear(256, num_classes)  # Final layer

    def forward(self, x):
        x = self.fc1(x)
        if x.shape[0] > 1:  # Apply BatchNorm only if batch size > 1
            x = self.batch_norm1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
    
        x = self.fc2(x)
        if x.shape[0] > 1:
            x = self.batch_norm2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
    
        x = self.fc3(x)
        if x.shape[0] > 1:
            x = self.batch_norm3(x)
        x = self.relu3(x)
        x = self.dropout3(x)
        
        x = self.fc4(x)
        return x

In [26]:
input_dim = next(load_embeddings_hdf5("train_embeddings.h5"))[0].shape[1]  # Get feature size
num_classes = 3  # Adjust based on labels
model = MLPClassifier(input_dim, num_classes).to(DEVICE)

In [27]:
model

MLPClassifier(
  (fc1): Linear(in_features=393216, out_features=1024, bias=True)
  (batch_norm1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (batch_norm2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (batch_norm3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (dropout3): Dropout(p=0.5, inplace=False)
  (fc4): Linear(in_features=256, out_features=3, bias=True)
)

In [28]:
# Compute class weights
# classes = np.unique(train_labels)
# class_weights = compute_class_weight(class_weight="balanced", classes=classes, y=train_labels)

# criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(DEVICE)).to(DEVICE)
# optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [29]:
# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs!")
#     model = torch.nn.DataParallel(model)

# model = model.to("cuda") 

In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from tqdm import tqdm  # Progress tracking

# Compute class weights
def compute_class_weights(y_train):
    class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
    return torch.tensor(class_weights, dtype=torch.float).to(DEVICE)

y_train = np.concatenate([y.cpu().numpy() for _, y in load_embeddings_hdf5("train_embeddings.h5", batch_size=32)])
class_weights = compute_class_weights(y_train)

# Define optimizer, scheduler, and scaler
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
scaler = torch.cuda.amp.GradScaler()

NUM_GPUS = 1
num_epochs = 100
patience = 15  # Early stopping patience
best_val_loss = float("inf")
epochs_without_improvement = 0

for epoch in range(num_epochs):
    model.train()
    train_loss, correct, total = 0.0, 0, 0

    # Training loop
    train_loader = load_embeddings_hdf5("train_embeddings.h5", batch_size=16 * NUM_GPUS)  # Adjust batch size
    train_bar = tqdm(train_loader, total=len(y_train) // (16 * NUM_GPUS), desc=f"Epoch {epoch+1}/{num_epochs} Train")
    for X_batch, y_batch in train_bar:
        X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():  # Mixed precision
            outputs = model(X_batch)
            loss = F.cross_entropy(outputs, y_batch, weight=class_weights)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(y_batch).sum().item()
        total += y_batch.size(0)
        
        train_bar.set_postfix(loss=loss.item(), acc=100 * correct / total)

    train_accuracy = 100 * correct / total
    avg_train_loss = train_loss / total

    # Validation loop
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    val_loader = load_embeddings_hdf5("val_embeddings.h5", batch_size=16 * NUM_GPUS)
    val_bar = tqdm(val_loader, total=len(y_train) // (16 * NUM_GPUS), desc="Validation")
    with torch.no_grad():
        for X_val, y_val in val_bar:
            X_val, y_val = X_val.to(DEVICE), y_val.to(DEVICE)
            outputs = model(X_val)
            loss = F.cross_entropy(outputs, y_val, weight=class_weights)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_correct += predicted.eq(y_val).sum().item()
            val_total += y_val.size(0)

            val_bar.set_postfix(loss=loss.item(), acc=100 * val_correct / val_total)
    
    val_accuracy = 100 * val_correct / val_total
    avg_val_loss = val_loss / val_total

    # Check for early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_without_improvement = 0  # Reset counter
        torch.save(model.state_dict(), "best_model.pth")  # Save best model
    else:
        epochs_without_improvement += 1

    # Stop if no improvement for 10 epochs
    if epochs_without_improvement >= patience:
        print(f"Early stopping at epoch {epoch+1}. No improvement in validation loss for {patience} epochs.")
        break

    # Learning rate scheduler step
    scheduler.step(avg_val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")

The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.                         | 0/100 [00:00<?, ?it/s]
Epoch 1/100 Train: 101it [00:11,  8.56it/s, acc=33.1, loss=1.25]                                                                                            
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 49.14it/s, acc=40.4, loss=1.04]


Epoch [1/100]
  Train Loss: 0.0737 | Train Acc: 33.08%
  Val Loss: 0.0697 | Val Acc: 40.41%


Epoch 2/100 Train: 101it [00:11,  8.45it/s, acc=35.8, loss=0.83]                                                                                            
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.44it/s, acc=41.3, loss=1.12]


Epoch [2/100]
  Train Loss: 0.0728 | Train Acc: 35.83%
  Val Loss: 0.0690 | Val Acc: 41.28%


Epoch 3/100 Train: 101it [00:11,  8.62it/s, acc=37.9, loss=1.11]                                                                                            
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.83it/s, acc=44.2, loss=1.08]


Epoch [3/100]
  Train Loss: 0.0717 | Train Acc: 37.88%
  Val Loss: 0.0685 | Val Acc: 44.19%


Epoch 4/100 Train: 101it [00:11,  8.64it/s, acc=40.2, loss=1.24]                                                                                            
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.66it/s, acc=44.5, loss=1.06]


Epoch [4/100]
  Train Loss: 0.0709 | Train Acc: 40.19%
  Val Loss: 0.0681 | Val Acc: 44.48%


Epoch 5/100 Train: 101it [00:11,  8.70it/s, acc=40.3, loss=0.966]                                                                                           
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.21it/s, acc=43.9, loss=1.04]


Epoch [5/100]
  Train Loss: 0.0704 | Train Acc: 40.31%
  Val Loss: 0.0683 | Val Acc: 43.90%


Epoch 6/100 Train: 101it [00:11,  8.45it/s, acc=40.1, loss=1.18]                                                                                            
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.17it/s, acc=42.7, loss=1.05]


Epoch [6/100]
  Train Loss: 0.0698 | Train Acc: 40.06%
  Val Loss: 0.0679 | Val Acc: 42.73%


Epoch 7/100 Train: 101it [00:11,  8.57it/s, acc=43.9, loss=0.911]                                                                                           
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.17it/s, acc=45.1, loss=1.1]


Epoch [7/100]
  Train Loss: 0.0689 | Train Acc: 43.93%
  Val Loss: 0.0677 | Val Acc: 45.06%


Epoch 8/100 Train: 101it [00:11,  8.60it/s, acc=41.3, loss=0.941]                                                                                           
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 51.86it/s, acc=47.4, loss=1.07]


Epoch [8/100]
  Train Loss: 0.0690 | Train Acc: 41.31%
  Val Loss: 0.0675 | Val Acc: 47.38%


Epoch 9/100 Train: 101it [00:11,  8.65it/s, acc=41.9, loss=0.803]                                                                                           
Validation:  22%|███████████████████▏                                                                   | 22/100 [00:00<00:01, 50.75it/s, acc=52, loss=1.04]


Epoch [9/100]
  Train Loss: 0.0689 | Train Acc: 41.87%
  Val Loss: 0.0670 | Val Acc: 52.03%


Epoch 10/100 Train: 101it [00:11,  8.63it/s, acc=40.9, loss=0.982]                                                                                          
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.55it/s, acc=49.1, loss=1.06]


Epoch [10/100]
  Train Loss: 0.0687 | Train Acc: 40.87%
  Val Loss: 0.0672 | Val Acc: 49.13%


Epoch 11/100 Train: 101it [00:11,  8.44it/s, acc=44.8, loss=1.02]                                                                                           
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.89it/s, acc=53.2, loss=1.06]


Epoch [11/100]
  Train Loss: 0.0670 | Train Acc: 44.80%
  Val Loss: 0.0660 | Val Acc: 53.20%


Epoch 12/100 Train: 101it [00:11,  8.65it/s, acc=43.2, loss=1.15]                                                                                           
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 49.70it/s, acc=49.7, loss=1.07]


Epoch [12/100]
  Train Loss: 0.0680 | Train Acc: 43.18%
  Val Loss: 0.0664 | Val Acc: 49.71%


Epoch 13/100 Train: 101it [00:11,  8.64it/s, acc=45.9, loss=0.749]                                                                                          
Validation:  22%|███████████████████▏                                                                   | 22/100 [00:00<00:01, 50.59it/s, acc=48, loss=1.09]


Epoch [13/100]
  Train Loss: 0.0667 | Train Acc: 45.92%
  Val Loss: 0.0664 | Val Acc: 47.97%


Epoch 14/100 Train: 101it [00:11,  8.63it/s, acc=47, loss=0.839]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.72it/s, acc=51.5, loss=0.986]


Epoch [14/100]
  Train Loss: 0.0652 | Train Acc: 47.04%
  Val Loss: 0.0654 | Val Acc: 51.45%


Epoch 15/100 Train: 101it [00:11,  8.65it/s, acc=48.5, loss=0.965]                                                                                          
Validation:  22%|███████████████████▎                                                                    | 22/100 [00:00<00:01, 50.41it/s, acc=49.1, loss=1]


Epoch [15/100]
  Train Loss: 0.0648 | Train Acc: 48.47%
  Val Loss: 0.0653 | Val Acc: 49.13%


Epoch 16/100 Train: 101it [00:11,  8.65it/s, acc=47.9, loss=0.776]                                                                                          
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.38it/s, acc=52.6, loss=1.01]


Epoch [16/100]
  Train Loss: 0.0654 | Train Acc: 47.91%
  Val Loss: 0.0648 | Val Acc: 52.62%


Epoch 17/100 Train: 101it [00:11,  8.62it/s, acc=50, loss=0.821]                                                                                            
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 51.27it/s, acc=53.8, loss=1.08]


Epoch [17/100]
  Train Loss: 0.0640 | Train Acc: 50.03%
  Val Loss: 0.0648 | Val Acc: 53.78%


Epoch 18/100 Train: 101it [00:11,  8.44it/s, acc=50.2, loss=0.85]                                                                                           
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 51.05it/s, acc=53.5, loss=1.05]


Epoch [18/100]
  Train Loss: 0.0625 | Train Acc: 50.22%
  Val Loss: 0.0645 | Val Acc: 53.49%


Epoch 19/100 Train: 101it [00:11,  8.64it/s, acc=51.3, loss=1.01]                                                                                           
Validation:  22%|███████████████████▏                                                                   | 22/100 [00:00<00:01, 50.95it/s, acc=50, loss=1.03]


Epoch [19/100]
  Train Loss: 0.0622 | Train Acc: 51.34%
  Val Loss: 0.0642 | Val Acc: 50.00%


Epoch 20/100 Train: 101it [00:11,  8.65it/s, acc=54.5, loss=0.758]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.64it/s, acc=52.6, loss=0.898]


Epoch [20/100]
  Train Loss: 0.0604 | Train Acc: 54.45%
  Val Loss: 0.0640 | Val Acc: 52.62%


Epoch 21/100 Train: 101it [00:11,  8.47it/s, acc=55.4, loss=0.858]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.32it/s, acc=51.7, loss=0.955]


Epoch [21/100]
  Train Loss: 0.0598 | Train Acc: 55.39%
  Val Loss: 0.0640 | Val Acc: 51.74%


Epoch 22/100 Train: 101it [00:11,  8.63it/s, acc=55.1, loss=0.768]                                                                                          
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.89it/s, acc=56.4, loss=1.03]


Epoch [22/100]
  Train Loss: 0.0596 | Train Acc: 55.08%
  Val Loss: 0.0634 | Val Acc: 56.40%


Epoch 23/100 Train: 101it [00:11,  8.63it/s, acc=57.4, loss=0.899]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.42it/s, acc=51.7, loss=0.915]


Epoch [23/100]
  Train Loss: 0.0582 | Train Acc: 57.38%
  Val Loss: 0.0631 | Val Acc: 51.74%


Epoch 24/100 Train: 101it [00:11,  8.65it/s, acc=54.4, loss=0.971]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.95it/s, acc=55.2, loss=0.915]


Epoch [24/100]
  Train Loss: 0.0587 | Train Acc: 54.39%
  Val Loss: 0.0628 | Val Acc: 55.23%


Epoch 25/100 Train: 101it [00:11,  8.47it/s, acc=57.4, loss=0.889]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 49.52it/s, acc=57.6, loss=0.905]


Epoch [25/100]
  Train Loss: 0.0575 | Train Acc: 57.45%
  Val Loss: 0.0620 | Val Acc: 57.56%


Epoch 26/100 Train: 101it [00:11,  8.62it/s, acc=58.1, loss=0.665]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.04it/s, acc=55.8, loss=0.962]


Epoch [26/100]
  Train Loss: 0.0554 | Train Acc: 58.13%
  Val Loss: 0.0622 | Val Acc: 55.81%


Epoch 27/100 Train: 101it [00:11,  8.64it/s, acc=62.2, loss=0.601]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.86it/s, acc=55.5, loss=0.899]


Epoch [27/100]
  Train Loss: 0.0530 | Train Acc: 62.18%
  Val Loss: 0.0612 | Val Acc: 55.52%


Epoch 28/100 Train: 101it [00:11,  8.64it/s, acc=60.1, loss=0.522]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 49.70it/s, acc=54.9, loss=0.938]


Epoch [28/100]
  Train Loss: 0.0529 | Train Acc: 60.06%
  Val Loss: 0.0610 | Val Acc: 54.94%


Epoch 29/100 Train: 101it [00:11,  8.61it/s, acc=63.6, loss=0.599]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.06it/s, acc=55.8, loss=0.886]


Epoch [29/100]
  Train Loss: 0.0513 | Train Acc: 63.55%
  Val Loss: 0.0597 | Val Acc: 55.81%


Epoch 30/100 Train: 101it [00:11,  8.65it/s, acc=64.4, loss=0.794]                                                                                          
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.74it/s, acc=57.8, loss=0.87]


Epoch [30/100]
  Train Loss: 0.0503 | Train Acc: 64.42%
  Val Loss: 0.0599 | Val Acc: 57.85%


Epoch 31/100 Train: 101it [00:11,  8.63it/s, acc=65.4, loss=0.521]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.34it/s, acc=58.7, loss=0.895]


Epoch [31/100]
  Train Loss: 0.0495 | Train Acc: 65.42%
  Val Loss: 0.0585 | Val Acc: 58.72%


Epoch 32/100 Train: 101it [00:11,  8.63it/s, acc=67.4, loss=0.654]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.91it/s, acc=60.8, loss=0.774]


Epoch [32/100]
  Train Loss: 0.0482 | Train Acc: 67.41%
  Val Loss: 0.0586 | Val Acc: 60.76%


Epoch 33/100 Train: 101it [00:12,  8.42it/s, acc=66, loss=0.541]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 49.90it/s, acc=57.3, loss=0.738]


Epoch [33/100]
  Train Loss: 0.0475 | Train Acc: 66.04%
  Val Loss: 0.0593 | Val Acc: 57.27%


Epoch 34/100 Train: 101it [00:11,  8.63it/s, acc=69, loss=0.407]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.67it/s, acc=61.9, loss=0.735]


Epoch [34/100]
  Train Loss: 0.0450 | Train Acc: 69.03%
  Val Loss: 0.0572 | Val Acc: 61.92%


Epoch 35/100 Train: 101it [00:11,  8.64it/s, acc=70.8, loss=0.763]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 52.43it/s, acc=61, loss=0.728]


Epoch [35/100]
  Train Loss: 0.0450 | Train Acc: 70.84%
  Val Loss: 0.0571 | Val Acc: 61.05%


Epoch 36/100 Train: 101it [00:11,  8.63it/s, acc=71.8, loss=0.551]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.70it/s, acc=61, loss=0.779]


Epoch [36/100]
  Train Loss: 0.0432 | Train Acc: 71.78%
  Val Loss: 0.0568 | Val Acc: 61.05%


Epoch 37/100 Train: 101it [00:11,  8.62it/s, acc=74.5, loss=0.803]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.31it/s, acc=61, loss=0.701]


Epoch [37/100]
  Train Loss: 0.0409 | Train Acc: 74.52%
  Val Loss: 0.0578 | Val Acc: 61.05%


Epoch 38/100 Train: 101it [00:11,  8.64it/s, acc=74.5, loss=0.557]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.63it/s, acc=60.2, loss=0.694]


Epoch [38/100]
  Train Loss: 0.0404 | Train Acc: 74.52%
  Val Loss: 0.0570 | Val Acc: 60.17%


Epoch 39/100 Train: 101it [00:11,  8.63it/s, acc=75.9, loss=0.384]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.38it/s, acc=59.9, loss=0.703]


Epoch [39/100]
  Train Loss: 0.0380 | Train Acc: 75.89%
  Val Loss: 0.0578 | Val Acc: 59.88%


Epoch 40/100 Train: 101it [00:11,  8.47it/s, acc=77, loss=0.431]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.12it/s, acc=57.6, loss=0.696]


Epoch [40/100]
  Train Loss: 0.0379 | Train Acc: 77.01%
  Val Loss: 0.0577 | Val Acc: 57.56%


Epoch 41/100 Train: 101it [00:11,  8.63it/s, acc=78.3, loss=0.482]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.67it/s, acc=59.9, loss=0.697]


Epoch [41/100]
  Train Loss: 0.0356 | Train Acc: 78.26%
  Val Loss: 0.0573 | Val Acc: 59.88%


Epoch 42/100 Train: 101it [00:11,  8.65it/s, acc=77.8, loss=0.393]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.05it/s, acc=60.2, loss=0.712]


Epoch [42/100]
  Train Loss: 0.0353 | Train Acc: 77.76%
  Val Loss: 0.0566 | Val Acc: 60.17%


Epoch 43/100 Train: 101it [00:11,  8.64it/s, acc=79.8, loss=0.471]                                                                                          
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 50.48it/s, acc=58.1, loss=0.72]


Epoch [43/100]
  Train Loss: 0.0335 | Train Acc: 79.81%
  Val Loss: 0.0563 | Val Acc: 58.14%


Epoch 44/100 Train: 101it [00:11,  8.44it/s, acc=81.2, loss=0.333]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.59it/s, acc=59.6, loss=0.724]


Epoch [44/100]
  Train Loss: 0.0321 | Train Acc: 81.18%
  Val Loss: 0.0571 | Val Acc: 59.59%


Epoch 45/100 Train: 101it [00:11,  8.63it/s, acc=80.5, loss=0.414]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.65it/s, acc=59.6, loss=0.613]


Epoch [45/100]
  Train Loss: 0.0313 | Train Acc: 80.50%
  Val Loss: 0.0549 | Val Acc: 59.59%


Epoch 46/100 Train: 101it [00:11,  8.71it/s, acc=83.6, loss=0.362]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 49.43it/s, acc=60.2, loss=0.581]


Epoch [46/100]
  Train Loss: 0.0300 | Train Acc: 83.55%
  Val Loss: 0.0580 | Val Acc: 60.17%


Epoch 47/100 Train: 101it [00:11,  8.61it/s, acc=83.9, loss=0.298]                                                                                          
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 51.33it/s, acc=61.9, loss=0.66]


Epoch [47/100]
  Train Loss: 0.0287 | Train Acc: 83.93%
  Val Loss: 0.0563 | Val Acc: 61.92%


Epoch 48/100 Train: 101it [00:11,  8.46it/s, acc=84.7, loss=0.411]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.66it/s, acc=58.4, loss=0.662]


Epoch [48/100]
  Train Loss: 0.0277 | Train Acc: 84.67%
  Val Loss: 0.0548 | Val Acc: 58.43%


Epoch 49/100 Train: 101it [00:11,  8.64it/s, acc=85.9, loss=0.237]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 49.73it/s, acc=57.6, loss=0.625]


Epoch [49/100]
  Train Loss: 0.0265 | Train Acc: 85.92%
  Val Loss: 0.0569 | Val Acc: 57.56%


Epoch 50/100 Train: 101it [00:11,  8.63it/s, acc=87.8, loss=0.258]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.99it/s, acc=61.6, loss=0.527]


Epoch [50/100]
  Train Loss: 0.0253 | Train Acc: 87.79%
  Val Loss: 0.0569 | Val Acc: 61.63%


Epoch 51/100 Train: 101it [00:11,  8.63it/s, acc=85.3, loss=0.268]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.25it/s, acc=58.1, loss=0.537]


Epoch [51/100]
  Train Loss: 0.0256 | Train Acc: 85.30%
  Val Loss: 0.0571 | Val Acc: 58.14%


Epoch 52/100 Train: 101it [00:11,  8.46it/s, acc=86.7, loss=0.302]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.53it/s, acc=61, loss=0.608]


Epoch [52/100]
  Train Loss: 0.0245 | Train Acc: 86.67%
  Val Loss: 0.0544 | Val Acc: 61.05%


Epoch 53/100 Train: 101it [00:11,  8.63it/s, acc=88.2, loss=0.484]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 51.87it/s, acc=64, loss=0.551]


Epoch [53/100]
  Train Loss: 0.0238 | Train Acc: 88.22%
  Val Loss: 0.0590 | Val Acc: 63.95%


Epoch 54/100 Train: 101it [00:11,  8.64it/s, acc=90.9, loss=0.209]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.21it/s, acc=66, loss=0.599]


Epoch [54/100]
  Train Loss: 0.0214 | Train Acc: 90.90%
  Val Loss: 0.0584 | Val Acc: 65.99%


Epoch 55/100 Train: 101it [00:11,  8.66it/s, acc=90.3, loss=0.298]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.15it/s, acc=61.9, loss=0.557]


Epoch [55/100]
  Train Loss: 0.0206 | Train Acc: 90.28%
  Val Loss: 0.0571 | Val Acc: 61.92%


Epoch 56/100 Train: 101it [00:11,  8.44it/s, acc=90.6, loss=0.251]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.23it/s, acc=63.4, loss=0.548]


Epoch [56/100]
  Train Loss: 0.0197 | Train Acc: 90.59%
  Val Loss: 0.0566 | Val Acc: 63.37%


Epoch 57/100 Train: 101it [00:11,  8.58it/s, acc=92, loss=0.263]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.62it/s, acc=63.1, loss=0.412]


Epoch [57/100]
  Train Loss: 0.0187 | Train Acc: 92.02%
  Val Loss: 0.0576 | Val Acc: 63.08%


Epoch 58/100 Train: 101it [00:11,  8.62it/s, acc=92, loss=0.307]                                                                                            
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.07it/s, acc=62.5, loss=0.422]


Epoch [58/100]
  Train Loss: 0.0184 | Train Acc: 91.96%
  Val Loss: 0.0577 | Val Acc: 62.50%


Epoch 59/100 Train: 101it [00:11,  8.61it/s, acc=93.1, loss=0.127]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.49it/s, acc=64.2, loss=0.633]


Epoch [59/100]
  Train Loss: 0.0167 | Train Acc: 93.08%
  Val Loss: 0.0539 | Val Acc: 64.24%


Epoch 60/100 Train: 101it [00:11,  8.62it/s, acc=94.8, loss=0.253]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.39it/s, acc=64.5, loss=0.538]


Epoch [60/100]
  Train Loss: 0.0148 | Train Acc: 94.83%
  Val Loss: 0.0547 | Val Acc: 64.53%


Epoch 61/100 Train: 101it [00:11,  8.62it/s, acc=94.9, loss=0.235]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.40it/s, acc=67.7, loss=0.422]


Epoch [61/100]
  Train Loss: 0.0136 | Train Acc: 94.89%
  Val Loss: 0.0536 | Val Acc: 67.73%


Epoch 62/100 Train: 101it [00:11,  8.61it/s, acc=96.3, loss=0.169]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.63it/s, acc=65.7, loss=0.566]


Epoch [62/100]
  Train Loss: 0.0132 | Train Acc: 96.26%
  Val Loss: 0.0539 | Val Acc: 65.70%


Epoch 63/100 Train: 101it [00:11,  8.64it/s, acc=96.6, loss=0.115]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.03it/s, acc=64.5, loss=0.496]


Epoch [63/100]
  Train Loss: 0.0121 | Train Acc: 96.64%
  Val Loss: 0.0545 | Val Acc: 64.53%


Epoch 64/100 Train: 101it [00:11,  8.61it/s, acc=96, loss=0.14]                                                                                             
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 49.77it/s, acc=65.1, loss=0.516]


Epoch [64/100]
  Train Loss: 0.0119 | Train Acc: 96.01%
  Val Loss: 0.0544 | Val Acc: 65.12%


Epoch 65/100 Train: 101it [00:11,  8.62it/s, acc=96.6, loss=0.138]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.43it/s, acc=64.2, loss=0.513]


Epoch [65/100]
  Train Loss: 0.0116 | Train Acc: 96.57%
  Val Loss: 0.0552 | Val Acc: 64.24%


Epoch 66/100 Train: 101it [00:11,  8.69it/s, acc=96.6, loss=0.228]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.19it/s, acc=63.7, loss=0.617]


Epoch [66/100]
  Train Loss: 0.0118 | Train Acc: 96.57%
  Val Loss: 0.0542 | Val Acc: 63.66%


Epoch 67/100 Train: 101it [00:11,  8.63it/s, acc=96.9, loss=0.273]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 49.45it/s, acc=65.4, loss=0.473]


Epoch [67/100]
  Train Loss: 0.0106 | Train Acc: 96.88%
  Val Loss: 0.0548 | Val Acc: 65.41%


Epoch 68/100 Train: 101it [00:11,  8.61it/s, acc=97.9, loss=0.248]                                                                                          
Validation:  22%|███████████████████▏                                                                   | 22/100 [00:00<00:01, 51.44it/s, acc=66, loss=0.52]


Epoch [68/100]
  Train Loss: 0.0097 | Train Acc: 97.94%
  Val Loss: 0.0542 | Val Acc: 65.99%


Epoch 69/100 Train: 101it [00:11,  8.63it/s, acc=97.9, loss=0.114]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.85it/s, acc=65.4, loss=0.531]


Epoch [69/100]
  Train Loss: 0.0096 | Train Acc: 97.94%
  Val Loss: 0.0534 | Val Acc: 65.41%


Epoch 70/100 Train: 101it [00:11,  8.65it/s, acc=98.1, loss=0.178]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.61it/s, acc=64.8, loss=0.582]


Epoch [70/100]
  Train Loss: 0.0090 | Train Acc: 98.13%
  Val Loss: 0.0532 | Val Acc: 64.83%


Epoch 71/100 Train: 101it [00:11,  8.63it/s, acc=97.9, loss=0.142]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.99it/s, acc=64.8, loss=0.555]


Epoch [71/100]
  Train Loss: 0.0088 | Train Acc: 97.94%
  Val Loss: 0.0542 | Val Acc: 64.83%


Epoch 72/100 Train: 101it [00:11,  8.64it/s, acc=97.7, loss=0.157]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 51.45it/s, acc=66, loss=0.483]


Epoch [72/100]
  Train Loss: 0.0091 | Train Acc: 97.69%
  Val Loss: 0.0542 | Val Acc: 65.99%


Epoch 73/100 Train: 101it [00:11,  8.63it/s, acc=98.6, loss=0.151]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.81it/s, acc=65.7, loss=0.471]


Epoch [73/100]
  Train Loss: 0.0083 | Train Acc: 98.57%
  Val Loss: 0.0544 | Val Acc: 65.70%


Epoch 74/100 Train: 101it [00:11,  8.63it/s, acc=98.2, loss=0.169]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 52.61it/s, acc=66, loss=0.461]


Epoch [74/100]
  Train Loss: 0.0085 | Train Acc: 98.19%
  Val Loss: 0.0556 | Val Acc: 65.99%


Epoch 75/100 Train: 101it [00:11,  8.65it/s, acc=98.9, loss=0.104]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.48it/s, acc=65.4, loss=0.516]


Epoch [75/100]
  Train Loss: 0.0078 | Train Acc: 98.88%
  Val Loss: 0.0531 | Val Acc: 65.41%


Epoch 76/100 Train: 101it [00:11,  8.61it/s, acc=98.9, loss=0.181]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.34it/s, acc=64.5, loss=0.529]


Epoch [76/100]
  Train Loss: 0.0079 | Train Acc: 98.88%
  Val Loss: 0.0536 | Val Acc: 64.53%


Epoch 77/100 Train: 101it [00:11,  8.61it/s, acc=99.1, loss=0.206]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.17it/s, acc=64.8, loss=0.496]


Epoch [77/100]
  Train Loss: 0.0078 | Train Acc: 99.07%
  Val Loss: 0.0542 | Val Acc: 64.83%


Epoch 78/100 Train: 101it [00:11,  8.62it/s, acc=98.9, loss=0.0673]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.31it/s, acc=67.2, loss=0.439]


Epoch [78/100]
  Train Loss: 0.0075 | Train Acc: 98.94%
  Val Loss: 0.0544 | Val Acc: 67.15%


Epoch 79/100 Train: 101it [00:11,  8.63it/s, acc=99.2, loss=0.145]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.19it/s, acc=65.1, loss=0.449]


Epoch [79/100]
  Train Loss: 0.0072 | Train Acc: 99.19%
  Val Loss: 0.0541 | Val Acc: 65.12%


Epoch 80/100 Train: 101it [00:11,  8.63it/s, acc=99.2, loss=0.16]                                                                                           
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 52.52it/s, acc=64.5, loss=0.475]


Epoch [80/100]
  Train Loss: 0.0073 | Train Acc: 99.19%
  Val Loss: 0.0548 | Val Acc: 64.53%


Epoch 81/100 Train: 101it [00:11,  8.62it/s, acc=99.1, loss=0.175]                                                                                          
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.88it/s, acc=65.1, loss=0.6]


Epoch [81/100]
  Train Loss: 0.0071 | Train Acc: 99.07%
  Val Loss: 0.0545 | Val Acc: 65.12%


Epoch 82/100 Train: 101it [00:11,  8.63it/s, acc=99.1, loss=0.348]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 50.86it/s, acc=66.3, loss=0.481]


Epoch [82/100]
  Train Loss: 0.0070 | Train Acc: 99.13%
  Val Loss: 0.0549 | Val Acc: 66.28%


Epoch 83/100 Train: 101it [00:11,  8.60it/s, acc=99.4, loss=0.223]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.70it/s, acc=65.7, loss=0.507]


Epoch [83/100]
  Train Loss: 0.0066 | Train Acc: 99.38%
  Val Loss: 0.0539 | Val Acc: 65.70%


Epoch 84/100 Train: 101it [00:11,  8.64it/s, acc=99.3, loss=0.195]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.70it/s, acc=65.4, loss=0.503]


Epoch [84/100]
  Train Loss: 0.0069 | Train Acc: 99.25%
  Val Loss: 0.0540 | Val Acc: 65.41%


Epoch 85/100 Train: 101it [00:11,  8.66it/s, acc=99.4, loss=0.0901]                                                                                         
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 50.51it/s, acc=66, loss=0.526]


Epoch [85/100]
  Train Loss: 0.0065 | Train Acc: 99.38%
  Val Loss: 0.0542 | Val Acc: 65.99%


Epoch 86/100 Train: 101it [00:11,  8.62it/s, acc=99.3, loss=0.0878]                                                                                         
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 51.34it/s, acc=65.7, loss=0.541]


Epoch [86/100]
  Train Loss: 0.0064 | Train Acc: 99.31%
  Val Loss: 0.0533 | Val Acc: 65.70%


Epoch 87/100 Train: 101it [00:11,  8.63it/s, acc=99.6, loss=0.0767]                                                                                         
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 51.51it/s, acc=66, loss=0.479]


Epoch [87/100]
  Train Loss: 0.0062 | Train Acc: 99.63%
  Val Loss: 0.0537 | Val Acc: 65.99%


Epoch 88/100 Train: 101it [00:11,  8.64it/s, acc=99.4, loss=0.172]                                                                                          
Validation:  22%|██████████████████▍                                                                 | 22/100 [00:00<00:01, 49.70it/s, acc=66.6, loss=0.489]


Epoch [88/100]
  Train Loss: 0.0064 | Train Acc: 99.38%
  Val Loss: 0.0543 | Val Acc: 66.57%


Epoch 89/100 Train: 101it [00:11,  8.62it/s, acc=99.3, loss=0.14]                                                                                           
Validation:  22%|██████████████████▉                                                                   | 22/100 [00:00<00:01, 52.31it/s, acc=66, loss=0.513]


Epoch [89/100]
  Train Loss: 0.0059 | Train Acc: 99.31%
  Val Loss: 0.0544 | Val Acc: 65.99%


Epoch 90/100 Train: 101it [00:11,  8.65it/s, acc=99.6, loss=0.082]                                                                                          
Validation:  22%|██████████████████▋                                                                  | 22/100 [00:00<00:01, 51.53it/s, acc=66.3, loss=0.49]

Early stopping at epoch 90. No improvement in validation loss for 15 epochs.





In [31]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Inference
y_true, y_pred = [], []

with torch.no_grad():
    for X_batch, y_batch in load_embeddings_hdf5("test_embeddings.h5", batch_size=4):
        outputs = model(X_batch)
        predicted_labels = torch.argmax(outputs, dim=1)

        y_true.extend(y_batch.cpu().numpy())
        y_pred.extend(predicted_labels.cpu().numpy())

# Compute metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average="weighted")  # "weighted" accounts for class imbalance
recall = recall_score(y_true, y_pred, average="weighted")
f1 = f1_score(y_true, y_pred, average="weighted")

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

Test Accuracy: 0.6029
Precision: 0.5990
Recall: 0.6029
F1 Score: 0.5983


In [32]:
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.60      0.50      0.55       106
           1       0.65      0.73      0.69       167
           2       0.48      0.46      0.47        72

    accuracy                           0.60       345
   macro avg       0.58      0.56      0.57       345
weighted avg       0.60      0.60      0.60       345

