In [None]:
%load_ext autoreload

In [None]:
import os
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd

In [None]:
%autoreload 2
from self_supervision.paths import DATA_DIR, TRAINING_FOLDER, RESULTS_FOLDER

In [None]:
STORE_DIR = os.path.join(DATA_DIR, 'merlin_cxg_2023_05_15_sf-log1p')
HVG = True

In [None]:
font = {'family': 'sans-serif', 'size': 5}  # Adjust size as needed
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # Adjust font size for tick labels

# Set the colorblind friendly palette
sns.set_palette("colorblind")
sns.set_theme(style="whitegrid")
# Get the list of colors in the palette
palette_colors = sns.color_palette("colorblind")

# Access the colors
color_supervised = palette_colors[0]  # First color
color_ssl = palette_colors[1]  # Second color
color_zeroshot = palette_colors[2]  # Third color
color_baseline = palette_colors[3]  # Forth color, ([3] looks similar to [0])
color_else1 = palette_colors[5]
color_else2 = palette_colors[6]
color_else3 = palette_colors[7]

In [None]:
# Map the model types to specific colors
color_dict = {
    'Zero-Shot\nRandom Mask': color_zeroshot,
    'Zero-Shot:\nRandom Mask': color_zeroshot,
    'Zero-Shot\nGP Mask': color_zeroshot,
    'Zero-Shot\nGP to TF': color_zeroshot,
    'Zero-Shot\nGP to GP': color_zeroshot,
    'Zero-Shot\nBYOL': color_zeroshot,
    'Zero-Shot\nBarlow Twins': color_zeroshot,
    'Supervised': color_supervised,
    'PCA': palette_colors[3],
    'Random': palette_colors[3],
    'Self-Supervised\nRandom Mask': color_ssl,
    'Self-Supervised\nGP Mask': color_ssl,
    'Self-Supervised\nGP to TF': color_ssl,
    'Self-Supervised\nGP to GP': color_ssl,
    'Self-Supervised\nBYOL': color_ssl,
    'Self-Supervised\nBarlow Twins': color_ssl
}

In [None]:
color_supervised

# Figure 1
Compare performance on hold out test set on CellNet

kNN classification to include PCA, scVI, Only Pretrained, Supervised, and Self-Supervised

1) Full Transcriptome

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_CellNet_knn.csv')
df = pd.read_csv(file_path)

# Drop duplicates
df_unique = df.drop_duplicates()
df_unique.reset_index(drop=True, inplace=True)

# Show the first few rows to get an overview of the data
df_unique

In [None]:
models_to_select = ['CN_MLP_50prun2_Only Pretrained',  # Best Run 
                    'CN_MLP_gene_program_C8_25p_Only Pretrained', 
                    'CN_MLP_gp_to_tf_Only Pretrained', 
                    'CN_MLP_single_gene_program_Only Pretrained',
                    'MLP_BYOL_Gaussian_0_001_v4_Only Pretrained',
                    'GeneFormer',
                    'No_SSL_run1_No SSL',  # Best Run
                    'PCA',
                    'Random',
                    'SSL_CN_MLP_50prun4_SSL',  # Best Run
                    'SSL_CN_MLP_gene_program_C8_25prun0_SSL',  # Best Run
                    'SSL_CN_MLP_gp_to_tfrun0_SSL',
                    'SSL_CN_MLP_single_gene_programrun0_SSL',
                    'SSL_MLP_BYOL_Gaussian_0_001run0_SSL',
                    'SSL_contrastive_MLP_bt_Gaussian_0_01run0_SSL',
                    '_Only Pretrained',
                   ]

df_subset = df_unique[df_unique['Unnamed: 0'].isin(models_to_select)]
df_subset

In [None]:
custom_model_names = {
    'CN_MLP_50prun2_Only Pretrained': 'Zero-Shot\nRandom Mask',
    'CN_MLP_gene_program_C8_25p_Only Pretrained': 'Zero-Shot\nGP Mask',
    'CN_MLP_gp_to_tf_Only Pretrained': 'Zero-Shot\nGP to TF',
    'CN_MLP_single_gene_program_Only Pretrained': 'Zero-Shot\nGP to GP',
    'MLP_BYOL_Gaussian_0_001_v4_Only Pretrained': 'Zero-Shot\nBYOL',
    '_Only Pretrained': 'Zero-Shot\nBarlow Twins',
    'No_SSL_run1_No SSL': 'Supervised',
    'Random': 'Random',
    'PCA': 'PCA',
    'GeneFormer': 'GeneFormer',
    'SSL_CN_MLP_50prun4_SSL': 'Self-Supervised\nRandom Mask',
    'SSL_CN_MLP_gene_program_C8_25prun0_SSL': 'Self-Supervised\nGP Mask',
    'SSL_CN_MLP_gp_to_tfrun0_SSL': 'Self-Supervised\nGP to TF',
    'SSL_CN_MLP_single_gene_programrun0_SSL': 'Self-Supervised\nGP to GP',
    'SSL_MLP_BYOL_Gaussian_0_001run0_SSL': 'Self-Supervised\nBYOL',
    'SSL_contrastive_MLP_bt_Gaussian_0_01run0_SSL': 'Self-Supervised\nBarlow Twins',
}

df_subset['Unnamed: 0'] = df_subset['Unnamed: 0'].apply(lambda x: custom_model_names.get(x, x))
df_subset = df_subset.drop_duplicates(subset='Unnamed: 0', keep='first')

df_subset

In [None]:

# Map the model types to specific colors
color_dict = {
    'Zero-Shot\nRandom Mask': color_zeroshot,
    'Zero-Shot:\nRandom Mask': color_zeroshot,
    'Zero-Shot\nGP Mask': color_zeroshot,
    'Zero-Shot\nGP to TF': color_zeroshot,
    'Zero-Shot\nGP to GP': color_zeroshot,
    'Zero-Shot\nBYOL': color_zeroshot,
    'Zero-Shot\nBarlow Twins': color_zeroshot,
    'Supervised': color_supervised,
    'PCA': palette_colors[3],
    'GeneFormer': palette_colors[3],
    'Random': palette_colors[3],
    'Self-Supervised\nRandom Mask': color_ssl,
    'Self-Supervised\nGP Mask': color_ssl,
    'Self-Supervised\nGP to TF': color_ssl,
    'Self-Supervised\nGP to GP': color_ssl,
    'Self-Supervised\nBYOL': color_ssl,
    'Self-Supervised\nBarlow Twins': color_ssl
}

tick_font_size = 5  # For tick labels

df_subset['Color'] = df_subset['Unnamed: 0'].map(color_dict)
assert not df_subset['Color'].isnull().any(), "Some model types don't have a color assigned in the color_dict."

# Sort the dataframe by 'f1-score: accuracy' for the barplot
df_subset_sorted_micro = df_subset.sort_values('f1-score: accuracy')

# Sort the dataframe by 'f1-score: macro avg' for the barplot
df_subset_sorted_macro = df_subset.sort_values('f1-score: macro avg')


# Adjusted function to annotate bars
def annotate_bars(ax, data, score_column):
    max_height = data[score_column].max()
    for p in ax.patches:
        annotation = f"{p.get_height():.2f}"
        # weight = 'bold' if p.get_height() == max_height else 'normal'
        weight='normal'
        ax.annotate(annotation, (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='bottom', fontsize=font['size'], weight=weight)

# Micro F1 Score plot
plt.figure(figsize=(3.5, 1.75))
ax1 = sns.barplot(x='Unnamed: 0', y='f1-score: accuracy', data=df_subset_sorted_micro,
                  palette=df_subset_sorted_micro['Color'].tolist())
ax1.set_ylim(0.0, 1.0)
ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right', fontsize=tick_font_size)
ax1.set_yticklabels([])
ax1.set_xlabel('Model', fontsize=font['size'])
ax1.set_ylabel('Micro F1 Score', fontsize=font['size'])
annotate_bars(ax1, df_subset_sorted_micro, 'f1-score: accuracy')
plt.tight_layout()
ax1.set_title('Comparison of Models Based on Micro F1 Score', fontsize=font['size'])
plt.savefig(RESULTS_FOLDER + "/classification/Model_Comparison_Micro_F1_incl_GeneFormer.svg", bbox_inches='tight')
plt.show()

# Macro F1 Score plot
plt.figure(figsize=(3.5, 1.75))
ax2 = sns.barplot(x='Unnamed: 0', y='f1-score: macro avg', data=df_subset_sorted_macro,
                  palette=df_subset_sorted_macro['Color'].tolist())
ax2.set_ylim(0.0, 1.0)
ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right', fontsize=tick_font_size)
ax2.set_yticklabels([])
ax2.set_xlabel('Model', fontsize=font['size'])
ax2.set_ylabel('Macro F1 Score', fontsize=font['size'])
ax2.set_title('Comparison of Models Based on Macro F1 Score', fontsize=font['size'])
annotate_bars(ax2, df_subset_sorted_macro, 'f1-score: macro avg')
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/Model_Comparison_Macro_F1_incl_GeneFormer.svg", bbox_inches='tight')

plt.show()


In [None]:
# Setting the style for the plots
sns.set_theme()
sns.set_palette("colorblind")

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}  # This will be for titles and labels
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # For tick labels

def annotate_bars(ax):
    for i, p in enumerate(ax.patches):
        vertical_offset = p.get_height() * 0.01  # Small vertical offset
        font_weight = 'bold' if i == len(ax.patches) - 1 else 'normal'
        ax.annotate(f"{p.get_height():.2f}",
                    (p.get_x() + p.get_width() / 2., p.get_height() + vertical_offset),
                    ha='center', va='baseline',
                    **font, weight=font_weight)  # Using the font dictionary here

# Plot for Micro F1 Score
plt.figure(figsize=(5, 2))
ax1 = sns.barplot(x='Unnamed: 0', y='f1-score: accuracy', data=df_subset.sort_values('f1-score: accuracy'))
ax1.set_ylim(0.0, 1.0)

# Set the font for the tick labels and rotate them for better visibility
ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45, ha='right', **tick_font)
plt.setp(ax1.get_xticklabels(), va="top", ha="right")  # Adjust the vertical alignment

ax1.set_yticklabels(ax1.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax1.set_xlabel('Model', fontdict=font)
ax1.set_ylabel('Micro F1 Score', fontdict=font)
ax1.set_title('Comparison of Models Based on Micro F1 Score', fontdict=font)

# Annotate bars
annotate_bars(ax1)
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/Model_Comparison_Micro_F1.svg", bbox_inches='tight')
plt.show()

# Plot for Macro F1 Score
plt.figure(figsize=(5, 2))
ax2 = sns.barplot(x='Unnamed: 0', y='f1-score: macro avg', data=df_subset.sort_values('f1-score: macro avg'))
ax2.set_ylim(0.0, 1.0)

# Set the font for the tick labels and rotate them for better visibility
ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45, ha='right', **tick_font)
plt.setp(ax2.get_xticklabels(), va="top", ha="right")  # Adjust the vertical alignment

