In [None]:
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, log_loss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd

# Define the Flexible 3D CNN model class (no changes here)
class Flexible3DCNN(nn.Module):
    def __init__(self, conv_layers=3, filters=[8, 16, 32], pooling="avg", activation="relu", optimizer_type="adam"):
        super(Flexible3DCNN, self).__init__()
        
        self.layers = nn.ModuleList()
        self.pooling_type = pooling
        self.activation_type = activation
        self.optimizer_type = optimizer_type  # Store optimizer type as part of the model

        # Initialize Convolutional Layers
        in_channels = 3  # 3 channels for MRI data
        for i in range(conv_layers):
            self.layers.append(nn.Conv3d(in_channels, filters[i], kernel_size=3, padding=1))
            in_channels = filters[i]

        # Calculate flattened size dynamically after conv layers
        sample_input = torch.zeros(1, 3, 48, 256, 256)  # Updated to match data dimensions
        flattened_size = self.determine_flattened_size(sample_input)

        # Fully connected classifier for 3 binary outputs (one per class)
        self.fc = nn.Linear(flattened_size, 3)  # For 3 binary output classes

    def determine_flattened_size(self, x):
        for layer in self.layers:
            x = layer(x)
            x = self.apply_activation(x)
            if min(x.shape[2:]) >= 2:  # Apply pooling only if all spatial dimensions are >= 2
                x = self.apply_pooling(x)
        return x.view(-1).size(0)
    
    def apply_pooling(self, x):
        return F.avg_pool3d(x, 2) if self.pooling_type == "avg" else F.max_pool3d(x, 2)
    
    def apply_activation(self, x):
        return torch.sigmoid(x) if self.activation_type == "sigmoid" else F.relu(x)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            x = self.apply_activation(x)
            if min(x.shape[2:]) >= 2:  # Apply pooling only if all spatial dimensions are >= 2
                x = self.apply_pooling(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)  # Output logits for each class
        return x

# Define the loss function with BCEWithLogitsLoss
def calculate_loss(preds, targets):
    criterion = nn.BCEWithLogitsLoss()
    return criterion(preds, targets.float())

def compute_metrics(y_true, y_pred, y_proba):
    labels = ["abnormal", "acl", "meniscus"]  # Label names for multi-label classification
    y_true_np, y_pred_np, y_proba_np = np.array(y_true), np.array(y_pred), np.array(y_proba)
    results = {}
    for i, label in enumerate(labels):
        cm = confusion_matrix(y_true_np[:, i], y_pred_np[:, i], labels=[0, 1])
        accuracy = accuracy_score(y_true_np[:, i], y_pred_np[:, i])
        precision = precision_score(y_true_np[:, i], y_pred_np[:, i], zero_division=0)
        recall = recall_score(y_true_np[:, i], y_pred_np[:, i], zero_division=0)
        f1 = f1_score(y_true_np[:, i], y_pred_np[:, i], zero_division=0)
        try:
            log_loss_value = log_loss(y_true_np[:, i], y_proba_np[:, i], labels=[0, 1])
        except ValueError as e:
            log_loss_value = None
        
        # For display
        print(f"\nMetrics for {label}:")
        print(f"Confusion Matrix:\n{cm}")
        print(f"Accuracy: {accuracy}")
        print(f"Precision: {precision}")
        print(f"Recall: {recall}")
        print(f"F1-Score: {f1}")
        print(f"Log Loss: {log_loss_value}")

        results[label] = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'log_loss': log_loss_value
        }
    return results


def flatten_metrics(metrics, prefix):
    flattened = {}
    for label, label_metrics in metrics.items():
        for metric_name, value in label_metrics.items():
            flattened[f"{prefix}_{label}_{metric_name}"] = float(value) if isinstance(value, np.float64) else value
    return flattened


