In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.ndimage import zoom
from concurrent.futures import ThreadPoolExecutor, as_completed
import warnings
import torchio as tio
from torchio import SubjectsLoader
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torchvision.models as models


In [5]:
# Define paths based on your directory structure
main_dir = "/content/drive/My Drive/DATASET_MRNET/MRNet-v1.0"
train_path = os.path.join(main_dir, "train")
valid_path = os.path.join(main_dir, "valid")

In [6]:
# Load labels from CSV files
train_abnormal = pd.read_csv(os.path.join(main_dir, "train-abnormal.csv"), header=None, index_col=0).squeeze("columns").to_dict()
train_acl = pd.read_csv(os.path.join(main_dir, "train-acl.csv"), header=None, index_col=0).squeeze("columns").to_dict()
train_meniscus = pd.read_csv(os.path.join(main_dir, "train-meniscus.csv"), header=None, index_col=0).squeeze("columns").to_dict()

valid_abnormal = pd.read_csv(os.path.join(main_dir, "valid-abnormal.csv"), header=None, index_col=0).squeeze("columns").to_dict()
valid_acl = pd.read_csv(os.path.join(main_dir, "valid-acl.csv"), header=None, index_col=0).squeeze("columns").to_dict()
valid_meniscus = pd.read_csv(os.path.join(main_dir, "valid-meniscus.csv"), header=None, index_col=0).squeeze("columns").to_dict()


In [7]:
# Function to resize the depth of a scan to a target depth
def resize_depth(scan, target_depth):
    depth_factor = target_depth / scan.shape[0]
    return zoom(scan, (depth_factor, 1, 1), order=1)

In [8]:
# Function to pad a scan to a target shape
def pad_to_shape(scan, target_shape):
    padded_scan = np.zeros(target_shape, dtype=scan.dtype)
    min_d, min_h, min_w = min(scan.shape[0], target_shape[0]), min(scan.shape[1], target_shape[1]), min(scan.shape[2], target_shape[2])
    padded_scan[:min_d, :min_h, :min_w] = scan[:min_d, :min_h, :min_w]
    return padded_scan

In [9]:
# Function to load a specific range of MRI data with labels

def load_mri_data(data_type="train", start_idx=0, end_idx=9, target_shape=(48, 256, 256), target_depth=48):
    """
    Loads MRI data from a specified range and resizes/pads each scan to a target shape.
    Parameters:
    - data_type: "train" or "valid"
    - start_idx, end_idx: Range of file indices to load (e.g., 0 to 9 for train, 1130 to 1249 for valid)
    - target_shape: Target shape for each scan after resizing and padding
    - target_depth: Target depth for each scan to ensure consistent depth
    """
    # Set data path and range
    data_path = train_path if data_type == "train" else valid_path
    axial_path, coronal_path, sagittal_path = Path(data_path) / "axial", Path(data_path) / "coronal", Path(data_path) / "sagittal"

    # Select the appropriate labels dictionary based on data type
    abnormal_labels = train_abnormal if data_type == "train" else valid_abnormal
    acl_labels = train_acl if data_type == "train" else valid_acl
    meniscus_labels = train_meniscus if data_type == "train" else valid_meniscus

    # Initialize lists to store data and labels
    mri_data, labels = [], []

    # Load each MRI scan within the specified range
    for i in range(start_idx, end_idx + 1):
        # Generate file name with zero-padded format (e.g., 0000, 0001, ...)
        file_name = f"{i:04}.npy"

        # Load and process each view with resizing and padding
        axial_scan = pad_to_shape(resize_depth(np.load(axial_path / file_name), target_depth), target_shape)
        coronal_scan = pad_to_shape(resize_depth(np.load(coronal_path / file_name), target_depth), target_shape)
        sagittal_scan = pad_to_shape(resize_depth(np.load(sagittal_path / file_name), target_depth), target_shape)

        # Combine the three views into one structure (3, depth, height, width)
        combined_scan = np.stack([axial_scan, coronal_scan, sagittal_scan], axis=0)
        mri_data.append(combined_scan)

        # Retrieve actual labels for the current scan
        abnormal_label = abnormal_labels.get(i, 0)  # Default to 0 if label is missing
        acl_label = acl_labels.get(i, 0)
        meniscus_label = meniscus_labels.get(i, 0)

        # Append the actual labels
        labels.append({"abnormal": abnormal_label, "acl": acl_label, "meniscus": meniscus_label})

    return np.array(mri_data), labels

