In [24]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip uninstall -y tensorflow
!pip install tensorflow-cpu


In [28]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
from sklearn.metrics import accuracy_score, roc_auc_score
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.distributed.parallel_loader import MpDeviceLoader
from torchvision.models import resnet50

In [29]:
# Define main directories
main_dir = "/content/drive/My Drive/DATASET_MRNET/MRNet-v1.0"
train_path = os.path.join(main_dir, "train")
valid_path = os.path.join(main_dir, "valid")

In [30]:
def load_labels(label_type):
    train_labels = pd.read_csv(os.path.join(main_dir, f"train-{label_type}.csv"),
                               header=None, index_col=0).squeeze("columns").to_dict()
    valid_labels = pd.read_csv(os.path.join(main_dir, f"valid-{label_type}.csv"),
                               header=None, index_col=0).squeeze("columns").to_dict()
    return train_labels, valid_labels

In [31]:
train_abnormal, valid_abnormal = load_labels("abnormal")
train_acl, valid_acl = load_labels("acl")
train_meniscus, valid_meniscus = load_labels("meniscus")

In [32]:
def pad_to_shape(scan, target_shape):
    padded_scan = np.zeros(target_shape, dtype=scan.dtype)
    min_d = min(scan.shape[0], target_shape[0])
    min_h = min(scan.shape[1], target_shape[1])
    min_w = min(scan.shape[2], target_shape[2])
    padded_scan[:min_d, :min_h, :min_w] = scan[:min_d, :min_h, :min_w]
    return padded_scan

In [33]:
def pad_or_crop_scan(scan, target_shape=(48, 256, 256)):
    d, h, w = scan.shape
    td, th, tw = target_shape

    # Pad or crop depth
    if d < td:
        pad_width = ((0, td - d), (0, 0), (0, 0))
        scan = np.pad(scan, pad_width=pad_width, mode='constant', constant_values=0)
    elif d > td:
        start = (d - td) // 2
        scan = scan[start:start+td, :, :]

    # Pad or crop height and width
    scan = pad_to_shape(scan, target_shape)
    return scan

Starting with Axial axes


In [34]:
def load_single_axis_mri_data(data_type="train", start_idx=0, end_idx=9, target_shape=(48, 256, 256)):
    data_path = train_path if data_type == "train" else valid_path
    abnormal_labels = train_abnormal if data_type == "train" else valid_abnormal
    acl_labels = train_acl if data_type == "train" else valid_acl
    meniscus_labels = train_meniscus if data_type == "train" else valid_meniscus

    data = []
    labels = []

    for i in range(start_idx, end_idx + 1):
        file_name = f"{i:04}.npy"
        scan = np.load(Path(data_path) / "axial" / file_name)
        scan = pad_or_crop_scan(scan, target_shape)

        # Convert 3D (D,H,W) to 2D by taking the middle slice
        d = scan.shape[0]
        mid_slice = d // 2
        single_slice = scan[mid_slice, :, :]  # shape (H, W)

        # Add channel dimension: (1, H, W)
        single_slice = np.expand_dims(single_slice, axis=0)
        data.append(single_slice)  # will become (1,H,W)

        # Labels: [abnormal, acl, meniscus]
        abnormal_label = abnormal_labels.get(i, 0)
        acl_label = acl_labels.get(i, 0)
        meniscus_label = meniscus_labels.get(i, 0)
        labels.append([abnormal_label, acl_label, meniscus_label])

    data = np.array(data)    # (N, 1, H, W)
    labels = np.array(labels) # (N, 3)
    return data, labels


In [35]:
def load_data_in_batches(start_indices, end_indices, data_type):
    all_data = []
    all_labels = []

    with ThreadPoolExecutor() as executor:
        future_to_indices = {
            executor.submit(load_single_axis_mri_data, data_type, start, end): (start, end)
            for start, end in zip(start_indices, end_indices)
        }
        for future in as_completed(future_to_indices):
            data, labels = future.result()
            all_data.append(data)
            all_labels.append(labels)

    all_data = np.concatenate(all_data, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    return all_data, all_labels

In [36]:
start_indices = list(range(0, 1130, 100))
end_indices = [min(start + 99, 1129) for start in start_indices]

train_data, train_labels = load_data_in_batches(start_indices, end_indices, "train")
valid_data, valid_labels = load_single_axis_mri_data("valid", 1130, 1249)

print("Train data shape:", train_data.shape)   # (N,1,H,W)
print("Train labels shape:", train_labels.shape)# (N,3)
print("Valid data shape:", valid_data.shape)
print("Valid labels shape:", valid_labels.shape)

Train data shape: (1130, 1, 256, 256)
Train labels shape: (1130, 3)
Valid data shape: (120, 1, 256, 256)
Valid labels shape: (120, 3)


In [42]:
class SingleChannelResNet18(nn.Module):
    def __init__(self, num_classes=3):
        super(SingleChannelResNet18, self).__init__()
        self.model = resnet18()
        # Modify the first conv layer to accept 1-channel input
        # Original: Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
        original_conv1 = self.model.conv1
        self.model.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=original_conv1.bias
        )
        # Change the final FC layer
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)

In [38]:
class MRNetDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx], dtype=torch.float32)   # (1,H,W)
        y = torch.tensor(self.labels[idx], dtype=torch.float32) # (3,)
        return x, y