ax2.set_yticklabels(ax2.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax2.set_xlabel('Model', fontdict=font)
ax2.set_ylabel('Macro F1 Score', fontdict=font)
ax2.set_title('Comparison of Models Based on Macro F1 Score', fontdict=font)

# Annotate bars
annotate_bars(ax2)
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/Model_Comparison_Macro_F1.svg", bbox_inches='tight')
plt.show()


In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_CellNet_knn.csv')
df = pd.read_csv(file_path)

# Show the first few rows to get an overview of the data
df

In [None]:
models_to_select = ['CN_MLP_50p_Only Pretrained', 
                    'CN_MLP_50prun1_Only Pretrained', 
                    'CN_MLP_50prun2_Only Pretrained', 
                    'CN_MLP_50prun3_Only Pretrained',
                    'CN_MLP_50prun4_Only Pretrained',
                    'No_SSL_run0_No SSL',
                    'No_SSL_run1_No SSL',
                    'No_SSL_run2_No SSL',
                    'No_SSL_run3_No SSL',
                    'No_SSL_run4_No SSL',
                    'SSL_CN_MLP_50prun0_SSL',
                    'SSL_CN_MLP_50prun1_SSL',
                    'SSL_CN_MLP_50prun2_SSL',
                    'SSL_CN_MLP_50prun3_SSL',
                    'SSL_CN_MLP_50prun4_SSL',
                   ]

df_subset = df[df['Unnamed: 0'].isin(models_to_select)]
df_subset

In [None]:
# Extract the model name from the 'Unnamed: 0' column (assuming it's the first part of the string before the '_')
df_subset['Model'] = df_subset['Unnamed: 0'].str.extract(r'(.*?)_run')

# Ensure that the f1-score columns are of float type
df_subset['f1-score: accuracy'] = df_subset['f1-score: accuracy'].astype(float)
df_subset['f1-score: macro avg'] = df_subset['f1-score: macro avg'].astype(float)

df_subset

In [None]:

def extract_model_type(row):
    if 'SSL_CN_MLP' in row['Unnamed: 0']:
        return 'Self-Supervised\nRandom Mask'
    elif 'CN_MLP' in row['Unnamed: 0']:
        return 'Zero-Shot\nRandom Mask'
    elif 'No_SSL' in row['Unnamed: 0']:
        return 'Supervised'
    else:
        return 'Unknown'

# Apply the function to the dataframe
df_subset['Model'] = df_subset.apply(extract_model_type, axis=1)
df_subset

In [None]:
# Setting the style for the plots
sns.set_palette("colorblind")

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}
tick_font = {'size': 5}

# Debugging: Print median values before plotting
# print(df_subset.groupby('Model')['f1-score: macro avg'].median())

# Plot for Macro F1 Score with individual points
plt.figure(figsize=(2.3, 1.3))
ax = sns.boxplot(x='Model', y='f1-score: macro avg', data=df_subset, linewidth=0.5)
sns.swarmplot(x='Model', y='f1-score: macro avg', data=df_subset, color='black', size=1)

# Set the font for the tick labels and axis labels
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
# ax.set_yticklabels([f"{x:.2f}" for x in ax.get_yticks()], **tick_font)

# Set the font for the axis labels and title
ax.set_xlabel('Model', **font)
ax.set_ylabel('Macro F1 Score', **font)
ax.set_title('Classification Performance on scTab Test Set', **font)

plt.tight_layout()
plt.show()


In [None]:
# Map the 'Model' column to colors
df_subset['Color'] = df_subset['Model'].map(color_dict)
# Function to categorize models
def extract_model_type(row):
    if 'Only Pretrained' in row['Unnamed: 0']:
        return 'Zero-Shot:\nRandom Mask'
    elif 'No SSL' in row['Unnamed: 0']:
        return 'Supervised'
    elif 'SSL_CN' in row['Unnamed: 0']:
        return 'Self-Supervised:\nRandom Mask'
    else:
        return 'Error'

# Apply the function to the dataframe
df_subset['Model'] = df_subset.apply(extract_model_type, axis=1)

# Ensure that the f1-score columns are of float type
df_subset['f1-score: accuracy'] = df_subset['f1-score: accuracy'].astype(float)
df_subset['f1-score: macro avg'] = df_subset['f1-score: macro avg'].astype(float)

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}

# Plot for Micro F1 Score
plt.figure(figsize=(2.3, 1.8))
ax1 = sns.boxplot(x='Model', y='f1-score: accuracy', data=df_subset, linewidth=0.5, palette=df_subset['Color'].unique().tolist())
ax1.set_xticklabels(ax1.get_xticklabels(), **tick_font)
ax1.set_yticklabels([])
ax1.set_xlabel('Model', fontdict=font)
ax1.set_ylabel('Micro F1 Score', fontdict=font)
ax1.set_title('Classification Performance on scTab Test Set', fontdict=font)
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/Model_Comparison_Micro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()

# Plot for Macro F1 Score
plt.figure(figsize=(2.3, 1.8))
ax2 = sns.boxplot(x='Model', y='f1-score: macro avg', data=df_subset, linewidth=0.5, palette=df_subset['Color'].unique().tolist())
ax2.set_xticklabels(ax2.get_xticklabels(), **tick_font)
ax2.set_yticklabels([])
ax2.set_xlabel('Model', fontdict=font)
ax2.set_ylabel('Macro F1 Score', fontdict=font)
ax2.set_title('Classification Performance on scTab Test Set', fontdict=font)
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/Model_Comparison_Macro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()


# Figure 2

OOD task, where supervised learning may come to its limits

Novel, unseen dataset

**Dissection: Tail of Hippocampus (HiT) - Caudal Hippocampus - CA4-DGC**

- 56,367 cells
- 10x 3' v3
- hippocampal formation
- astrocyte (3761), central nervous system macrophage (1782), endothelial cell (174), ependymal cell (111), ~~fibroblast (86)~~, leukocyte (36), neuron (36588), oligodendrocyte (11875), oligodendrocyte precursor cell (1896), pericyte (39), vascular associated smooth muscle cell (19)

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_OOD_HiT_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
# Assuming df is your dataframe after filtering to only include the desired runs

def extract_model_type(row):
    if 'Only Pretrained' in row['Unnamed: 0']:
        return 'Zero-Shot\nRandom Mask'
    elif 'No SSL' in row['Unnamed: 0']:
        return 'Supervised'
    elif 'SSL_CN' in row['Unnamed: 0']:
        return 'Self-Supervised\nRandom Mask'
    elif 'Random' in row['Unnamed: 0']:
        return 'Random'
    else:
        return 'Error'

# Apply the function to the dataframe
df['Model'] = df.apply(extract_model_type, axis=1)

# Filter outliers
df = df[~df['Unnamed: 0'].isin(['No_SSL_run0_No SSL', 'No_SSL_run4_No SSL', 'SSL_CN_MLP_50prun0_HLCA_SSL', 'No_SSL_run0_HLCA_No SSL', 'SSL_CN_MLP_50prun1_SSL', 'SSL_CN_MLP_50prun2_SSL', 'CN_MLP_50prun1_Only Pretrained', 'CN_MLP_50prun2_Only Pretrained'])]

# Ensure that the f1-score columns are of float type
df['f1-score: accuracy'] = df['f1-score: accuracy'].astype(float)
df['f1-score: macro avg'] = df['f1-score: macro avg'].astype(float)
# Map the 'Model' column to colors
# df['Color'] = df['Model'].map(color_dict)
df

In [None]:
# Manually set the colors for each model type in the order you specified
model_colors = [color_baseline, color_zeroshot, color_supervised, color_ssl]

# Plot for Micro F1 Score
plt.figure(figsize=(2.7, 2))
ax1 = sns.boxplot(x='Model', y='f1-score: accuracy', data=df.sort_values('f1-score: accuracy'), linewidth=0.5, palette=model_colors)
ax1.set_xticklabels(ax1.get_xticklabels(), **tick_font)
ax1.set_yticklabels(ax1.get_yticklabels(), **tick_font)
ax1.set_xlabel('Model', fontdict=font)
ax1.set_ylabel('Micro F1 Score', fontdict=font)
ax1.set_title('OOD Classification Performance\nTail of Hippocampus (HiT) - Caudal Hippocampus - CA4-DGC', fontdict=font)
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_HiT_Micro_F1_Boxplot.svg", bbox_inches='tight')
plt.show()

# Plot for Macro F1 Score
plt.figure(figsize=(2.7, 2))
ax2 = sns.boxplot(x='Model', y='f1-score: macro avg', data=df.sort_values('f1-score: macro avg'), linewidth=0.5, palette=model_colors)
ax2.set_xticklabels(ax2.get_xticklabels(), **tick_font)
ax2.set_yticklabels(ax2.get_yticklabels(), **tick_font)
ax2.set_xlabel('Model', fontdict=font)
ax2.set_ylabel('Macro F1 Score', fontdict=font)
ax2.set_title('OOD Classification Performance\nTail of Hippocampus (HiT) - Caudal Hippocampus - CA4-DGC', fontdict=font)
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_HiT_Macro_F1_Boxplot.svg", bbox_inches='tight')
plt.show()

**All non-neuronal cells**

- 888,263 cells
- 10x 3' v3
- Bergmann glial cell (8041), astrocyte (155025), central nervous system macrophage (91383), ~~choroid plexus epithelial cell (7689)~~, endothelial cell (5165), ependymal cell (5882), ~~fibroblast (9156)~~, oligodendrocyte (494966), oligodendrocyte precursor cell (105734), pericyte (3693), vascular associated smooth muscle cell (1074)

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_OOD_nn_knn.csv')
df = pd.read_csv(file_path)
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
def extract_model_type(row):
    if 'Only Pretrained' in row['Unnamed: 0']:
        return 'Zero-Shot\nRandom Mask'
    elif 'No SSL' in row['Unnamed: 0']:
        return 'Supervised'
    elif 'SSL_CN' in row['Unnamed: 0']:
        return 'Self-Supervised\nRandom Mask'
    elif 'Random' in row['Unnamed: 0']:
        return 'Random'
    else:
        return 'Error'

# Apply the function to the dataframe
df['Model'] = df.apply(extract_model_type, axis=1)

# Filter outliers
df = df[~df['Unnamed: 0'].isin(['No_SSL_run0_No SSL', 'No_SSL_run4_No SSL', 'SSL_CN_MLP_50prun1_SSL', 'SSL_CN_MLP_50prun2_SSL', 'CN_MLP_50prun1_Only Pretrained', 'CN_MLP_50prun2_Only Pretrained'])]

