Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ parts:
- caption: Tutorials
chapters:
- file: tutorials/brain-disorder-diagnosis/notebook
- file: tutorials/cardiac-hemodynamics-assesment/notebook
- file: tutorials/drug-target-interaction/notebook-cross-domain
- file: tutorials/multiomics-cancer-classification/notebook
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
DATA:
ECG_PATH: "/content/drive/MyDrive/EMBC_workshop_data/ecg_features_tensor_last_1000.pt"
CXR_PATH: "/content/drive/MyDrive/EMBC_workshop_data/cxr_features_tensor_last_1000.pt"
CSV_PATH: "/content/drive/MyDrive/EMBC_workshop_data/chexpert_healthy_abnormality_subset.csv"
BATCH_SIZE: 32
NUM_WORKERS: 2

MODEL:
LATENT_DIM: 256
INPUT_IMAGE_CHANNELS: 1
INPUT_DIM_ECG: 60000
NUM_LEADS: 12

FT:
EPOCHS: 10
LR: 0.001
CKPT_PATH: "/content/drive/MyDrive/EMBC_workshop_data/CardioVAE.pth"
ACCELERATOR: "gpu"
DEVICE: "cuda"
KFOLDS: 5
SEED: 42

INTERPRET:
SAMPLE_IDX: 101
ZOOM_RANGE: [3, 3.5]
ECG_THRESHOLD: 0.7
CXR_THRESHOLD: 0.7
SAMPLING_RATE: 500
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
DATA:
ECG_PATH: "/content/drive/MyDrive/EMBC_workshop_data/ecg_features_tensor_1000.pt"
CXR_PATH: "/content/drive/MyDrive/EMBC_workshop_data/cxr_features_tensor_1000.pt"
BATCH_SIZE: 32
NUM_WORKERS: 2

MODEL:
LATENT_DIM: 128
INPUT_DIM_ECG: 60000
INPUT_DIM_CXR: 1
NUM_LEADS: 12

TRAIN:
EPOCHS: 10
LR: 0.001
SEED: 123
DEVICE: "cuda"
DATA_DEVICE: "cpu"
LAMBDA_IMAGE: 1.0
LAMBDA_SIGNAL: 10.0
SCALE_FACTOR: 0.0001
SAVE_PATH: "cardioVAE.pth"
ACCELERATOR: "gpu"
DEVICES: 1
50 changes: 50 additions & 0 deletions tutorials/cardiac-hemodynamics-assesment/finetune_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from yacs.config import CfgNode as CN

_C = CN()

# Data configuration
_C.DATA = CN()
_C.DATA.ECG_PATH = (
"/content/drive/MyDrive/EMBC_workshop_data/ecg_features_tensor_last_1000.pt"
)
_C.DATA.CXR_PATH = (
"/content/drive/MyDrive/EMBC_workshop_data/cxr_features_tensor_last_1000.pt"
)
_C.DATA.CSV_PATH = (
"/content/drive/MyDrive/EMBC_workshop_data/chexpert_healthy_abnormality_subset.csv"
)
_C.DATA.BATCH_SIZE = 32
_C.DATA.NUM_WORKERS = 2
_C.DATA.DATA_DEVICE = "cpu"

# Model configuration
_C.MODEL = CN()
_C.MODEL.LATENT_DIM = 256
_C.MODEL.INPUT_IMAGE_CHANNELS = 1
_C.MODEL.INPUT_DIM_ECG = 60000
_C.MODEL.NUM_LEADS = 12

# Fine-tuning configuration
_C.FT = CN()
_C.FT.EPOCHS = 15
_C.FT.LR = 0.001
_C.FT.HIDDEN_DIM = 128
_C.FT.NUM_CLASSES = 2
_C.FT.CKPT_PATH = "/content/drive/MyDrive/EMBC_workshop_data/CardioVAE.pth"
_C.FT.ACCELERATOR = "gpu"
_C.FT.DEVICES = 1 # This is for PyTorch Lightning's Trainer, set as int not string
_C.FT.DEVICE = "cuda" # For torch.device()
_C.FT.KFOLDS = 5
_C.FT.SEED = 42

# Interpretation configuration
_C.INTERPRET = CN()
_C.INTERPRET.SAMPLE_IDX = 101
_C.INTERPRET.ZOOM_RANGE = [3, 3.5]
_C.INTERPRET.ECG_THRESHOLD = 0.7
_C.INTERPRET.CXR_THRESHOLD = 0.7
_C.INTERPRET.SAMPLING_RATE = 500


def get_cfg_defaults():
return _C.clone()
187 changes: 187 additions & 0 deletions tutorials/cardiac-hemodynamics-assesment/interpret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# interpret.py

import torch
import numpy as np
import neurokit2 as nk
from captum.attr import IntegratedGradients
from scipy.ndimage import binary_dilation


