In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
'''
Purpose: Run CoMTE with full ECG dataset

'''

'\nPurpose: Run CoMTE with full ECG dataset\n\n'

## Set up Data

In [3]:
# Third party modules
import os
import h5py
import math
import pandas as pd
import numpy as np
import time
from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import (
    Input, Conv1D, MaxPooling1D, Dropout, BatchNormalization, Activation, Add, Flatten, Dense)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (ModelCheckpoint, TensorBoard, ReduceLROnPlateau,
                                        CSVLogger, EarlyStopping)
from tensorflow.keras.models import load_model
from sklearn.pipeline import Pipeline

# Project modules
import datasets as datasets

# Define environmental variables to prevent overuse of CPU Cores
# Access and modify environmental variables
os.environ['TF_NUM_INTRAOP_THREADS'] = '1' # Set to 1
os.environ['TF_NUM_INTEROP_THREADS'] = '3' # Set to 1 less than # of requested cores
print(f"TF_NUM_INTRAOP_THREADS is {os.getenv('TF_NUM_INTRAOP_THREADS')}")
print(f"TF_NUM_INTEROP_THREADS is {os.getenv('TF_NUM_INTEROP_THREADS')}")

2025-10-16 04:08:57.304361: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-16 04:08:57.305631: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-10-16 04:08:57.312599: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-10-16 04:08:57.323021: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760602137.337456 3392248 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760602137.34

TF_NUM_INTRAOP_THREADS is 1
TF_NUM_INTEROP_THREADS is 3


Select points to explain (from test set)

In [4]:
# Select points to explain from testing set

# load testing dataset 
path_to_hdf5_test = "/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/CODE/ecg_tracings.hdf5"
dataset_name_test = "tracings"  

# Import data. SEQ is an instance of class ECGSequence
seq = datasets.ECGSequence(path_to_hdf5_test, dataset_name_test)  # using default batch size

# load pretrained model (still need to compile later) 
model_path = "/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/PretrainedModels/model/model.hdf5"
pre_model = load_model(model_path)  

# compile and apply model to testing dataset
pre_model.compile(loss='binary_crossentropy', optimizer=Adam())
model_predictions = pre_model.predict(seq,verbose=1)   # y_score is a numpy array with dimensions 827x6. It holds the predictions generated by the model

# extra
print(model_predictions.shape)
print(model_predictions[:5])

# Generate dataframe
np.save("/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/dnn_output.npy", model_predictions)
print("Output predictions saved")

W0000 00:00:1760602157.834229 3392248 gpu_device.cc:2344] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
  self._warn_if_super_not_called()