In [10]:
# Define the parameters for batch loading with exact final index coverage
start_indices = list(range(0, 1130, 100))
end_indices = [min(start + 99, 1129) for start in start_indices]  # Ensure final batch ends at 1129

# Function to load a batch of data given start and end indices
def load_batch(start, end):
    return load_mri_data(data_type="train", start_idx=start, end_idx=end)

# Initialize lists to store all data and labels
all_data, all_labels = [], []

# Use ThreadPoolExecutor to parallelize data loading for all batches
with ThreadPoolExecutor() as executor:
    # Launch parallel tasks for loading each batch
    future_to_indices = {executor.submit(load_batch, start, end): (start, end) for start, end in zip(start_indices, end_indices)}

    for future in as_completed(future_to_indices):
        data, labels = future.result()
        all_data.append(data)
        all_labels.extend(labels)  # Extend to add lists of labels directly

# Concatenate all data batches into a single array
train_data = np.concatenate(all_data, axis=0)
train_labels = all_labels  # Already extended to combine all label lists

# Check the final shape of training data and labels
print("Final training data shape:", train_data.shape)  # Expected: (1130, 3, 48, 256, 256)
print("Final number of training labels:", len(train_labels))  # Should match the number of samples in train_data


Final training data shape: (1130, 3, 48, 256, 256)
Final number of training labels: 1130


In [11]:
# Load the validation data from indices 1130 to 1249
valid_data, valid_labels = load_mri_data(data_type="valid", start_idx=1130, end_idx=1249)

# Check data shapes and labels
print("Validation data shape:", valid_data.shape)  # Expected: (120, 3, 48, 256, 256)

Validation data shape: (120, 3, 48, 256, 256)


In [15]:
# Define transformations using TorchIO
train_transforms = tio.Compose([
    tio.RandomFlip(axes=('LR',), flip_probability=0.5, include=('image',)),
    tio.ZNormalization(include=('image',)),
])

valid_transforms = tio.Compose([
    tio.ZNormalization(include=('image',)),
])

In [16]:
# Prepare subjects for training data
train_subjects = []
for i in range(len(train_data)):
    image_tensor = train_data[i]  # Shape: (C, D, H, W)
    label = train_labels[i]
    image = tio.ScalarImage(tensor=image_tensor)
    subject = tio.Subject(
    image=image,
    abnormal=label['abnormal'],
    acl=label['acl'],
    meniscus=label['meniscus']
    )
    train_subjects.append(subject)

# Prepare subjects for validation data
valid_subjects = []
for i in range(len(valid_data)):
    image_tensor = valid_data[i]
    label = valid_labels[i]
    image = tio.ScalarImage(tensor=image_tensor)
    subject = tio.Subject(
    image=image,
    abnormal=label['abnormal'],
    acl=label['acl'],
    meniscus=label['meniscus']
    )
    valid_subjects.append(subject)

# Create datasets using tio.SubjectsDataset
train_dataset = tio.SubjectsDataset(train_subjects, transform=train_transforms)
valid_dataset = tio.SubjectsDataset(valid_subjects, transform=valid_transforms)

In [17]:
train_loader = SubjectsLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)
valid_loader = SubjectsLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=4)

In [18]:
print(type(train_dataset[0]))

<class 'torchio.data.subject.Subject'>