In [39]:
# Use TPU device
device = xm.xla_device()

In [40]:
train_dataset = MRNetDataset(train_data, train_labels)
valid_dataset = MRNetDataset(valid_data, valid_labels)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

train_loader = MpDeviceLoader(train_loader, device)
valid_loader = MpDeviceLoader(valid_loader, device)

In [43]:
model = SingleChannelResNet18(num_classes=3).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [46]:
def train_and_validate(model, train_loader, valid_loader, criterion, optimizer, epochs=10, classes=["abnormal", "acl", "meniscus"]):
    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_loss = 0
        train_preds = {cls: [] for cls in classes}
        train_truth = {cls: [] for cls in classes}

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)  # (N,3)
            loss = criterion(outputs, y)
            loss.backward()
            xm.optimizer_step(optimizer)

            train_loss += loss.item()
            preds_sig = torch.sigmoid(outputs).detach().cpu().numpy()  # (N,3)
            labels_np = y.detach().cpu().numpy()                       # (N,3)

            for i, cls in enumerate(classes):
                train_preds[cls].extend(preds_sig[:, i])
                train_truth[cls].extend(labels_np[:, i])

        train_loss /= len(train_loader)

        # Validation Phase
        model.eval()
        valid_loss = 0
        valid_preds = {cls: [] for cls in classes}
        valid_truth = {cls: [] for cls in classes}

        with torch.no_grad():
            for x, y in valid_loader:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                loss = criterion(outputs, y)
                valid_loss += loss.item()

                preds_sig = torch.sigmoid(outputs).cpu().numpy()
                labels_np = y.cpu().numpy()
                for i, cls in enumerate(classes):
                    valid_preds[cls].extend(preds_sig[:, i])
                    valid_truth[cls].extend(labels_np[:, i])

        valid_loss /= len(valid_loader)

        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")

        # Compute total accuracy
        # Convert dictionary lists into arrays of shape (N,3)
        train_pred_matrix = np.column_stack([train_preds[cls] for cls in classes])
        train_truth_matrix = np.column_stack([train_truth[cls] for cls in classes])
        valid_pred_matrix = np.column_stack([valid_preds[cls] for cls in classes])
        valid_truth_matrix = np.column_stack([valid_truth[cls] for cls in classes])

        # Threshold predictions at 0.5
        train_binary_preds = (train_pred_matrix > 0.5).astype(int)
        valid_binary_preds = (valid_pred_matrix > 0.5).astype(int)

        # Print per-class metrics
        for cls in classes:
            train_acc = accuracy_score(train_truth[cls], (np.array(train_preds[cls]) > 0.5).astype(int))
            valid_acc = accuracy_score(valid_truth[cls], (np.array(valid_preds[cls]) > 0.5).astype(int))

            # ROC-AUC might require both classes present in data
            if len(np.unique(train_truth[cls])) > 1:
                train_roc = roc_auc_score(train_truth[cls], train_preds[cls])
            else:
                train_roc = float('nan')
            if len(np.unique(valid_truth[cls])) > 1:
                valid_roc = roc_auc_score(valid_truth[cls], valid_preds[cls])
            else:
                valid_roc = float('nan')

            print(f"\nClass: {cls.upper()}")
            print(f"Train Acc: {train_acc:.4f}, Train ROC-AUC: {train_roc:.4f}")
            print(f"Valid Acc: {valid_acc:.4f}, Valid ROC-AUC: {valid_roc:.4f}")

        print("---------------------------------------------------")


In [47]:
classes = ["abnormal", "acl", "meniscus"]

In [48]:
train_and_validate(model, train_loader, valid_loader, criterion, optimizer, epochs=20, classes=classes)

Epoch 1/20
Train Loss: 0.4838 | Valid Loss: 0.9988

Class: ABNORMAL
Train Acc: 0.8071, Train ROC-AUC: 0.7934
Valid Acc: 0.7500, Valid ROC-AUC: 0.6926

Class: ACL
Train Acc: 0.8124, Train ROC-AUC: 0.6276
Valid Acc: 0.5667, Valid ROC-AUC: 0.5881

Class: MENISCUS
Train Acc: 0.6805, Train ROC-AUC: 0.6955
Valid Acc: 0.5333, Valid ROC-AUC: 0.6674
---------------------------------------------------
Epoch 2/20
Train Loss: 0.4919 | Valid Loss: 0.6691

Class: ABNORMAL
Train Acc: 0.7903, Train ROC-AUC: 0.7899
Valid Acc: 0.7250, Valid ROC-AUC: 0.7242

Class: ACL
Train Acc: 0.8168, Train ROC-AUC: 0.6451
Valid Acc: 0.5583, Valid ROC-AUC: 0.5629

Class: MENISCUS
Train Acc: 0.6637, Train ROC-AUC: 0.6677
Valid Acc: 0.5667, Valid ROC-AUC: 0.6643
---------------------------------------------------
Epoch 3/20
Train Loss: 0.4839 | Valid Loss: 2.7417

Class: ABNORMAL
Train Acc: 0.8044, Train ROC-AUC: 0.7869
Valid Acc: 0.7917, Valid ROC-AUC: 0.6139

Class: ACL
Train Acc: 0.8230, Train ROC-AUC: 0.6757
Valid A

Now, Coronal Axes



