In [1]:
# Import necessary libraries
import numpy as np
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from scipy.ndimage import zoom

from torch.utils.data import DataLoader, Dataset
import os
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from torchvision import models, transforms
from pathlib import Path

from matplotlib.animation import FuncAnimation, FFMpegWriter
from concurrent.futures import ThreadPoolExecutor, as_completed
import warnings
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import roc_curve, auc, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, log_loss
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision.models.video import r3d_18

In [4]:
main_dir = "../MRNet-v1.0"
train_path = os.path.join(main_dir, "train")
valid_path = os.path.join(main_dir, "valid")

In [5]:
base_dir = "../MRNet-v1.0/train"
file_id = "0000"

axial_path = f"{base_dir}/axial/{file_id}.npy"
coronal_path = f"{base_dir}/coronal/{file_id}.npy"
sagittal_path = f"{base_dir}/sagittal/{file_id}.npy"

# Load MRI scans from the .npy files
axial_scan = np.load(axial_path)
coronal_scan = np.load(coronal_path)
sagittal_scan = np.load(sagittal_path)

# Check the shape of each MRI scan
print("Axial scan shape:", axial_scan.shape)
print("Coronal scan shape:", coronal_scan.shape)
print("Sagittal scan shape:", sagittal_scan.shape)

Axial scan shape: (44, 256, 256)
Coronal scan shape: (36, 256, 256)
Sagittal scan shape: (36, 256, 256)


In [6]:
# Load labels from CSV files
train_abnormal = pd.read_csv(os.path.join(main_dir, "train-abnormal.csv"), header=None, index_col=0).squeeze("columns").to_dict()
train_acl = pd.read_csv(os.path.join(main_dir, "train-acl.csv"), header=None, index_col=0).squeeze("columns").to_dict()
train_meniscus = pd.read_csv(os.path.join(main_dir, "train-meniscus.csv"), header=None, index_col=0).squeeze("columns").to_dict()

valid_abnormal = pd.read_csv(os.path.join(main_dir, "valid-abnormal.csv"), header=None, index_col=0).squeeze("columns").to_dict()
valid_acl = pd.read_csv(os.path.join(main_dir, "valid-acl.csv"), header=None, index_col=0).squeeze("columns").to_dict()
valid_meniscus = pd.read_csv(os.path.join(main_dir, "valid-meniscus.csv"), header=None, index_col=0).squeeze("columns").to_dict()


In [7]:
# Function to resize the depth of a scan to a target depth
def resize_depth(scan, target_depth):
    depth_factor = target_depth / scan.shape[0]
    return zoom(scan, (depth_factor, 1, 1), order=1)

In [8]:
# Function to pad a scan to a target shape
def pad_to_shape(scan, target_shape):
    padded_scan = np.zeros(target_shape, dtype=scan.dtype)
    min_d, min_h, min_w = min(scan.shape[0], target_shape[0]), min(scan.shape[1], target_shape[1]), min(scan.shape[2], target_shape[2])
    padded_scan[:min_d, :min_h, :min_w] = scan[:min_d, :min_h, :min_w]
    return padded_scan

In [11]:
# Function to load a specific range of MRI data with labels