In [19]:
class MRI3DResNet(nn.Module):
    def __init__(self, num_classes=3):
        super(MRI3DResNet, self).__init__()
        # Load a pretrained 3D ResNet
        self.backbone = models.video.r3d_18(pretrained=True)
        # Modify the first convolutional layer to accept 3 input channels
        self.backbone.stem[0] = nn.Conv3d(
            in_channels=3,  # Change to 3 channels
            out_channels=64,
            kernel_size=(3, 7, 7),
            stride=(1, 2, 2),
            padding=(1, 3, 3),
            bias=False
        )
        # Replace the fully connected layer
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)
        # No activation here because we'll use BCEWithLogitsLoss

    def forward(self, x):
        x = self.backbone(x)
        return x

In [20]:
criterion = nn.BCEWithLogitsLoss()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MRI3DResNet(num_classes=3).to(device)

In [23]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [24]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [25]:
print(type(train_subjects[0]))

<class 'torchio.data.subject.Subject'>


In [None]:
import torch
print(torch.__version__)

2.5.1+cu121


In [26]:
for name, param in model.named_parameters():
    if not param.requires_grad:
        print(f"Parameter {name} does not require gradient")

In [27]:
num_epochs = 20
warnings.filterwarnings("ignore", category=UserWarning, module="torchio")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_train_preds = []
    all_train_labels = []
    for batch in train_loader:
        data = batch['image'][tio.DATA].to(device)
        labels = torch.tensor(
            [batch['abnormal'], batch['acl'], batch['meniscus']],
            dtype=torch.float32
        ).T.to(device)

        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * data.size(0)

        # Collect predictions and labels for metrics
        outputs = torch.sigmoid(outputs)
        preds = (outputs > 0.5).float()
        all_train_preds.append(preds.cpu())
        all_train_labels.append(labels.cpu())

    # Compute training metrics
    all_train_preds = torch.cat(all_train_preds)
    all_train_labels = torch.cat(all_train_labels)
    train_accuracy = accuracy_score(all_train_labels.numpy(), all_train_preds.numpy())
    train_precision = precision_score(all_train_labels.numpy(), all_train_preds.numpy(), average='weighted', zero_division=0)
    train_recall = recall_score(all_train_labels.numpy(), all_train_preds.numpy(), average='weighted', zero_division=0)
    epoch_loss = running_loss / len(train_dataset)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Training Loss: {epoch_loss:.4f}, Accuracy: {train_accuracy:.4f}, Precision: {train_precision:.4f}, Recall: {train_recall:.4f}")

    model.eval()
    val_running_loss = 0.0
    all_val_preds = []
    all_val_labels = []
    with torch.no_grad():
        for batch in valid_loader:
            data = batch['image'][tio.DATA].to(device)
            labels = torch.tensor(
                [batch['abnormal'], batch['acl'], batch['meniscus']],
                dtype=torch.float32
            ).T.to(device)

            outputs = model(data)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item() * data.size(0)

            # Collect predictions and labels for metrics
            outputs = torch.sigmoid(outputs)
            preds = (outputs > 0.5).float()
            all_val_preds.append(preds.cpu())
            all_val_labels.append(labels.cpu())

    # Compute validation metrics
    all_val_preds = torch.cat(all_val_preds)
    all_val_labels = torch.cat(all_val_labels)
    val_accuracy = accuracy_score(all_val_labels.numpy(), all_val_preds.numpy())
    val_precision = precision_score(all_val_labels.numpy(), all_val_preds.numpy(), average='weighted', zero_division=0)
    val_recall = recall_score(all_val_labels.numpy(), all_val_preds.numpy(), average='weighted', zero_division=0)
    val_loss = val_running_loss / len(valid_dataset)

    print(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}, Precision: {val_precision:.4f}, Recall: {val_recall:.4f}\n")
    scheduler.step()


Epoch 1/20
Training Loss: 0.5448, Accuracy: 0.3690, Precision: 0.6314, Recall: 0.6304
Validation Loss: 1.7998, Accuracy: 0.2417, Precision: 0.4035, Recall: 0.1741

Epoch 2/20
Training Loss: 0.5310, Accuracy: 0.3673, Precision: 0.5943, Recall: 0.6192
Validation Loss: 0.6786, Accuracy: 0.2250, Precision: 0.6074, Recall: 0.4826