In [49]:
def load_single_axis_mri_data(data_type="train", start_idx=0, end_idx=9, target_shape=(48, 256, 256)):
    data_path = train_path if data_type == "train" else valid_path
    abnormal_labels = train_abnormal if data_type == "train" else valid_abnormal
    acl_labels = train_acl if data_type == "train" else valid_acl
    meniscus_labels = train_meniscus if data_type == "train" else valid_meniscus

    data = []
    labels = []

    for i in range(start_idx, end_idx + 1):
        file_name = f"{i:04}.npy"
        scan = np.load(Path(data_path) / "coronal" / file_name)
        scan = pad_or_crop_scan(scan, target_shape)

        # Convert 3D (D,H,W) to 2D by taking the middle slice
        d = scan.shape[0]
        mid_slice = d // 2
        single_slice = scan[mid_slice, :, :]  # shape (H, W)

        # Add channel dimension: (1, H, W)
        single_slice = np.expand_dims(single_slice, axis=0)
        data.append(single_slice)  # will become (1,H,W)

        # Labels: [abnormal, acl, meniscus]
        abnormal_label = abnormal_labels.get(i, 0)
        acl_label = acl_labels.get(i, 0)
        meniscus_label = meniscus_labels.get(i, 0)
        labels.append([abnormal_label, acl_label, meniscus_label])

    data = np.array(data)    # (N, 1, H, W)
    labels = np.array(labels) # (N, 3)
    return data, labels


In [50]:
start_indices = list(range(0, 1130, 100))
end_indices = [min(start + 99, 1129) for start in start_indices]

train_data, train_labels = load_data_in_batches(start_indices, end_indices, "train")
valid_data, valid_labels = load_single_axis_mri_data("valid", 1130, 1249)

print("Train data shape:", train_data.shape)   # (N,1,H,W)
print("Train labels shape:", train_labels.shape)# (N,3)
print("Valid data shape:", valid_data.shape)
print("Valid labels shape:", valid_labels.shape)

Train data shape: (1130, 1, 256, 256)
Train labels shape: (1130, 3)
Valid data shape: (120, 1, 256, 256)
Valid labels shape: (120, 3)


In [51]:
train_dataset = MRNetDataset(train_data, train_labels)
valid_dataset = MRNetDataset(valid_data, valid_labels)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

train_loader = MpDeviceLoader(train_loader, device)
valid_loader = MpDeviceLoader(valid_loader, device)

In [52]:
train_and_validate(model, train_loader, valid_loader, criterion, optimizer, epochs=20, classes=classes)

Epoch 1/20
Train Loss: 0.5954 | Valid Loss: 0.7354

Class: ABNORMAL
Train Acc: 0.7965, Train ROC-AUC: 0.6530
Valid Acc: 0.7917, Valid ROC-AUC: 0.5962

Class: ACL
Train Acc: 0.7982, Train ROC-AUC: 0.4951
Valid Acc: 0.5500, Valid ROC-AUC: 0.5460

Class: MENISCUS
Train Acc: 0.6274, Train ROC-AUC: 0.5626
Valid Acc: 0.5500, Valid ROC-AUC: 0.5263
---------------------------------------------------
Epoch 2/20
Train Loss: 0.5221 | Valid Loss: 0.7165

Class: ABNORMAL
Train Acc: 0.8071, Train ROC-AUC: 0.6860
Valid Acc: 0.7917, Valid ROC-AUC: 0.5945

Class: ACL
Train Acc: 0.8150, Train ROC-AUC: 0.5519
Valid Acc: 0.5417, Valid ROC-AUC: 0.4952

Class: MENISCUS
Train Acc: 0.6460, Train ROC-AUC: 0.6255
Valid Acc: 0.6083, Valid ROC-AUC: 0.5761
---------------------------------------------------
Epoch 3/20
Train Loss: 0.5147 | Valid Loss: 0.7382

Class: ABNORMAL
Train Acc: 0.8062, Train ROC-AUC: 0.7022
Valid Acc: 0.7917, Valid ROC-AUC: 0.5979

Class: ACL
Train Acc: 0.8159, Train ROC-AUC: 0.5738
Valid A

Now, Sagittal axes


In [53]:
def load_single_axis_mri_data(data_type="train", start_idx=0, end_idx=9, target_shape=(48, 256, 256)):
    data_path = train_path if data_type == "train" else valid_path
    abnormal_labels = train_abnormal if data_type == "train" else valid_abnormal
    acl_labels = train_acl if data_type == "train" else valid_acl
    meniscus_labels = train_meniscus if data_type == "train" else valid_meniscus

    data = []
    labels = []

    for i in range(start_idx, end_idx + 1):
        file_name = f"{i:04}.npy"
        scan = np.load(Path(data_path) / "sagittal" / file_name)
        scan = pad_or_crop_scan(scan, target_shape)

        # Convert 3D (D,H,W) to 2D by taking the middle slice
        d = scan.shape[0]
        mid_slice = d // 2
        single_slice = scan[mid_slice, :, :]  # shape (H, W)

        # Add channel dimension: (1, H, W)
        single_slice = np.expand_dims(single_slice, axis=0)
        data.append(single_slice)  # will become (1,H,W)

        # Labels: [abnormal, acl, meniscus]
        abnormal_label = abnormal_labels.get(i, 0)
        acl_label = acl_labels.get(i, 0)
        meniscus_label = meniscus_labels.get(i, 0)
        labels.append([abnormal_label, acl_label, meniscus_label])

    data = np.array(data)    # (N, 1, H, W)
    labels = np.array(labels) # (N, 3)
    return data, labels


