### Build TensorFlow Dataset

In [None]:
# !pip uninstall tensorflow -y
# !python3 -m venv .venv
# !source .venv/bin/activate
# !pip install tensorflow==2.15.0
# !pip install pandas pydicom scikit-learn seaborn nbformat

# The cuXXX (CUDA) wheels are Linux-only, NVIDIA GPU-only
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# Use the command from the PyTorch website for macOS!
# !pip install torch torchvision torchaudio

#### Set Parameters

In [None]:
# Set TensorFlow logging level to suppress warnings
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import torch
import ast
import pydicom
# from tensorflow.data import AUTOTUNE

#  Check TensorFlow version and GPU availability
print("TensorFlow version:", tf.__version__)
print("Num GPUs Available:", len(tf.config.list_physical_devices('GPU')))
print("GPU Devices:", tf.config.list_physical_devices('GPU'))

# Global configuration
INPUT_SHAPE = (224, 224, 1)  # (512, 512, 1)
TARGET_SIZE = INPUT_SHAPE[:2]

In [None]:
# Check PyTorch version and GPU availability
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
# Never use the cuXXX index for macOS!
# Use the official PyPI source, or the command from the PyTorch website.
# print("Current device:", torch.cuda.current_device())
# print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

#### Helper Functions

In [None]:
# DICOM Loader
# Load and normalize a DICOM image from a byte string path
def load_dicom_image(path_tensor):
    path = path_tensor.decode('utf-8')  # Decode byte string to UTF-8
    try:
        ds = pydicom.dcmread(path)
        img = ds.pixel_array.astype(np.float32)
        img -= np.min(img)
        img /= (np.max(img) + 1e-6)  # normalize to [0,1]
    except Exception as e:
        print(f"[DICOM ERROR] {path}: {e}")
        img = np.zeros(TARGET_SIZE, dtype=np.float32)
    return img

# TensorFlow Wrappers
# Load and preprocess a single full mammogram image
def tf_load_dicom(path):
    # img = tf.numpy_function(load_dicom_image, [path], tf.float32)
    img = tf.numpy_function(func=load_dicom_image, inp=[path], Tout=tf.float32)
    img.set_shape([None, None])  # initially 2D
    img = tf.expand_dims(img, axis=-1)  # [H, W, 1]
    img.set_shape([None, None, 1])
    img = tf.image.resize(img, TARGET_SIZE)
    return img

def tf_load_multiple_dicom(paths):
    # paths: tf.Tensor of shape [N] (string paths)
    def load_single(path):
        img = tf.numpy_function(load_dicom_image, [path], tf.float32)
        img.set_shape([None, None])
        img = tf.expand_dims(img, axis=-1)
        img.set_shape([None, None, 1])
        img = tf.image.resize(img, TARGET_SIZE)
        return img

    masks = tf.map_fn(
        load_single,
        paths,
        fn_output_signature=tf.TensorSpec(shape=(TARGET_SIZE[0], TARGET_SIZE[1], 1), dtype=tf.float32)
    )
    return tf.reduce_max(masks, axis=0)  # union of all masks

# Unified MTL Preprocessor
# Load and preprocess multiple ROI masks and combine into a single mask tensor
def load_and_preprocess(image_path, mask_paths, label):
    image = tf_load_dicom(image_path)  # (512, 512, 1)
    mask = tf_load_multiple_dicom(mask_paths)  # (512, 512, 1)
    label = tf.cast(label, tf.float32)
    return image, {"segmentation": mask, "classification": label}

# Parse a dictionary record into image + MTL target dict
def parse_record(record):
    image_path = record['image_path']
    mask_paths = record['mask_paths']
    label = record['label']

    image, target = load_and_preprocess(image_path, mask_paths, label)
    return image, target

