# 🧪 CNN Classifier for Waveform Image Data

In this notebook, you will build and fine-tune a **Convolutional Neural Network (CNN)** to classify medical waveform images — such as plethysmography or ECG lead II traces.

We use a **pretrained ResNet18** model and customize it to learn from a small dataset representing different classes:  
e.g.
- `as`      (Traces from patients with Aortic Stenosis)  
- `control` (Traces from control patients)  

This project introduces key machine learning concepts:
- Preparing image data for training
- Using file naming conventions for auto-labeling
- Creating a robust train/val/test folder structure
- Fine-tuning a CNN using transfer learning
- Tracking model performance with accuracy, loss, and confusion matrices

> 🎯 **Goal**: Understand how a CNN can learn from medical image patterns — even in small datasets — and evaluate model performance with appropriate metrics.

You can experiment by:
- Adding your own image classes
- Modifying batch size, learning rate, or architecture
- Inspecting misclassified examples

Record any changes you make and the effect on the accuracy of your model. 

## 🧠 A note on CNNs

A **Convolutional Neural Network** is a type of neural network designed to automatically learn spatial patterns in images. It applies filters (like edge detectors or curve recognizers) at multiple layers and learns hierarchical features — from pixels to shapes to complex patterns.

## 🏗️ Our Model: ResNet18

We use a **pretrained ResNet18 model**, which was originally trained on **ImageNet** — a massive dataset of **over 1 million images** across **1,000 classes** (e.g., cats, planes, trees, tools, etc.).

ResNet18 uses a deep architecture with **residual blocks**, which help it learn efficiently even when the network is deep (18 layers). This architecture has become a standard in medical imaging research.

## 🛠️ What We're Doing Here

We keep the powerful pretrained layers — which already know how to detect general visual features (edges, curves, textures) — and **only modify the final classification layer** to output **2 classes** instead of 1,000.

This approach is called **transfer learning**, specifically **fine-tuning**:
- We first freeze the pretrained layers and only train the final layer on our waveform data.
- Then we "unfreeze" the model and continue training the whole network gently to adjust to our task.

We are modifing a general purpose image recognizer into a specialized waveform classifier — with limited clinical data.

# 🧭 Getting Started: Imports and Path Setup

This section loads all the Python libraries you'll need — including `torch`, `torchvision`, and `sklearn`.

We also define the base folders for your project. Make sure:
- Your waveform images are placed inside the `input/` folder.
- Each image filename begins with its class eg `as_image001.jpg`.  
- The `output/` folder is where your preprocessed image files will be saved in training, validation, and test subfolders.
- Each subfolder will have further subfolders automatically named for each class.  
- The `output/` folder is also where the trained model weights will be saved.  

⚠️ Make sure `base_path` points to the root of your project on your machine.

In [None]:
import os
import shutil
from pathlib import Path

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

base_path = Path.cwd()
input_folder = base_path / "input"
output_folder = base_path / "output"



# 🗂️ Understanding Your Data 

We scan the `input/` folder and list the available files.  
We create a list of files in the variable `image_files`

In [None]:
image_extensions = [".jpg", ".jpeg", ".png"]

image_files = []
for f in input_folder.glob("*"):
    try:
        if f.suffix.lower() in image_extensions and f.is_file():
            image_files.append(f)
    except Exception as e:
        print(f"⚠️ Skipped {f.name}: {e}")

print(f"Found {len(image_files)} images:")
for file in image_files[:10]:  # Only print first 10
    print(f"File: {file.name}")


# 🧩 Map Class Names & Integer Labels

In this step, we list the class names we'll be working with (e.g. `as`, `control`) and automatically assign each one a unique integer label starting from 0.
We create a dictionary called `class_dict` that maps each class name.  
This mapping is important because neural networks require numerical labels — not text — for classification.

💡 For example:
- `"as"` → `0`
- `"control"` → `1`

These will be used during model training and when building your folder structure.  
You will also refer to this map at inference - when you run your trained model - to decode the model output back into clinically recognisable classes. 


In [None]:

class_names = ["as", "control"]  # You may need to change classes based on your project
num_classes = len(class_names)

# 📋 Prepare key and value lists
class_keys = []
class_values = []

# 🔢 Manually assign variables like class_key0 = "as", class_value0 = 0, etc.
for i in range(len(class_names)):
    key_var = f"class_key{i}"
    value_var = f"class_value{i}"

    key = class_names[i]
    value = i

    class_keys.append(key)
    class_values.append(value)

class_dict = dict(zip(class_keys, class_values))

# 🖨️ Preview the result
print("class_dict =", class_dict)

# Organize Image Files by Class