In [54]:
start_indices = list(range(0, 1130, 100))
end_indices = [min(start + 99, 1129) for start in start_indices]

train_data, train_labels = load_data_in_batches(start_indices, end_indices, "train")
valid_data, valid_labels = load_single_axis_mri_data("valid", 1130, 1249)

print("Train data shape:", train_data.shape)   # (N,1,H,W)
print("Train labels shape:", train_labels.shape)# (N,3)
print("Valid data shape:", valid_data.shape)
print("Valid labels shape:", valid_labels.shape)

Train data shape: (1130, 1, 256, 256)
Train labels shape: (1130, 3)
Valid data shape: (120, 1, 256, 256)
Valid labels shape: (120, 3)


In [55]:
train_dataset = MRNetDataset(train_data, train_labels)
valid_dataset = MRNetDataset(valid_data, valid_labels)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

train_loader = MpDeviceLoader(train_loader, device)
valid_loader = MpDeviceLoader(valid_loader, device)

In [56]:
train_and_validate(model, train_loader, valid_loader, criterion, optimizer, epochs=20, classes=classes)

Epoch 1/20
Train Loss: 0.5977 | Valid Loss: 0.6508

Class: ABNORMAL
Train Acc: 0.7938, Train ROC-AUC: 0.6613
Valid Acc: 0.7917, Valid ROC-AUC: 0.7322

Class: ACL
Train Acc: 0.8062, Train ROC-AUC: 0.5622
Valid Acc: 0.5500, Valid ROC-AUC: 0.5342

Class: MENISCUS
Train Acc: 0.6204, Train ROC-AUC: 0.5681
Valid Acc: 0.6583, Valid ROC-AUC: 0.6867
---------------------------------------------------
Epoch 2/20
Train Loss: 0.4882 | Valid Loss: 0.7269

Class: ABNORMAL
Train Acc: 0.8088, Train ROC-AUC: 0.7606
Valid Acc: 0.7750, Valid ROC-AUC: 0.7213

Class: ACL
Train Acc: 0.8159, Train ROC-AUC: 0.6875
Valid Acc: 0.5500, Valid ROC-AUC: 0.5800

Class: MENISCUS
Train Acc: 0.6735, Train ROC-AUC: 0.6641
Valid Acc: 0.6083, Valid ROC-AUC: 0.6643
---------------------------------------------------
Epoch 3/20
Train Loss: 0.4537 | Valid Loss: 0.6917

Class: ABNORMAL
Train Acc: 0.8212, Train ROC-AUC: 0.8072
Valid Acc: 0.7917, Valid ROC-AUC: 0.7663

Class: ACL
Train Acc: 0.8372, Train ROC-AUC: 0.7508
Valid A

# 1. Performance Analysis of Current Views with ResNet-18

## Axial View

### Final Epoch Metrics (Epoch 20/20)

| Metric         | Value  |
|----------------|--------|
| **Train Loss** | 0.1882 |
| **Valid Loss** | 1.8368 |

### ABNORMAL

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.9133 |
| **Train ROC-AUC** | 0.9616 |
| **Valid Acc**     | 0.8000 |
| **Valid ROC-AUC** | 0.7596 |

### ACL

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.9531 |
| **Train ROC-AUC** | 0.9751 |
| **Valid Acc**     | 0.5667 |
| **Valid ROC-AUC** | 0.5835 |

### MENISCUS

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.9009 |
| **Train ROC-AUC** | 0.9599 |
| **Valid Acc**     | 0.5417 |
| **Valid ROC-AUC** | 0.6202 |

## Sagittal View

### Final Epoch Metrics (Epoch 20/20)

| Metric         | Value  |
|----------------|--------|
| **Train Loss** | 0.2576 |
| **Valid Loss** | 1.1965 |

### ABNORMAL

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.8673 |
| **Train ROC-AUC** | 0.8731 |
| **Valid Acc**     | 0.7917 |
| **Valid ROC-AUC** | 0.6177 |

### ACL

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.9133 |
| **Train ROC-AUC** | 0.9120 |
| **Valid Acc**     | 0.5667 |
| **Valid ROC-AUC** | 0.6089 |

### MENISCUS

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.8761 |
| **Train ROC-AUC** | 0.9663 |
| **Valid Acc**     | 0.5833 |
| **Valid ROC-AUC** | 0.5894 |

## Coronal View

### Final Epoch Metrics (Epoch 20/20)

| Metric         | Value  |
|----------------|--------|
| **Train Loss** | 0.2576 |
| **Valid Loss** | 1.1965 |

### ABNORMAL

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.8673 |
| **Train ROC-AUC** | 0.8731 |
| **Valid Acc**     | 0.7917 |
| **Valid ROC-AUC** | 0.6177 |

### ACL

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.9133 |
| **Train ROC-AUC** | 0.9120 |
| **Valid Acc**     | 0.5667 |
| **Valid ROC-AUC** | 0.6089 |

### MENISCUS

| Metric            | Value  |
|-------------------|--------|
| **Train Acc**     | 0.8761 |
| **Train ROC-AUC** | 0.9663 |
| **Valid Acc**     | 0.5833 |
| **Valid ROC-AUC** | 0.5894 |