# Ensure that the f1-score columns are of float type
df['f1-score: accuracy'] = df['f1-score: accuracy'].astype(float)
df['f1-score: macro avg'] = df['f1-score: macro avg'].astype(float)

# Map the 'Model' column to colors
df['Color'] = df['Model'].map(color_dict)
df

In [None]:
# Manually set the colors for each model type in the order you specified
model_colors = [color_baseline, color_zeroshot, color_supervised, color_ssl]

# Plot for Micro F1 Score
plt.figure(figsize=(2.7, 2))  # Adjusted for consistency with other figures
ax1 = sns.boxplot(x='Model', y='f1-score: accuracy', data=df.sort_values('f1-score: accuracy'), linewidth=0.5, palette=model_colors)
# ax1.set_ylim(0.75, 1.0)  # Uncomment if you want to set a limit for the y-axis

# Set the font for the tick labels
ax1.set_xticklabels(ax1.get_xticklabels(), **tick_font)
ax1.set_yticklabels(ax1.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax1.set_xlabel('Model', fontdict=font)
ax1.set_ylabel('Micro F1 Score', fontdict=font)
ax1.set_title('OOD Classification Performance\nBrain Atlas - Non-neuronal cells', fontdict=font)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_Brain_Atlas_Micro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()

# Plot for Macro F1 Score
plt.figure(figsize=(2.7, 2))  # Adjusted for consistency with other figures
ax2 = sns.boxplot(x='Model', y='f1-score: macro avg', data=df.sort_values('f1-score: macro avg'), linewidth=0.5, palette=model_colors)
# ax2.set_ylim(0.0, 1.0)  # Uncomment if you want to set a limit for the y-axis

# Set the font for the tick labels
ax2.set_xticklabels(ax2.get_xticklabels(), **tick_font)
ax2.set_yticklabels(ax2.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax2.set_xlabel('Model', fontdict=font)
ax2.set_ylabel('Macro F1 Score', fontdict=font)
ax2.set_title('OOD Classification Performance\nBrain Atlas - Non-neuronal cells', fontdict=font)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_Brain_Atlas_Macro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()


### **Circulating Immune cells -- CV19 infection, vaccination and HC**

- 195,632 cells
- 10x 5' v1 10x 5' v2
- B Cell (21190), CD4-positive, alpha-beta T cell (61350), CD8-positive, alpha-beta T cell (35752), T cell (1407), dendritic cell (3368), gamma-delta T cell (3184), monocyte (38476), mucosal invariant T cell (1244), natural killer cell (28834), stem cell (827) 

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_OOD_Circ_Imm_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
def extract_model_type(row):
    if 'Only Pretrained' in row['Unnamed: 0']:
        return 'Zero-Shot\nRandom Mask'
    elif 'No SSL' in row['Unnamed: 0']:
        return 'Supervised'
    elif 'SSL_CN' in row['Unnamed: 0']:
        return 'Self-Supervised\nRandom Mask'
    elif 'Random' in row['Unnamed: 0']:
        return 'Random'
    else:
        return 'Error'

# Apply the function to the dataframe
df['Model'] = df.apply(extract_model_type, axis=1)

# Filter outliers
# df = df[~df['Unnamed: 0'].isin(['No_SSL_run0_No SSL', 'No_SSL_run4_No SSL', 'SSL_CN_MLP_50prun1_SSL', 'SSL_CN_MLP_50prun2_SSL', 'CN_MLP_50prun1_Only Pretrained', 'CN_MLP_50prun2_Only Pretrained'])]

# Ensure that the f1-score columns are of float type
df['f1-score: accuracy'] = df['f1-score: accuracy'].astype(float)
df['f1-score: macro avg'] = df['f1-score: macro avg'].astype(float)

# Map the 'Model' column to colors
df['Color'] = df['Model'].map(color_dict)
df

In [None]:
# Manually set the colors for each model type in the order you specified
model_colors = [color_baseline, color_zeroshot, color_supervised, color_ssl]

# Plot for Micro F1 Score
plt.figure(figsize=(2.7, 2))  # Adjusted for consistency with other figures
ax1 = sns.boxplot(x='Model', y='f1-score: accuracy', data=df.sort_values('f1-score: accuracy'), linewidth=0.5, palette=model_colors)
# ax1.set_ylim(0.75, 1.0)  # Uncomment if you want to set a limit for the y-axis

# Set the font for the tick labels
ax1.set_xticklabels(ax1.get_xticklabels(), **tick_font)
ax1.set_yticklabels(ax1.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax1.set_xlabel('Model', fontdict=font)
ax1.set_ylabel('Micro F1 Score', fontdict=font)
ax1.set_title('OOD Classification Performance\nCirculating Immune cells\nCV19 infection, vaccination and HC', fontdict=font)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_Circ_Imm_Micro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()

# Plot for Macro F1 Score

# Calculate the mean f1-score for each model
mean_scores = df.groupby('Model')['f1-score: macro avg'].mean().sort_values()

# Create a list of models sorted by their mean f1-score
sorted_models = mean_scores.index.tolist()

plt.figure(figsize=(2.7, 2))  # Adjusted for consistency with other figures
ax2 = sns.boxplot(x='Model', y='f1-score: macro avg', data=df.sort_values('f1-score: macro avg'), linewidth=0.5, palette=model_colors, order=sorted_models)
# ax2.set_ylim(0.0, 1.0)  # Uncomment if you want to set a limit for the y-axis

# Set the font for the tick labels
ax2.set_xticklabels(ax2.get_xticklabels(), **tick_font)
ax2.set_yticklabels(ax2.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax2.set_xlabel('Model', fontdict=font)
ax2.set_ylabel('Macro F1 Score', fontdict=font)
ax2.set_title('OOD Classification Performance\nCirculating Immune cells\nCV19 infection, vaccination and HC', fontdict=font)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_Circ_Imm_Macro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()


### **Single-cell analysis of prenatal and postnatal human cortical development**

- 709,372 cells
- 110x 3' v2, 10x 3' v3, 10x multiome
- astrocyte (67868), microglial cell (15857), native cell (15828), neural cell (537452), oligodendrocyte (40875), oligodendrocyte precursor cell (31392)

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_OOD_Cort_Dev_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
def extract_model_type(row):
    if 'Only Pretrained' in row['Unnamed: 0']:
        return 'Zero-Shot\nRandom Mask'
    elif 'No SSL' in row['Unnamed: 0']:
        return 'Supervised'
    elif 'SSL_CN' in row['Unnamed: 0']:
        return 'Self-Supervised\nRandom Mask'
    elif 'Random' in row['Unnamed: 0']:
        return 'Random'
    else:
        return 'Error'

# Apply the function to the dataframe
df['Model'] = df.apply(extract_model_type, axis=1)

# Filter outliers
# df = df[~df['Unnamed: 0'].isin(['No_SSL_run0_No SSL', 'No_SSL_run4_No SSL', 'SSL_CN_MLP_50prun1_SSL', 'SSL_CN_MLP_50prun2_SSL', 'CN_MLP_50prun1_Only Pretrained', 'CN_MLP_50prun2_Only Pretrained'])]

# Ensure that the f1-score columns are of float type
df['f1-score: accuracy'] = df['f1-score: accuracy'].astype(float)
df['f1-score: macro avg'] = df['f1-score: macro avg'].astype(float)

# Map the 'Model' column to colors
df['Color'] = df['Model'].map(color_dict)
df

In [None]:
# Manually set the colors for each model type in the order you specified
model_colors = [color_baseline, color_zeroshot, color_supervised, color_ssl]

# Plot for Micro F1 Score
plt.figure(figsize=(2.7, 2))  # Adjusted for consistency with other figures
ax1 = sns.boxplot(x='Model', y='f1-score: accuracy', data=df.sort_values('f1-score: accuracy'), linewidth=0.5, palette=model_colors)
# ax1.set_ylim(0.75, 1.0)  # Uncomment if you want to set a limit for the y-axis

# Set the font for the tick labels
ax1.set_xticklabels(ax1.get_xticklabels(), **tick_font)
ax1.set_yticklabels(ax1.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax1.set_xlabel('Model', fontdict=font)
ax1.set_ylabel('Micro F1 Score', fontdict=font)
ax1.set_title('OOD Classification Performance\nPrenatal and postnatal human cortical development', fontdict=font)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_Cort_Dev_Micro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()

# Plot for Macro F1 Score

# Calculate the mean f1-score for each model
mean_scores = df.groupby('Model')['f1-score: macro avg'].mean().sort_values()

# Create a list of models sorted by their mean f1-score
sorted_models = mean_scores.index.tolist()

plt.figure(figsize=(2.7, 2))  # Adjusted for consistency with other figures
ax2 = sns.boxplot(x='Model', y='f1-score: macro avg', data=df.sort_values('f1-score: macro avg'), linewidth=0.5, palette=model_colors, order=sorted_models)
# ax2.set_ylim(0.0, 1.0)  # Uncomment if you want to set a limit for the y-axis

# Set the font for the tick labels
ax2.set_xticklabels(ax2.get_xticklabels(), **tick_font)
ax2.set_yticklabels(ax2.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax2.set_xlabel('Model', fontdict=font)
ax2.set_ylabel('Macro F1 Score', fontdict=font)
ax2.set_title('OOD Classification Performance\nPrenatal and postnatal human cortical development', fontdict=font)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_Cort_Dev_Macro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()


### **Human: Great apes study**

- 156,285 cells
- 10x 3' v3 Smart-seq v4
- L2/3-6 intratelencephalic projecting glutamatergic cortical neuron (85276), L5 something not included (392), L6b glutamatergic cortical neuron (3415), astrocyte of the cerebral cortex (3047), caudial ganglio... not included (844), cerebral cortex endothelial cell (168), chandelier pval (728), cortocothalami... (3118), lamp5 GABAeric... (6416), microglial cell (1263), near-projecting ... (3461), oligodendrocyte (7876), oligodendrocyte precursor cell (2392), pvalb GABAergic (11778), sncg GABAergic (2025), sst GABAergic cortical interneuron (13593), vascular leptomeningeal cell (276), vip GABAergic cortical interneuron (10219)

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_OOD_Great_Apes_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
def extract_model_type(row):
    if 'Only Pretrained' in row['Unnamed: 0']:
        return 'Zero-Shot\nRandom Mask'
    elif 'No SSL' in row['Unnamed: 0']:
        return 'Supervised'
    elif 'SSL_CN' in row['Unnamed: 0']:
        return 'Self-Supervised\nRandom Mask'
    elif 'Random' in row['Unnamed: 0']:
        return 'Random'
    else:
        return 'Error'

# Apply the function to the dataframe
df['Model'] = df.apply(extract_model_type, axis=1)

# Filter outliers
# df = df[~df['Unnamed: 0'].isin(['No_SSL_run0_No SSL', 'No_SSL_run4_No SSL', 'SSL_CN_MLP_50prun1_SSL', 'SSL_CN_MLP_50prun2_SSL', 'CN_MLP_50prun1_Only Pretrained', 'CN_MLP_50prun2_Only Pretrained'])]

# Ensure that the f1-score columns are of float type
df['f1-score: accuracy'] = df['f1-score: accuracy'].astype(float)
df['f1-score: macro avg'] = df['f1-score: macro avg'].astype(float)

# Map the 'Model' column to colors
df['Color'] = df['Model'].map(color_dict)
df

In [None]:
np.std((0.211831, 0.164773, 0.240348, 0.107757, 0.173766))

In [None]:
# Manually set the colors for each model type in the order you specified
model_colors = [color_baseline, color_zeroshot, color_supervised, color_ssl]

# Plot for Micro F1 Score
plt.figure(figsize=(2.7, 2))  # Adjusted for consistency with other figures
ax1 = sns.boxplot(x='Model', y='f1-score: accuracy', data=df.sort_values('f1-score: accuracy'), linewidth=0.5, palette=model_colors)
# ax1.set_ylim(0.75, 1.0)  # Uncomment if you want to set a limit for the y-axis

# Set the font for the tick labels
ax1.set_xticklabels(ax1.get_xticklabels(), **tick_font)
ax1.set_yticklabels(ax1.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax1.set_xlabel('Model', fontdict=font)
ax1.set_ylabel('Micro F1 Score', fontdict=font)
ax1.set_title('OOD Classification Performance\nGreat Apes Study', fontdict=font)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_Great_Apes_Micro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()

# Plot for Macro F1 Score

# Calculate the mean f1-score for each model
mean_scores = df.groupby('Model')['f1-score: macro avg'].mean().sort_values()

# Create a list of models sorted by their mean f1-score
sorted_models = mean_scores.index.tolist()

plt.figure(figsize=(2.7, 2))  # Adjusted for consistency with other figures
ax2 = sns.boxplot(x='Model', y='f1-score: macro avg', data=df.sort_values('f1-score: macro avg'), linewidth=0.5, palette=model_colors, order=sorted_models)
# ax2.set_ylim(0.0, 1.0)  # Uncomment if you want to set a limit for the y-axis

# Set the font for the tick labels
ax2.set_xticklabels(ax2.get_xticklabels(), **tick_font)
ax2.set_yticklabels(ax2.get_yticklabels(), **tick_font)

# Set the font for the axis labels and title
ax2.set_xlabel('Model', fontdict=font)
ax2.set_ylabel('Macro F1 Score', fontdict=font)
ax2.set_title('OOD Classification Performance\nGreat Apes Study', fontdict=font)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/OOD_Great_Apes_Macro_F1_Boxplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()


# Figure 3

Does pretraining on a large, auxiliary dataset improve performance on a specific, known dataset to be classified?

### 1. HLCA

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_hlca_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
models_to_select = ['CN_MLP_50p_Only Pretrained', 
                    'CN_MLP_50prun1_Only Pretrained', 
                    'CN_MLP_50prun2_Only Pretrained', 
                    'CN_MLP_50prun3_Only Pretrained', 
                    'CN_MLP_50prun4_Only Pretrained', 
                    'No_SSL_new_run0_HLCA_No SSL',
                    'No_SSL_new_run1_HLCA_No SSL',
                    'No_SSL_new_run2_HLCA_No SSL',
                    'No_SSL_new_run3_HLCA_No SSL',
                    'No_SSL_new_run4_HLCA_No SSL',
                    'Random',
                    'SSL_CN_MLP_50pnew_run0_HLCA_SSL',
                    'SSL_CN_MLP_50prun1_HLCA_SSL',
                    'SSL_CN_MLP_50prun2_HLCA_SSL',
                    'SSL_CN_MLP_50prun3_HLCA_SSL',
                    'SSL_CN_MLP_50prun4_HLCA_SSL',
                   ]

df_new_run = df[df['Unnamed: 0'].isin(models_to_select)]
df_new_run

In [None]:
np.mean((0.744914, 0.782049, 0.809279, 0.787886, 0.737570))

In [None]:
# Step 1: Filter to include only 'new_run' entries
df_new_run = df[df['Unnamed: 0'].str.contains('new_run')]

# Step 2: Rename model types
def rename_model(row):
    if 'Only Pretrained' in row:
        return 'Zero-Shot\nRandom Mask'
    elif 'HLCA_No SSL' in row:
        return 'Supervised'
    elif 'HLCA_SSL' in row:
        return 'Self-Supervised\nRandom Mask'
    elif 'Random' in row:
        return 'Random'
    else:
        return row

df_new_run['Unnamed: 0'] = df_new_run['Unnamed: 0'].apply(rename_model)

model_colors = [color_baseline, color_zeroshot, color_supervised, color_ssl]

# Step 3: Calculate mean and std for each model
mean_std_df = df_new_run.groupby('Unnamed: 0')['f1-score: macro avg', 'f1-score: accuracy'].agg(['mean', 'std'])

# Step 4: Create box plots
sns.set_palette("colorblind")

# Define font properties
font = {'family': 'sans-serif', 'size': 5}  # This will be for titles and labels

# Start plotting
plt.figure(figsize=(3,2))
ax = sns.boxplot(x='Unnamed: 0', y='f1-score: accuracy', data=df_new_run.sort_values('f1-score: accuracy'), linewidth=0.5, palette=model_colors)
ax.set_xlabel('Model', fontdict=font)
ax.set_ylabel('Micro F1 Score', fontdict=font)
ax.set_title('HLCA Classification Performance', fontdict=font)

# Set font for all tick labels to match the fontdict
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
ax.set_yticklabels(ax.get_yticklabels(), **tick_font)
plt.savefig(RESULTS_FOLDER + "/classification/HLCA_Clf_Micro_F1.svg", bbox_inches='tight')  # Save as SVG

plt.tight_layout()
plt.show()

# Repeat for the second plot
plt.figure(figsize=(3, 2))
ax = sns.boxplot(x='Unnamed: 0', y='f1-score: macro avg', data=df_new_run.sort_values('f1-score: macro avg'), linewidth=0.5, palette=model_colors)
ax.set_xlabel('Model', fontdict=font)
ax.set_ylabel('Macro F1 Score', fontdict=font)
ax.set_title('HLCA Classification Performance', fontdict=font)
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
ax.set_yticklabels(ax.get_yticklabels(), **tick_font)
plt.savefig(RESULTS_FOLDER + "/classification/HLCA_Clf_Macro_F1.svg", bbox_inches='tight')  # Save as SVG

plt.tight_layout()
plt.show()


In [None]:
# Step 1: Filter to include only 'new_run' entries
df_new_run = df# [df['Unnamed: 0'].str.contains('new_run')]

# Step 2: Rename model types
def rename_model(row):
    if 'SSL_CN' in row:
        return 'Self-Supervised\nRandom Mask'
    elif 'No SSL' in row:
        return 'Supervised'
    elif 'Only Pretrained' in row:
        return "Zero-Shot\nRandom Mask"
    elif 'Random' in row:
        return 'Random'
    else:
        return row

df_new_run['Unnamed: 0'] = df_new_run['Unnamed: 0'].apply(rename_model)

# Step 3: Calculate mean and std for each model
mean_std_df = df_new_run.groupby('Unnamed: 0')['f1-score: macro avg', 'f1-score: accuracy'].agg(['mean', 'std'])

# Map the 'Model' column to colors
df_new_run['Color'] = df_new_run['Unnamed: 0'].map(color_dict)

df_new_run

In [None]:
# Function to generate color palette based on sorted model order
def get_palette(df, col_name):
    sorted_models = df.sort_values(col_name)['Unnamed: 0'].unique()
    return [color_dict[model] for model in sorted_models]

# Plot for Micro F1 Score
plt.figure(figsize=(3, 2))
palette = get_palette(df_new_run, 'f1-score: accuracy')
ax = sns.boxplot(x='Unnamed: 0', y='f1-score: accuracy', data=df_new_run.sort_values('f1-score: accuracy'), linewidth=0.5, palette=palette)
ax.set_xlabel('Model', fontdict=font)
ax.set_ylabel('Micro F1 Score', fontdict=font)
ax.set_title('HLCA Classification Performance', fontdict=font)
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
ax.set_yticklabels(ax.get_yticklabels(), **tick_font)
# plt.savefig(RESULTS_FOLDER + "/classification/HLCA_Clf_Micro_F1.svg", bbox_inches='tight')
plt.tight_layout()
plt.show()

# Repeat for the second plot
plt.figure(figsize=(3, 2))
palette = get_palette(df_new_run, 'f1-score: macro avg')
ax = sns.boxplot(x='Unnamed: 0', y='f1-score: macro avg', data=df_new_run.sort_values('f1-score: macro avg'), linewidth=0.5, palette=palette)
ax.set_xlabel('Model', fontdict=font)
ax.set_ylabel('Macro F1 Score', fontdict=font)
ax.set_title('HLCA Classification Performance', fontdict=font)
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
ax.set_yticklabels(ax.get_yticklabels(), **tick_font)
# plt.savefig(RESULTS_FOLDER + "/classification/HLCA_Clf_Macro_F1.svg", bbox_inches='tight')
plt.tight_layout()
plt.show()

Histogram

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_per_class_report_hlca_merged_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
# Use a colorblind-friendly palette
sns.set_palette("colorblind")
colors = sns.color_palette()

# Define font properties
font = {'family': 'sans-serif', 'size': 5}  # Adjust the size as needed

# Create a JointGrid with specified height
g = sns.JointGrid(x='Cell Count', y='F1_Supervised', data=df, height=3, marginal_ticks=True, space=0.2)
g = g.plot_joint(plt.scatter, s=6, color=colors[0], label="Supervised")  # Model A in one color
g.ax_joint.scatter(df['Cell Count'], df['F1_Self-Supervised'], s=5, color=colors[1], label="Self Supervised")  # Model B in another color

# Histograms
g.ax_marg_x.hist(df['Cell Count'], bins=np.geomspace(df['Cell Count'].min(), df['Cell Count'].max(), 20), alpha=.6, edgecolor='black', color=colors[2])
g.ax_marg_y.hist(df['F1_Supervised'], bins=np.linspace(0, 1, 20), alpha=.6, orientation='horizontal', edgecolor='black', color=colors[0])
g.ax_marg_y.hist(df['F1_Self-Supervised'], bins=np.linspace(0, 1, 20), alpha=.6, orientation='horizontal', edgecolor='black', color=colors[1])

# Labels & Title
g.set_axis_labels('Number of Cells per Cell Type (log scale)', 'F1-Score per Cell Type', **font)
g.ax_joint.set_xscale('log')

# Adjust legend with font properties
g.ax_joint.legend(prop=font)

# Apply font properties to all tick labels
for ax in [g.ax_joint, g.ax_marg_x, g.ax_marg_y]:
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontsize(font['size'])
        label.set_family(font['family'])

# Save the plot
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/HLCA_per_celltype_perf.svg")
plt.show()


In [None]:
# Calculate the difference between F1 scores of Self-Supervised and Supervised
df['F1_Difference'] = df['F1_Self-Supervised'] - df['F1_Supervised']

# Filter the data for positive, negative, and equal non-zero F1_Difference
positive_diff = df[df['F1_Difference'] > 0]
negative_diff = df[df['F1_Difference'] < 0]
equal_non_zero_diff = df[(df['F1_Difference'] == 0)] # & (df['F1_Self-Supervised'] != 0) & (df['F1_Supervised'] != 0)]

# Create a JointGrid without the right histogram (deactivate marginal plots)
g = sns.JointGrid(x='Cell Count', y='F1_Difference', data=df, height=2.5, marginal_ticks=True, space=0.2)

# Plot the scatter points with color depending on the sign of the F1_Difference
# Points where Self-Supervised is better (positive difference)
g.ax_joint.scatter(positive_diff['Cell Count'], positive_diff['F1_Difference'], 
                   s=5, color=color_ssl, label="Self-Supervised Better")

# Points where Supervised is better (negative difference)
g.ax_joint.scatter(negative_diff['Cell Count'], negative_diff['F1_Difference'], 
                   s=5, color=color_supervised, label="Supervised Better")

# Points where performance is equal and non-zero
# g.ax_joint.scatter(equal_non_zero_diff['Cell Count'], equal_non_zero_diff['F1_Difference'], 
#                    s=5, color=color_else2, label="Equal Performance (F1=0)")

# Histogram on the top
g.ax_marg_x.hist(df['Cell Count'], bins=np.geomspace(df['Cell Count'].min(), df['Cell Count'].max(), 20), 
                 alpha=.6, edgecolor='black', color='grey')

# Labels & Title
g.set_axis_labels('Number of Cells per Cell Type (log scale)', 'Performance Difference (Δ Macro F1)', **font)
g.ax_joint.set_xscale('log')

# Adjust legend with font properties
g.ax_joint.legend(prop=font)

# Apply font properties to all tick labels
for label in g.ax_joint.get_xticklabels() + g.ax_joint.get_yticklabels():
    label.set_fontsize(font['size'])
    label.set_family(font['family'])

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/HLCA_Per_CT_Difference.svg", bbox_inches='tight')  # Save as SVG
plt.show()


Confusion Matrix

In [None]:
y_pred_supervised = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_predicted_labels_No_SSL_new_run1_HLCA_No SSL.npy'))
y_pred_ssl = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_predicted_labels_SSL_CN_MLP_50prun4_HLCA_SSL.npy'))
y_true = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_true_labels_No_SSL_new_run1_HLCA_No SSL.npy'))  # same as for ssl, deterministic data loader


In [None]:
# Calculate correct predictions
correct_supervised = np.equal(y_pred_supervised, y_true).astype(int)
correct_ssl = np.equal(y_pred_ssl, y_true).astype(int)

# Calculate the sum of correct predictions for each class
unique_classes = np.unique(y_true)
correct_counts_supervised = [np.sum(correct_supervised[y_true == cls]) for cls in unique_classes]
correct_counts_ssl = [np.sum(correct_ssl[y_true == cls]) for cls in unique_classes]

# Calculate the differences
differences = np.array(correct_counts_ssl) - np.array(correct_counts_supervised)

In [None]:
differences

In [None]:
cell_type_mapping = pd.read_parquet(
        os.path.join(STORE_DIR, "categorical_lookup/cell_type.parquet")
    )
cell_type_mapping['label'] = cell_type_mapping['label'].str.title()


In [None]:
# Create DataFrame for plotting
df_plot = pd.DataFrame({
    'Class': unique_classes,
    'Difference': differences
})

# Sort by absolute difference and select top n classes
n = 6  # Replace with your desired number of classes
df_plot['Absolute Difference'] = df_plot['Difference'].abs()
df_top_n = df_plot.sort_values(by='Absolute Difference', ascending=False).head(n)



# Assign colors based on the sign of the difference
df_top_n['Color'] = df_top_n['Difference'].apply(lambda x: color_ssl if x > 0 else color_supervised)

# Create a mapping dictionary from integer labels to string names
label_to_name_dict = cell_type_mapping['label'].to_dict()

# Replace integer class labels in df_top_n with string names
df_top_n['Class'] = df_top_n['Class'].map(label_to_name_dict)

# Plotting
bar_colors = [color_ssl, color_ssl, color_ssl, color_supervised, color_supervised, color_ssl]

plt.figure(figsize=(2.5, 2.5))
ax = sns.barplot(x='Class', y='Difference', data=df_top_n, palette=bar_colors)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, **tick_font, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), rotation=45, **tick_font, ha='right')
ax.set_xlabel('Class', fontdict=font)
ax.set_ylabel('Δ Correct Predictions', fontdict=font)
ax.set_title('HLCA Cell Type Prediction Difference', fontdict=font)

# Annotate bars
for p in ax.patches:
    ax.annotate(f"{int(p.get_height())}", (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='bottom', fontsize=font['size'])
    
# Create legend handles
ssl_patch = mpatches.Patch(color=color_ssl, label='Self-Supervised Better')
supervised_patch = mpatches.Patch(color=color_supervised, label='Supervised Better')

# Add legend to the plot
ax.legend(handles=[ssl_patch, supervised_patch], loc='upper right', prop=font)


plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/hlca_biggest_difference_barplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()

In [None]:
conf_matrix_supervised = pd.read_csv(os.path.join(RESULTS_FOLDER, 'classification', 'conf_matrix_No_SSL_new_run1_HLCA_No SSL.csv'), index_col='Unnamed: 0')
conf_matrix_ssl = pd.read_csv(os.path.join(RESULTS_FOLDER, 'classification', 'conf_matrix_SSL_CN_MLP_50prun4_HLCA_SSL.csv'), index_col='Unnamed: 0')

In [None]:
import matplotlib.colors as mcolors

# Compute the difference matrix
conf_matrix_difference = conf_matrix_ssl - conf_matrix_supervised

# Identify the top N cell types with the largest absolute differences
N = 5  # Number of top differences to display
top_differences = conf_matrix_difference.abs().sum(axis=1).nlargest(N).index

# Create a subset DataFrame for these top differences
conf_matrix_subset = conf_matrix_difference.loc[top_differences, top_differences]

# Capitalize the first letter of each word in the columns
conf_matrix_subset.columns = conf_matrix_subset.columns.str.title()

# Capitalize the first letter of each word in the index
conf_matrix_subset.index = conf_matrix_subset.index.str.title()

# Create a custom diverging colormap
top = mcolors.to_rgba(color_ssl)
bottom = mcolors.to_rgba(color_supervised)
custom_colormap = mcolors.LinearSegmentedColormap.from_list("custom_map", [bottom, "white", top])

# Define the range for the colormap
max_abs_value = np.abs(conf_matrix_subset.values).max()
vmin, vmax = -max_abs_value, max_abs_value

# Create heatmap without annotations for the subset difference matrix
plt.figure(figsize=(1.5, 1.2))  # Adjust figure size as needed for better visibility
ax = sns.heatmap(conf_matrix_subset, annot=False, cmap=custom_colormap, linewidths=.5, vmin=vmin, vmax=vmax)

# Set the font for the tick labels
ax.set_xticklabels(conf_matrix_subset.columns, **tick_font, rotation=45, ha='right')
ax.set_yticklabels(conf_matrix_subset.index, **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('HLCA Performance Difference', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(6)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/hlca_biggest_difference_confusion_matrix.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Compute the difference matrix
conf_matrix_difference = conf_matrix_ssl - conf_matrix_supervised

# Create a custom diverging colormap
top = mcolors.to_rgba(color_ssl)
bottom = mcolors.to_rgba(color_supervised)
custom_colormap = mcolors.LinearSegmentedColormap.from_list("custom_map", [bottom, "white", top])

# Create heatmap without annotations for the difference matrix
plt.figure(figsize=(10, 8))  # Adjust figure size as needed for better visibility
ax = sns.heatmap(conf_matrix_difference, annot=False, cmap=custom_colormap, linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(conf_matrix_difference.columns, **tick_font, rotation=45, ha='right')
ax.set_yticklabels(conf_matrix_difference.index, **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('HLCA Performance Difference: Self-Supervised vs Supervised Model\n(Positive: More Counts Self-Supervised, Negative: More Counts Supervised)', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(6)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/hlca_difference_confusion_matrix.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Set the style for the plots
sns.set_theme()
sns.set_palette("colorblind")

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}  # Adjust size as needed
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # Adjust font size for tick labels

# Create heatmap without annotations
plt.figure(figsize=(5, 4))  # Adjust figure size as needed
ax = sns.heatmap(conf_matrix_supervised, annot=False, cmap='viridis', linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(ax.get_xticklabels(), **tick_font, rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('HLCA Confusion Matrix Supervised Model', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(5)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/confusion_matrix_hlca_supervised.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Set the style for the plots
sns.set_theme()
sns.set_palette("colorblind")

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}  # Adjust size as needed
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # Adjust font size for tick labels

# Create heatmap without annotations
plt.figure(figsize=(5, 4))  # Adjust figure size as needed
ax = sns.heatmap(conf_matrix_ssl, annot=False, cmap='viridis', linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(ax.get_xticklabels(), **tick_font, rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('HLCA Confusion Matrix Self-Supervised Model', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(5)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/confusion_matrix_hlca_ssl.svg", bbox_inches='tight')  # Save as SVG
plt.show()


### 2. PBMC

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_pbmc_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Identify the rows to be dropped by their index values
rows_to_drop = ['No_SSL_new_run5_PBMC_No SSL', 'No_SSL_new_run2_PBMC_No SSL', 'No_SSL_run0_PBMC_No SSL']

# Drop the specified rows
df = df[~df['Unnamed: 0'].isin(rows_to_drop)]

# Show the first few rows to get an overview of the data
df

In [None]:
models_to_select = ['CN_MLP_50p_Only Pretrained', 
                    'CN_MLP_50prun1_Only Pretrained', 
                    'CN_MLP_50prun2_Only Pretrained', 
                    'CN_MLP_50prun3_Only Pretrained', 
                    'CN_MLP_50prun4_Only Pretrained', 
                    'No_SSL_new_run0_PBMC_No SSL',
                    'No_SSL_new_run1_PBMC_No SSL',
                    'No_SSL_new_run2_PBMC_No SSL',
                    'No_SSL_new_run3_PBMC_No SSL',
                    'No_SSL_new_run4_PBMC_No SSL',
                    'Random',
                    'SSL_CN_MLP_50pnew_run0_PBMC_SSL',
                    'SSL_CN_MLP_50pnew_run1_PBMC_SSL',
                    'SSL_CN_MLP_50pnew_run2_PBMC_SSL',
                    'SSL_CN_MLP_50pnew_run3_PBMC_SSL',
                    'SSL_CN_MLP_50pnew_run4_PBMC_SSL',
                   ]

df_new_run = df[df['Unnamed: 0'].isin(models_to_select)]
df_new_run

In [None]:
np.mean((0.657838, 0.681130, 0.727992, 0.702843, 0.674552))

In [None]:
# Step 1: Filter to include only 'new_run' entries
df_new_run = df[df['Unnamed: 0'].str.contains('new_run')]

# Step 2: Rename model types
def rename_model(row):
    if 'PBMC_SSL' in row:
        return 'Self-Supervised\nRandom Mask'
    elif 'No SSL' in row:
        return 'Supervised'
    elif 'Only Pretrained' in row:
        return 'Zero-Shot\nRandom Mask'
    elif 'Random' in row:
        return 'Random'
    else:
        return row

df_new_run['Unnamed: 0'] = df_new_run['Unnamed: 0'].apply(rename_model)

model_colors = [color_baseline, color_zeroshot, color_supervised, color_ssl]

# Step 3: Calculate mean and std for each model
mean_std_df = df_new_run.groupby('Unnamed: 0')['f1-score: macro avg', 'f1-score: accuracy'].agg(['mean', 'std'])

# Step 4: Create box plots
sns.set_palette("colorblind")

# Define font properties
font = {'family': 'sans-serif', 'size': 5}  # This will be for titles and labels

# Start plotting
plt.figure(figsize=(3,2))
ax = sns.boxplot(x='Unnamed: 0', y='f1-score: accuracy', data=df_new_run.sort_values('f1-score: accuracy'), linewidth=0.5, palette=model_colors)
ax.set_xlabel('Model', fontdict=font)
ax.set_ylabel('Micro F1 Score', fontdict=font)
ax.set_title('PBMC Classification Performance', fontdict=font)

# Set font for all tick labels to match the fontdict
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
ax.set_yticklabels(ax.get_yticklabels(), **tick_font)
plt.savefig(RESULTS_FOLDER + "/classification/PBMC_Clf_Micro_F1.svg", bbox_inches='tight')  # Save as SVG

plt.tight_layout()
plt.show()

# Repeat for the second plot
plt.figure(figsize=(3, 2))
ax = sns.boxplot(x='Unnamed: 0', y='f1-score: macro avg', data=df_new_run.sort_values('f1-score: macro avg'), linewidth=0.5, palette=model_colors)
ax.set_xlabel('Model', fontdict=font)
ax.set_ylabel('Macro F1 Score', fontdict=font)
ax.set_title('PBMC Classification Performance', fontdict=font)
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
ax.set_yticklabels(ax.get_yticklabels(), **tick_font)
plt.savefig(RESULTS_FOLDER + "/classification/PBMC_Clf_Macro_F1.svg", bbox_inches='tight')  # Save as SVG

plt.tight_layout()
plt.show()


Histogram

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_per_class_report_pbmc_merged_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
# Use a colorblind-friendly palette
sns.set_palette("colorblind")
colors = sns.color_palette()

# Define font properties
font = {'family': 'sans-serif', 'size': 5}  # Adjust the size as needed

# Create a JointGrid with specified height
g = sns.JointGrid(x='Cell Count', y='F1_Supervised', data=df, height=3, marginal_ticks=True, space=0.2)
g = g.plot_joint(plt.scatter, s=6, color=colors[0], label="Supervised")  # Model A in one color
g.ax_joint.scatter(df['Cell Count'], df['F1_Self-Supervised'], s=6, color=colors[1], label="Self Supervised")  # Model B in another color

# Histograms
g.ax_marg_x.hist(df['Cell Count'], bins=np.geomspace(df['Cell Count'].min(), df['Cell Count'].max(), 20), alpha=.6, edgecolor='black', color=colors[2])
g.ax_marg_y.hist(df['F1_Supervised'], bins=np.linspace(0, 1, 20), alpha=.6, orientation='horizontal', edgecolor='black', color=colors[0])
g.ax_marg_y.hist(df['F1_Self-Supervised'], bins=np.linspace(0, 1, 20), alpha=.6, orientation='horizontal', edgecolor='black', color=colors[1])

# Labels & Title
g.set_axis_labels('Number of Cells per Cell Type (log scale)', 'F1-Score per Cell Type', **font)
g.ax_joint.set_xscale('log')

# Adjust legend with font properties
g.ax_joint.legend(prop=font)

# Apply font properties to all tick labels
for ax in [g.ax_joint, g.ax_marg_x, g.ax_marg_y]:
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontsize(font['size'])
        label.set_family(font['family'])

# Save the plot
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/PBMC_per_celltype_perf.svg")
plt.show()


In [None]:
# Calculate the difference between F1 scores of Self-Supervised and Supervised
df['F1_Difference'] = df['F1_Self-Supervised'] - df['F1_Supervised']

# Filter the data for positive, negative, and equal non-zero F1_Difference
positive_diff = df[df['F1_Difference'] > 0]
negative_diff = df[df['F1_Difference'] < 0]
equal_non_zero_diff = df[(df['F1_Difference'] == 0)] # & (df['F1_Self-Supervised'] != 0) & (df['F1_Supervised'] != 0)]

# Create a JointGrid without the right histogram (deactivate marginal plots)
g = sns.JointGrid(x='Cell Count', y='F1_Difference', data=df, height=2.5, marginal_ticks=True, space=0.2)

# Plot the scatter points with color depending on the sign of the F1_Difference
# Points where Self-Supervised is better (positive difference)
g.ax_joint.scatter(positive_diff['Cell Count'], positive_diff['F1_Difference'], 
                   s=5, color=color_ssl, label="Self-Supervised Better")

# Points where Supervised is better (negative difference)
g.ax_joint.scatter(negative_diff['Cell Count'], negative_diff['F1_Difference'], 
                   s=5, color=color_supervised, label="Supervised Better")

# Points where performance is equal and non-zero
g.ax_joint.scatter(equal_non_zero_diff['Cell Count'], equal_non_zero_diff['F1_Difference'], 
                   s=5, color=color_else2, label="Equal Performance (F1=0)")

# Histogram on the top
g.ax_marg_x.hist(df['Cell Count'], bins=np.geomspace(df['Cell Count'].min(), df['Cell Count'].max(), 20), 
                 alpha=.6, edgecolor='black', color='grey')

# Labels & Title
g.set_axis_labels('Number of Cells per Cell Type (log scale)', 'Performance Difference (Δ Macro F1)', **font)
g.ax_joint.set_xscale('log')

# Adjust legend with font properties
g.ax_joint.legend(prop=font)

# Apply font properties to all tick labels
for label in g.ax_joint.get_xticklabels() + g.ax_joint.get_yticklabels():
    label.set_fontsize(font['size'])
    label.set_family(font['family'])

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/PBMC_Per_CT_Difference.svg", bbox_inches='tight')  # Save as SVG
plt.show()


Confusion Matrix

In [None]:
y_pred_supervised = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_predicted_labels_No_SSL_run2_PBMC_No SSL.npy'))
y_pred_ssl = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_predicted_labels_SSL_CN_MLP_50pnew_run2_PBMC_SSL.npy'))
y_true = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_true_labels_No_SSL_run2_PBMC_No SSL.npy'))  # same as for ssl, deterministic data loader


In [None]:
np.unique(y_pred_ssl)

In [None]:
# Calculate correct predictions
correct_supervised = np.equal(y_pred_supervised, y_true).astype(int)
correct_ssl = np.equal(y_pred_ssl, y_true).astype(int)

# Calculate the sum of correct predictions for each class
unique_classes = np.unique(y_true)
correct_counts_supervised = [np.sum(correct_supervised[y_true == cls]) for cls in unique_classes]
correct_counts_ssl = [np.sum(correct_ssl[y_true == cls]) for cls in unique_classes]

# Calculate the differences
differences = np.array(correct_counts_ssl) - np.array(correct_counts_supervised)

In [None]:
cell_type_mapping = pd.read_parquet(
        os.path.join(STORE_DIR, "categorical_lookup/cell_type.parquet")
    )
cell_type_mapping['label'] = cell_type_mapping['label'].str.title()


In [None]:
# Create DataFrame for plotting
df_plot = pd.DataFrame({
    'Class': unique_classes,
    'Difference': differences
})

# Sort by absolute difference and select top n classes
n = 6  # Replace with your desired number of classes
df_plot['Absolute Difference'] = df_plot['Difference'].abs()
df_top_n = df_plot.sort_values(by='Absolute Difference', ascending=False).head(n)

# Assign colors based on the sign of the difference
df_top_n['Color'] = df_top_n['Difference'].apply(lambda x: color_ssl if x > 0 else color_supervised)

# Create a mapping dictionary from integer labels to string names
label_to_name_dict = cell_type_mapping['label'].to_dict()

# Replace integer class labels in df_top_n with string names
df_top_n['Class'] = df_top_n['Class'].map(label_to_name_dict)

# Plotting
bar_colors = [color_ssl, color_ssl, color_ssl, color_ssl, color_ssl, color_ssl]

plt.figure(figsize=(2.5, 1.5))
ax = sns.barplot(x='Class', y='Difference', data=df_top_n, palette=bar_colors)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, **tick_font, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), rotation=45, **tick_font, ha='right')
ax.set_xlabel('Class', fontdict=font)
ax.set_ylabel('Δ Correct Predictions', fontdict=font)
ax.set_title('PBMC Cell Type Prediction Difference', fontdict=font)

# Annotate bars
for p in ax.patches:
    ax.annotate(f"{int(p.get_height())}", (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='bottom', fontsize=font['size'])
# Create legend handles
ssl_patch = mpatches.Patch(color=color_ssl, label='Self-Supervised Better')
supervised_patch = mpatches.Patch(color=color_supervised, label='Supervised Better')

# Add legend to the plot
ax.legend(handles=[ssl_patch, supervised_patch], loc='upper right', prop=font)
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/pbmc_biggest_difference_barplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()

In [None]:
# Compute the difference matrix
conf_matrix_difference = conf_matrix_ssl - conf_matrix_supervised

# Create a custom diverging colormap
top = mcolors.to_rgba(color_ssl)
bottom = mcolors.to_rgba(color_supervised)
custom_colormap = mcolors.LinearSegmentedColormap.from_list("custom_map", [bottom, "white", top])

# Create heatmap without annotations for the difference matrix
plt.figure(figsize=(10, 8))  # Adjust figure size as needed for better visibility
ax = sns.heatmap(conf_matrix_difference, annot=False, cmap=custom_colormap, linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(conf_matrix_difference.columns, **tick_font, rotation=45, ha='right')
ax.set_yticklabels(conf_matrix_difference.index, **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('PBMC Performance Difference: Self-Supervised vs Supervised Model\n(Positive: More Counts Self-Supervised, Negative: More Counts Supervised)', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(6)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/pbmc_difference_confusion_matrix.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Set the style for the plots
sns.set_theme()
sns.set_palette("colorblind")

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}  # Adjust size as needed
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # Adjust font size for tick labels

# Create heatmap without annotations
plt.figure(figsize=(5, 4))  # Adjust figure size as needed
ax = sns.heatmap(conf_matrix_supervised, annot=False, cmap='viridis', linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(ax.get_xticklabels(), **tick_font, rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('PBMC Confusion Matrix Supervised Model', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(5)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/confusion_matrix_pbmc_supervised.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Set the style for the plots
sns.set_theme()
sns.set_palette("colorblind")

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}  # Adjust size as needed
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # Adjust font size for tick labels

# Create heatmap without annotations
plt.figure(figsize=(5, 4))  # Adjust figure size as needed
ax = sns.heatmap(conf_matrix_ssl, annot=False, cmap='viridis', linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(ax.get_xticklabels(), **tick_font, rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('PBMC Confusion Matrix Self-Supervised Model', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(5)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/confusion_matrix_pbmc_ssl.svg", bbox_inches='tight')  # Save as SVG
plt.show()


### 3. Tabula Sapiens

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_report_tabula_sapiens_knn.csv')
df = pd.read_csv(file_path)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
models_to_select = ['CN_MLP_50p_Only Pretrained', 
                    'CN_MLP_50prun1_Only Pretrained', 
                    'CN_MLP_50prun2_Only Pretrained', 
                    'CN_MLP_50prun3_Only Pretrained', 
                    'CN_MLP_50prun4_Only Pretrained', 
                    'No_SSL_new_run0_Tabula_Sapiens_No SSL',
                    'No_SSL_new_run1_Tabula_Sapiens_No SSL',
                    'No_SSL_new_run2_Tabula_Sapiens_No SSL',
                    'No_SSL_new_run3_Tabula_Sapiens_No SSL',
                    'No_SSL_new_run4_Tabula_Sapiens_No SSL',
                    'Random',
                    'SSL_CN_MLP_50pnew_run0_Tabula_Sapiens_SSL',
                    'SSL_CN_MLP_50pnew_run1_Tabula_Sapiens_SSL',
                    'SSL_CN_MLP_50pnew_run2_Tabula_Sapiens_SSL',
                    'SSL_CN_MLP_50pnew_run3_Tabula_Sapiens_SSL',
                    'SSL_CN_MLP_50pnew_run4_Tabula_Sapiens_SSL',
                   ]

df_new_run = df[df['Unnamed: 0'].isin(models_to_select)]
df_new_run

In [None]:
np.mean((0.378728, 0.333734, 0.415839, 0.366523, 0.372547))

In [None]:
# Step 1: Filter to include only 'new_run' entries
df_new_run = df[df['Unnamed: 0'].str.contains('new_run')]

# Step 2: Rename model types
def rename_model(row):
    if 'SSL_CN' in row:
        return 'Self-Supervised\nRandom Mask'
    elif 'No SSL' in row:
        return 'Supervised'
    elif 'Only Pretrained' in row:
        return 'Zero-Shot\nRandom Mask'
    elif 'Random' in row:
        return 'Random'
    else:
        return row

df_new_run['Unnamed: 0'] = df_new_run['Unnamed: 0'].apply(rename_model)

model_colors = [color_baseline, color_zeroshot, color_supervised, color_ssl]

# Step 3: Calculate mean and std for each model
mean_std_df = df_new_run.groupby('Unnamed: 0')['f1-score: macro avg', 'f1-score: accuracy'].agg(['mean', 'std'])

# Step 4: Create box plots
sns.set_palette("colorblind")

# Define font properties
font = {'family': 'sans-serif', 'size': 5}  # This will be for titles and labels

# Start plotting
plt.figure(figsize=(3,2))
ax = sns.boxplot(x='Unnamed: 0', y='f1-score: accuracy', data=df_new_run.sort_values('f1-score: accuracy'), linewidth=0.5, palette=model_colors)
ax.set_xlabel('Model', fontdict=font)
ax.set_ylabel('Micro F1 Score', fontdict=font)
ax.set_title('Tabula Sapiens Classification Performance', fontdict=font)

# Set font for all tick labels to match the fontdict
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
ax.set_yticklabels(ax.get_yticklabels(), **tick_font)
plt.savefig(RESULTS_FOLDER + "/classification/Tabula_Sapiens_Clf_Micro_F1.svg", bbox_inches='tight')  # Save as SVG

plt.tight_layout()
plt.show()

# Repeat for the second plot
plt.figure(figsize=(3, 2))
ax = sns.boxplot(x='Unnamed: 0', y='f1-score: macro avg', data=df_new_run.sort_values('f1-score: macro avg'), linewidth=0.5, palette=model_colors)
ax.set_xlabel('Model', fontdict=font)
ax.set_ylabel('Macro F1 Score', fontdict=font)
ax.set_title('Tabula Sapiens Classification Performance', fontdict=font)
ax.set_xticklabels(ax.get_xticklabels(), **tick_font)
ax.set_yticklabels(ax.get_yticklabels(), **tick_font)
plt.savefig(RESULTS_FOLDER + "/classification/Tabula_Sapiens_Clf_Macro_F1.svg", bbox_inches='tight')  # Save as SVG

plt.tight_layout()
plt.show()


Histogram

In [None]:
# Load the CSV file into a DataFrame
file_path = os.path.join(RESULTS_FOLDER, 'classification', 'val_clf_per_class_report_tabula_sapiens_merged_knn.csv')
df = pd.read_csv(file_path, index_col=0)

# Remove duplicates
df = df.drop_duplicates()

# Show the first few rows to get an overview of the data
df

In [None]:
# Calculate the difference between F1 scores of Self-Supervised and Supervised
df['F1_Difference'] = df['F1_Self-Supervised'] - df['F1_Supervised']

# Filter the data for positive, negative, and equal non-zero F1_Difference
positive_diff = df[df['F1_Difference'] > 0]
negative_diff = df[df['F1_Difference'] < 0]
equal_non_zero_diff = df[(df['F1_Difference'] == 0)] # & (df['F1_Self-Supervised'] != 0) & (df['F1_Supervised'] != 0)]

# Create a JointGrid without the right histogram (deactivate marginal plots)
g = sns.JointGrid(x='Cell Count', y='F1_Difference', data=df, height=2.5, marginal_ticks=True, space=0.2)

# Plot the scatter points with color depending on the sign of the F1_Difference
# Points where Self-Supervised is better (positive difference)
g.ax_joint.scatter(positive_diff['Cell Count'], positive_diff['F1_Difference'], 
                   s=5, color=color_ssl, label="Self-Supervised Better")

# Points where Supervised is better (negative difference)
g.ax_joint.scatter(negative_diff['Cell Count'], negative_diff['F1_Difference'], 
                   s=5, color=color_supervised, label="Supervised Better")

# Points where performance is equal and non-zero
g.ax_joint.scatter(equal_non_zero_diff['Cell Count'], equal_non_zero_diff['F1_Difference'], 
                   s=5, color=color_else2, label="Equal Performance (F1=0)")

# Histogram on the top
g.ax_marg_x.hist(df['Cell Count'], bins=np.geomspace(df['Cell Count'].min(), df['Cell Count'].max(), 20), 
                 alpha=.6, edgecolor='black', color='grey')

# Labels & Title
g.set_axis_labels('Number of Cells per Cell Type (log scale)', 'Performance Difference (Δ Macro F1)', **font)
g.ax_joint.set_xscale('log')

# Adjust legend with font properties
g.ax_joint.legend(prop=font)

# Apply font properties to all tick labels
for label in g.ax_joint.get_xticklabels() + g.ax_joint.get_yticklabels():
    label.set_fontsize(font['size'])
    label.set_family(font['family'])

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/Tabula_Sapiens_Per_CT_Difference.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Use a colorblind-friendly palette
sns.set_palette("colorblind")
colors = sns.color_palette()

# Define font properties
font = {'family': 'sans-serif', 'size': 5}  # Adjust the size as needed

# Create a JointGrid with specified height
g = sns.JointGrid(x='Cell Count', y='F1_Supervised', data=df, height=3, marginal_ticks=True, space=0.2)
g = g.plot_joint(plt.scatter, s=6, color=colors[0], label="Supervised")  # Model A in one color
g.ax_joint.scatter(df['Cell Count'], df['F1_Self-Supervised'], s=6, color=colors[1], label="Self Supervised")  # Model B in another color

# Histograms
g.ax_marg_x.hist(df['Cell Count'], bins=np.geomspace(df['Cell Count'].min(), df['Cell Count'].max(), 20), alpha=.6, edgecolor='black', color=colors[2])
g.ax_marg_y.hist(df['F1_Supervised'], bins=np.linspace(0, 1, 20), alpha=.6, orientation='horizontal', edgecolor='black', color=colors[0])
g.ax_marg_y.hist(df['F1_Self-Supervised'], bins=np.linspace(0, 1, 20), alpha=.6, orientation='horizontal', edgecolor='black', color=colors[1])

# Labels & Title
g.set_axis_labels('Number of Cells per Cell Type (log scale)', 'F1-Score per Cell Type', **font)
g.ax_joint.set_xscale('log')

# Adjust legend with font properties
g.ax_joint.legend(prop=font)

# Apply font properties to all tick labels
for ax in [g.ax_joint, g.ax_marg_x, g.ax_marg_y]:
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontsize(font['size'])
        label.set_family(font['family'])

# Save the plot
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/Tabula_Sapiens_per_celltype_perf.svg")
plt.show()


Confusion Matrix

In [None]:
y_pred_supervised = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_predicted_labels_No_SSL_new_run4_Tabula_Sapiens_No SSL.npy'))
y_pred_ssl = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_predicted_labels_SSL_CN_MLP_50pnew_run0_Tabula_Sapiens_SSL.npy'))
y_true = np.load(os.path.join(RESULTS_FOLDER, 'classification', 'new_true_labels_No_SSL_new_run4_Tabula_Sapiens_No SSL.npy'))  # same as for ssl, deterministic data loader


In [None]:
np.unique(y_pred_ssl)

In [None]:
# Calculate correct predictions
correct_supervised = np.equal(y_pred_supervised, y_true).astype(int)
correct_ssl = np.equal(y_pred_ssl, y_true).astype(int)

# Calculate the sum of correct predictions for each class
unique_classes = np.unique(y_true)
correct_counts_supervised = [np.sum(correct_supervised[y_true == cls]) for cls in unique_classes]
correct_counts_ssl = [np.sum(correct_ssl[y_true == cls]) for cls in unique_classes]

# Calculate the differences
differences = np.array(correct_counts_ssl) - np.array(correct_counts_supervised)

In [None]:
cell_type_mapping = pd.read_parquet(
        os.path.join(STORE_DIR, "categorical_lookup/cell_type.parquet")
    )
cell_type_mapping['label'] = cell_type_mapping['label'].str.title()


In [None]:
(y_true == 160).sum()

In [None]:
correct_pneumocyte_supervised = [np.sum(correct_supervised[y_true == cls]) for cls in [160]]
correct_pneumocyte_supervised

In [None]:
correct_pneumocyte_ssl = [np.sum(correct_ssl[y_true == cls]) for cls in [160]]
correct_pneumocyte_ssl

In [None]:
# Create DataFrame for plotting
df_plot = pd.DataFrame({
    'Class': unique_classes,
    'Difference': differences
})

# Sort by absolute difference and select top n classes
n = 6  # Replace with your desired number of classes
df_plot['Absolute Difference'] = df_plot['Difference'].abs()
df_top_n = df_plot.sort_values(by='Absolute Difference', ascending=False).head(n)

# Assign colors based on the sign of the difference
df_top_n['Color'] = df_top_n['Difference'].apply(lambda x: color_ssl if x > 0 else color_supervised)

# Create a mapping dictionary from integer labels to string names
label_to_name_dict = cell_type_mapping['label'].to_dict()

# Replace integer class labels in df_top_n with string names
df_top_n['Class'] = df_top_n['Class'].map(label_to_name_dict)

# Plotting
bar_colors = [color_ssl, color_ssl, color_ssl, color_supervised, color_supervised, color_supervised]

plt.figure(figsize=(2.5, 2.5))
ax = sns.barplot(x='Class', y='Difference', data=df_top_n, palette=bar_colors)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, **tick_font, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), rotation=45, **tick_font, ha='right')
ax.set_xlabel('Class', fontdict=font)
ax.set_ylabel('Δ Correct Predictions', fontdict=font)
ax.set_title('Tabula Sapiens Cell Type Prediction Difference', fontdict=font)

# Annotate bars
for p in ax.patches:
    ax.annotate(f"{int(p.get_height())}", (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='bottom', fontsize=font['size'])
# Create legend handles
ssl_patch = mpatches.Patch(color=color_ssl, label='Self-Supervised Better')
supervised_patch = mpatches.Patch(color=color_supervised, label='Supervised Better')

# Add legend to the plot
ax.legend(handles=[ssl_patch, supervised_patch], loc='upper right', prop=font)
plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/tabula_sapiens_biggest_difference_barplot.svg", bbox_inches='tight')  # Save as SVG
plt.show()

In [None]:
true_counts_supervised = pd.read_csv(os.path.join(RESULTS_FOLDER, 'classification', 'correct_counts_No_SSL_new_run4_Tabula_Sapiens_No SSL.csv'), index_col='Unnamed: 0')
true_counts_ssl = pd.read_csv(os.path.join(RESULTS_FOLDER, 'classification', 'correct_counts_SSL_CN_MLP_50pnew_run0_Tabula_Sapiens_SSL.csv'), index_col='Unnamed: 0')
true_counts = pd.read_csv(os.path.join(RESULTS_FOLDER, 'classification', 'Tabula_Sapiens_true_counts.csv'), index_col='Unnamed: 0')

In [None]:
# Ensure cell type names are capitalized
true_counts_ssl['Cell Type'] = true_counts_ssl.index.str.title()
true_counts_supervised['Cell Type'] = true_counts_supervised.index.str.title()

# Merge the dataframes on cell type
df_merged = pd.merge(true_counts_ssl, true_counts_supervised, on='Cell Type', suffixes=('_self', '_supervised'))

# Calculate the difference in counts
df_merged['Count Difference'] = df_merged['Correct Count_self'] - df_merged['Correct Count_supervised']

# Select n cell types with the largest absolute differences
n = 6 # You can adjust this number
df_subset = df_merged.reindex(df_merged['Count Difference'].abs().nlargest(n).index)

# Assign colors based on whether the self-supervised model is better or not
df_subset['Color'] = df_subset['Count Difference'].apply(lambda x: color_ssl if x > 0 else color_supervised)

# Plotting
plt.figure(figsize=(4, 3))
ax = sns.barplot(x='Cell Type', y='Count Difference', data=df_subset, palette=df_subset['Color'].tolist())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, **tick_font, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), rotation=45, **tick_font, ha='right')
ax.set_xlabel('Cell Type', fontdict=font)
ax.set_ylabel('Count Difference', fontdict=font)
ax.set_title('Tabula Sapiens Cell Type Prediction Difference', fontdict=font)

# Annotate bars
for p in ax.patches:
    ax.annotate(f"{p.get_height():.2f}", (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='bottom', fontsize=font['size'])

plt.tight_layout()
plt.show()

In [None]:
conf_matrix_supervised = pd.read_csv(os.path.join(RESULTS_FOLDER, 'classification', 'conf_matrix_No_SSL_new_run4_Tabula_Sapiens_No SSL.csv'), index_col='Unnamed: 0')
conf_matrix_ssl = pd.read_csv(os.path.join(RESULTS_FOLDER, 'classification', 'conf_matrix_SSL_CN_MLP_50pnew_run0_Tabula_Sapiens_SSL.csv'), index_col='Unnamed: 0')

In [None]:
# Summing up the columns for each confusion matrix
sum_per_class_1 = conf_matrix_supervised.sum(axis=0)
sum_per_class_2 = conf_matrix_ssl.sum(axis=0)

# Display the sums
print("Sum per class for the first model:\n", sum_per_class_1)
print("\nSum per class for the second model:\n", sum_per_class_2)

In [None]:
# Compute the difference matrix
conf_matrix_difference = conf_matrix_ssl - conf_matrix_supervised

# Identify the top N cell types with the largest absolute differences
N = 5  # Number of top differences to display
top_differences = conf_matrix_difference.abs().sum(axis=1).nlargest(N).index

# Create a subset DataFrame for these top differences
conf_matrix_subset = conf_matrix_difference.loc[top_differences, top_differences]

# Capitalize the first letter of each word in the columns
conf_matrix_subset.columns = conf_matrix_subset.columns.str.title()

# Capitalize the first letter of each word in the index
conf_matrix_subset.index = conf_matrix_subset.index.str.title()

# Create a custom diverging colormap
top = mcolors.to_rgba(color_ssl)
bottom = mcolors.to_rgba(color_supervised)
custom_colormap = mcolors.LinearSegmentedColormap.from_list("custom_map", [bottom, "white", top])

# Define the range for the colormap
max_abs_value = np.abs(conf_matrix_subset.values).max()
vmin, vmax = -max_abs_value, max_abs_value

# Create heatmap without annotations for the subset difference matrix
plt.figure(figsize=(1.5, 1.2))  # Adjust figure size as needed for better visibility
ax = sns.heatmap(conf_matrix_subset, annot=False, cmap=custom_colormap, linewidths=.5, vmin=vmin, vmax=vmax)

# Set the font for the tick labels
ax.set_xticklabels(conf_matrix_subset.columns, **tick_font, rotation=45, ha='right')
ax.set_yticklabels(conf_matrix_subset.index, **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('Tabula Sapiens Performance Difference', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(6)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/tabula_sapiens_biggest_difference_confusion_matrix.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Compute the difference matrix
conf_matrix_difference = conf_matrix_ssl - conf_matrix_supervised

# Create a custom diverging colormap
top = mcolors.to_rgba(color_ssl)
bottom = mcolors.to_rgba(color_supervised)
custom_colormap = mcolors.LinearSegmentedColormap.from_list("custom_map", [bottom, "white", top])

# Create heatmap without annotations for the difference matrix
plt.figure(figsize=(10, 8))  # Adjust figure size as needed for better visibility
ax = sns.heatmap(conf_matrix_difference, annot=False, cmap=custom_colormap, linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(conf_matrix_difference.columns, **tick_font, rotation=45, ha='right')
ax.set_yticklabels(conf_matrix_difference.index, **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('Tabula Sapiens Performance Difference: Self-Supervised vs Supervised Model\n(Positive: More Counts Self-Supervised, Negative: More Counts Supervised)', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(6)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/tabula_sapiens_difference_confusion_matrix.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Set the style for the plots
sns.set_theme()
sns.set_palette("colorblind")

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}  # Adjust size as needed
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # Adjust font size for tick labels

# Create heatmap without annotations
plt.figure(figsize=(5, 4))  # Adjust figure size as needed
ax = sns.heatmap(conf_matrix_supervised, annot=False, cmap='viridis', linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(ax.get_xticklabels(), **tick_font, rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('Tabula Sapiens Confusion Matrix Supervised Model', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(5)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/confusion_matrix_tabula_sapiens_supervised.svg", bbox_inches='tight')  # Save as SVG
plt.show()


In [None]:
# Set the style for the plots
sns.set_theme()
sns.set_palette("colorblind")

# Define font properties for titles and labels
font = {'family': 'sans-serif', 'size': 5}  # Adjust size as needed
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # Adjust font size for tick labels

# Create heatmap without annotations
plt.figure(figsize=(5, 4))  # Adjust figure size as needed
ax = sns.heatmap(conf_matrix_ssl, annot=False, cmap='viridis', linewidths=.5)

# Set the font for the tick labels
ax.set_xticklabels(ax.get_xticklabels(), **tick_font, rotation=45, ha='right')
ax.set_yticklabels(ax.get_yticklabels(), **tick_font, rotation=0)

# Set the font for the axis labels and title
ax.set_xlabel('Predicted Label', fontdict=font)
ax.set_ylabel('True Label', fontdict=font)
ax.set_title('Tabula Sapiens Confusion Matrix Self-Supervised Model', fontdict=font)

# Adjust the font for the numbers on the heatbar
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=6)  # Adjust font size for color bar

# Change font for the color bar tick labels
for label in cbar.ax.get_yticklabels():
    label.set_fontname('sans-serif')
    label.set_fontsize(5)

plt.tight_layout()
plt.savefig(RESULTS_FOLDER + "/classification/confusion_matrix_tabula_sapiens_ssl.svg", bbox_inches='tight')  # Save as SVG
plt.show()