We iterate through the `image_files` list and assign each image to its class based on filename prefix (e.g. `as_001.png` → `"as"`).  
We create a dictionary called `class_files` that holds a list of image files for each class, grouped by their filename prefix.  
We will later split these into training, validation, and test sets.


In [None]:
# 🗃️ Create a dictionary to hold files for each class
class_files = {class_key: [] for class_key in class_keys}

# 🔁 Loop through all images and assign to the correct class list
for file in image_files: # We created this list at the very beginning.
    for class_key in class_keys:
        if file.name.lower().startswith(class_key):
            class_files[class_key].append(file)
            break  # Stop after first match

# 🖨️ Summary
for class_key, files in class_files.items():
    print(f"\n📁 {class_key.upper()} ({len(files)} files):")
    for f in files[:3]:  # show first 3 files per class
        print(f"  {f.name}")

# 🧼 Preprocess, Resize, and Split Data into Train/Val/Test

In this step, we:

- **Standardize image size and format**  
   We define a function called `standardise_image()` that loads each image, converts it to RGB, and resizes it to 224×224 pixels.  
   This is the format expected by ResNet18.

- **Create the output subfolders required by our classes.**  
    The CNN expects train, test and validate folders.  
    We check the number and names of the classes.  
    We create child folders named for each class in each of the train, test and validate folders

- **Split into train/val/test and save**  
   We split the files 80% for training, 10% for validation, and 10% for testing.  
   We save the standardised images to a structured `output/` directory.



In [None]:

def standardise_image(path, size=(224, 224)):
    """
    Open an image, convert to RGB, resize to target size using PIL.
    """
    img = Image.open(path).convert("RGB")
    img = img.resize(size)
    return img


# 🗂️ Define where to place the processed files
output_folder = output_folder
splits = ["train", "val", "test"]
split_ratios = (0.8, 0.1, 0.1)  # 8:1:1

# 📂 Preview and create the required folder structure
print("📦 Creating output folder structure:")
for split in splits:
    for cls in class_keys:
        dest_dir = output_folder / split / cls
        dest_dir.mkdir(parents=True, exist_ok=True)
        print(f"  {Path(split) / cls}")

# 🔀 Split and copy the files for each class
for cls in class_keys:
    files = class_files[cls]
    
    # Train/val/test split using sklearn
    train_files, temp_files = train_test_split(files, test_size=(1 - split_ratios[0]), random_state=42)
    val_files, test_files = train_test_split(temp_files, test_size=0.5, random_state=42)

    split_map = {
        "train": train_files,
        "val": val_files,
        "test": test_files
    }

    # 📥 Copy each file into its split/class folder
    for split, split_files in split_map.items():
        for src_path in split_files:
            dst_path = output_folder / split / cls / src_path.name
            standardise_image(src_path).save(dst_path)

# ✅ Done
print("\n✅ Images successfully preprocessed and saved into clild folders in `output/` for fine-tuning.")

# 🧠 Model Setup: ResNet18 for Classification

We use a pre-trained ResNet18 model — a deep convolutional network that has already learned to recognize general features like edges, textures, and shapes.

We'll:
- Freeze the base layers
- Replace the final classification layer
- Train it on our waveform images (e.g., plethysmography or ECG)

In [None]:
# ---------------------------
# 🔧 HYPERPARAMETERS
# ---------------------------

num_classes = num_classes

train_batch_size = 8
val_batch_size = 32
initial_lr = 0.0001
momentum = 0.9
frozen_epochs = 6
finetune_epochs = 20

train_path = output_folder / "train"
val_path = output_folder / "val"
ft_model_path = output_folder
model_weights_filename = 'model_weights.pth'

# ---------------------------
# 🖼️ TRANSFORM
# ---------------------------

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# ---------------------------
# 📁 DATASETS & LOADERS
# ---------------------------

train_dataset = ImageFolder(train_path, transform=transform)
val_dataset = ImageFolder(val_path, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)

print("🧭 Class to index mapping:", train_dataset.class_to_idx)

# ---------------------------
# 🧠 MODEL INIT
# ---------------------------

model = models.resnet18(weights='ResNet18_Weights.IMAGENET1K_V1')

for param in model.parameters():
    param.requires_grad = False

model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

device = torch.device("mps" if torch.backends.mps.is_available()
                      else "cuda" if torch.cuda.is_available()
                      else "cpu")
print(f"Using device: {device}")
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=initial_lr, momentum=momentum)

# 🔁 Model Training

We train the model in **two phases**:
1. **Frozen base** — only the final layer is trained for a few epochs.
2. **Fine-tuning** — all layers are unfrozen and trained further to adapt the model more fully to our dataset.

Training and validation losses and accuracies are printed per epoch.

📈 This allows us to track whether the model is overfitting (memorizing training data) or generalizing well.

