<a href="https://colab.research.google.com/github/uoacapstonegroup6/CapstoneUOATeam6/blob/main/Capstone_MedVit_SynapseMNIST3d_FinalCopy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install the MedMNIST library, a large-scale benchmark dataset for medical image classification tasks

!pip install medmnist



In [None]:
%pwd

'/teamspace/studios/this_studio'

# Import Library

In [None]:
!pip install torchsummary



In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torchsummary import summary

from tqdm import tqdm
import numpy as np

import torch.utils.data as data
import torchvision.transforms as transforms
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score

import medmnist
from medmnist import INFO, Evaluator

from PIL import Image

# Download Data

In [None]:
# data_flag = 'vesselmnist3d'
data_flag='synapsemnist3d'
# data_flag = 'fracturemnist3d'

In [None]:

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])
train_dataset = DataClass(split='train', download=True,size=64)
val_dataset = DataClass(split='val', download=True,size=64)
test_dataset = DataClass(split='test', download=True,size=64)



Using downloaded and verified file: /teamspace/studios/this_studio/.medmnist/synapsemnist3d_64.npz
Using downloaded and verified file: /teamspace/studios/this_studio/.medmnist/synapsemnist3d_64.npz
Using downloaded and verified file: /teamspace/studios/this_studio/.medmnist/synapsemnist3d_64.npz


In [None]:
# Create directories for saving 2D frames
train_dir = f'./{data_flag}/train'
valid_dir = f'./{data_flag}/valid'
test_dir = f'./{data_flag}/test'

In [None]:
# Save data to
for dir in [train_dir, valid_dir, test_dir]:
    if not os.path.exists(dir):
        os.makedirs(dir)

# Extract 2D slices from 3D data and save to local drive
def extract_and_save_2d_slices(dataset, dir):
    for idx, img in enumerate(dataset.imgs):
        label = dataset.labels[idx]
        for i in range(img.shape[0]):  # Extract each of the 64 frames
            slice_2d = img[i, :, :]
            slice_2d_img = Image.fromarray(slice_2d)
            filename = f"{idx}_{i}.png"
            filepath = os.path.join(dir, str(label), filename)
            if not os.path.exists(os.path.dirname(filepath)):
                os.makedirs(os.path.dirname(filepath))
            slice_2d_img.save(filepath)

extract_and_save_2d_slices(train_dataset, train_dir)
extract_and_save_2d_slices(val_dataset, valid_dir)
extract_and_save_2d_slices(test_dataset, test_dir)

# Data Preprocessing

In [None]:
x, y = test_dataset[0]

print(x.shape, y.shape)

(1, 64, 64, 64) (1,)


In [None]:
print(x[0][0])

[[0.67843137 0.5372549  0.58431373 ... 0.89803922 0.61960784 0.68627451]
 [0.50196078 0.39215686 0.43921569 ... 0.87843137 0.72941176 0.75686275]
 [0.34901961 0.23921569 0.49803922 ... 0.75294118 0.83529412 0.65098039]
 ...
 [0.47843137 0.30196078 0.50196078 ... 0.28627451 0.09803922 0.12941176]
 [0.23921569 0.32156863 0.24313725 ... 0.50588235 0.26666667 0.05882353]
 [0.4        0.21568627 0.17647059 ... 0.74117647 0.40392157 0.10196078]]


In [None]:
os.listdir(test_dir)

['[0]', '[1]']

In [None]:
label_dir = os.path.join(test_dir, '[0]')

In [None]:
print(label_dir)

./synapsemnist3d/test/[0]


In [None]:
os.listdir(label_dir)

