In [None]:
! pip install geopandas==0.14.4

In [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import matplotlib.patches as patch
import json
import geopandas as gpd
from pathlib import Path

# suppress warnings
import warnings
warnings.filterwarnings('ignore')

path_to_xlxs = "./xlxs"

Code to generate Figure 1

In [98]:
root = Path('./metadata/')

sns.set(style="whitegrid")

sex = {}
age = {}
length = {}
labels = {}
geographic_origin = {}
fns = root.glob('*.json')
for fn in fns:
    with open(fn) as f:
        js = json.load(f)
    sex[fn.stem[:-9]] = js['sex']
    age[fn.stem[:-9]] = js['age']
    length[fn.stem[:-9]] = js['length']
    labels[fn.stem[:-9]] = js['labels']
    geographic_origin[fn.stem[:-9]] = js['geographic_origin']
    
df_sex = pd.DataFrame(sex)
df_sex = df_sex.reindex(sorted(df_sex.columns), axis=1)
df_age = pd.DataFrame(age)
df_age = df_age.reindex(sorted(df_age.columns), axis=1)
df_length = pd.DataFrame(length)
df_length = df_length.reindex(sorted(df_length.columns), axis=1)
df_labels = pd.DataFrame(labels)
df_labels = df_labels.reindex(sorted(df_labels.columns), axis=1)

In [None]:
plt.figure(figsize=(12, 8))
ax = df_sex.T.plot(kind='bar', stacked=True, colormap='viridis', edgecolor='black')
plt.xticks(rotation=45, ha='right', fontsize=12)

plt.title('Sex Divisions Across Datasets', fontsize=20)
plt.xlabel('Datasets', fontsize=15)
plt.ylabel('Percentage', fontsize=15)

plt.legend(title='Sex', title_fontsize='13', fontsize='11')

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 8))

colors = sns.color_palette("viridis", n_colors=len(df_length.columns))
plt.bar(df_length.sum(axis=0).index, np.log10(df_length.sum(axis=0)), alpha=0.7, color=colors)
plt.xticks(rotation=45, ha='right', fontsize=12)

plt.title('Samples Across Dataset', fontsize=20)
plt.xlabel('Dataset', fontsize=15)
plt.ylabel('Count (log10 scale)', fontsize=15)

plt.tight_layout()
plt.savefig('sample_count_across_datasets.pdf', dpi=300)

In [None]:
dataset_sums = df_length.sum(axis=0)

colors = sns.color_palette("viridis", n_colors=len(df_length.columns))

fig, ax = plt.subplots(figsize=(10, 10), facecolor='white')

plt.title('Sample Distribution Across Datasets', fontsize=22, weight='bold', pad=20)

wedges, texts = ax.pie(
    dataset_sums, 
    colors=colors, 
    startangle=140, 
    wedgeprops={'edgecolor': 'white', 'linewidth': 2}
)

centre_circle = plt.Circle((0, 0), 0.70, color='white', fc='white', linewidth=1.25)
fig.gca().add_artist(centre_circle)

ax.axis('equal')  


center_text = "\n".join([f"{name}: {count}" for name, count in zip(dataset_sums.index, dataset_sums.values)])
ax.text(0, 0, center_text, ha='center', va='center', fontsize=14, weight='bold', color='black')

plt.tight_layout()
plt.savefig('size_datasets_pie.pdf', dpi=300)

In [None]:
df_age_norm = df_age / df_age.sum(axis=0)

bar_width = 0.1
x = np.arange(len(df_age_norm.index)) 

plt.figure(figsize=(12, 8))
for i, dataset in enumerate(df_age_norm.columns):
    plt.bar(x + i * bar_width, df_age_norm[dataset], width=bar_width, label=dataset, alpha=0.7, color=colors[i])

plt.title('Age Distribution Across Datasets', fontsize=20)
plt.xlabel('Age Range', fontsize=15)
plt.ylabel('Count', fontsize=15)
plt.xticks(x + bar_width * (len(df_age_norm.columns) / 2), df_age_norm.index)

plt.legend(title='Datasets', title_fontsize='13', fontsize='11')

plt.tight_layout()
plt.savefig('age_distribution_across_datasets.pdf', dpi=300)