Epoch 3/20
Training Loss: 0.5303, Accuracy: 0.3726, Precision: 0.5861, Recall: 0.6212
Validation Loss: 0.6962, Accuracy: 0.1833, Precision: 0.5896, Recall: 0.4030

Epoch 4/20
Training Loss: 0.5205, Accuracy: 0.3690, Precision: 0.6059, Recall: 0.6271
Validation Loss: 0.7724, Accuracy: 0.3500, Precision: 0.7381, Recall: 0.6915

Epoch 5/20
Training Loss: 0.5133, Accuracy: 0.3814, Precision: 0.6831, Recall: 0.6469
Validation Loss: 0.6772, Accuracy: 0.2000, Precision: 0.7362, Recall: 0.5821

Epoch 6/20
Training Loss: 0.5123, Accuracy: 0.3708, Precision: 0.6705, Recall: 0.6344
Validation Loss: 0.6924, Accuracy: 0.2083, Precision: 0.8128, Recall: 0.5274

Epoch 7/20
Train

## Explanation of 3D-ResNet Results:

### **Epochs 1-5**

| Epoch | Training Loss | Training Accuracy | Training Precision | Training Recall | Validation Loss | Validation Accuracy | Validation Precision | Validation Recall |
|-------|---------------|-------------------|--------------------|------------------|------------------|---------------------|----------------------|--------------------|
| 1     | 0.5448        | 0.3690            | 0.6314             | 0.6304           | 1.7998           | 0.2417              | 0.4035               | 0.1741             |
| 2     | 0.5310        | 0.3673            | 0.5943             | 0.6192           | 0.6786           | 0.2250              | 0.6074               | 0.4826             |
| 3     | 0.5303        | 0.3726            | 0.5861             | 0.6212           | 0.6962           | 0.1833              | 0.5896               | 0.4030             |
| 4     | 0.5205        | 0.3690            | 0.6059             | 0.6271           | 0.7724           | 0.3500              | 0.7381               | 0.6915             |
| 5     | 0.5133        | 0.3814            | 0.6831             | 0.6469           | 0.6772           | 0.2000              | 0.7362               | 0.5821             |

**Observations:**
- **Training Metrics:** There's a slight decrease in training loss and a modest increase in training accuracy from Epoch 1 to Epoch 5. Precision and recall show improvement, indicating better performance on the training data.
- **Validation Metrics:** Validation loss fluctuates without a clear downward trend, and validation accuracy remains low (~18-35%). Precision and recall are inconsistent, suggesting that the model struggles to generalize beyond the training data.

---

### **Epochs 6-10**

| Epoch | Training Loss | Training Accuracy | Training Precision | Training Recall | Validation Loss | Validation Accuracy | Validation Precision | Validation Recall |
|-------|---------------|-------------------|--------------------|------------------|------------------|---------------------|----------------------|--------------------|
| 6     | 0.5123        | 0.3708            | 0.6705             | 0.6344           | 0.6924           | 0.2083              | 0.8128               | 0.5274             |
| 7     | 0.5073        | 0.3867            | 0.6706             | 0.6489           | 0.7137           | 0.2167              | 0.7282               | 0.6468             |
| 8     | 0.5002        | 0.3832            | 0.6522             | 0.6331           | 0.6368           | 0.3167              | 0.7719               | 0.7363             |
| 9     | 0.5065        | 0.3796            | 0.6862             | 0.6542           | 0.6742           | 0.2083              | 0.8099               | 0.5423             |
| 10    | 0.5021        | 0.3735            | 0.6955             | 0.6449           | 0.6415           | 0.1750              | 0.7418               | 0.5721             |

**Observations:**
- **Training Metrics:** Training loss continues to decrease slightly, and accuracy improves marginally. Precision shows a positive trend, while recall remains relatively stable.
- **Validation Metrics:** Validation loss remains high and varies without consistent improvement. Validation accuracy shows minimal changes, and precision and recall fluctuate, further indicating poor generalization.

---

### **Epochs 11-15**