['257_0.png',
 '257_1.png',
 '257_10.png',
 '257_11.png',
 '257_12.png',
 '257_13.png',
 '257_14.png',
 '257_15.png',
 '257_16.png',
 '257_17.png',
 '257_18.png',
 '257_19.png',
 '257_2.png',
 '257_20.png',
 '257_21.png',
 '257_22.png',
 '257_23.png',
 '257_24.png',
 '257_25.png',
 '257_26.png',
 '257_27.png',
 '257_28.png',
 '257_29.png',
 '257_3.png',
 '257_30.png',
 '257_31.png',
 '257_32.png',
 '257_33.png',
 '257_34.png',
 '257_35.png',
 '257_36.png',
 '257_37.png',
 '257_38.png',
 '257_39.png',
 '257_4.png',
 '257_40.png',
 '257_41.png',
 '257_42.png',
 '257_43.png',
 '257_44.png',
 '257_45.png',
 '257_46.png',
 '257_47.png',
 '257_48.png',
 '257_49.png',
 '257_5.png',
 '257_50.png',
 '257_51.png',
 '257_52.png',
 '257_53.png',
 '257_54.png',
 '257_55.png',
 '257_56.png',
 '257_57.png',
 '257_58.png',
 '257_59.png',
 '257_6.png',
 '257_60.png',
 '257_61.png',
 '257_62.png',
 '257_63.png',
 '257_7.png',
 '257_8.png',
 '257_9.png',
 '258_0.png',
 '258_1.png',
 '258_10.png',
 '258_1

In [None]:
# Create a custom dataset class to handle 2D slices
class SliceDataset(data.Dataset):
    def __init__(self, dir, transform=None):
        self.dir = dir
        self.transform = transform
        self.labels = []
        self.filenames = []
        self.image_ids = []

        for label in os.listdir(dir):
            label_value = int(label.strip('[]'))
            label_dir = os.path.join(dir, label)
            for filename in sorted(os.listdir(label_dir)):  # Sort to ensure slices are in order
                image_id = filename.split('_')[0]  # Extract 3D image ID (e.g., '12' from '12_0.png')
                self.labels.append(label_value)
                self.filenames.append(os.path.join(label_dir, filename))
                self.image_ids.append(image_id)  # Store the image ID for tracking

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        img = Image.open(filename)
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx], self.image_ids[idx]  # Return image ID for tracking

# # Create data loaders for the 2D slices
# data_transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[.5], std=[.5])
# ])
# preprocessing
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Lambda(lambda image: image.convert('RGB')),
    #torchvision.transforms.AugMix(),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[.5], std=[.5])
])
test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Lambda(lambda image: image.convert('RGB')),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[.5], std=[.5])
])

train_slice_dataset = SliceDataset(train_dir, transform=train_transform)
valid_slice_dataset = SliceDataset(valid_dir, transform=test_transform)
test_slice_dataset = SliceDataset(test_dir, transform=test_transform)

In [None]:
# Custom DataLoader to ensure batches of 64 slices from the same 3D image
def collate_fn(batch):
    batch_images = []
    batch_labels = []
    current_image_id = batch[0][2]  # Get the 3D image ID of the first element in the batch

    # Print the 3D image ID being processed
    print(f"Processing 3D image ID: {current_image_id}")

    for img, label, image_id in batch:
        if image_id != current_image_id:
            raise ValueError(f"Mixed slices in batch. Expected image ID {current_image_id}, but got {image_id}")
        batch_images.append(img)
        batch_labels.append(label)

    return torch.stack(batch_images), batch_labels  # Stack the images as tensors