In [None]:
df_length.loc['>40'] = df_length.loc['40-45':].sum(axis=0)

df_length = df_length.drop(df_length.index[df_length.index.get_loc('40-45'):df_length.index.get_loc('>40')])
df_length = df_length.drop(['0-5','5-10'])

 
df_length_norm = df_length / df_length.sum(axis=0)

bar_width = 0.1
x = np.arange(len(df_length_norm.index)) 
plt.figure(figsize=(10, 8))
for i, dataset in enumerate(df_length_norm.columns):
    plt.bar(x + i * bar_width, df_length_norm[dataset], width=bar_width, label=dataset, alpha=0.7, color=colors[i])

plt.title('Lenght Distribution Across Datasets', fontsize=20)
plt.xlabel('Lenght Range', fontsize=15)
plt.ylabel('Count', fontsize=15)
plt.xticks(x + bar_width * (len(df_length_norm.columns) / 2), df_length_norm.index)

plt.legend(title='Datasets', title_fontsize='13', fontsize='11')

plt.tight_layout()
plt.savefig('lenght_distribution_across_datasets.pdf', dpi=300)

In [None]:
df_labels.rename({'CHAGAS' : 'SAMITROP-DEATH'}, inplace=True)

# Mapping each class to a specific color
class_colors = {
    2: 'tomato',   # # Group 2: Red color for "ECG is not used in clinic: prediction of CVE"
    1: 'skyblue',   # Group 1: Yellow color for "ECG is a Supportive Diagnostic Tool"
    0: 'lightgreen'     # Group 0: Green color for "ECG is the Primary Diagnostic Tool"
}

classes_map = { 
    'SAMITROP-DEATH': 2,
    'TIA': 2,
    'EAMI':1,
    'IPLMI':1,
    'PMI':1,
    'ILMI':1,
    'IPMI':1,
    'ALMI':1,
    'LMI':1,
    'ASMI':1,
    'IMI':1,
    'INJLA':1,
    'LAA':1,
    'LAH':1,
    'RAAB':1,
    'LVH':1,
    'SEHYP':1,
    'RVH':1,
    'AH':1,
    'VH':1,
    'CHD':1,
    'CMIS':1,
    'HF':1,
    'HVD':1,
    'LVS':1,
    'PACE':1,
    'ISC_':1,
    'HYP': 1,
}

plt.figure(figsize=(18, 6)) 
sns.set(style="whitegrid")

colors = sns.color_palette("viridis", n_colors=len(df_labels.columns))
for i, dataset in enumerate(df_labels.columns):
    plt.bar(df_labels.index, np.log10(df_labels[dataset]), label=dataset, color=colors[i], alpha=0.7)

plt.xticks(rotation=45, ha='right', fontsize=6)

for label in df_labels.index:
    class_group = classes_map.get(label, 0)  
    bar_color = class_colors[class_group]
    
    plt.bar(label, -0.25, width=0.8, color=bar_color, align='center')  

plt.title('Label Distribution Across Datasets', fontsize=20)
plt.xlabel('Label', fontsize=15)
plt.ylabel('Count (log10 scale)', fontsize=15)
plt.grid(visible=False)

plt.legend(title='Datasets', title_fontsize='13', fontsize='11')

plt.tight_layout()
plt.savefig('label_distribution_across_datasets_with_classes.pdf', dpi=300)

plt.show()

In [None]:
# Define the geographical origin of the datasets
dataset_countries = {'mimic': 'United States of America',
 'ribeiroLabled': 'Brazil',
 'samitrop': 'Brazil',
 'chapman': 'China',
 'georgia': 'United States of America',
 'ptb': 'Germany',
 'ningbo': 'China',
 'ribeiroUnlabled': 'Brazil',
 'cpscExtra': 'China',
 'hefei': 'China',
 'cpsc': 'China',
 'sph': 'China',
 'ptbxl': 'Germany'}

# World map data
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
#world = gpd.read_file("path/to/your/ne_110m_admin_0_countries.shp")

# Extract relevant countries from the dataset_countries mapping
countries = set(dataset_countries.values())
highlight_countries = world[world['name'].isin(countries)]

# Plotting the world map and highlighting the relevant countries
plt.figure(figsize=(10, 10))
world.plot(ax=plt.gca(), color='lightgray')
highlight_countries.plot(ax=plt.gca(), color='skyblue')