def train_and_evaluate(model, train_loader, valid_loader, epochs=5, learning_rate=0.001):
    if model.optimizer_type.lower() == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    elif model.optimizer_type.lower() == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    else:
        raise ValueError("Invalid optimizer type. Choose 'adam' or 'sgd'.")

    results = []  # Store summarized epoch metrics for table output
    detailed_metrics = []  # Store metrics for `detailed_df` with all per-class data
    final_y_true_valid, final_y_pred_proba_valid = None, None  # To store final epoch validation values

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        y_true_train, y_pred_train, y_pred_proba_train = [], [], []

        # Training loop
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = calculate_loss(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            y_true_train.extend(labels.numpy())
            y_pred_train.extend((torch.sigmoid(outputs) > 0.5).detach().numpy())
            y_pred_proba_train.extend(torch.sigmoid(outputs).detach().numpy())
        
        # Compute training metrics
        train_metrics = compute_metrics(np.array(y_true_train), np.array(y_pred_train), np.array(y_pred_proba_train))

        model.eval()
        valid_loss = 0
        y_true_valid, y_pred_proba_valid = [], []
        
        # Validation loop
        with torch.no_grad():
            for inputs, labels in valid_loader:
                outputs = model(inputs)
                loss = calculate_loss(outputs, labels)
                valid_loss += loss.item()
                
                y_true_valid.extend(labels.numpy())
                y_pred_proba_valid.extend(torch.sigmoid(outputs).detach().numpy())
        
        y_pred_valid = (np.array(y_pred_proba_valid) > 0.5).astype(int)
        valid_metrics = compute_metrics(np.array(y_true_valid), y_pred_valid, np.array(y_pred_proba_valid))

        # Update final validation metrics for last epoch outputs
        final_y_true_valid = np.array(y_true_valid)
        final_y_pred_proba_valid = np.array(y_pred_proba_valid)

        # Add detailed metrics (all classes, all epochs) for `detailed_df`
        for phase, metrics in zip(["Train", "Validation"], [train_metrics, valid_metrics]):
            for label, label_metrics in metrics.items():
                detailed_metrics.append({
                    "Epoch": epoch + 1,
                    "Phase": phase,
                    "Class": label,
                    "Accuracy": label_metrics["accuracy"],
                    "Precision": label_metrics["precision"],
                    "Recall": label_metrics["recall"],
                    "F1-Score": label_metrics["f1"],
                    "Log Loss": label_metrics["log_loss"]
                })

        # Aggregate metrics for `results` (summary output without per-class breakdown)
        overall_train_metrics = flatten_metrics(train_metrics, "Train")
        overall_valid_metrics = flatten_metrics(valid_metrics, "Valid")
        
        row_data = {
            "Epoch": epoch + 1,
            "Train Loss": train_loss / len(train_loader),
            "Valid Loss": valid_loss / len(valid_loader),
        }
        row_data.update(overall_train_metrics)
        row_data.update(overall_valid_metrics)
        results.append(row_data)

    # Display summarized DataFrame (`summary_df`) with per-epoch metrics
    summary_df = pd.DataFrame(results)
    pd.set_option('display.max_rows', None)
    display(summary_df.style.set_table_styles([{'selector': 'th', 'props': [('font-weight', 'bold')]}]))

    # Create `detailed_df` with per-class metrics in columns, suitable for further analysis
    detailed_df = pd.DataFrame(detailed_metrics)
    return final_y_true_valid, final_y_pred_proba_valid

# Convert labels from dictionaries to multi-label binary tensors
def convert_labels_to_tensor(labels):
    labels_tensor = torch.zeros(len(labels), 3)  # Three classes: abnormal, ACL, meniscus
    for i, label_dict in enumerate(labels):
        labels_tensor[i, 0] = label_dict['abnormal']
        labels_tensor[i, 1] = label_dict['acl']
        labels_tensor[i, 2] = label_dict['meniscus']
    return labels_tensor


Here’s a detailed explanation of how the code works when the model is configured with the following parameters:

- **`conv_layers=3`**: Three convolutional layers.
- **`filters=[8, 16, 32]`**: Number of filters in each convolutional layer.
- **`pooling="avg"`**: Average pooling is applied after convolution.
- **`activation="relu"`**: ReLU activation is used after convolution.
- **`optimizer_type="adam"`**: Adam optimizer is used for training.

We’ll walk through the process from the perspective of a 3D image being inputted into the model and moving through its components.

---

### **1. Input (3D Image)**
- **Input Shape**: A single input MRI scan is represented as a tensor of shape `(Batch Size, Channels, Depth, Height, Width)`.
  - Example: For one image in the batch: `(1, 3, 48, 256, 256)`
    - **Batch Size = 1**: One image in the batch.
    - **Channels = 3**: RGB or similar multi-channel format.
    - **Depth = 48**: Number of slices along the depth (z-axis).
    - **Height = 256**: Vertical size of the image.
    - **Width = 256**: Horizontal size of the image.

The input represents a 3D volume with three channels.

---

### **2. First Convolutional Layer (Conv3D)**
- **Operation**: A 3D convolution is applied.
  - Input Shape: `(1, 3, 48, 256, 256)` (3 channels from the input).
  - Filters: The first layer uses `8 filters`, so the output will have 8 channels.
  - Kernel Size: `(3, 3, 3)` (default).
  - Padding: Adds padding to preserve spatial dimensions.

- **Output Shape**: After convolution:
  - `(1, 8, 48, 256, 256)` (8 channels).

- **Activation**: ReLU activation is applied to introduce non-linearity:
  - Any negative values in the output are set to 0.

- **Pooling**: Average pooling reduces the spatial dimensions by a factor of 2:
  - `(1, 8, 24, 128, 128)`.

---

### **3. Second Convolutional Layer (Conv3D)**
- **Operation**: A second 3D convolution is applied to the output of the first layer.
  - Input Shape: `(1, 8, 24, 128, 128)` (8 channels from the first layer).
  - Filters: The second layer uses `16 filters`.
  - Kernel Size: `(3, 3, 3)`.
  - Padding: Preserves spatial dimensions.

- **Output Shape**: After convolution:
  - `(1, 16, 24, 128, 128)`.

- **Activation**: ReLU activation is applied again.

- **Pooling**: Average pooling reduces spatial dimensions by a factor of 2:
  - `(1, 16, 12, 64, 64)`.

---

### **4. Third Convolutional Layer (Conv3D)**
- **Operation**: A third 3D convolution is applied.
  - Input Shape: `(1, 16, 12, 64, 64)` (16 channels from the second layer).
  - Filters: The third layer uses `32 filters`.
  - Kernel Size: `(3, 3, 3)`.
  - Padding: Preserves spatial dimensions.

- **Output Shape**: After convolution:
  - `(1, 32, 12, 64, 64)`.

- **Activation**: ReLU activation is applied.

- **Pooling**: Average pooling reduces spatial dimensions by a factor of 2:
  - `(1, 32, 6, 32, 32)`.

---

### **5. Flattening**
- **Operation**: The output from the final convolutional layer is flattened.
  - Input Shape: `(1, 32, 6, 32, 32)` (32 channels, depth = 6, height = 32, width = 32).
  - Flattened Size: Computed as \( 32 \times 6 \times 32 \times 32 = 196,608 \).

- **Output Shape**: The flattened tensor becomes a vector of size `(1, 196,608)`.

---

### **6. Fully Connected Layer**
- **Operation**: The flattened vector is passed through a fully connected (dense) layer for classification.
  - Input Shape: `(1, 196,608)`.
  - Output Shape: `(1, 3)` (one logit for each class: `abnormal`, `acl`, `meniscus`).

---

### **7. Output**
- **Output Logits**: The output is a raw score (logit) for each of the 3 classes. These logits represent unnormalized probabilities.

---

### **8. Loss Computation**
- **Loss Function**: Binary Cross-Entropy with Logits Loss (`BCEWithLogitsLoss`) is applied.
  - Converts the logits to probabilities using the sigmoid function:
    - \( \text{Sigmoid}(x) = \frac{1}{1 + e^{-x}} \).
  - Compares the predicted probabilities to the ground truth multi-label targets.
  - Example Target: `[1, 0, 1]` (indicating abnormalities and meniscus tear).

---

### **9. Optimizer (Adam)**
- **Optimizer**: The Adam optimizer updates the model weights during backpropagation using the gradients computed from the loss.

---

### **10. Training Process**
- **Steps**:
  1. Input a batch of 3D images and labels.
  2. Compute predictions by passing the input through the model.
  3. Calculate loss between predictions and true labels.
  4. Backpropagate to compute gradients.
  5. Update model weights using the Adam optimizer.
  6. Repeat for all batches in the training set.

---

### **11. Validation Process**
- Similar to training but without updating weights (model evaluation mode).
- Computes metrics like accuracy, precision, recall, F1 score, and log loss for each class.

---

### **Summary of the Workflow**
1. **Input**: 3D MRI image of shape `(Batch Size, 3, Depth, Height, Width)`.
2. **Convolution Layers**: Extract hierarchical features with increasing filter depth (8 → 16 → 32).
3. **Pooling**: Reduces spatial dimensions while retaining important features.
4. **Flattening**: Converts the 3D feature map into a 1D vector.
5. **Fully Connected Layer**: Maps the flattened features to 3 binary outputs (classes).
6. **Output**: Raw logits for each class, converted to probabilities for multi-label classification.
7. **Training**: Minimizes the Binary Cross-Entropy loss using the Adam optimizer.

This setup processes 3D medical images to classify them into three possible categories (`abnormal`, `acl`, `meniscus`) while allowing flexible configuration of layers and training parameters.