# Build tf.data.Dataset from metadata CSV
def build_tf_dataset(
    metadata_csv: str,
    batch_size: int = 8,
    shuffle: bool = True
) -> tf.data.Dataset:

    # Load metadata CSV
    df = pd.read_csv(metadata_csv)

    # Parse stringified list of mask_paths
    df['mask_paths'] = df['mask_paths'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])

    # Ensure label column is float32-compatible (e.g., 0.0, 1.0)
    df['label'] = df['label'].astype(np.float32)

    # Convert to list of dicts
    records = df[['image_path', 'mask_paths', 'label']].to_dict(orient='records')

    # Create dataset
    ds = tf.data.Dataset.from_generator(
        lambda: (r for r in records),
        output_signature={
            "image_path": tf.TensorSpec(shape=(), dtype=tf.string),
            "mask_paths": tf.TensorSpec(shape=(None,), dtype=tf.string),
            "label": tf.TensorSpec(shape=(), dtype=tf.float32),
        }
    )

    # Apply MTL-compatible mapping function
    ds = ds.map(lambda r: parse_record(r), num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(buffer_size=len(records))

    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds

The resulting ds is a complete TensorFlow dataset — (image, {"segmentation": mask, "classification": label}), normalized, resized, shuffled, batched, ready for model training and validation.

### Explore the Resulting Dataset

In [None]:
# Build dataset
ds = build_tf_dataset(
    metadata_csv="../data/processed/cbis_ddsm_metadata_full.csv",
    batch_size=8
)

# Preview one batch
for images, targets in ds.take(1):
    print(f"Images batch shape: {images.shape}")  # (8, 224, 224, 1)
    print(f"Masks batch shape: {targets['segmentation'].shape}")     # (8, 224, 224, 1)
    print(f"Labels batch shape: {targets['classification'].shape}")  # (8,)

In [None]:
import matplotlib.pyplot as plt

for images, targets in ds.take(1):
    num_examples = 3  # Number of samples to visualize

    plt.figure(figsize=(num_examples * 3, 6))

    for i in range(num_examples):
        # Plot image
        plt.subplot(3, num_examples, i + 1)
        plt.imshow(images[i, ..., 0], cmap="gray")
        plt.axis("off")
        plt.title(f"Image {i+1}")

        # Plot mask
        plt.subplot(3, num_examples, num_examples + i + 1)
        plt.imshow(targets['segmentation'][i, ..., 0], cmap="gray")
        plt.axis("off")
        plt.title(f"Mask {i+1}")

        # Display the label
        plt.subplot(3, num_examples, 2 * num_examples + i + 1)
        label = targets['classification'][i].numpy()
        plt.text(0.5, 0.5, str(label), fontsize=16, ha='center', va='center')
        plt.axis("off")
        plt.title(f"Label {i+1}")

    plt.tight_layout()
    plt.show()

### Develop a Baseline Sequential CNN Classification Model

Das et al. (2023) provide a comprehensive overview of the architecture and training process for deep learning-based breast cancer classification, presenting a pipeline that can be adapted to a variety of datasets and tasks. Their framework delineates two primary CNN strategies: **Approach 1 (Shallow CNN)** and **Approach 2 (Deep CNN)**. In this work, we begin by implementing the shallow CNN approach as depicted in their proposed workflow.

As a first step, we **develop a baseline convolutional neural network (CNN) model consisting of an encoder and a classification head only**. This model will serve as a foundational benchmark, using a series of convolutional and pooling layers followed by fully connected layers for binary (or multiclass) classification, without any segmentation or auxiliary outputs. Establishing such a baseline is critical for objectively evaluating the impact of subsequent model enhancements.

The **“Shallow CNN” approach** in the diagram is a progressive three-part strategy aimed at incrementally increasing model robustness and generalization:
- Part 1: CNN with 2 Convolutional Layers
- Part 2: CNN with 2 Conv Layers + Dropout
- Part 3: CNN with 2 Conv Layers + Data Augmentation

For all three parts, the training pipeline begins with data collection (for example, from public datasets such as CBIS-DDSM and INbreast), followed by pre-processing steps including image resizing and partitioning into training and testing sets. The chosen shallow CNN model is then trained on the prepared data, with hyperparameters fine-tuned as needed to achieve the desired accuracy. Performance is evaluated using standard metrics, and the architecture can be iteratively refined based on results.

**In summary,** starting with the shallow CNN path and a simple encoder-classifier model enables rapid prototyping and establishes a robust baseline. It also provides valuable insights into the data and the classification task, which can inform the design and tuning of deeper or more complex models in subsequent experiments.

### Part 1: CNN with 2 Convolutional Layers
  This baseline model consists of two convolutional layers followed by pooling, a dense layer, and an output layer. It serves as the simplest deep learning architecture in the pipeline and acts as a benchmark for evaluating further enhancements.

#### Import Libraries

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.data import AUTOTUNE

import wandb

# Ensure the models directory exists
import os
model_dir = os.path.abspath("../models")
os.makedirs(model_dir, exist_ok=True)


#### Build and compile a shallow CNN model as a baseline

Compile the model with an appropriate loss function and optimizer, and train it on the dataset. The model will be evaluated on a validation set to monitor performance metrics such as accuracy and loss.

In [None]:
# Build a shallow CNN model as a baseline
def build_shallow_cnn(input_shape=INPUT_SHAPE, num_classes=1):
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(num_classes, activation='sigmoid')
    ])
    return model