In [None]:
lr = 0.001
BATCH_SIZE = 64
NUM_EPOCHS = 10
train_loader = data.DataLoader(dataset=train_slice_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = data.DataLoader(dataset=valid_slice_dataset, batch_size=BATCH_SIZE, shuffle=True)
# test_loader = data.DataLoader(dataset=test_slice_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_slice_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# MedVit Training

In [None]:
# Install MedVit
!git clone https://github.com/Omid-Nejati/MedViT.git

In [None]:
# Change directory to MedVit
%cd ./MedViT

/home/exouser/Documents/RRR_MedVit/MedViT


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [None]:
# Check Folder Content
%ls

[0m[01;32mColab_MedViT.ipynb[0m*  [01;32mLICENSE[0m*      [01;34mfracturemnist3d[0m/   [01;32mutils.py[0m*
[01;34mCustomDataset[0m/       [01;32mMedViT.py[0m*    [01;34mimages[0m/
[01;32mCustomDataset.md[0m*    [01;32mREADME.md[0m*    [01;32mrequirements.txt[0m*
[01;32mInstructions.ipynb[0m*  [01;34m__pycache__[0m/  [01;34msynapsemnist3d[0m/


In [None]:
# Install Dependencies
pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [None]:
# Detect GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
# Load MedVit Base model
from MedViT import MedViT_small as tiny, MedViT_base as base
model = base()

initialize_weights...


In [None]:
# Move the model to the GPU
model = model.to(device)

  return torch._C._cuda_getDeviceCount() > 0


In [None]:
model.proj_head[0] = torch.nn.Linear(in_features=1230, out_features=2, bias=True)

In [None]:
print(task)

binary-class


In [None]:
# define loss function and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
# train
for epoch in range(NUM_EPOCHS):
    train_correct = 0
    train_total = 0
    test_correct = 0
    test_total = 0
    print('Epoch [%d/%d]'% (epoch+1, NUM_EPOCHS))
    model.train()
    for inputs, targets in tqdm(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(inputs)

        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

Epoch [1/10]


100%|███████████████████████████████████████| 1230/1230 [07:41<00:00,  2.66it/s]


Epoch [2/10]


100%|███████████████████████████████████████| 1230/1230 [07:29<00:00,  2.73it/s]


Epoch [3/10]


100%|███████████████████████████████████████| 1230/1230 [07:20<00:00,  2.79it/s]


Epoch [4/10]


100%|███████████████████████████████████████| 1230/1230 [07:20<00:00,  2.79it/s]


Epoch [5/10]


100%|███████████████████████████████████████| 1230/1230 [07:20<00:00,  2.79it/s]


Epoch [6/10]


100%|███████████████████████████████████████| 1230/1230 [07:20<00:00,  2.79it/s]


Epoch [7/10]


100%|███████████████████████████████████████| 1230/1230 [07:20<00:00,  2.79it/s]


Epoch [8/10]


100%|███████████████████████████████████████| 1230/1230 [07:20<00:00,  2.79it/s]


Epoch [9/10]


100%|███████████████████████████████████████| 1230/1230 [07:20<00:00,  2.79it/s]


Epoch [10/10]


100%|███████████████████████████████████████| 1230/1230 [07:20<00:00,  2.79it/s]


In [None]:
torch.save(model.state_dict(), 'medvit_synapsemninst3d_base_weights.pth')

In [None]:
# Save the entire model
torch.save(model, 'medvit_synapsemninst3d_base_model.pth')

# MedVit Test accuracy

In [None]:
%cd ./MedViT

/teamspace/studios/this_studio/MedViT


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [None]:
%pwd

'/teamspace/studios/this_studio/MedViT'

In [None]:
pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [None]:
%ls

[0m[01;32mColab_MedViT.ipynb[0m*  [01;32mLICENSE[0m*      [01;34mfracturemnist3d[0m/   [01;32mutils.py[0m*
[01;34mCustomDataset[0m/       [01;32mMedViT.py[0m*    [01;34mimages[0m/
[01;32mCustomDataset.md[0m*    [01;32mREADME.md[0m*    [01;32mrequirements.txt[0m*
[01;32mInstructions.ipynb[0m*  [01;34m__pycache__[0m/  [01;34msynapsemnist3d[0m/


In [None]:
# import torch

# Load the saved model
#For GPU:
model = torch.load('/teamspace/studios/this_studio/SB_synapse_medMedViT/medvit_synapsemninst3d_base_model.pth')
#for CPU
# model = torch.load('/teamspace/studios/this_studio/SB_synapse_medMedViT/medvit_synapsemninst3d_base_model.pth', map_location=torch.device('cpu'))


In [None]:

# Validation loop
model.eval()  # Set model to evaluation mode
if task == "multi-label, binary-class":
    all_labels = []
    all_preds = []
    total_correct = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # Convert 1-channel to 3-channel
            # inputs = inputs.repeat(1, 3, 1, 1)

            outputs = model(inputs)
            outputs = torch.sigmoid(outputs)
            print("outputs", outputs) # Apply sigmoid for multi-label classification
            predicted = (outputs > 0.5).int()  # Thresholding for binary classification

            total_correct += (predicted == labels).sum().item()
            # Collect all labels and predictions for metric calculation
            all_labels.extend(labels.cpu().numpy().flatten())
            all_preds.extend(predicted.cpu().numpy().flatten())

    accuracy = total_correct / (len(test_loader.dataset) * n_classes) * 100
    print(f'Validation Accuracy: {accuracy:.2f}%')
    # Calculate Precision, Recall, and F1 Score using sklearn
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')

    print(f'Precision: {precision:.2f}')
    print(f'Recall: {recall:.2f}')
    print(f'F1 Score: {f1:.2f}')

else:  # Multi-class classification
    total_correct = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

             # Convert 1-channel to 3-channel
            # inputs = inputs.repeat(1, 3, 1, 1)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)  # Get the predicted class

            total_correct += (predicted == labels).sum().item()
            # Collect all labels and predictions for metric calculation
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    accuracy = total_correct / len(test_loader.dataset) * 100
    print(f'Test Accuracy: {accuracy:.2f}%')
    # Calculate Precision, Recall, and F1 Score using sklearn
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')

    print(f'Precision: {precision:.2f}')
    print(f'Recall: {recall:.2f}')
    print(f'F1 Score: {f1:.2f}')

Test Accuracy: 80.08%
Precision: 0.76
Recall: 0.69
F1 Score: 0.71


# Voting by Training XGB for Synapse3D

In [None]:
import os
import torch
import numpy as np
import joblib
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import xgboost as xgb

# Paths to save the features and labels
SAVE_DIR = './XGBfeatures'
os.makedirs(SAVE_DIR, exist_ok=True)  # Create directory if it doesn't exist

TRAIN_FEATURES_PATH = os.path.join(SAVE_DIR, 'train_features.npy')
TRAIN_LABELS_PATH = os.path.join(SAVE_DIR, 'train_labels.npy')
TEST_FEATURES_PATH = os.path.join(SAVE_DIR, 'test_features.npy')
TEST_LABELS_PATH = os.path.join(SAVE_DIR, 'test_labels.npy')

# Function to extract and save features and labels
def extract_and_save_features_and_labels(model, loader, device, features_path, labels_path):
    model.eval()  # Set model to evaluation mode
    all_features = []
    all_labels = []

    for data in loader:
        if len(data) == 2:
            batch_images, batch_labels = data  # Expected structure
        elif len(data) > 2:
            batch_images, batch_labels, _ = data  # Adjust if additional info is present

        batch_images = batch_images.to(device)  # Move batch to GPU

        with torch.no_grad():
            outputs = model(batch_images)  # Extract features
            all_features.append(outputs.cpu().numpy())  # Move to CPU and store

        # Use batch_labels[0] since all slices have the same 3D label
        all_labels.append(batch_labels[0])

    # Stack features and labels
    features = np.vstack(all_features)
    labels = np.array(all_labels)

    # Save features and labels as npy files
    np.save(features_path, features)
    np.save(labels_path, labels)

    print(f"Saved features to {features_path} and labels to {labels_path}")

# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Extract and save train features and labels
extract_and_save_features_and_labels(model, train_loader, device, TRAIN_FEATURES_PATH, TRAIN_LABELS_PATH)

# Extract and save test features and labels
extract_and_save_features_and_labels(model, test_loader, device, TEST_FEATURES_PATH, TEST_LABELS_PATH)



Saved features to ./XGBfeatures/train_features.npy and labels to ./XGBfeatures/train_labels.npy
Saved features to ./XGBfeatures/test_features.npy and labels to ./XGBfeatures/test_labels.npy


In [None]:
# Paths to save the features and labels
SAVE_DIR = './XGBfeatures'
TRAIN_FEATURES_PATH = os.path.join(SAVE_DIR, 'train_features.npy')
TRAIN_LABELS_PATH = os.path.join(SAVE_DIR, 'train_labels.npy')
# Load saved features and labels for training and testing
train_features = np.load(TRAIN_FEATURES_PATH)
train_labels = np.load(TRAIN_LABELS_PATH)
test_features = np.load(TEST_FEATURES_PATH)
test_labels = np.load(TEST_LABELS_PATH)
print(train_features.shape)
print(train_labels.shape)

(78720, 2)
(1230,)
(22528, 2)
(352,)


In [None]:
# Assuming train_features shape is (78720, 2) with 64 slices per 3D image
NUM_SLICES = 64  # Number of slices per 3D image
NUM_CLASSES = 2  # Number of class scores per slice

# Reshape to get (1230, 64 * 2) features for each 3D image
num_images = train_features.shape[0] // NUM_SLICES  # Should be 1230 images
reshaped_features = train_features.reshape(num_images, NUM_SLICES * NUM_CLASSES)

print(f"Reshaped Features Shape: {reshaped_features.shape}")  # Should be (1230, 128)
print(f"Train Labels Shape: {train_labels.shape}")  # Should be (1230,)


Reshaped Features Shape: (1230, 128)
Train Labels Shape: (1230,)


In [None]:
# Train XGBoost model with reshaped features
xgb_model = xgb.XGBClassifier(
    objective='binary:logistic',
    use_label_encoder=False,
    eval_metric='logloss'
)
xgb_model.fit(reshaped_features, train_labels)

# Save the trained XGB model
joblib.dump(xgb_model, 'xgb_model.pkl')
print("XGB model saved as 'xgb_model.pkl'")

XGB model saved as 'xgb_model.pkl'


In [None]:
# Load the saved XGB model
xgb_model = joblib.load('xgb_model.pkl')

# Paths to save the features and labels
SAVE_DIR = './XGBfeatures'
TEST_FEATURES_PATH = os.path.join(SAVE_DIR, 'test_features.npy')
TEST_LABELS_PATH = os.path.join(SAVE_DIR, 'test_labels.npy')

# Load test features and labels
test_features = np.load(TEST_FEATURES_PATH)
test_labels = np.load(TEST_LABELS_PATH)

# Reshape the test features similarly
test_features_reshaped = test_features.reshape(-1, NUM_SLICES * NUM_CLASSES)

# Predict and calculate probabilities
test_predictions = xgb_model.predict(test_features_reshaped)
test_probabilities = xgb_model.predict_proba(test_features_reshaped)[:, 1]

# Calculate evaluation metrics
accuracy = accuracy_score(test_labels, test_predictions)
precision = precision_score(test_labels, test_predictions)
recall = recall_score(test_labels, test_predictions)
f1 = f1_score(test_labels, test_predictions)
roc_auc = roc_auc_score(test_labels, test_probabilities)

# Print evaluation results
print(f"XGB Model - Accuracy: {accuracy * 100:.2f}%")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1-score: {f1:.2f}")
print(f"ROC-AUC Score: {roc_auc:.2f}")


XGB Model - Accuracy: 77.27%
Precision: 0.82
Recall: 0.89
F1-score: 0.85
ROC-AUC Score: 0.68


# Voting using Shannon Entropy (for Synapse3D)

To ensure the 64 2D slices from the same 3D image are loaded in an ordered fashion, we'll need to modify how you load the test dataset. Here's a high-level strategy:

Load images in a structured way: Ensure that the file reading function reads 64 consecutive 2D slices corresponding to a single 3D image.
Group these 64 slices: As each batch consists of slices from the same 3D image, we'll process them together for accuracy calculation after combining their probabilities.

## Shannon entropy voting: GPU optimized

In [None]:
import torch
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import roc_auc_score
import numpy as np

# Function to calculate entropy for each slice's probability (on GPU)
def calculate_entropy_gpu(probabilities):
    # Entropy formula: H(X) = -p*log(p) - (1-p)*log(1-p)
    entropy_values = -probabilities * torch.log(probabilities + 1e-10) - (1 - probabilities) * torch.log(1 - probabilities + 1e-10)
    return (entropy_values)

# Function to normalize entropies (on GPU)
def normalize_entropies_gpu(entropies):
    total_entropy = torch.sum(entropies)
    if total_entropy == 0:
        return torch.ones_like(entropies) / len(entropies)  # Handle zero entropy case
    return entropies / total_entropy

# Function to combine probabilities using the normalized entropy as weights (on GPU)
def combine_probabilities_gpu(probabilities, weights):
    return torch.sum(probabilities * weights)

def combine_scores_gpu(score, weights):
    # Perform matrix multiplication
    result = torch.matmul(weights,score)
    return result

# Function to calculate 3D image-wise metrics
def calculate_3d_image_metrics_gpu(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    true_labels = []
    predicted_labels = []
    combined_probabilities = []

    for batch_images, batch_labels in test_loader:
        batch_images = batch_images.to(device)  # Move batch to GPU
        batch_labels = torch.tensor(batch_labels).to(device)  # Move labels to GPU

        # Step 1: Get prediction for each individual slice in the batch
        with torch.no_grad():
            outputs = model(batch_images)  # Model outputs for the 64 slices
            softmax_outputs = F.softmax(outputs, dim=1)  # Apply softmax to get probabilities
            slice_probs = softmax_outputs[:, 1]  # Get probability for class 1 (binary classification)

        # Step 2: Calculate entropies for the 64 slices (on GPU)
        entropies = calculate_entropy_gpu(slice_probs)

        # Step 3: Normalize the entropies (on GPU)
        weights = normalize_entropies_gpu(entropies)

        # Step 4: Combine the probabilities using the entropy-based weights (on GPU)
        # combined_probability = combine_probabilities_gpu(slice_probs, weights)
        combined_scores = combine_scores_gpu(outputs, weights)
        softmax_3Doutputs = F.softmax(combined_scores, dim=-1)
        combined_probability=softmax_3Doutputs[1]

        # Step 5: Get the true label (all slices in the batch have the same label) and make a prediction
        true_label = batch_labels[0]  # All slices have the same label
        predicted_label = 1 if combined_probability.item() >= 0.5 else 0  # Binary classification threshold

        # Store true and predicted labels for final metrics calculation (move to CPU for sklearn)
        true_labels.append(true_label.item())
        predicted_labels.append(predicted_label)
        combined_probabilities.append(combined_probability.item())

    # Calculate accuracy, precision, recall, and F1-score (on CPU for sklearn compatibility)
    accuracy = sum(np.array(true_labels) == np.array(predicted_labels)) / len(true_labels)
    precision = precision_score(true_labels, predicted_labels)
    recall = recall_score(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels)
    auc = roc_auc_score(true_labels,combined_probabilities)

    return accuracy, precision, recall, f1, auc

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # Move model to the appropriate device

# Assuming test_loader is already defined
accuracy, precision, recall, f1, auc = calculate_3d_image_metrics_gpu(model, test_loader, device)

print(f"3D Image-wise Accuracy: {accuracy * 100:.2f}%")
print(f"Precision: {precision*100:.2f}%")
print(f"Recall: {recall*100:.2f}%")
print(f"F1-score: {f1*100:.2f}%")
print(f"AUC: {auc*100:.2f}%")


In [None]:
print(f"3D Image-wise Accuracy: {accuracy * 100:.2f}%")
print(f"Precision: {precision*100:.2f}%")
print(f"Recall: {recall*100:.2f}%")
print(f"F1-score: {f1*100:.2f}%")
print(f"AUC: {auc*100:.2f}%")

3D Image-wise Accuracy: 82.39%
Precision: 82.61%
Recall: 96.11%
F1-score: 88.85%
AUC: 83.82%