# 2. Why Choose the Axial View for ResNet-50?

## a. Superior Performance Metrics

### ABNORMAL Class
- **Axial view** achieved the highest **Train ROC-AUC (0.9616)** and **Valid ROC-AUC (0.7596)** compared to sagittal and coronal views.

### ACL Class
- While **axial view** has a high **Train ROC-AUC (0.9751)**, the **Valid ROC-AUC (0.5835)** indicates room for improvement. However, it still outperforms the coronal view's **Valid ROC-AUC (~0.6)** and is comparable to the sagittal view.

### MENISCUS Class
- **Axial view** shows a high **Train ROC-AUC (0.9599)** and a moderate **Valid ROC-AUC (0.6202)**, which is better than the coronal view and comparable to the sagittal view.

## b. Consistency and Feature Richness
- The **axial view** likely captures more discriminative anatomical features relevant to your classification tasks. This is evident from the higher ROC-AUC scores in critical classes.
- **Axial scans** provide a top-down view, which might be more informative for detecting abnormalities, ACL injuries, and meniscus issues compared to other views.

## c. Potential for Better Generalization with ResNet-50
- **ResNet-50** is a deeper architecture with more parameters, which can potentially capture more complex patterns and improve performance, especially in scenarios where the current model shows promising results but has room to grow.
- Given that **ResNet-18** already performs relatively well on the axial view, **ResNet-50** can leverage its increased capacity to further enhance feature extraction and classification performance.

---

We will now try **ResNet50** to explore its effectiveness and potentially reduce overfitting while maintaining strong performance metrics.


In [62]:
class SingleChannelResNet50(nn.Module):
    def __init__(self, num_classes=3, pretrained=True):
        super(SingleChannelResNet50, self).__init__()
        # Initialize the ResNet-50 model
        self.model = resnet50()

        # Modify the first convolutional layer to accept 1-channel input
        original_conv1 = self.model.conv1
        self.model.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=original_conv1.bias
        )

        # Initialize the new conv1 weights
        if pretrained:
            with torch.no_grad():
                # Average the weights across the original RGB channels to initialize the single channel
                self.model.conv1.weight = nn.Parameter(original_conv1.weight.sum(dim=1, keepdim=True))

        # Replace the final fully connected layer to output `num_classes` logits
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.model(x)


In [63]:
def load_single_axis_mri_data(data_type="train", start_idx=0, end_idx=9, target_shape=(48, 256, 256)):
    data_path = train_path if data_type == "train" else valid_path
    abnormal_labels = train_abnormal if data_type == "train" else valid_abnormal
    acl_labels = train_acl if data_type == "train" else valid_acl
    meniscus_labels = train_meniscus if data_type == "train" else valid_meniscus

    data = []
    labels = []

    for i in range(start_idx, end_idx + 1):
        file_name = f"{i:04}.npy"
        scan = np.load(Path(data_path) / "axial" / file_name)
        scan = pad_or_crop_scan(scan, target_shape)

        # Convert 3D (D,H,W) to 2D by taking the middle slice
        d = scan.shape[0]
        mid_slice = d // 2
        single_slice = scan[mid_slice, :, :]  # shape (H, W)

        # Add channel dimension: (1, H, W)
        single_slice = np.expand_dims(single_slice, axis=0)
        data.append(single_slice)  # will become (1,H,W)

        # Labels: [abnormal, acl, meniscus]
        abnormal_label = abnormal_labels.get(i, 0)
        acl_label = acl_labels.get(i, 0)
        meniscus_label = meniscus_labels.get(i, 0)
        labels.append([abnormal_label, acl_label, meniscus_label])

    data = np.array(data)    # (N, 1, H, W)
    labels = np.array(labels) # (N, 3)
    return data, labels


In [64]:
start_indices = list(range(0, 1130, 100))
end_indices = [min(start + 99, 1129) for start in start_indices]

train_data, train_labels = load_data_in_batches(start_indices, end_indices, "train")
valid_data, valid_labels = load_single_axis_mri_data("valid", 1130, 1249)

print("Train data shape:", train_data.shape)   # (N,1,H,W)
print("Train labels shape:", train_labels.shape)# (N,3)
print("Valid data shape:", valid_data.shape)
print("Valid labels shape:", valid_labels.shape)

Train data shape: (1130, 1, 256, 256)
Train labels shape: (1130, 3)
Valid data shape: (120, 1, 256, 256)
Valid labels shape: (120, 3)


In [65]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [66]:
model = SingleChannelResNet50(num_classes=3).to(device)

In [67]:
train_and_validate(model, train_loader, valid_loader, criterion, optimizer, epochs=20, classes=classes)

Epoch 1/20
Train Loss: 0.6417 | Valid Loss: 0.6687

Class: ABNORMAL
Train Acc: 0.7133, Train ROC-AUC: 0.6898
Valid Acc: 0.7167, Valid ROC-AUC: 0.6261

Class: ACL
Train Acc: 0.7938, Train ROC-AUC: 0.4604
Valid Acc: 0.5333, Valid ROC-AUC: 0.4826

Class: MENISCUS
Train Acc: 0.6487, Train ROC-AUC: 0.5943
Valid Acc: 0.5667, Valid ROC-AUC: 0.6852
---------------------------------------------------
Epoch 2/20
Train Loss: 0.6431 | Valid Loss: 0.6706