# Build and compile the model
model = build_shallow_cnn(INPUT_SHAPE)
model.compile(
    optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=1e-4),
    loss='binary_crossentropy',
    metrics=[
        'accuracy',
        tf.keras.metrics.AUC(name='auc'),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall')
    ]
)
model.summary()

#### Split and Build a TensorFlow Dataset

In [None]:
# Load metadata and split
metadata = pd.read_csv("../data/processed/cbis_ddsm_metadata_full.csv")
train_meta, val_meta = train_test_split(
    metadata, test_size=0.2, stratify=metadata['label'], random_state=42
)
train_meta.to_csv("../temporary/train_split.csv", index=False)
val_meta.to_csv("../temporary/val_split.csv", index=False)

# Build datasets
train_ds = build_tf_dataset(metadata_csv="../temporary/train_split.csv", batch_size=8)
val_ds = build_tf_dataset(metadata_csv="../temporary/val_split.csv", batch_size=8)
train_ds = train_ds.map(lambda x, y: (x, y["classification"])).prefetch(AUTOTUNE)
val_ds = val_ds.map(lambda x, y: (x, y["classification"])).prefetch(AUTOTUNE) 

#### Set up Weights & Biases for experiment tracking

- Install Weights & Biases (wandb) for experiment tracking and visualization. 
  - This tool will help us log metrics, visualize model performance, and manage experiments effectively.
```sh
pip install wandb
```
- Login to wandb: 
    - We need to log in by pasting an API key for the first time we use wandb.
```sh
wandb login
```

#### Train the Model

Train the model using the training dataset, and validate it using the validation dataset. Monitor the training and validation loss and accuracy to ensure the model is learning effectively without overfitting.

In [None]:
# Initialize Weights & Biases for experiment tracking
WandbMetricsLogger = wandb.keras.WandbMetricsLogger
WandbModelCheckpoint = wandb.keras.WandbModelCheckpoint

wandb.init(project="baseline_part_1_cnn_2_conv_layers", config={
    "batch_size": 8,
    "epochs": 20,
    "optimizer": "Adam",
    "learning_rate": 1e-4,
    "architecture": "shallow_cnn"
})

# Train model with callbacks
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=20,
    callbacks=[
        WandbMetricsLogger(),
        WandbModelCheckpoint(filepath="./models/best_model_epoch.keras", monitor="val_loss", save_best_only=True),  
        # os.path.join(model_dir, "best_model_epoch{epoch:02d}.keras"), 
        EarlyStopping(patience=5, restore_best_weights=True, monitor="val_loss")
    ]
)

Epoch 5/20


#### Integrate wandb in Python Code

In [None]:
# import wandb

# # Initialize a new run
# wandb.init(project="your-project-name")

# # Example: log hyperparameters
# config = wandb.config
# config.learning_rate = 0.001
# config.batch_size = 32

# # During/after training, log metrics
# wandb.log({'accuracy': 0.85, 'loss': 0.3})

# # End the run (optional, wandb does this automatically on script exit)
# wandb.finish()

#### Save history to CSV/Pickle

In [None]:
# Save history to CSV
import pandas as pd
filename = "../results/history/baseline-part-1-cnn-2-conv-layers.csv"
# Convert history to DataFrame
df = pd.DataFrame(history.history)
# Save to CSV
df.to_csv(filename, index=True)
# Load history from CSV
# loaded_history = pd.read_csv(filename, index_col=0)

