-
Notifications
You must be signed in to change notification settings - Fork 2
Add Heart tutorial #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d9d6ec6
0eee90b
1928b7c
400412c
59782e4
40bdc1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| DATA: | ||
| ECG_PATH: "/content/drive/MyDrive/EMBC_workshop_data/ecg_features_tensor_last_1000.pt" | ||
| CXR_PATH: "/content/drive/MyDrive/EMBC_workshop_data/cxr_features_tensor_last_1000.pt" | ||
| CSV_PATH: "/content/drive/MyDrive/EMBC_workshop_data/chexpert_healthy_abnormality_subset.csv" | ||
| BATCH_SIZE: 32 | ||
| NUM_WORKERS: 2 | ||
|
|
||
| MODEL: | ||
| LATENT_DIM: 256 | ||
| INPUT_IMAGE_CHANNELS: 1 | ||
| INPUT_DIM_ECG: 60000 | ||
| NUM_LEADS: 12 | ||
|
|
||
| FT: | ||
| EPOCHS: 10 | ||
| LR: 0.001 | ||
| CKPT_PATH: "/content/drive/MyDrive/EMBC_workshop_data/CardioVAE.pth" | ||
| ACCELERATOR: "gpu" | ||
| DEVICE: "cuda" | ||
| KFOLDS: 5 | ||
| SEED: 42 | ||
|
|
||
| INTERPRET: | ||
| SAMPLE_IDX: 101 | ||
| ZOOM_RANGE: [3, 3.5] | ||
| ECG_THRESHOLD: 0.7 | ||
| CXR_THRESHOLD: 0.7 | ||
| SAMPLING_RATE: 500 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| DATA: | ||
| ECG_PATH: "/content/drive/MyDrive/EMBC_workshop_data/ecg_features_tensor_1000.pt" | ||
| CXR_PATH: "/content/drive/MyDrive/EMBC_workshop_data/cxr_features_tensor_1000.pt" | ||
| BATCH_SIZE: 32 | ||
| NUM_WORKERS: 2 | ||
|
|
||
| MODEL: | ||
| LATENT_DIM: 128 | ||
| INPUT_DIM_ECG: 60000 | ||
| INPUT_DIM_CXR: 1 | ||
| NUM_LEADS: 12 | ||
|
|
||
| TRAIN: | ||
| EPOCHS: 10 | ||
| LR: 0.001 | ||
| SEED: 123 | ||
| DEVICE: "cuda" | ||
| DATA_DEVICE: "cpu" | ||
| LAMBDA_IMAGE: 1.0 | ||
| LAMBDA_SIGNAL: 10.0 | ||
| SCALE_FACTOR: 0.0001 | ||
| SAVE_PATH: "cardioVAE.pth" | ||
| ACCELERATOR: "gpu" | ||
| DEVICES: 1 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| from yacs.config import CfgNode as CN | ||
|
|
||
| _C = CN() | ||
|
|
||
| # Data configuration | ||
| _C.DATA = CN() | ||
| _C.DATA.ECG_PATH = ( | ||
| "/content/drive/MyDrive/EMBC_workshop_data/ecg_features_tensor_last_1000.pt" | ||
| ) | ||
| _C.DATA.CXR_PATH = ( | ||
| "/content/drive/MyDrive/EMBC_workshop_data/cxr_features_tensor_last_1000.pt" | ||
| ) | ||
| _C.DATA.CSV_PATH = ( | ||
| "/content/drive/MyDrive/EMBC_workshop_data/chexpert_healthy_abnormality_subset.csv" | ||
| ) | ||
| _C.DATA.BATCH_SIZE = 32 | ||
| _C.DATA.NUM_WORKERS = 2 | ||
| _C.DATA.DATA_DEVICE = "cpu" | ||
|
|
||
| # Model configuration | ||
| _C.MODEL = CN() | ||
| _C.MODEL.LATENT_DIM = 256 | ||
| _C.MODEL.INPUT_IMAGE_CHANNELS = 1 | ||
| _C.MODEL.INPUT_DIM_ECG = 60000 | ||
| _C.MODEL.NUM_LEADS = 12 | ||
|
|
||
| # Fine-tuning configuration | ||
| _C.FT = CN() | ||
| _C.FT.EPOCHS = 15 | ||
| _C.FT.LR = 0.001 | ||
| _C.FT.HIDDEN_DIM = 128 | ||
| _C.FT.NUM_CLASSES = 2 | ||
| _C.FT.CKPT_PATH = "/content/drive/MyDrive/EMBC_workshop_data/CardioVAE.pth" | ||
| _C.FT.ACCELERATOR = "gpu" | ||
| _C.FT.DEVICES = 1 # This is for PyTorch Lightning's Trainer, set as int not string | ||
| _C.FT.DEVICE = "cuda" # For torch.device() | ||
| _C.FT.KFOLDS = 5 | ||
| _C.FT.SEED = 42 | ||
|
|
||
| # Interpretation configuration | ||
| _C.INTERPRET = CN() | ||
| _C.INTERPRET.SAMPLE_IDX = 101 | ||
| _C.INTERPRET.ZOOM_RANGE = [3, 3.5] | ||
| _C.INTERPRET.ECG_THRESHOLD = 0.7 | ||
| _C.INTERPRET.CXR_THRESHOLD = 0.7 | ||
| _C.INTERPRET.SAMPLING_RATE = 500 | ||
|
|
||
|
|
||
| def get_cfg_defaults(): | ||
| return _C.clone() |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,187 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # interpret.py | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import neurokit2 as nk | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from captum.attr import IntegratedGradients | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from scipy.ndimage import binary_dilation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def multimodal_ecg_cxr_attribution( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| last_fold_model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| last_val_loader, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sample_idx=0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ecg_threshold=0.70, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cxr_threshold=0.7, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| zoom_range=(3, 3.5), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lead_number=12, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sampling_rate=500, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Computes model attributions for multimodal (ECG + CXR) input using Integrated Gradients. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| This function selects a sample from the provided validation loader and computes the attributions | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (importance scores) for both ECG and CXR modalities using Captum's Integrated Gradients. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| It returns all relevant arrays and data needed for downstream visualization, including normalized | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| attributions, important indices, and segment data for zoomed-in views. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| last_fold_model : torch.nn.Module | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Trained multimodal model that accepts both CXR images and ECG waveforms as input. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| last_val_loader : DataLoader | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PyTorch DataLoader for the validation dataset. Each batch should yield (CXR, ECG, label). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sample_idx : int, optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Index of the sample in the validation set to interpret (default is 0). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ecg_threshold : float, optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Threshold (0-1) to consider ECG attributions as important (default is 0.70). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cxr_threshold : float, optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Threshold (0-1) to consider CXR attributions as important (default is 0.70). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| zoom_range : tuple of float, optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Start and end (in seconds) for zoomed ECG visualization window (default is (3, 3.5)). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lead_number : int, optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Number of ECG leads (default is 12). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sampling_rate : int, optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Sampling rate of the ECG waveform in Hz (default is 500). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Dictionary containing: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - label : int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| True class label for the selected sample. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - predicted_label : int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Model's predicted class for the sample. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - predicted_probability : float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Probability of the predicted class. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - ecg_waveform_np : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 1D numpy array of the processed ECG waveform. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - full_time : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Time axis (seconds) for the full ECG. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - full_length : int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Number of time points in the (possibly trimmed) ECG. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - important_indices_full : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Indices in the full ECG considered important by attribution threshold. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - segment_ecg_waveform : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Zoomed ECG segment. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - zoom_time : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Time axis (seconds) for the zoomed ECG segment. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - important_indices_zoom : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Important indices within the zoomed ECG segment. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - zoom_start_sec : float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Start time (seconds) of the zoomed window. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - zoom_end_sec : float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| End time (seconds) of the zoomed window. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - xray_image_np : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| CXR image as a numpy array. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - x_pts, y_pts : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Coordinates of important points in the CXR image (after dilation). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - importance_pts : np.ndarray | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Attribution values at (x_pts, y_pts). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - ecg_threshold : float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The threshold used for ECG attributions. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - cxr_threshold : float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The threshold used for CXR attributions. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Gather all batches (as in your code) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batches = list(last_val_loader) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_xray_images, all_ecg_waveforms, all_labels = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cat(items) for items in zip(*batches) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # --- Select Sample --- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| xray_image = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_xray_images[sample_idx] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .to(next(last_fold_model.parameters()).device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ecg_waveform = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_ecg_waveforms[sample_idx] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .unsqueeze(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .to(next(last_fold_model.parameters()).device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| label = all_labels[sample_idx].item() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+86
to
+103
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Gather all batches (as in your code) | |
| batches = list(last_val_loader) | |
| all_xray_images, all_ecg_waveforms, all_labels = [ | |
| torch.cat(items) for items in zip(*batches) | |
| ] | |
| # --- Select Sample --- | |
| xray_image = ( | |
| all_xray_images[sample_idx] | |
| .unsqueeze(0) | |
| .to(next(last_fold_model.parameters()).device) | |
| ) | |
| ecg_waveform = ( | |
| all_ecg_waveforms[sample_idx] | |
| .unsqueeze(0) | |
| .to(next(last_fold_model.parameters()).device) | |
| ) | |
| label = all_labels[sample_idx].item() | |
| # --- Select Sample --- | |
| xray_image, ecg_waveform, label = last_val_loader.dataset[sample_idx] | |
| xray_image = ( | |
| xray_image.unsqueeze(0) | |
| .to(next(last_fold_model.parameters()).device) | |
| ) | |
| ecg_waveform = ( | |
| ecg_waveform.unsqueeze(0) | |
| .to(next(last_fold_model.parameters()).device) | |
| ) | |
| label = label.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for now. Rejecting it.
Copilot
AI
Jun 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model forward pass uses the raw ecg_waveform, but IntegratedGradients is applied to ecg_smoothed_tensor. This mismatch can produce incorrect attributions; use the same input tensor in both the prediction and attribution steps.
| logits = last_fold_model(xray_image, ecg_waveform) | |
| logits = last_fold_model(xray_image, ecg_smoothed_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for prediction. so the Smoothing is for making the visualization better. So, rejecting this copilot suggestion.
Copilot
AI
Jun 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dividing time indices by lead_number skews the time axis. Time should be computed as np.arange(full_length) / sampling_rate without dividing by the number of leads.
| full_time = np.arange(0, full_length) / sampling_rate / lead_number | |
| important_indices_full = np.where( | |
| norm_attributions_ecg[:full_length] >= ecg_threshold | |
| )[0] | |
| zoom_start = int(zoom_range[0] * 6000) | |
| zoom_end = int(zoom_range[1] * 6000) | |
| zoom_time = np.arange(zoom_start, zoom_end) / sampling_rate / lead_number | |
| segment_ecg_waveform = ecg_waveform_np[zoom_start:zoom_end] | |
| segment_attributions = norm_attributions_ecg[zoom_start:zoom_end] | |
| important_indices_zoom = np.where(segment_attributions >= ecg_threshold)[0] | |
| zoom_start_sec = zoom_start / sampling_rate / lead_number | |
| zoom_end_sec = zoom_end / sampling_rate / lead_number | |
| full_time = np.arange(0, full_length) / sampling_rate | |
| important_indices_full = np.where( | |
| norm_attributions_ecg[:full_length] >= ecg_threshold | |
| )[0] | |
| zoom_start = int(zoom_range[0] * 6000) | |
| zoom_end = int(zoom_range[1] * 6000) | |
| zoom_time = np.arange(zoom_start, zoom_end) / sampling_rate | |
| segment_ecg_waveform = ecg_waveform_np[zoom_start:zoom_end] | |
| segment_attributions = norm_attributions_ecg[zoom_start:zoom_end] | |
| important_indices_zoom = np.where(segment_attributions >= ecg_threshold)[0] | |
| zoom_start_sec = zoom_start / sampling_rate | |
| zoom_end_sec = zoom_end / sampling_rate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, rejecting this suggestion as we divide by the lead to get the values in seconds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add docstring. Consider how to integrate this function to kale in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.