# Adding labels for each dataset
for dataset, country in dataset_countries.items():
    country_data = highlight_countries[highlight_countries['name'] == country]
    #plt.text(country_data.geometry.centroid.x.values[0], 
    #         country_data.geometry.centroid.y.values[0], 
    #         dataset, fontsize=12, ha='center')

# Adding title
plt.title('Geographical Origin of Datasets', fontsize=20)
plt.savefig('geographical_origin_datasets.pdf', dpi=300)

Code to generate Figure 2b, 2c and 2d in the paper, that is, the label-wise performance obtained by HuBERT-ECG SMALL, BASE and LARGE across tasks

In [None]:
paths = ["Figure 2b.xlsx", "Figure 2c.xlsx", "Figure 2d.xlsx"]
sizes = ['small', 'base', 'large']

labels_abbreviations = pd.read_csv(os.path.join(path_to_xlxs, "labels_abbreviations.csv"), sep=';')
labels_abbreviations = labels_abbreviations[["Abbreviation",
                                            "Gruppo 1 (ECG is the Primary Diagnostic Tool)",
                                            "Gruppo 2 (ECG is a Supportive, Not Primary, Diagnostic Tool)",
                                            "Gruppo 3 (prediction of CVE)"]]

labels_abbreviations.rename(columns={"Gruppo 1 (ECG is the Primary Diagnostic Tool)": "Gruppo 1",
                                    "Gruppo 2 (ECG is a Supportive, Not Primary, Diagnostic Tool)": "Gruppo 2",
                                    "Gruppo 3 (prediction of CVE)" : "Gruppo 3"}, inplace=True)

labels_abbreviations.set_index("Abbreviation", inplace=True)
labels_abbreviations.fillna(0, inplace=True)

for size, path in zip(sizes, paths):
    performance = pd.read_excel(os.path.join(path_to_xlxs, path), index_col=0)


    fig, ax = plt.subplots(figsize=(40, 10))

    # Generate the color palette
    colors = sns.color_palette('tab20', len(performance.index))

    # Plot each label with its corresponding color
    for i, label in enumerate(performance.index):
        ax.scatter(performance.columns, performance.iloc[i], label=label, color=colors[i])

    ax.set_xlabel('CONDITIONS', fontsize=15)
    ax.set_xticklabels(performance.columns, rotation=90, fontsize=15)
    ax.set_yticklabels(np.round(ax.get_yticks(), 1), fontsize=15)
    ax.set_xlim(-1, len(performance.columns) + 0.01)
    ax.set_ylabel('AUROC', fontsize=15)
    fig.suptitle(f'HuBERT-ECG {size.upper()} label-wise performance', fontsize=23, y=1.03)

    # Create the legend with the same colors
    legend = ax.legend(ncol=len(performance.index)//3, 
                    handles=[patch.Patch(color=colors[i], label=performance.index[i]) for i in range(len(performance.index))],
                    title='Tasks',
                    loc=(0.42, 1.02),
                    fontsize=18, 
                    title_fontsize=15)

    # Add colored patches under the x labels to show the group of the label
    for i, label in enumerate(performance.columns):
        group = labels_abbreviations.loc[label].values.sum()
        if group == 1:
            color = 'green'
        elif group == 2:
            color = 'blue'
        elif group == 3:
            color = 'red'
        plt.axvspan(i-0.5, i+0.5, color=color, alpha=0.2)

    group_legend = ax.legend(handles=[patch.Patch(color='green', label='ECG is the primary diagnostic tool', alpha=0.2),
                                    patch.Patch(color='blue', label='ECG is a supportive, not primary, diagnostic tool', alpha=0.2),
                                    patch.Patch(color='red', label='Prediction of CVE', alpha=0.2)],
                            loc=(0.0, 1.07),
                            ncols=3,
                            title='ECG diagnostic role-based classes',
                            fontsize=18, 
                            title_fontsize=15) 
    fig.add_artist(legend)
    fig.add_artist(group_legend)
    plt.tight_layout()  
    plt.hlines(0.9, 0-0.5, len(performance.columns)-0.5, linestyles='dashed', color='black', linewidth=0.5)
    fig.savefig("./label_wise_performance_" + size + ".svg", bbox_inches = "tight")