In [None]:
import pickle
filename = "../results/history/baseline-part-1-cnn-2-conv-layers.pkl"
# Save history
with open(filename, 'wb') as f:
    pickle.dump(history.history, f)
# To load later
# with open(filename, 'rb') as f:
#     hist_dict = pickle.load(f)

#### Visualize Training History

In [None]:
import matplotlib.pyplot as plt

plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Assume 'history' is the returned object from model.fit()
# Example: history = model.fit(...)

# Plot training & validation accuracy, AUC, precision, and recall
metrics = ['accuracy', 'auc', 'precision', 'recall']

plt.figure(figsize=(12, 10))
for idx, metric in enumerate(metrics, 1):
    plt.subplot(2, 2, idx)
    plt.plot(history.history[metric], label='Train')
    plt.plot(history.history['val_' + metric], label='Validation')
    plt.title(metric.capitalize())
    plt.xlabel('Epoch')
    plt.ylabel(metric.capitalize())
    plt.legend()

plt.tight_layout()
plt.show()

Loss serves as an indicator of model learning: while a lower loss reflects better model fit, a plateau or persistently high loss signals limited learning capacity. In this case, the shallow model can grasp basic patterns but quickly reaches its performance ceiling. To achieve further reductions in loss—and corresponding gains in accuracy and recall—a deeper architecture or improved data strategies are necessary.

To understand the behavior and limitations of the shallow CNN model, we visualize performance metrics—including **accuracy**, **AUC (Area Under the ROC Curve)**, **precision**, and **recall**—for both the training and validation datasets across all epochs. By tracking these metrics throughout training, we gain a comprehensive view of the model’s overall performance as well as its ability to correctly identify each class. This approach highlights not only how well the model distinguishes between positive and negative cases but also its strengths and weaknesses in real-world clinical contexts.

Interpretation of the Plots

- **Accuracy & AUC:** Both accuracy and AUC increased gradually over time, reflecting the model’s ability to learn basic distinctions between classes. However, validation AUC started near 0.59 and only reached about 0.71 by the end of training (with training AUC slightly higher at 0.82). These values fall short of the 0.8–0.85+ threshold generally expected for clinically useful screening models. Similarly, validation accuracy plateaued around 0.65, which is better than random guessing (0.5 for binary classification) but still not adequate for high-stakes medical applications.

- **Precision:** Validation precision showed significant fluctuations, at times exceeding 0.7 but dropping lower in other epochs. This inconsistency points to the model’s variable confidence in positive predictions. In most epochs, precision was higher than recall, indicating the model preferred to make fewer positive predictions—acting cautiously to avoid false positives, but at the expense of missing true positives.

- **Recall** was especially unstable, ranging from very low values (as low as 0.07) up to just above 0.6 in certain epochs. This volatility suggests that the model sometimes effectively “gave up” on detecting positives, which could be due to class imbalance or other optimization challenges. Critically, persistently low recall means that the model misses a large proportion of true positive cases (e.g., actual cancers), which is unacceptable in the context of medical diagnosis.

- **Summary**: 

* The shallow CNN is capable of learning some meaningful features, but its performance **saturates quickly**, never advancing beyond moderate accuracy and AUC.
* **Recall remains a major weakness:** even at its peak, it fails to reliably identify enough positive cases, severely limiting clinical applicability.
* To address these shortcomings, it is recommended to **adopt a deeper, more expressive model architecture** and/or apply **more advanced data processing techniques** such as stronger augmentation, handling class imbalance, or experimenting with alternative loss functions.

While the shallow CNN provides a useful learning baseline, substantial improvements are necessary to reach clinically relevant performance, particularly in terms of recall and robust, consistent predictions.

### Part 2: CNN with 2 Conv Layers + Dropout
  To mitigate overfitting and improve the model’s ability to generalize, a Dropout layer is added after the dense layer. Dropout randomly deactivates a proportion of neurons during training, preventing the network from relying too heavily on specific features.

In [None]:
def build_shallow_cnn_dropout(input_shape=INPUT_SHAPE, num_classes=2):
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dropout(0.5),  # Dropout layer added
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