[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 227ms/step
(827, 6)
[[1.4243184e-06 1.0710034e-07 2.6337020e-07 4.5377439e-07 9.4853954e-07
  6.4135128e-09]
 [2.8897336e-02 2.0066653e-03 3.1778628e-01 2.8277325e-05 4.8343450e-02
  3.2050055e-04]
 [3.1124667e-04 2.9402891e-05 4.1752287e-06 1.9712781e-05 9.3489848e-03
  2.4932444e-05]
 [2.3969071e-09 1.7344906e-09 6.9393691e-10 8.1738605e-10 5.6821450e-09
  2.7672692e-10]
 [5.3062354e-04 3.5334433e-06 3.3941666e-07 1.4301372e-06 2.2422880e-04
  4.7077428e-06]]
Output predictions saved


In [5]:
def ECG_one_d_labels(model_predictions, onehot_labels = True):
    '''
    
    Purpose: turn one-hot encoding (N,6) array into (Nx1) vector of classes

    Input: 
    model_predictions: 2D array of probabilities or one-hot encodings (827x6)
    onehot_labels: boolean variable 

    Output: 
    (Nx1) vector of classes

    Comments: 
    The sample class is the class that exceeds the threshold
    If there are >1 classes that exceed the threshold, a tuple will be used to store the multiple classes 
    '''
    

    if not onehot_labels:
        # establish threshold
        threshold = np.array([0.124, 0.07, 0.05, 0.278, 0.390, 0.174])
        # generate class 0 probability
        exceedances = 1 - (np.maximum((model_predictions - threshold) , 0) / (1 - threshold))
        normal_prob = np.mean(exceedances, axis = 1, keepdims = True) # normal prob should be (N,1)
        # add normal prob
        probability_n = np.column_stack((normal_prob, model_predictions))
        # new threshold
        new_threshold = np.array([1, 0.124, 0.07, 0.05, 0.278, 0.390, 0.174])

        # make mask
        mask = probability_n >= new_threshold
    else:
        print(model_predictions.shape)

        mask = model_predictions == 1

        # Ensure each row has at least one '1'
        # no_positive_class is a column vector
        # Find rows with all False (no '1') # rows with all false becomes true
        no_positive_class = ~mask.any(axis=1) 
        
        # Expand mask by adding a new first column of zeros
        mask = np.column_stack((no_positive_class, mask))
    

    sample_classes = []
    for row in mask:
        passing_indices = np.where(row)[0]
        if len(passing_indices) > 1:  # If more than one indices pass
            if not onehot_labels: 
                # calc exceedances    
                exceedances = row - new_threshold
                # Get class with the highest exceedance
                max_class = np.argmax(exceedances)
                sample_classes.append(max_class)
            else:
                sample_classes.append(tuple(sorted(passing_indices)))  # Ensure passing indices are sorted in ascending order
        elif len(passing_indices) == 0:  # no passes
            sample_classes.append(0) 
        else:
            sample_classes.append(passing_indices[0])  

    return sample_classes


In [6]:
# load predictions to make y_pred
model_predictions = np.load("/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/dnn_output.npy")
# make y_pred
y_pred = ECG_one_d_labels(model_predictions, onehot_labels = False)

# make y_true
y_true_2D = pd.read_csv('/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/CODE/annotations/gold_standard.csv').values
# convert 2D to 1D
y_true = ECG_one_d_labels(y_true_2D, onehot_labels = True)

(827, 6)


In [7]:
# select indices/conditions for CoMTE
true_select = 0 #UPDATE HERE FOR OTHER CLASSES
pred_select = 0 #UPDATE HERE FOR OTHER CLASSES

# find relevant indices
indices_test = []
for idx, (true, pred) in enumerate(zip(y_true, y_pred)):
    print(f"Index:{idx}, True Label: {true}, Predicted Label: {pred}") # print elements
    if true ==  true_select and pred == pred_select:
        indices_test.append(idx)   
        
print('\n\n\n')
print(f"The {indices_test} indices match the case defined above:\n(true_select = {true_select}, pred_select = {pred_select})")

Index:0, True Label: 0, Predicted Label: 0
Index:1, True Label: 3, Predicted Label: 3
Index:2, True Label: 0, Predicted Label: 0
Index:3, True Label: 0, Predicted Label: 0
Index:4, True Label: 0, Predicted Label: 0
Index:5, True Label: 0, Predicted Label: 0
Index:6, True Label: 0, Predicted Label: 0
Index:7, True Label: 0, Predicted Label: 0
Index:8, True Label: 0, Predicted Label: 0
Index:9, True Label: 0, Predicted Label: 0
Index:10, True Label: 0, Predicted Label: 0
Index:11, True Label: 0, Predicted Label: 0
Index:12, True Label: 1, Predicted Label: 1
Index:13, True Label: 0, Predicted Label: 0
Index:14, True Label: 0, Predicted Label: 0
Index:15, True Label: (np.int64(1), np.int64(3)), Predicted Label: 3
Index:16, True Label: 0, Predicted Label: 0
Index:17, True Label: 0, Predicted Label: 0
Index:18, True Label: (np.int64(1), np.int64(3)), Predicted Label: 3
Index:19, True Label: 0, Predicted Label: 0
Index:20, True Label: 0, Predicted Label: 0
Index:21, True Label: 0, Predicted L

In [8]:
# feature importance experiment: 

# select indices/conditions for CoMTE (modified for getting TP's)

true_select = 6 #UPDATE HERE FOR OTHER CLASSES
pred_select = 6 #UPDATE HERE FOR OTHER CLASSES

# find relevant indices
indices_test = []
tp_indices = {}
class_array = [1,2,3,4,5,6]


for c in class_array:
    true_select = c
    pred_select = c

    # get TP's 
    for idx, (true, pred) in enumerate(zip(y_true, y_pred)):
        print(f"Index:{idx}, True Label: {true}, Predicted Label: {pred}") # print elements
        if true ==  true_select and pred == pred_select:
            if true not in tp_indices:
                tp_indices[true] = []
            tp_indices[true].append(idx)
            indices_test.append(idx)   
        
print('\n\n\n')
print(f"The {indices_test} indices match the case defined above:\n(true_select = {true_select}, pred_select = {pred_select})")
print(tp_indices.keys())

Index:0, True Label: 0, Predicted Label: 0
Index:1, True Label: 3, Predicted Label: 3
Index:2, True Label: 0, Predicted Label: 0
Index:3, True Label: 0, Predicted Label: 0
Index:4, True Label: 0, Predicted Label: 0
Index:5, True Label: 0, Predicted Label: 0
Index:6, True Label: 0, Predicted Label: 0
Index:7, True Label: 0, Predicted Label: 0
Index:8, True Label: 0, Predicted Label: 0
Index:9, True Label: 0, Predicted Label: 0
Index:10, True Label: 0, Predicted Label: 0
Index:11, True Label: 0, Predicted Label: 0
Index:12, True Label: 1, Predicted Label: 1
Index:13, True Label: 0, Predicted Label: 0
Index:14, True Label: 0, Predicted Label: 0
Index:15, True Label: (np.int64(1), np.int64(3)), Predicted Label: 3
Index:16, True Label: 0, Predicted Label: 0
Index:17, True Label: 0, Predicted Label: 0
Index:18, True Label: (np.int64(1), np.int64(3)), Predicted Label: 3
Index:19, True Label: 0, Predicted Label: 0
Index:20, True Label: 0, Predicted Label: 0
Index:21, True Label: 0, Predicted L

In [9]:
# feature importance experiment: 
# select 3 random indices from each key

# import random
# sampled_tp_indices = {}
# for key, indices in tp_indices.items():
#     # sample min(3, len(indices)) elements to avoid errors when fewer than 3 exist
#     sampled_tp_indices[key] = random.sample(indices, min(3, len(indices)))

# print("Randomly sampled indices per class:")
# for k, v in sampled_tp_indices.items():
#     print(f"{k}: {v}")

sampled_tp_indices = {
    1: [249, 420, 463],
    2: [106, 618, 683],
    3: [293, 1, 253],
    4: [495, 98, 188],
    5: [501, 572, 259],
    6: [492, 659, 767]}

## Make KDTree to pass into CoMTE 

In [10]:
# load in KDTree + inspect
import pickle

kd_tree_path = "/projectnb/peaclab-mon/JLi/projectx/AutomaticECGDiagnosis/KDTree_all_training/combined_kdtrees.pkl"

with open(kd_tree_path, 'rb') as f:
    print('loading kdtree from pkl')
    combined_kdtrees = pickle.load(f)

for key, kdtree in combined_kdtrees.items():
    print(f"Key: {key}")
    print("Data points in the KDTree:")
    print(kdtree.data.shape)  # Print the data points stored in the KDTree
    print("-" * 50)  # Print a separator line for clarity

loading kdtree from pkl
Key: 0
Data points in the KDTree:
(280394, 49152)
--------------------------------------------------
Key: 1
Data points in the KDTree:
(3428, 49152)
--------------------------------------------------
Key: 2
Data points in the KDTree:
(6769, 49152)
--------------------------------------------------
Key: 3
Data points in the KDTree:
(4597, 49152)
--------------------------------------------------
Key: 4
Data points in the KDTree:
(4103, 49152)
--------------------------------------------------
Key: 5
Data points in the KDTree:
(4498, 49152)
--------------------------------------------------
Key: 6
Data points in the KDTree:
(6499, 49152)
--------------------------------------------------


## Apply CoMTE_V2 that KDTree input

In [11]:
"""
Part 1: A Classifier that works with COMLEX

The classifier must have 2 capabilities:
1. Predict a class ie: class 0 in classes {0, 1}
2. Predict the probability for each class
-ie: [0.1, 0.9]

and

Be able to execute capability 1 and 2 on a PANDAS dataframe,
returning an array of corresponding predictions.



input:
    samples to be classified (pandas multiindex dataframe)

output: 
    for contrived_classification: length N list of classes

    for contrived_classification_proba: 
            length N list of 1x7 np arrays
"""

class BasicClassifier:
    classifier = pre_model  # tensorflow CNN
    import os
    
    @staticmethod
    def contrived_classification(pandas_dfs):
        classifier = pre_model  # tensorflow CNN

        # convert 2D pandas df to 3D dataframe (N,4096,12)
        array_3d = pandas_dfs.to_numpy().reshape(int(pandas_dfs.shape[0]/4096), 4096, 12)

        # create instance of ECGSequence to store the (N,4096,12) dataset
        temp_path = "/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/temporary.hdf5"
        temp_dataset_name = "tracings"
        if os.path.exists(temp_path):
            os.remove(temp_path)
        # create hdf with appropriate data
        hdf_file = h5py.File(temp_path, 'w')
        hdf_file.create_dataset(temp_dataset_name,data = array_3d)
        # init instnace of ECG Sequence holding modified with hdf path
        modified_instance = datasets.ECGSequence(temp_path, temp_dataset_name)

        # get classification and probability
        probability = classifier.predict(modified_instance, verbose = 1)    
        
    
        # close hdf5's
        modified_instance._closehdf()
        hdf_file.close()
        os.remove(temp_path)

        # analyze model output with thresholding
        # define given thresholds
        threshold = np.array([0.124, 0.07, 0.05, 0.278, 0.390, 0.174])
        
        # generate class 0 probability
        exceedances = 1 - (np.maximum((probability - threshold) , 0) / (1 - threshold))
        normal_prob = np.mean(exceedances, axis = 1, keepdims = True) # normal prob should be (N,1)
        
        # Add normal_prob as a new column
        probability_n = np.column_stack((normal_prob, probability))     

        # new threshold
        new_threshold = np.array([1, 0.124, 0.07, 0.05, 0.278, 0.390, 0.174])
        
        mask = probability_n >= new_threshold
        sample_classes = []  # init list for appends later
        
        for row, mask in zip(probability_n, mask):
            passing_indices = np.where(mask)[0]
            if len(passing_indices) > 1:  # If more than one indices pass
                # find margin between threshold and probability
                diff_array = row - new_threshold
                passing_index = np.argmax(diff_array)
                # append the index that has the highest margin
                sample_classes.append(passing_index)
            
            elif len(passing_indices) == 0:  # no passes
                sample_classes.append(0) 
            else:
                sample_classes.append(passing_indices[0])  # Select the first (or adjust logic)
                
        return sample_classes


    @staticmethod
    def contrived_classification_proba(pandas_dfs):
        classifier = pre_model  # tensorflow CNN
        
        # convert 2D pandas df to 3D dataframe (N,4096,12)
        array_3d = pandas_dfs.to_numpy().reshape(int(pandas_dfs.shape[0]/4096), 4096, 12)

        # create instance of ECGSequence to store the (N,4096,12) dataset
        temp_path = "/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/temporary.hdf5"
        temp_dataset_name = "tracings"
        if os.path.exists(temp_path):
            os.remove(temp_path)
        # create hdf with appropriate data
        hdf_file = h5py.File(temp_path, 'w')
        hdf_file.create_dataset(temp_dataset_name,data = array_3d)
        # init instnace of ECG Sequence holding modified with hdf path
        modified_instance = datasets.ECGSequence(temp_path, temp_dataset_name)

        # get classification and probability
        probability = classifier.predict(modified_instance, verbose = 1)  
        
        # close hdf5's
        modified_instance._closehdf()
        hdf_file.close()
        os.remove(temp_path)

        # analyze model output with thresholding
         # define given thresholds
        threshold = np.array([0.124, 0.07, 0.05, 0.278, 0.390, 0.174])
        
        # generate class 0 probability
        exceedances = 1 - (np.maximum((probability - threshold) , 0) / (1 - threshold))
        normal_prob = np.mean(exceedances)

        # modify result 
        probability = np.insert(probability,0,normal_prob)   

        # probability should be in a 2D array format
        if probability.ndim == 1:  # Check if it's 1D
            probability = probability.reshape(1, -1)
        
        return probability

In [12]:
"""
Part 2: Training data and labels

[The explanation will use counterfactuals drawn from this input data]

The training data can be should be an iterable of samples
(ie: python array, numpy array, pandas dataframe),
where each sample needs to be the same size array as the others.

The labels should be a corresponding iterable to the samples.

COMLEX will only use samples for which the labels are the same
as the prediction from the trained classifier.

Note:
We don't support variable-length training data at this time,
use a different projection of the data if you have such data.
"""

class BasicData:
    # define basic variables
    classes_available = [0,1,2,3,4,5,6]
    num_columns = 4096

    # define key paths and variables for training data
    path_to_hdf5_test = "/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/CODE/ecg_tracings.hdf5"
    num_features = 12
    dataset_name_hdf_tracings = "tracings" 
    # training set data and lables
    training_set_hdf_path = "/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/combined_V2.hdf5"
    y_train_csv_path = "/projectnb/peaclab-mon/JLi/projectx/AutoECGDiagnosisData/labels_combined_V2.csv"
    # read the csv into a np dataframe 
    np_train_labels = np.genfromtxt(y_train_csv_path, delimiter=",")
    
    # for the ECG implementation, the data wrapper must convert a 3D HDF5 file into a pandas multiindex array
    # create instances of ECGSequence for train data 
    train_seq, valid_seq = datasets.ECGSequence.get_train_and_val(training_set_hdf_path, dataset_name_hdf_tracings, y_train_csv_path,val_split=0.02)
    # return array-like samples for the data wrapper (returns 20000x4096x12 np array)
    timeseries = train_seq._gettimeseries_()
    num_features = 12
    # iterable of corresponding labels for the samples for the data wrapper (returns 20000x6 np array) <--- take out first column that represents ExamID
    labels = train_seq._gettruelabel_()[:,:] 

The number of samples in the dataset is 20000
The index in which the validation set starts and train set ends is 19600
<datasets.ECGSequence object at 0x152f5dceded0>
<datasets.ECGSequence object at 0x152f5dcef910>


In [13]:
print(type(BasicData.labels))

print(BasicData.labels)


<class 'numpy.ndarray'>
[[False False False False False False]
 [False False False False False False]
 [False False False False False False]
 ...
 [False False False False False False]
 [False False  True False False False]
 [False False False False False False]]


In [14]:
"""
Part 3: Wrapping it up.

The training data, training labels, and trained classifier need to be wrapped up
into a form that can pass through COMLEX.

While wrapping up the training data and labels is relatively straightforward,
wrapping up the classifier is more difficult
"""

import sys
sys.path.append('/projectnb/peaclab-mon/JLi/projectx/CoMTE_V2_JLi/CoMTE_V2/comlex_core/src')  # Path to the comlex_core directory

# import project (wrapper) modules
import explainers_input_kd as explainers_V2
from explainable_model_ECG import ClfModel as ClfModel
from explainable_data_ECG import ClfData as ClfData

class BasicComlexInput:

    # 1. wrap training points
    df_train_points = ClfData.wrap_df_x(BasicData.timeseries, BasicData.num_features)
    
    # 2. wrap training labels
    df_train_labels = ClfData.wrap_df_y(BasicData.labels)
    
    # 3. wrap up the classifier
    # note: column_attr, or the corresponding name of the columns in the sample,
    #  is unique to dataframes, and auto-generated by wrap_df_x
    wrapped_classifier = ClfModel(BasicClassifier.classifier,
                                predict_attr=BasicClassifier.contrived_classification,
                                predict_proba_attr=BasicClassifier.contrived_classification_proba,
                                column_attr=df_train_points.columns.values.tolist(),
                                classes_attr=BasicData.classes_available,
                                window_size_attr=BasicData.num_columns)

In [15]:
# Part 4: run through COMLEX

"""
Part 4: Running it through COMLEX

Requires:
1. wrapped classifier
2. wrapped training data
3. wrapped training labels

To run COMLEX:
1. wrap the test point
2. instantiate a comlex runner on the wrapped components
-OptimizedSearch sets up a KDTree for based on the data,
 in order to speed up the search time for the counterfactual
 explanation.
-OptimizedSearch will fallback to BruteForceSearch if it fails
 to find a counterfactual explanation with a predicted
 probability greater than 0.95.
3. use the comlex runner to explain wrapped datapoint
"""


# get testing point
test_point = seq._getsample_(40)
# wrap test point 
test_df = ClfData.wrap_df_test_point(test_point)

# 2. set up an optimized search comlex runner
comlex = explainers_V2.OptimizedSearch(BasicComlexInput.wrapped_classifier,
                                    BasicComlexInput.df_train_points,
                                    BasicComlexInput.df_train_labels,
                                    combined_kdtrees,
                                    silent=True, threads=4, num_distractors=2)



In [16]:
# 3. explain the test point

target_class = 0
explanation = comlex.explain(test_df,to_maximize=target_class,
                             return_dist=True,single=True,
                             savefig=True,train_iter=100,
                             timeseries=False,filename="sample_result.png")

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 647ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step
-------Preliminary Statistics-------
Original Sample Class: [np.int64(4)] 
Sample Probabilities: [[9.2353284e-01 2.2014480e-03 1.1947299e-03 4.0518144e-06 6.0925579e-01
  5.8941361e-03 1.0822760e-05]]
Class of Interest: 0


generating distractors
Class 0 has 280394 indices.
Class 1 has 3428 indices.
Class 2 has 6769 indices.
Class 3 has 4597 indices.
Class 4 has 4103 indices.
Class 5 has 4498 indices.
Class 6 has 6499 indices.
queried indices are [139974 131593  28847  67596 193349  41171 272573 271726  25194  24076
 197950 203839  24224 131822 200524 197787 268750 268409  81188  22607
 103501 120696 111502 267596  28891 220730 206469 149997 258884   9186
 243946 227692 149405 258204 267776]
edited queried indices: [139974 131593  28847  67596 193349  41171 272573 271726  25194  24076
 197950 203839  24224 131822 200524 197787 268750 268409  8

## Analysis and Visualization of Explanation

In [17]:
import sys
sys.path.insert(0,"/projectnb/peaclab-mon/JLi/projectx/CoMTE_V2_JLi/ECG_Visualization/ecg_plot_counterfactual")
import ecg_plot_counterfactual

In [18]:
# analyze output
# print(f"explanation is {explanation}\n")

replacements_np = explanation[0]
replacements = {str(item) for item in replacements_np}
distractor_new = explanation[1]
counterfactual_explanation = test_df.copy()
#counterfactual_explanation = [point for point in test_df] # make copy of original test data before doing replacements

for replacement_i in replacements:
    counterfactual_explanation[replacement_i] = distractor_new[replacement_i].values

print(replacements_np)
print(distractor_new.columns)
print(test_df.columns)

['V1']
Index(['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5',
       'V6'],
      dtype='object')
Index(['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5',
       'V6'],
      dtype='object')


In [19]:
# validate explanation

print(type(counterfactual_explanation))
print(counterfactual_explanation.shape)
# convert pd df to 3D numpy
# Assuming df is your DataFrame
array_3d = np.expand_dims(counterfactual_explanation.values, axis=0)
print(array_3d.shape)
predictions = BasicComlexInput.wrapped_classifier.predict_proba(counterfactual_explanation)
print(predictions)




<class 'pandas.core.frame.DataFrame'>
(4096, 12)
(1, 4096, 12)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 57ms/step
[[1.0000000e+00 1.3703873e-02 2.5442715e-03 5.0810835e-05 1.8770416e-01
  8.8611737e-02 7.7898771e-04]]


In [20]:
# visualize with ecg_plot counterfactual

lead_map = {'DI':0,
            'DII':1,
            'DIII':2,
            'AVR':3,
            'AVL':4,
            'AVF':5,
            'V1':6,
            'V2':7,
            'V3':8,
            'V4':9,
            'V5':10,
            'V6':11}


explanation_mapped = [lead_map[lead] for lead in replacements_np]
print(explanation_mapped)


# transform distractor t_test
test_df_visualization = test_df.to_numpy().T
print(test_df_visualization.shape)
distractor_visualization = distractor_new.to_numpy().T
print(distractor_visualization.shape)

[6]
(12, 4096)
(12, 4096)


In [21]:
# import sys
# sys.path.insert(0,"/projectnb/peaclab-mon/JLi/projectx/CoMTE_V2_JLi/ECG_Visualization/ecg_plot_counterfactual")
# import ecg_plot_counterfactual 

# ecg_plot_counterfactual.plot(test_df_visualization, distractor_visualization, explanation_mapped, sample_rate = 400, title = 'ECG 12', style = 'bw', row_height = 10)
# ecg_plot_counterfactual.show()
# ecg_plot_counterfactual.save_as_jpg('ECGCounterfactual')


## Extra Feature Experiment Runs


In [None]:
from collections import Counter
import json

target_class_list = [0, 1, 2, 3, 4, 5, 6]

# Sort keys by class label (k) descending so largest class goes first
for k, v in sorted(sampled_tp_indices.items()):
    print(f"{k}: {v}")
    for sample in v:
        feature_list = []
        for target_class in target_class_list:
            # get explanation
            explanation = comlex.explain(
                test_df,
                to_maximize=target_class,
                return_dist=True,
                single=True,
                savefig=True,
                train_iter=100,
                timeseries=False,
                filename="sample_result.png"
            )

            # store features
            feature_list.append(explanation[0])

        print(feature_list)

        # Flatten all sub-collections (lists or sets)
        flattened = [lead for group in feature_list for lead in group]

        # Count frequencies
        freq = Counter(flattened)
        print(freq)
        freq_dict = dict(freq)

        # save off dictionary
        with open(f"./LIME_SHAP_Comparison_Experiments/lime_shap_explanations/sample{sample}_class{k}_comte_.json", "w") as f:
            json.dump(freq_dict, f)


1: [249, 420, 463]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step
-------Preliminary Statistics-------
Original Sample Class: [np.int64(4)] 
Sample Probabilities: [[9.2353284e-01 2.2014480e-03 1.1947299e-03 4.0518144e-06 6.0925579e-01
  5.8941361e-03 1.0822760e-05]]
Class of Interest: 0


generating distractors
Class 0 has 280394 indices.
Class 1 has 3428 indices.
Class 2 has 6769 indices.
Class 3 has 4597 indices.
Class 4 has 4103 indices.
Class 5 has 4498 indices.
Class 6 has 6499 indices.
queried indices are [139974 131593  28847  67596 193349  41171 272573 271726  25194  24076
 197950 203839  24224 131822 200524 197787 268750 268409  81188  22607
 103501 120696 111502 267596  28891 220730 206469 149997 258884   9186
 243946 227692 149405 258204 267776]
edited queried indices: [139974 131593  28847  67596 193349  41171 272573 271726  25194  24076
 197950 203839  24224 131822 200524 19778

In [None]:
## extra feature experiment runs
from collections import Counter
import json


# sample_num = 40
# # get testing point
# test_point = seq._getsample_(40)
# # wrap test point 
# test_df = ClfData.wrap_df_test_point(test_point)


# make list of target classes
target_class_list = [0,1,2,3,4,5,6]

for k, v in sampled_tp_indices.items():
    print(f"{k}: {v}")
    for sample in v:
        feature_list = []    
        for target_class in target_class_list:
            # get explanation
            explanation = comlex.explain(test_df,to_maximize=target_class,
                                         return_dist=True,single=True,
                                         savefig=True,train_iter=100,
                                         timeseries=False,filename="sample_result.png")
            # store features
            feature_list.append(explanation[0])
        print(feature_list)

        # Flatten all sub-collections (lists or sets)
        flattened = [lead for group in feature_list for lead in group]
        
        # Count frequencies
        freq = Counter(flattened)
        print(freq)
        freq_dict = dict(freq)
        
        # save off dictionary
        with open(f"./LIME_SHAP_Comparison_Experiments/lime_shap_explanations/sample{sample}_class{k}_comte_.json", "w") as f:
            json.dump(freq_dict, f)
    
          