Class: ABNORMAL
Train Acc: 0.7071, Train ROC-AUC: 0.6734
Valid Acc: 0.7000, Valid ROC-AUC: 0.6278

Class: ACL
Train Acc: 0.7929, Train ROC-AUC: 0.4694
Valid Acc: 0.5500, Valid ROC-AUC: 0.4717

Class: MENISCUS
Train Acc: 0.6496, Train ROC-AUC: 0.5682
Valid Acc: 0.5667, Valid ROC-AUC: 0.6547
---------------------------------------------------
Epoch 3/20
Train Loss: 0.6428 | Valid Loss: 0.6709

Class: ABNORMAL
Train Acc: 0.7035, Train ROC-AUC: 0.6707
Valid Acc: 0.7167, Valid ROC-AUC: 0.6291

Class: ACL
Train Acc: 0.8009, Train ROC-AUC: 0.4646
Valid A

# Summary of Results

## 1. ResNet-18 on Axial View (Final Epoch: Epoch 20/20)

**Train Loss:** 0.1882  
**Valid Loss:** 1.8368  

### Class-wise Metrics

#### ABNORMAL

| Metric            | ResNet-18 |
|-------------------|-----------|
| **Train Accuracy** | 0.9133    |
| **Train ROC-AUC** | 0.9616    |
| **Valid Accuracy** | 0.8000    |
| **Valid ROC-AUC** | 0.7596    |

#### ACL

| Metric            | ResNet-18 |
|-------------------|-----------|
| **Train Accuracy** | 0.9531    |
| **Train ROC-AUC** | 0.9751    |
| **Valid Accuracy** | 0.5667    |
| **Valid ROC-AUC** | 0.5835    |

#### MENISCUS

| Metric            | ResNet-18 |
|-------------------|-----------|
| **Train Accuracy** | 0.9009    |
| **Train ROC-AUC** | 0.9599    |
| **Valid Accuracy** | 0.5417    |
| **Valid ROC-AUC** | 0.6202    |

---

## 2. ResNet-50 on Axial View (Final Epoch: Epoch 20/20)

**Train Loss:** 0.6432  
**Valid Loss:** 0.6661  

### Class-wise Metrics

#### ABNORMAL

| Metric            | ResNet-50 |
|-------------------|-----------|
| **Train Accuracy** | 0.7000    |
| **Train ROC-AUC** | 0.6818    |
| **Valid Accuracy** | 0.7167    |
| **Valid ROC-AUC** | 0.6341    |

#### ACL

| Metric            | ResNet-50 |
|-------------------|-----------|
| **Train Accuracy** | 0.7912    |
| **Train ROC-AUC** | 0.4585    |
| **Valid Accuracy** | 0.5167    |
| **Valid ROC-AUC** | 0.4902    |

#### MENISCUS

| Metric            | ResNet-50 |
|-------------------|-----------|
| **Train Accuracy** | 0.6478    |
| **Train ROC-AUC** | 0.5719    |
| **Valid Accuracy** | 0.5667    |
| **Valid ROC-AUC** | 0.6377    |

---

# Detailed Comparison

## A. Loss Metrics

| Metric              | ResNet-18 (Axial) | ResNet-50 (Axial) |
|---------------------|-------------------|-------------------|
| **Training Loss**   | 0.1882 (Lower)    | 0.6432 (Higher)   |
| **Validation Loss** | 1.8368 (Higher)   | 0.6661 (Lower)    |

**Interpretation:**

- **ResNet-18** achieves a lower training loss but a higher validation loss, indicating that it fits the training data very well but may be **overfitting**, struggling to generalize to unseen data.
- **ResNet-50** exhibits a higher training loss but a lower validation loss, suggesting that it might be **underfitting** and not capturing the underlying patterns in the data as effectively as ResNet-18.

## B. Per-Class Accuracy and ROC-AUC

### ABNORMAL Class

| Metric            | ResNet-18 | ResNet-50 |
|-------------------|-----------|-----------|
| **Train Acc**     | 91.33%    | 70.00%    |
| **Train ROC-AUC** | 0.9616    | 0.6818    |
| **Valid Acc**     | 80.00%    | 71.67%    |
| **Valid ROC-AUC** | 0.7596    | 0.6341    |

**Conclusion:**  
ResNet-18 significantly outperforms ResNet-50 in both training and validation for the **ABNORMAL** class.

### ACL Class

| Metric            | ResNet-18 | ResNet-50 |
|-------------------|-----------|-----------|
| **Train Acc**     | 95.31%    | 79.12%    |
| **Train ROC-AUC** | 0.9751    | 0.4585    |
| **Valid Acc**     | 56.67%    | 51.67%    |
| **Valid ROC-AUC** | 0.5835    | 0.4902    |

**Conclusion:**  
ResNet-18 shows superior performance in both accuracy and ROC-AUC for the **ACL** class compared to ResNet-50.

### MENISCUS Class

| Metric            | ResNet-18 | ResNet-50 |
|-------------------|-----------|-----------|
| **Train Acc**     | 90.09%    | 64.78%    |
| **Train ROC-AUC** | 0.9599    | 0.5719    |
| **Valid Acc**     | 54.17%    | 56.67%    |
| **Valid ROC-AUC** | 0.6202    | 0.6377    |