In [None]:
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        train_loss = running_loss / len(train_dataset)
        train_acc = running_corrects.float() / len(train_dataset)

        model.eval()
        running_loss = 0.0
        running_corrects = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
        val_loss = running_loss / len(val_dataset)
        val_acc = running_corrects.float() / len(val_dataset)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc.item())
        val_accuracies.append(val_acc.item())

        print(f"Epoch [{epoch+1}/{num_epochs}] → "
              f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.2%} | "
              f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.2%}")

    return train_losses, val_losses, train_accuracies, val_accuracies



# Freeze phase
metrics1 = train(model, train_loader, val_loader, criterion, optimizer, num_epochs=frozen_epochs)

# Unfreeze and retrain
for param in model.parameters():
    param.requires_grad = True
optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr, momentum=momentum)
metrics2 = train(model, train_loader, val_loader, criterion, optimizer, num_epochs=finetune_epochs)

print('so far so good')

# 📉 Visualizing Training Progress

Here we plot:
- Training vs Validation **Loss** (how wrong the model was)
- Training vs Validation **Accuracy** (how often it was right)

These plots help you understand:
- Is your model improving?
- Is it overfitting?
- Did it converge?

🚨 Flat lines or diverging curves may signal problems with your data, learning rate, or model capacity.

In [None]:
def plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies, title="Training Curves"):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 5))

    # 📉 Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Train Loss")
    plt.plot(epochs, val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss")
    plt.legend()

    # 📈 Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Train Acc")
    plt.plot(epochs, val_accuracies, label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy")
    plt.legend()

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


# Merge and plot
train_losses = metrics1[0] + metrics2[0]
val_losses = metrics1[1] + metrics2[1]
train_accuracies = metrics1[2] + metrics2[2]
val_accuracies = metrics1[3] + metrics2[3]

plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies)

# ✅ Evaluating Performance

We compute a **confusion matrix** and calculate:
- Accuracy
- Precision
- Recall
- F1 Score

These metrics give a complete picture of how well your model is classifying the different classes, especially when the dataset is imbalanced.

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# 🔍 Evaluate on validation set
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# 📊 Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
class_labels = train_dataset.classes  # ← from ImageFolder, usually alphabetical

# 🎨 Plot heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix (Validation Set)")
plt.tight_layout()
plt.show()

In [None]:
def compute_binary_metrics_from_cm(cm):
    """
    Compute accuracy, precision, recall, and F1 score from a 2x2 confusion matrix.
    Assumes binary classification with cm = [[TP, FN], [FP, TN]] format.
    """
    if cm.shape != (2, 2):
        raise ValueError("This function only supports binary classification (2x2 confusion matrix).")

    tp = cm[0, 0]
    fn = cm[0, 1]
    fp = cm[1, 0]
    tn = cm[1, 1]

    total = tp + fn + fp + tn

    accuracy = (tp + tn) / total if total else 0
    precision = tp / (tp + fp) if (tp + fp) else 0
    recall = tp / (tp + fn) if (tp + fn) else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) else 0

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }

metrics_from_cm = compute_binary_metrics_from_cm(cm)

print("\n📋 Evaluation Metrics from Confusion Matrix:")
for metric, value in metrics_from_cm.items():
    print(f"{metric.capitalize()}: {value:.2f}")

# 🎉 Inference of the new finetuned model

Now that we’ve trained our CNN model, it's time to put it to the test!

Earlier, we carefully set aside a **test group** during preprocessing that the model has **never seen**. This allows us to evaluate how well our model generalizes to new data.

Below, we’ll run inference on each test image and compare the **true class** to the **predicted class** using our fine-tuned ResNet-18.

In [None]:


# 🔁 Use the same transform as training
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 📂 Load test data
test_path = output_folder / "test"
test_dataset = ImageFolder(test_path, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)  # batch_size=1 for per-file printing

# 🏷️ Map class index to label
idx_to_class = {v: k for k, v in test_dataset.class_to_idx.items()}

# 🔍 Run inference and show per-file results
model.eval()
all_preds = []
all_labels = []

print("\n🖼️ Predictions on Test Files:")
with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        true_label = idx_to_class[labels.item()]
        pred_label = idx_to_class[preds.item()]

        # Get image filename
        img_path = test_dataset.samples[i][0]
        filename = Path(img_path).name

        print("-" * 60)
        print(f"{filename}\n  true class = {true_label:<10}   predicted class = {pred_label:<10}")

        all_preds.append(preds.item())
        all_labels.append(labels.item())

# 📊 Summary
test_acc = accuracy_score(all_labels, all_preds)
print(f"\n✅ Test Accuracy: {test_acc:.2%}")
print("\n📄 Classification Report:")
print(classification_report(all_labels, all_preds, target_names=test_dataset.classes))