def multimodal_ecg_cxr_attribution(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring. Consider how to integrate this function to kale in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

last_fold_model,
last_val_loader,
sample_idx=0,
ecg_threshold=0.70,
cxr_threshold=0.7,
zoom_range=(3, 3.5),
lead_number=12,
sampling_rate=500,
):
"""
Computes model attributions for multimodal (ECG + CXR) input using Integrated Gradients.

This function selects a sample from the provided validation loader and computes the attributions
(importance scores) for both ECG and CXR modalities using Captum's Integrated Gradients.
It returns all relevant arrays and data needed for downstream visualization, including normalized
attributions, important indices, and segment data for zoomed-in views.

Parameters
----------
last_fold_model : torch.nn.Module
Trained multimodal model that accepts both CXR images and ECG waveforms as input.
last_val_loader : DataLoader
PyTorch DataLoader for the validation dataset. Each batch should yield (CXR, ECG, label).
sample_idx : int, optional
Index of the sample in the validation set to interpret (default is 0).
ecg_threshold : float, optional
Threshold (0-1) to consider ECG attributions as important (default is 0.70).
cxr_threshold : float, optional
Threshold (0-1) to consider CXR attributions as important (default is 0.70).
zoom_range : tuple of float, optional
Start and end (in seconds) for zoomed ECG visualization window (default is (3, 3.5)).
lead_number : int, optional
Number of ECG leads (default is 12).
sampling_rate : int, optional
Sampling rate of the ECG waveform in Hz (default is 500).

Returns
-------
dict
Dictionary containing:
- label : int
True class label for the selected sample.
- predicted_label : int
Model's predicted class for the sample.
- predicted_probability : float
Probability of the predicted class.
- ecg_waveform_np : np.ndarray
1D numpy array of the processed ECG waveform.
- full_time : np.ndarray
Time axis (seconds) for the full ECG.
- full_length : int
Number of time points in the (possibly trimmed) ECG.
- important_indices_full : np.ndarray
Indices in the full ECG considered important by attribution threshold.
- segment_ecg_waveform : np.ndarray
Zoomed ECG segment.
- zoom_time : np.ndarray
Time axis (seconds) for the zoomed ECG segment.
- important_indices_zoom : np.ndarray
Important indices within the zoomed ECG segment.
- zoom_start_sec : float
Start time (seconds) of the zoomed window.
- zoom_end_sec : float
End time (seconds) of the zoomed window.
- xray_image_np : np.ndarray
CXR image as a numpy array.
- x_pts, y_pts : np.ndarray
Coordinates of important points in the CXR image (after dilation).
- importance_pts : np.ndarray
Attribution values at (x_pts, y_pts).
- ecg_threshold : float
The threshold used for ECG attributions.
- cxr_threshold : float
The threshold used for CXR attributions.
"""
# Gather all batches (as in your code)
batches = list(last_val_loader)
all_xray_images, all_ecg_waveforms, all_labels = [
torch.cat(items) for items in zip(*batches)
]

# --- Select Sample ---
xray_image = (
all_xray_images[sample_idx]
.unsqueeze(0)
.to(next(last_fold_model.parameters()).device)
)
ecg_waveform = (
all_ecg_waveforms[sample_idx]
.unsqueeze(0)
.to(next(last_fold_model.parameters()).device)
)
label = all_labels[sample_idx].item()
Comment on lines +86 to +103
Copy link

Copilot AI Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concatenating all validation batches into memory may be expensive for large datasets. Consider indexing directly into the DataLoader's dataset or loading only the required sample to reduce memory usage.

Suggested change
# Gather all batches (as in your code)
batches = list(last_val_loader)
all_xray_images, all_ecg_waveforms, all_labels = [
torch.cat(items) for items in zip(*batches)
]
# --- Select Sample ---
xray_image = (
all_xray_images[sample_idx]
.unsqueeze(0)
.to(next(last_fold_model.parameters()).device)
)
ecg_waveform = (
all_ecg_waveforms[sample_idx]
.unsqueeze(0)
.to(next(last_fold_model.parameters()).device)
)
label = all_labels[sample_idx].item()
# --- Select Sample ---
xray_image, ecg_waveform, label = last_val_loader.dataset[sample_idx]
xray_image = (
xray_image.unsqueeze(0)
.to(next(last_fold_model.parameters()).device)
)
ecg_waveform = (
ecg_waveform.unsqueeze(0)
.to(next(last_fold_model.parameters()).device)
)
label = label.item()

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for now. Rejecting it.


# --- ECG Preprocessing ---
ecg_waveform_1d = all_ecg_waveforms[sample_idx].cpu().numpy().ravel()
ecg_smoothed = nk.ecg_clean(ecg_waveform_1d, sampling_rate=sampling_rate)
ecg_smoothed_tensor = (
torch.tensor(ecg_smoothed.copy(), dtype=torch.float32)
.unsqueeze(0)
.unsqueeze(0)
.to(next(last_fold_model.parameters()).device)
)

# --- Prediction ---
last_fold_model.eval()
with torch.no_grad():
logits = last_fold_model(xray_image, ecg_waveform)
Copy link

Copilot AI Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model forward pass uses the raw ecg_waveform, but IntegratedGradients is applied to ecg_smoothed_tensor. This mismatch can produce incorrect attributions; use the same input tensor in both the prediction and attribution steps.