**Conclusion:**  
While ResNet-50 shows a slight improvement in validation accuracy and ROC-AUC for the **MENISCUS** class, ResNet-18's training performance is substantially better. However, both models struggle with generalization in validation.

## C. Overall Performance

- **ResNet-18:** Demonstrates higher training accuracies and ROC-AUC scores across all classes, indicating strong learning from the training data.
- **ResNet-50:** Despite its more complex architecture, fails to generalize effectively to the validation set, as evidenced by lower ROC-AUC and comparable or slightly better validation accuracies in some classes.

## D. Potential Reasons for ResNet-50's Underperformance

- **Overfitting vs. Underfitting:**
  - **ResNet-18** might be overfitting the training data, achieving low training loss but high validation loss.
  - **ResNet-50** appears to underfit, not capturing the data's complexities sufficiently, leading to higher training loss but slightly better validation loss.

- **Training Configuration:**
  - **Learning Rate:** ResNet-50 may require different learning rate settings. Using the same learning rate as ResNet-18 might not be optimal.
  - **Batch Size:** Due to its larger size, ResNet-50 might benefit from different batch sizes to stabilize training.
  - **Epochs and Early Stopping:** ResNet-50 may need more epochs or early stopping criteria to optimize its performance effectively.

- **Model Complexity and Data Size:**
  - **ResNet-50** has more parameters, requiring more data to train effectively. If the dataset isn't large enough, ResNet-50 might not utilize its capacity fully.

- **Regularization Techniques:**
  - Insufficient regularization (e.g., dropout, weight decay) can adversely affect ResNet-50's performance, leading to underfitting or overfitting.

- **Initialization and Fine-Tuning:**
  - If ResNet-50 wasn't properly initialized or fine-tuned, it might not leverage the pre-trained weights effectively.

---

# Recommendations

**Continue with ResNet-18 on Axial View:**

Given the current performance metrics, **ResNet-18** trained on the axial view is performing better overall, especially in critical classes like **ABNORMAL** and **ACL**. It shows higher ROC-AUC scores, indicating better discriminative ability.

We will now try **ResNet18** to further explore its effectiveness and potentially reduce overfitting while maintaining strong performance metrics.


In [68]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [71]:
model = SingleChannelResNet18(num_classes=3).to(device)

In [72]:
train_and_validate(model, train_loader, valid_loader, criterion, optimizer, epochs=20, classes=classes)

Epoch 1/20
Train Loss: 0.6330 | Valid Loss: 0.6672

Class: ABNORMAL
Train Acc: 0.8000, Train ROC-AUC: 0.6401
Valid Acc: 0.7250, Valid ROC-AUC: 0.5815

Class: ACL
Train Acc: 0.8142, Train ROC-AUC: 0.4810
Valid Acc: 0.5500, Valid ROC-AUC: 0.5449

Class: MENISCUS
Train Acc: 0.6487, Train ROC-AUC: 0.5421
Valid Acc: 0.5667, Valid ROC-AUC: 0.4893
---------------------------------------------------
Epoch 2/20
Train Loss: 0.6323 | Valid Loss: 0.6650

Class: ABNORMAL
Train Acc: 0.8000, Train ROC-AUC: 0.6730
Valid Acc: 0.7250, Valid ROC-AUC: 0.5865

Class: ACL
Train Acc: 0.8159, Train ROC-AUC: 0.5278
Valid Acc: 0.5500, Valid ROC-AUC: 0.5474

Class: MENISCUS
Train Acc: 0.6487, Train ROC-AUC: 0.5284
Valid Acc: 0.5667, Valid ROC-AUC: 0.5062
---------------------------------------------------
Epoch 3/20
Train Loss: 0.6337 | Valid Loss: 0.6662

Class: ABNORMAL
Train Acc: 0.7947, Train ROC-AUC: 0.6435
Valid Acc: 0.7250, Valid ROC-AUC: 0.5861

Class: ACL
Train Acc: 0.8150, Train ROC-AUC: 0.4978
Valid A

# Summary of Results

## 1. Previous ResNet-18 on Axial View (Final Epoch: Epoch 20/20)

**Train Loss:** 0.1882  
**Valid Loss:** 1.8368  

### Class-wise Metrics

#### ABNORMAL

| Metric             | Value  |
|--------------------|--------|
| **Train Accuracy** | 0.9133 |
| **Train ROC-AUC**  | 0.9616 |
| **Valid Accuracy** | 0.8000 |
| **Valid ROC-AUC**  | 0.7596 |

#### ACL

| Metric             | Value  |
|--------------------|--------|
| **Train Accuracy** | 0.9531 |
| **Train ROC-AUC**  | 0.9751 |
| **Valid Accuracy** | 0.5667 |
| **Valid ROC-AUC**  | 0.5835 |

#### MENISCUS

| Metric             | Value  |
|--------------------|--------|
| **Train Accuracy** | 0.9009 |
| **Train ROC-AUC**  | 0.9599 |
| **Valid Accuracy** | 0.5417 |
| **Valid ROC-AUC**  | 0.6202 |

---

## 2. New ResNet-18 with Learning Rate = 0.0001 on Axial View (Final Epoch: Epoch 20/20)

**Train Loss:** 0.6335  
**Valid Loss:** 0.6657  

### Class-wise Metrics

#### ABNORMAL