Code to generate Supplementary Figures

In [None]:
supplementary_figure_1 = pd.read_excel(os.path.join(path_to_xlxs, "Supplementary figure 1.xlsx"), index_col=0)

supplementary_figure_1

c = supplementary_figure_1.index

time_freq = supplementary_figure_1["time-freq"]
mixed = supplementary_figure_1["mixed"]
mfccs = supplementary_figure_1["mfccs"]

db_time_freq = 1.225
db_mixed = 1.516
db_mfcc = 1.213

plt.figure()
plt.plot(c, time_freq, '--*', label=f"time_freq (n=16, DB_C100={db_time_freq})", color='tab:orange')
plt.plot(c, mixed, '--x', label=f"mixed (n=29, DB_C100={db_mixed})", color='tab:blue')
plt.plot(c, mfccs, '--*', label=f"mfcc (n=39, DB_C100={db_mfcc})", color='tab:green')
plt.grid()
plt.legend()
plt.xlabel("Number of clusters - C")
plt.ylabel("SSE")
plt.show()
plt.savefig("./SSE_vs_C.svg", bbox_inches='tight')

In [None]:
supplementary_figure_2 = pd.read_excel(os.path.join(path_to_xlxs, "Supplementary Figure 2.xlsx"))
supplementary_figure_2.index = supplementary_figure_2.index + 1

fig, imgs = plt.subplots(nrows=1, ncols=2, figsize=(20, 5))
img1, img2 = imgs
img1.plot(supplementary_figure_2['Pre-training 100 Hz Validation Loss'], '-o', label='100 Hz')
img1.plot(supplementary_figure_2['Pre-training 50 Hz Validation Loss'], '-o', label='50 Hz')
img1.grid(True)
img1.set_xlabel("Steps x 2500")
img1.set_ylabel("Validation loss")
img1.set_xticks([1, 10, 20, 30, 40, 50])
img1.legend(loc="upper right")
img1.set_title("(a)")

img2.plot(supplementary_figure_2['Macro-avg AUROC Linear Evaluation 100 Hz'], '-o', label='100 Hz')
img2.plot(supplementary_figure_2['Macro-avg AUROC Linear Evaluation 50 Hz'], '-o', label='50 Hz')
img2.grid(True)
img2.set_xlabel("Steps x 5000")
img2.set_ylabel("Macro-avg AUROC")
img2.set_xticks([1, 3, 5, 7, 9, 11, 13])
img2.legend(loc="lower right")
img2.set_title("(b)")

plt.show()
plt.savefig("upstream_downstream_performance_varying_samp_rate.svg", bbox_inches='tight')


In [None]:
supplementary_figure_3 = pd.read_excel(os.path.join(path_to_xlxs, "Supplementary Figure 3.xlsx"))

sse_500 = supplementary_figure_3["SSE C 500"]
sse_1000 = supplementary_figure_3["SSE C 1000"]
db_500 = supplementary_figure_3["DB C 500"]
db_1000 = supplementary_figure_3["DB C 1000"]
ch_500 = supplementary_figure_3["CH C 500"]
ch_1000 = supplementary_figure_3["CH C 1000"]

sse_500_it1 = sse_500.iloc[:6]
sse_500_it2 = sse_500.iloc[6:]
sse_1000_it1 = sse_1000.iloc[:6]
sse_1000_it2 = sse_1000.iloc[6:]

db_500_it1 = db_500.iloc[:6]
db_500_it2 = db_500.iloc[6:]
db_1000_it1 = db_1000.iloc[:6]
db_1000_it2 = db_1000.iloc[6:]

ch_500_it1 = ch_500.iloc[:6]
ch_500_it2 = ch_500.iloc[6:]
ch_1000_it1 = ch_1000.iloc[:6]
ch_1000_it2 = ch_1000.iloc[6:]

layers = [5, 6, 7, 8, 9, 10]
                                       
                            