def load_mri_data(data_type="train", start_idx=0, end_idx=9, target_shape=(48, 256, 256), target_depth=48):
    """
    Loads MRI data from a specified range and resizes/pads each scan to a target shape.
    Parameters:
    - data_type: "train" or "valid"
    - start_idx, end_idx: Range of file indices to load (e.g., 0 to 9 for train, 1130 to 1249 for valid)
    - target_shape: Target shape for each scan after resizing and padding
    - target_depth: Target depth for each scan to ensure consistent depth
    """
    # Set data path and range
    data_path = train_path if data_type == "train" else valid_path
    axial_path, coronal_path, sagittal_path = Path(data_path) / "axial", Path(data_path) / "coronal", Path(data_path) / "sagittal"

    # Select the appropriate labels dictionary based on data type
    abnormal_labels = train_abnormal if data_type == "train" else valid_abnormal
    acl_labels = train_acl if data_type == "train" else valid_acl
    meniscus_labels = train_meniscus if data_type == "train" else valid_meniscus

    # Initialize lists to store data and labels
    mri_data, labels = [], []

    # Load each MRI scan within the specified range
    for i in range(start_idx, end_idx + 1):
        # Generate file name with zero-padded format (e.g., 0000, 0001, ...)
        file_name = f"{i:04}.npy"

        # Load and process each view with resizing and padding
        axial_scan = pad_to_shape(resize_depth(np.load(axial_path / file_name), target_depth), target_shape)
        coronal_scan = pad_to_shape(resize_depth(np.load(coronal_path / file_name), target_depth), target_shape)
        sagittal_scan = pad_to_shape(resize_depth(np.load(sagittal_path / file_name), target_depth), target_shape)

        # Combine the three views into one structure (3, depth, height, width)
        combined_scan = np.stack([axial_scan, coronal_scan, sagittal_scan], axis=0)
        mri_data.append(combined_scan)

        # Retrieve actual labels for the current scan
        abnormal_label = abnormal_labels.get(i, 0)  # Default to 0 if label is missing
        acl_label = acl_labels.get(i, 0)
        meniscus_label = meniscus_labels.get(i, 0)

        # Append the actual labels
        # print("Loaded", i)
        if (i%10==0):
            print("Loaded, ", i) 
        labels.append({"abnormal": abnormal_label, "acl": acl_label, "meniscus": meniscus_label})

    return np.array(mri_data), labels

In [12]:
# Load the validation data from indices 1130 to 1249
train_data, train_labels = load_mri_data(data_type="train", start_idx=0, end_idx=19)

# Check data shapes and labels
print("Train data shape:", train_data.shape)  # Expected: (120, 3, 48, 256, 256)

Loaded 0
Loaded 1
Loaded 2
Loaded 3
Loaded 4
Loaded 5
Loaded 6
Loaded 7
Loaded 8
Loaded 9
Loaded 10
Loaded 11
Loaded 12
Loaded 13
Loaded 14
Loaded 15
Loaded 16
Loaded 17
Loaded 18
Loaded 19
Loaded 20
Loaded 21
Loaded 22
Loaded 23
Loaded 24
Loaded 25
Loaded 26
Loaded 27
Loaded 28
Loaded 29
Loaded 30
Loaded 31
Loaded 32
Loaded 33
Loaded 34
Loaded 35
Loaded 36
Loaded 37
Loaded 38
Loaded 39
Loaded 40
Loaded 41
Loaded 42
Loaded 43
Loaded 44
Loaded 45
Loaded 46
Loaded 47
Loaded 48
Loaded 49
Loaded 50
Loaded 51
Loaded 52
Loaded 53
Loaded 54
Loaded 55
Loaded 56
Loaded 57
Loaded 58
Loaded 59
Loaded 60
Loaded 61
Loaded 62
Loaded 63
Loaded 64
Loaded 65
Loaded 66
Loaded 67
Loaded 68
Loaded 69
Loaded 70
Loaded 71
Loaded 72
Loaded 73
Loaded 74
Loaded 75
Loaded 76
Loaded 77
Loaded 78
Loaded 79
Loaded 80
Loaded 81
Loaded 82
Loaded 83
Loaded 84
Loaded 85
Loaded 86
Loaded 87
Loaded 88
Loaded 89
Loaded 90
Loaded 91
Loaded 92
Loaded 93
Loaded 94
Loaded 95
Loaded 96
Loaded 97
Loaded 98
Loaded 99
Loaded 100

In [14]:
# Load the validation data from indices 1130 to 1249
valid_data, valid_labels = load_mri_data(data_type="valid", start_idx=1130, end_idx=1149)

# Check data shapes and labels
print("Validation data shape:", valid_data.shape)  # Expected: (120, 3, 48, 256, 256)