| Metric             | Value  |
|--------------------|--------|
| **Train Accuracy** | 0.7912 |
| **Train ROC-AUC**  | 0.6481 |
| **Valid Accuracy** | 0.7250 |
| **Valid ROC-AUC**  | 0.5941 |

#### ACL

| Metric             | Value  |
|--------------------|--------|
| **Train Accuracy** | 0.8150 |
| **Train ROC-AUC**  | 0.5133 |
| **Valid Accuracy** | 0.5500 |
| **Valid ROC-AUC**  | 0.5477 |

#### MENISCUS

| Metric             | Value  |
|--------------------|--------|
| **Train Accuracy** | 0.6487 |
| **Train ROC-AUC**  | 0.5130 |
| **Valid Accuracy** | 0.5667 |
| **Valid ROC-AUC**  | 0.5170 |

---

# Detailed Comparison

## A. Loss Metrics

| Metric            | Previous ResNet-18 | New ResNet-18 (LR=0.0001) |
|-------------------|--------------------|----------------------------|
| **Training Loss** | 0.1882 (Lower)     | 0.6335 (Higher)            |
| **Validation Loss** | 1.8368 (Higher)   | 0.6657 (Lower)             |

**Interpretation:**

- **Training Loss:** The previous ResNet-18 achieved a much lower training loss, indicating it fit the training data more effectively.
- **Validation Loss:** The new ResNet-18 has a lower validation loss, which might suggest better generalization at a glance. However, this should be interpreted cautiously given the other metrics.

## B. Per-Class Accuracy and ROC-AUC

### 1. ABNORMAL Class

| Metric             | Previous ResNet-18 | New ResNet-18 (LR=0.0001) |
|--------------------|--------------------|----------------------------|
| **Train Accuracy** | 91.33%             | 79.12%                     |
| **Train ROC-AUC**  | 0.9616             | 0.6481                     |
| **Valid Accuracy** | 80.00%             | 72.50%                     |
| **Valid ROC-AUC**  | 0.7596             | 0.5941                     |

**Conclusion:**  
The previous ResNet-18 significantly outperforms the new ResNet-18 in both training and validation metrics for the **ABNORMAL** class.

### 2. ACL Class

| Metric             | Previous ResNet-18 | New ResNet-18 (LR=0.0001) |
|--------------------|--------------------|----------------------------|
| **Train Accuracy** | 95.31%             | 81.50%                     |
| **Train ROC-AUC**  | 0.9751             | 0.5133                     |
| **Valid Accuracy** | 56.67%             | 55.00%                     |
| **Valid ROC-AUC**  | 0.5835             | 0.5477                     |

**Conclusion:**  
The previous ResNet-18 shows superior performance in both accuracy and ROC-AUC for the **ACL** class compared to the new ResNet-18.

### 3. MENISCUS Class

| Metric             | Previous ResNet-18 | New ResNet-18 (LR=0.0001) |
|--------------------|--------------------|----------------------------|
| **Train Accuracy** | 90.09%             | 64.87%                     |
| **Train ROC-AUC**  | 0.9599             | 0.5130                     |
| **Valid Accuracy** | 54.17%             | 56.67%                     |
| **Valid ROC-AUC**  | 0.6202             | 0.5170                     |

**Conclusion:**  
While the new ResNet-18 shows a slight improvement in validation accuracy for the **MENISCUS** class, the ROC-AUC remains significantly lower, indicating poorer discriminative ability.

## C. Overall Performance

- **ResNet-18 (Previous):**
  - **High Training Performance:** The model fits the training data exceptionally well, as evidenced by high accuracies and ROC-AUC scores.
  - **Moderate Validation Performance:** Despite overfitting, the validation metrics are reasonably good, especially for the ABNORMAL class.

- **ResNet-18 (New with LR=0.0001):**
  - **Lower Training Performance:** The model does not fit the training data as well, resulting in lower accuracies and ROC-AUC scores.
  - **Comparable or Slightly Lower Validation Performance:** Validation metrics are on par or slightly worse than the previous model, despite a lower validation loss.

## D. Potential Reasons for ResNet-18 (New) Underperformance

- **Learning Rate Adjustment:**
  - **Lower Learning Rate (0.0001):** The reduced learning rate may have slowed down the training process, preventing the model from adequately fitting the training data within 20 epochs.

- **Model Capacity:**
  - **ResNet-18's Flexibility:** Even though ResNet-18 is less complex than deeper models, the change in learning rate may have affected its ability to learn effectively from the data.

- **Training Duration:**
  - **Insufficient Epochs:** With a lower learning rate, the model might require more epochs to converge and achieve better performance.

- **Regularization Techniques:**
  - **Potential Over-regularization:** If additional regularization was applied (e.g., dropout, weight decay), it might have hindered the model's ability to learn from the training data.

- **Batch Size and Optimization:**
  - **Batch Size Impact:** The chosen batch size in training could influence the optimization dynamics, especially with a lower learning rate.
  - **Optimizer Configuration:** The optimizer settings (e.g., momentum, weight decay) might need adjustment to complement the lower learning rate.

---

# Recommendations

**Revert to the Previous ResNet-18 Configuration:**

Given the substantial drop in performance metrics with the new ResNet-18 (LR=0.0001), it is advisable to continue with the **previous ResNet-18 configuration**. The earlier model demonstrated superior training and validation performance, particularly in critical classes like **ABNORMAL** and **ACL**.
