In [None]:
from abp_artf_plugin.pipeline import ProcessingPipeline
import numpy as np
import wfdb
import random
from utils.stats import calculate_cls_metrics
from scipy.signal import butter, filtfilt
from scipy.signal import resample

import matplotlib.pyplot as plt
from matplotlib import rcParams, font_manager as fm
arial_font_path = 'fonts/Arial.ttf'
fm.fontManager.addfont(arial_font_path)
arial_font = fm.FontProperties(fname=arial_font_path)
rcParams['font.family'] = arial_font.get_name()


import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

#Set seed for reproducibility
seed = 224
random.seed(seed)
np.random.seed(seed)


In [None]:
def plot_data_and_label(data, label, save_memory=True):
    # Create a subplot with two rows
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.1,
                        subplot_titles=("Processed Data", "Label"))

    # Plot the processed data in the first subplot
    fig.add_trace(go.Scatter(
        x=np.arange(len(data)),
        y=data,
        mode='lines',
        name='Processed Data',
    ), row=1, col=1)

    # Plot the label in the second subplot
    fig.add_trace(go.Scatter(
        x=np.arange(len(label)),
        y=label,
        mode='lines',
        name='Label',
    ), row=2, col=1)

    # Update layout for better visualization
    fig.update_layout(
        title="Processed Data and Label",
        xaxis_title="Index",
        yaxis_title="Data",
        xaxis2_title="Index",
        yaxis2_title="Label",
        height=600
    )

    # Show the plot
    if save_memory:
        pio.show()
    else:
        fig.show()
        
    return fig

In [None]:
p = ProcessingPipeline()

## Test on MIMIC-III dataset

In [None]:
# file_path:
record = wfdb.rdrecord('../datasets/physionet.org/files/mimic3wdb-matched/1.0/p00/p000020/p000020-2183-04-28-17-47') 
display(record.__dict__)
signals = record.p_signal
display(signals.shape)

wfdb.plot_wfdb(record=record, title='p000020')

In [None]:
fs = 125
skip_s = 60*5
duration_s = 60*60*3
data_abp = signals[fs*skip_s:fs*(skip_s+duration_s), 2]
print(data_abp.shape)

In [None]:
# high risk memory line!
data, label, mse_list = p.process(data_abp, fs)
print(data.shape)
print(label.shape)

In [None]:
fig = plot_data_and_label(data_abp, label.astype(int), save_memory=False)

## Arterial Blood Pressure Hypertension

In [None]:
data_abp = np.load('buffer/data_abp.npy')
data_label = np.load('buffer/data_abp_label.npy')

# masked data with label 0
data_abp_masked = data_abp.copy()
data_abp_masked[data_label == 1] = 0


In [None]:
# calculate the hypertension period
from scipy.signal import find_peaks

def find_sbp_dbp(abp_series):
    """
    Find Systolic Blood Pressure (SBP) and Diastolic Blood Pressure (DBP) from a high-resolution ABP series.
    
    Parameters:
    - abp_series: 1D numpy array, high-resolution ABP data.
    
    Returns:
    - sbp_values: List of detected Systolic BP values (peaks).
    - dbp_values: List of detected Diastolic BP values (troughs).
    """
    # Find systolic peaks (SBP) - local maxima
    sbp_peaks, _ = find_peaks(abp_series, distance=50)  # Distance avoids small variations
    sbp_values = abp_series[sbp_peaks]
    
    # Find diastolic troughs (DBP) - local minima
    dbp_troughs, _ = find_peaks(-abp_series, distance=50)
    dbp_values = abp_series[dbp_troughs]
    return sbp_values, dbp_values

def count_hypertension_events(sbp, dbp):
    n_sbp_hypertension = np.sum(sbp > 140)
    n_dbp_hypertension = np.sum(dbp > 90)
    return n_sbp_hypertension + n_dbp_hypertension

def cnt_hypertension(abp):
    sbp, dbp = find_sbp_dbp(abp)
    return count_hypertension_events(sbp, dbp)

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=np.arange(data_abp.shape[0]), y=data_abp, mode='lines', name='ABP'))
pio.show()
del fig

In [None]:
fs = 125
skip_s = 60*5
duration_s = 60*60*3
data_pap = signals[fs*skip_s:fs*(skip_s+duration_s), 3]
print(data_pap.shape)
pdata_pap, label_pap, mse_list_pap = p.process(data_pap, fs)
print(data_pap.shape)
print(label_pap.shape)

In [None]:
plot_data_and_label(data_pap, label_pap)

## Test on Our Dataset (Please write own dataloader function!)

In [None]:
# load our data 
# from data_provider import data_factory
# data, label = data_factory.npy_provider('240405', flag='test') # use help function
display(data.shape)
data, item_labels, mse_list = p.process(data, 120)
items_per_sample = np.mean(item_labels, axis=1)
print(calculate_cls_metrics(label, items_per_sample))
# {'accuracy': 0.95, 'f1_score': 0.9480249480249481, 'sensitivity': 0.912, 'specificity': 0.988}


## Sampling Frequency Experiments

In [None]:
# import sys
# sys.path.append("../BP-artefact-removal")
# from data_provider import data_factory

# data, label = data_factory.npy_provider('240405', flag='test')
# print(data.shape, label.shape)

In [None]:
import time

def resample_to(data, ori_fs, target_fs):
    assert data.ndim == 2, "Data must be 2D, such as N*1200"
    
    num_samples = int((target_fs / ori_fs) * data.shape[1])
    resampled = np.zeros((data.shape[0], num_samples))
    for i in range(data.shape[0]):
        resampled[i, :] = resample(data[i, :], num_samples)
    return resampled

# Target sampling rates
target_fs_list = [50, 75, 100, 120, 125, 150, 175, 200, 240]

for target_fs in target_fs_list:
    if target_fs != 120:
        resampled_data = resample_to(data, 120, target_fs)
    else:
        resampled_data = data
    print(f"Target FS: {target_fs} Hz, Resampled Shape: {resampled_data.shape}")
    tit = time.time()
    data_recon, item_labels, mse_list = p.process(resampled_data, target_fs)
    tot = time.time()
    items_per_sample = np.around(np.mean(item_labels, axis=1))
    print(tot-tit)
    print(calculate_cls_metrics(label, items_per_sample))