### Part 3: CNN with 2 Conv Layers + Data Augmentation
  Building on Part 2, data augmentation techniques are introduced during the training phase. Methods such as random rotations, translations, zooms, and horizontal flips are applied to the input images, artificially increasing the diversity of the training set and further reducing the risk of overfitting.

We apply data augmentation to your data pipeline to enhance the model's robustness and generalization capabilities. This involves applying transformations such as random rotations, shifts, zooms, and flips to the training images, which helps the model learn invariant features and improves its performance on unseen data.

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Example of setting up data augmentation for training data
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

# For validation/test, only rescale
val_datagen = ImageDataGenerator(rescale=1./255)

# Example:
# train_generator = train_datagen.flow_from_directory(
#     'data/train',
#     target_size=(224, 224),
#     batch_size=32,
#     class_mode='categorical',  # or 'binary'
#     color_mode='grayscale',    # if images are grayscale
# )

# Use the same model as Part 2 (with dropout)
model = build_shallow_cnn_dropout(input_shape=INPUT_SHAPE, num_classes=2)


### Original Code

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Use new Keras format for saving
checkpoint_cb = ModelCheckpoint("../models/best_model.keras", save_best_only=True, monitor="val_loss")
earlystop_cb = EarlyStopping(patience=5, restore_best_weights=True, monitor="val_loss")

metadata = pd.read_csv("../data/processed/cbis_ddsm_metadata_full.csv")
train_meta, val_meta = train_test_split(
    metadata, test_size=0.2, stratify=metadata['label'], random_state=42
)

# Save to new CSV files
train_meta.to_csv("../temporary/train_split.csv", index=False)
val_meta.to_csv("../temporary/val_split.csv", index=False)

# Use your build_tf_dataset as before
train_ds = build_tf_dataset(metadata_csv="../temporary/train_split.csv", batch_size=8)
val_ds = build_tf_dataset(metadata_csv="../temporary/val_split.csv", batch_size=8)

# Keep only the classification label
train_ds = train_ds.map(lambda x, y: (x, y["classification"]))
val_ds = val_ds.map(lambda x, y: (x, y["classification"]))

train_ds = train_ds.prefetch(AUTOTUNE)
val_ds = val_ds.prefetch(AUTOTUNE)

# Build the model
model = build_classification_model(INPUT_SHAPE)

# Train the model
model.compile(
    optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-4), 
    loss="binary_crossentropy", 
    metrics=[
        'accuracy',
        tf.keras.metrics.AUC(name='auc'),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall')
    ]
)

# Proceed with model.fit
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=20,
    callbacks=[checkpoint_cb, earlystop_cb]
)

### Evaluating and Plotting Metrics

In [None]:
import matplotlib.pyplot as plt

def plot_training_history(history):
    metrics = ['accuracy', 'val_accuracy', 'auc', 'val_auc', 'precision', 'val_precision', 'recall', 'val_recall']
    for metric in metrics:
        if metric in history.history:
            plt.plot(history.history[metric], label=metric)
    plt.xlabel('Epochs')
    plt.ylabel('Metric')
    plt.legend()
    plt.title('Training & Validation Metrics')
    plt.show()

plot_training_history(history)

In [None]:
results = model.evaluate(val_ds)
print(dict(zip(model.metrics_names, results)))

In [None]:
import numpy as np

# Collect all true labels and predictions
y_true = []
y_pred = []

for x_batch, y_batch in val_ds:
    # Predict probabilities
    probs = model.predict(x_batch)
    # For binary, threshold at 0.5
    preds = (probs.flatten() > 0.5).astype(int)
    y_true.extend(y_batch.numpy().astype(int))
    y_pred.extend(preds)
    
y_true = np.array(y_true)
y_pred = np.array(y_pred)


In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

cm = confusion_matrix(y_true, y_pred)
labels = ["BENIGN", "MALIGNANT"]

plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

cm = confusion_matrix(y_true, y_pred)
labels = ["BENIGN", "MALIGNANT"]

plt.figure(figsize=(6, 5))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(labels))
plt.xticks(tick_marks, labels)
plt.yticks(tick_marks, labels)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")

# Annotate values
thresh = cm.max() / 2
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.show()