Suggested change
logits = last_fold_model(xray_image, ecg_waveform)
logits = last_fold_model(xray_image, ecg_smoothed_tensor)

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for prediction. so the Smoothing is for making the visualization better. So, rejecting this copilot suggestion.

probabilities = torch.softmax(logits, dim=1)
predicted_label = torch.argmax(probabilities, dim=1).item()
predicted_probability = probabilities[0, predicted_label].item()

# --- Integrated Gradients ---
integrated_gradients = IntegratedGradients(last_fold_model)
xray_image.requires_grad_(True)
ecg_waveform.requires_grad_(True)
attributions, _ = integrated_gradients.attribute(
inputs=(xray_image, ecg_smoothed_tensor),
target=predicted_label,
return_convergence_delta=True,
)
attributions_xray = attributions[0]
attributions_ecg = attributions[1]

# --- ECG Attribution ---
attributions_ecg_np = attributions_ecg.cpu().detach().numpy().squeeze()
norm_attributions_ecg = (attributions_ecg_np - attributions_ecg_np.min()) / (
attributions_ecg_np.max() - attributions_ecg_np.min() + 1e-8
)
ecg_waveform_np = ecg_smoothed_tensor.cpu().detach().numpy().squeeze()
full_length = min(60000, len(ecg_waveform_np))
full_time = np.arange(0, full_length) / sampling_rate / lead_number
important_indices_full = np.where(
norm_attributions_ecg[:full_length] >= ecg_threshold
)[0]

zoom_start = int(zoom_range[0] * 6000)
zoom_end = int(zoom_range[1] * 6000)
zoom_time = np.arange(zoom_start, zoom_end) / sampling_rate / lead_number
segment_ecg_waveform = ecg_waveform_np[zoom_start:zoom_end]
segment_attributions = norm_attributions_ecg[zoom_start:zoom_end]
important_indices_zoom = np.where(segment_attributions >= ecg_threshold)[0]
zoom_start_sec = zoom_start / sampling_rate / lead_number
zoom_end_sec = zoom_end / sampling_rate / lead_number
Comment on lines +142 to +154
Copy link

Copilot AI Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dividing time indices by lead_number skews the time axis. Time should be computed as np.arange(full_length) / sampling_rate without dividing by the number of leads.

Suggested change
full_time = np.arange(0, full_length) / sampling_rate / lead_number
important_indices_full = np.where(
norm_attributions_ecg[:full_length] >= ecg_threshold
)[0]
zoom_start = int(zoom_range[0] * 6000)
zoom_end = int(zoom_range[1] * 6000)
zoom_time = np.arange(zoom_start, zoom_end) / sampling_rate / lead_number
segment_ecg_waveform = ecg_waveform_np[zoom_start:zoom_end]
segment_attributions = norm_attributions_ecg[zoom_start:zoom_end]
important_indices_zoom = np.where(segment_attributions >= ecg_threshold)[0]
zoom_start_sec = zoom_start / sampling_rate / lead_number
zoom_end_sec = zoom_end / sampling_rate / lead_number
full_time = np.arange(0, full_length) / sampling_rate
important_indices_full = np.where(
norm_attributions_ecg[:full_length] >= ecg_threshold
)[0]
zoom_start = int(zoom_range[0] * 6000)
zoom_end = int(zoom_range[1] * 6000)
zoom_time = np.arange(zoom_start, zoom_end) / sampling_rate
segment_ecg_waveform = ecg_waveform_np[zoom_start:zoom_end]
segment_attributions = norm_attributions_ecg[zoom_start:zoom_end]
important_indices_zoom = np.where(segment_attributions >= ecg_threshold)[0]
zoom_start_sec = zoom_start / sampling_rate
zoom_end_sec = zoom_end / sampling_rate

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, rejecting this suggestion as we divide by the lead to get the values in seconds.


# --- CXR Attribution: Points ---
attributions_xray_np = attributions_xray.cpu().detach().numpy().squeeze()
norm_attributions_xray = (attributions_xray_np - np.min(attributions_xray_np)) / (
np.max(attributions_xray_np) - np.min(attributions_xray_np) + 1e-8
)
xray_image_np = xray_image.cpu().detach().numpy().squeeze()

binary_mask = norm_attributions_xray >= cxr_threshold
dilated_mask = binary_dilation(binary_mask, iterations=1)
y_pts, x_pts = np.where(dilated_mask)
importance_pts = norm_attributions_xray[y_pts, x_pts]

return {
"label": label,
"predicted_label": predicted_label,
"predicted_probability": predicted_probability,
"ecg_waveform_np": ecg_waveform_np,
"full_time": full_time,
"full_length": full_length,
"important_indices_full": important_indices_full,
"segment_ecg_waveform": segment_ecg_waveform,
"zoom_time": zoom_time,
"important_indices_zoom": important_indices_zoom,
"zoom_start_sec": zoom_start_sec,
"zoom_end_sec": zoom_end_sec,
"xray_image_np": xray_image_np,
"x_pts": x_pts,
"y_pts": y_pts,
"importance_pts": importance_pts,
"ecg_threshold": ecg_threshold,
"cxr_threshold": cxr_threshold,
}
Loading