| Epoch | Training Loss | Training Accuracy | Training Precision | Training Recall | Validation Loss | Validation Accuracy | Validation Precision | Validation Recall |
|-------|---------------|-------------------|--------------------|------------------|------------------|---------------------|----------------------|--------------------|
| 11    | 0.4836        | 0.3938            | 0.6969             | 0.6798           | 0.5906           | 0.3250              | 0.7765               | 0.6468             |
| 12    | 0.4977        | 0.3894            | 0.6772             | 0.6344           | 0.6221           | 0.3583              | 0.7552               | 0.7662             |
| 13    | 0.4807        | 0.3832            | 0.7060             | 0.6871           | 0.6123           | 0.2333              | 0.7707               | 0.5970             |
| 14    | 0.4787        | 0.4000            | 0.7021             | 0.6607           | 0.5710           | 0.3250              | 0.7798               | 0.6716             |
| 15    | 0.4754        | 0.4000            | 0.6945             | 0.6719           | 0.6450           | 0.3333              | 0.7538               | 0.7711             |

**Observations:**
- **Training Metrics:** There's a slight improvement in training accuracy and a decrease in training loss. Precision and recall remain relatively stable, showing that the model maintains consistent performance on training data.
- **Validation Metrics:** Validation loss shows minor fluctuations but doesn't exhibit a clear downward trend. Validation accuracy remains low to moderate (~23-35%), while precision and recall vary, indicating inconsistent performance.

---

### **Epochs 16-20**

| Epoch | Training Loss | Training Accuracy | Training Precision | Training Recall | Validation Loss | Validation Accuracy | Validation Precision | Validation Recall |
|-------|---------------|-------------------|--------------------|------------------|------------------|---------------------|----------------------|--------------------|
| 16    | 0.4769        | 0.3965            | 0.6785             | 0.6634           | 0.5431           | 0.3167              | 0.7291               | 0.7662             |
| 17    | 0.4700        | 0.3929            | 0.7046             | 0.6686           | 0.6213           | 0.4083              | 0.7466               | 0.8109             |
| 18    | 0.4721        | 0.4106            | 0.7113             | 0.6871           | 0.5710           | 0.2833              | 0.8019               | 0.6517             |
| 19    | 0.4682        | 0.4186            | 0.7109             | 0.6831           | 0.6158           | 0.2583              | 0.7711               | 0.6716             |
| 20    | 0.4658        | 0.4257            | 0.7173             | 0.6798           | 0.5662           | 0.2917              | 0.7749               | 0.7164             |

**Observations:**
- **Training Metrics:** Training loss decreases slightly, and accuracy shows a modest improvement. Precision and recall remain consistent, indicating stable performance on the training set.
- **Validation Metrics:** Validation loss experiences minor reductions but remains relatively high. Validation accuracy shows negligible improvement, and precision and recall vary without a clear upward trend. These patterns reinforce the presence of **overfitting**, where the model fails to generalize effectively to validation data despite improving performance on training data.

---

### **Summary of 3D ResNet Performance**

- **Training Performance:**
  - **Loss:** Decreases gradually from ~0.5448 to ~0.4658.
  - **Accuracy:** Increases from ~37% to ~43%.
  - **Precision & Recall:** Fluctuate between ~0.59 to ~0.71, showing inconsistent improvement.

- **Validation Performance:**
  - **Loss:** Remains high, fluctuating between ~0.54 to ~1.80.
  - **Accuracy:** Remains low to moderate, between ~17% to ~41%.
  - **Precision & Recall:** Vary significantly, indicating inconsistent model behavior on unseen data.

**Conclusion:** The **3D ResNet** model exhibits clear signs of **overfitting**. While there's some improvement in training metrics, the validation performance remains subpar and inconsistent, suggesting that the model struggles to generalize beyond the training dataset.

---

### **Next Steps: Transitioning to ResNet18**

Given the overfitting observed with the 3D ResNet model, we will now **transition to a 3D ResNet18** architecture. **ResNet18** is less complex and has fewer parameters compared to deeper ResNet variants, which can help mitigate overfitting and improve the model's ability to generalize to validation data.