Loaded 1130
Loaded 1131
Loaded 1132
Loaded 1133
Loaded 1134
Loaded 1135
Loaded 1136
Loaded 1137
Loaded 1138
Loaded 1139
Loaded 1140
Loaded 1141
Loaded 1142
Loaded 1143
Loaded 1144
Loaded 1145
Loaded 1146
Loaded 1147
Loaded 1148
Loaded 1149
Loaded 1150
Loaded 1151
Loaded 1152
Loaded 1153
Loaded 1154
Loaded 1155
Loaded 1156
Loaded 1157
Loaded 1158
Loaded 1159
Loaded 1160
Loaded 1161
Loaded 1162
Loaded 1163
Loaded 1164
Loaded 1165
Loaded 1166
Loaded 1167
Loaded 1168
Loaded 1169
Loaded 1170
Loaded 1171
Loaded 1172
Loaded 1173
Loaded 1174
Loaded 1175
Loaded 1176
Loaded 1177
Loaded 1178
Loaded 1179
Loaded 1180
Loaded 1181
Loaded 1182
Loaded 1183
Loaded 1184
Loaded 1185
Loaded 1186
Loaded 1187
Loaded 1188
Loaded 1189
Loaded 1190
Loaded 1191
Loaded 1192
Loaded 1193
Loaded 1194
Loaded 1195
Loaded 1196
Loaded 1197
Loaded 1198
Loaded 1199
Loaded 1200
Loaded 1201
Loaded 1202
Loaded 1203
Loaded 1204
Loaded 1205
Loaded 1206
Loaded 1207
Loaded 1208
Loaded 1209
Loaded 1210
Loaded 1211
Loaded 1212
Load

In [15]:
# Set display options to avoid truncation
pd.set_option('display.max_rows', None)      # Show all rows in the DataFrame
pd.set_option('display.max_columns', None)   # Show all columns in the DataFrame
pd.set_option('display.width', None)         # Expand display width to accommodate more columns
pd.set_option('display.max_colwidth', None)  # Expand column width if necessary

In [16]:
# Define the Flexible 3D ResNet model class
class ResNet3D(nn.Module):
    def __init__(self, pretrained=True, num_classes=3, optimizer_type="adam"):
        super(ResNet3D, self).__init__()
        self.optimizer_type = optimizer_type  # Store optimizer type as part of the model

        # Load a pre-trained 3D ResNet-18 model
        self.resnet3d = r3d_18(pretrained=pretrained)

        # Replace the final fully connected layer to match the desired output size
        in_features = self.resnet3d.fc.in_features
        self.resnet3d.fc = nn.Linear(in_features, num_classes)  # 3 binary outputs (one per class)

    def forward(self, x):
        return self.resnet3d(x)

# Define the loss function with BCEWithLogitsLoss
def calculate_loss(preds, targets):
    criterion = nn.BCEWithLogitsLoss()
    return criterion(preds, targets.float())

# Compute metrics for multi-label classification
def compute_metrics(y_true, y_pred, y_proba):
    labels = ["abnormal", "acl", "meniscus"]  # Label names for multi-label classification
    y_true_np, y_pred_np, y_proba_np = np.array(y_true), np.array(y_pred), np.array(y_proba)
    results = {}
    for i, label in enumerate(labels):
        cm = confusion_matrix(y_true_np[:, i], y_pred_np[:, i], labels=[0, 1])
        accuracy = accuracy_score(y_true_np[:, i], y_pred_np[:, i])
        precision = precision_score(y_true_np[:, i], y_pred_np[:, i], zero_division=0)
        recall = recall_score(y_true_np[:, i], y_pred_np[:, i], zero_division=0)
        f1 = f1_score(y_true_np[:, i], y_pred_np[:, i], zero_division=0)
        try:
            log_loss_value = log_loss(y_true_np[:, i], y_proba_np[:, i], labels=[0, 1])
        except ValueError:
            log_loss_value = None

        print(f"\nMetrics for {label}:")
        print(f"Confusion Matrix:\n{cm}")
        print(f"Accuracy: {accuracy}")
        print(f"Precision: {precision}")
        print(f"Recall: {recall}")
        print(f"F1-Score: {f1}")
        print(f"Log Loss: {log_loss_value}")

        results[label] = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'log_loss': log_loss_value
        }
    return results