fig, imgs = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
img1, img2, img3 = imgs
img1.plot(layers, sse_500_it1, '-s', label='C = 500 (it1)', color='tab:blue')
img1.plot(layers, sse_1000_it1, '-D', label='C = 1000 (it1)', color='tab:red')
img1.plot(layers, sse_500_it2, '-s', label='C = 500 (it2)', color='tab:green')
img1.plot(layers, sse_1000_it2, '-D', label='C = 1000 (it2)', color='tab:purple')
img1.grid(True)
img1.set_xticks(np.arange(1, 13))
img1.fill_between(np.arange(1, 5, 0.1), 0, 4300, color='tab:grey', alpha=0.5)
img1.fill_between(np.arange(10.1, 12.1, 0.1), 0, 4300, color='tab:grey', alpha=0.5)
img1.set_xlabel("Encoding layers")
img1.set_ylabel("SSE ←")
img1.legend(loc="lower left")

img2.plot(layers, db_500_it1, '-s', label='C = 500 (it1)', color='tab:blue')
img2.plot(layers, db_1000_it1, '-D', label='C = 1000 (it1)', color='tab:red')
img2.plot(layers, db_500_it2, '-s', label='C = 500 (it2)', color='tab:green')
img2.plot(layers, db_1000_it2, '-D', label='C = 1000 (it2)', color='tab:purple')
img2.grid(True)
img2.set_xlabel("Encoding layers")
img2.set_ylabel("Davies-Bouldin ←")
img2.legend(loc="lower left")
img2.set_xticks(np.arange(1, 13))
img2.fill_between(np.arange(1, 5, 0.1), 0, 3, color='tab:grey', alpha=0.5)
img2.fill_between(np.arange(10.1, 12.1, 0.1), 0, 3, color='tab:grey', alpha=0.5)

img3.plot(layers, ch_500_it1, '-s', label='C = 500 (it1)', color='tab:blue')
img3.plot(layers, ch_1000_it1, '-D', label='C = 1000 (it1)', color='tab:red')
img3.plot(layers, ch_500_it2, '-s', label='C = 500 (it2)', color='tab:green')
img3.plot(layers, ch_1000_it2, '-D', label='C = 1000 (it2)', color='tab:purple')
img3.grid(True)
img3.set_xlabel("Encoding layers")
img3.set_ylabel("Calinsky-Harabasz →")
img3.legend(loc="upper left")
img3.set_xticks(np.arange(1, 13))
img3.fill_between(np.arange(1, 5, 0.1), 0, 65, color='tab:grey', alpha=0.5)
img3.fill_between(np.arange(10.1, 12.1, 0.1), 0, 65, color='tab:grey', alpha=0.5)

plt.show()

plt.savefig("clustering_quality_across_iterations_and_layers.svg", bbox_inches='tight')

In [None]:
supplementary_figure_4 = pd.read_excel(os.path.join(path_to_xlxs, "Supplementary Figure 4.xlsx"), index_col=0)

plt.figure(figsize=(12, 8))
plt.plot(supplementary_figure_4['Linear Evaluation BASE it1'], '-o', color='tab:orange', label='BASE it1')
plt.plot(supplementary_figure_4['Linear Evaluation BASE it2'], '-o', color='tab:blue', label='BASE it2')
plt.plot(supplementary_figure_4['Linear Evaluation SMALL'], '-o', color='tab:green', label='SMALL')
plt.plot(supplementary_figure_4['Linear Evaluation LARGE'], '-o', color='tab:red', label='LARGE')
#plt.axvline(13, linestyle='--', color='tab:grey', label='plateau BASE it1')
plt.grid()
plt.legend()
plt.xlabel("Steps")
plt.ylabel("Macro-avg AUROC")
plt.xticks(supplementary_figure_4.index)
plt.show()
plt.savefig("linear_eval_varying_model_sizes.svg", bbox_inches='tight')


In [None]:
supplementary_figure_5 = pd.read_excel(os.path.join(path_to_xlxs, "Supplementary Figure 5.xlsx"), index_col=0)

p = supplementary_figure_5['masking p']
aucs = supplementary_figure_5['Linear Evaluation AUROC']
plt.figure(figsize=(12, 8))
plt.plot(p, aucs, '-o')
plt.axvline(0.21, linestyle='--', color='tab:red')
plt.xlabel("Masking percentage p = Percentage of masked embeddings")
plt.xticks(p)
plt.ylabel("Macro-averaged AUROC")
plt.grid()
plt.savefig("masking_p.svg", bbox_inches='tight')