# Flatten metrics for logging
def flatten_metrics(metrics, prefix):
    flattened = {}
    for label, label_metrics in metrics.items():
        for metric_name, value in label_metrics.items():
            flattened[f"{prefix}_{label}_{metric_name}"] = float(value) if isinstance(value, np.float64) else value
    return flattened

# Training and evaluation function
def train_and_evaluate(model, train_loader, valid_loader, epochs=5, learning_rate=0.001):
    if model.optimizer_type.lower() == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    elif model.optimizer_type.lower() == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    else:
        raise ValueError("Invalid optimizer type. Choose 'adam' or 'sgd'.")

    results = []  # Store summarized epoch metrics for table output
    detailed_metrics = []  # Store metrics for `detailed_df` with all per-class data
    final_y_true_valid, final_y_pred_proba_valid = None, None  # To store final epoch validation values

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        train_loss = 0
        y_true_train, y_pred_train, y_pred_proba_train = [], [], []

        # Training loop
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = calculate_loss(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            y_true_train.extend(labels.numpy())
            y_pred_train.extend((torch.sigmoid(outputs) > 0.5).detach().numpy())
            y_pred_proba_train.extend(torch.sigmoid(outputs).detach().numpy())

        # Compute training metrics
        train_metrics = compute_metrics(np.array(y_true_train), np.array(y_pred_train), np.array(y_pred_proba_train))

        model.eval()
        valid_loss = 0
        y_true_valid, y_pred_proba_valid = [], []

        # Validation loop
        with torch.no_grad():
            for inputs, labels in valid_loader:
                outputs = model(inputs)
                loss = calculate_loss(outputs, labels)
                valid_loss += loss.item()

                y_true_valid.extend(labels.numpy())
                y_pred_proba_valid.extend(torch.sigmoid(outputs).detach().numpy())

        y_pred_valid = (np.array(y_pred_proba_valid) > 0.5).astype(int)
        valid_metrics = compute_metrics(np.array(y_true_valid), y_pred_valid, np.array(y_pred_proba_valid))

        # Update final validation metrics for last epoch outputs
        final_y_true_valid = np.array(y_true_valid)
        final_y_pred_proba_valid = np.array(y_pred_proba_valid)

        # Add detailed metrics (all classes, all epochs) for `detailed_df`
        for phase, metrics in zip(["Train", "Validation"], [train_metrics, valid_metrics]):
            for label, label_metrics in metrics.items():
                detailed_metrics.append({
                    "Epoch": epoch + 1,
                    "Phase": phase,
                    "Class": label,
                    "Accuracy": label_metrics["accuracy"],
                    "Precision": label_metrics["precision"],
                    "Recall": label_metrics["recall"],
                    "F1-Score": label_metrics["f1"],
                    "Log Loss": label_metrics["log_loss"]
                })

        # Aggregate metrics for `results` (summary output without per-class breakdown)
        overall_train_metrics = flatten_metrics(train_metrics, "Train")
        overall_valid_metrics = flatten_metrics(valid_metrics, "Valid")

        row_data = {
            "Epoch": epoch + 1,
            "Train Loss": train_loss / len(train_loader),
            "Valid Loss": valid_loss / len(valid_loader),
        }
        #format and print the row data

        print(row_data)
        row_data.update(overall_train_metrics)
        row_data.update(overall_valid_metrics)
        results.append(row_data)

    # Display summarized DataFrame (`summary_df`) with per-epoch metrics
    summary_df = pd.DataFrame(results)
    pd.set_option('display.max_rows', None)
    display(summary_df.style.set_table_styles([{'selector': 'th', 'props': [('font-weight', 'bold')]}]))

    # Create `detailed_df` with per-class metrics in columns, suitable for further analysis
    detailed_df = pd.DataFrame(detailed_metrics)
    return final_y_true_valid, final_y_pred_proba_valid

# Convert labels from dictionaries to multi-label binary tensors
def convert_labels_to_tensor(labels):
    labels_tensor = torch.zeros(len(labels), 3)  # Three classes: abnormal, ACL, meniscus
    for i, label_dict in enumerate(labels):
        labels_tensor[i, 0] = label_dict['abnormal']
        labels_tensor[i, 1] = label_dict['acl']
        labels_tensor[i, 2] = label_dict['meniscus']
    return labels_tensor


In [17]:
# Function to convert a list of dictionaries to a tensor
def convert_labels_to_tensor(label_dicts):
    # Convert each dictionary's values into a list and create a tensor from it
    labels_as_lists = [list(label.values()) for label in label_dicts]
    return torch.tensor(labels_as_lists, dtype=torch.float32)

# Apply this to train and validation labels
train_labels_tensor = convert_labels_to_tensor(train_labels)
valid_labels_tensor = convert_labels_to_tensor(valid_labels)

# Create TensorDatasets for training and validation data
train_tensor = TensorDataset(torch.tensor(train_data).float(), train_labels_tensor)
valid_tensor = TensorDataset(torch.tensor(valid_data).float(), valid_labels_tensor)

# Adjust batch size to avoid high memory usage
batch_size = 8  # Adjust based on available memory

# Create DataLoaders with updated parameters
train_loader = DataLoader(
    train_tensor,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4  # Adjust based on your CPU cores and memory
)

valid_loader = DataLoader(
    valid_tensor,
    batch_size=batch_size,
    pin_memory=True,
    num_workers=4
)

# Print out DataLoader settings for confirmation
print(f"Train DataLoader - Batch Size: {batch_size}, Num Workers: 4")
print(f"Validation DataLoader - Batch Size: {batch_size}, Num Workers: 4")


Train DataLoader - Batch Size: 8, Num Workers: 4
Validation DataLoader - Batch Size: 8, Num Workers: 4


In [18]:
def plot_roc_and_calculate_auc(y_true, y_pred_proba, class_names=["abnormal", "acl", "meniscus"]):
    plt.figure(figsize=(10, 8))
    auc_scores = {}

    for i in range(y_true.shape[1]):  # Assuming y_true and y_pred_proba are (num_samples, num_classes)
        fpr, tpr, _ = roc_curve(y_true[:, i], y_pred_proba[:, i])
        roc_auc = auc(fpr, tpr)
        auc_scores[class_names[i]] = roc_auc  # Store AUC for each class

        plt.plot(fpr, tpr, label=f"{class_names[i]} (AUC = {roc_auc:.2f})")

    # Plot the random classifier line
    plt.plot([0, 1], [0, 1], "k--")
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve for Each Class")
    plt.legend(loc="lower right")
    plt.show()

    # Print AUC scores for each class
    print("AUC Scores for each class:")
    for class_name, auc_score in auc_scores.items():
        print(f"{class_name}: {auc_score:.2f}")

In [21]:
model = ResNet3D(pretrained=True, num_classes=3, optimizer_type="adam")

# Run the training and evaluation function and capture the outputs
y_true_valid_np, y_pred_proba_valid_np = train_and_evaluate(model, train_loader, valid_loader, epochs=1, learning_rate=0.1)

print("y_true_valid_np shape:", y_true_valid_np.shape)
print("y_pred_proba_valid_np shape:", y_pred_proba_valid_np.shape)

# Directly call the function with `y_true_valid_np` and `y_pred_proba_valid_np`
plot_roc_and_calculate_auc(y_true_valid_np, y_pred_proba_valid_np)

Epoch 1/1


RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 1610612736 bytes.