# QuantCell Paper Figures

In [None]:
import joblib
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import FuncFormatter

import seaborn as sns
import os
import warnings

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, accuracy_score, average_precision_score, f1_score, precision_score, balanced_accuracy_score

from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay
from scipy.stats import pearsonr

os.chdir('/home/LULAB/wboohar/CODEX/data_processing/code')
from codex_project import codex_project, read_marker_combos
from quantcell import quantcell_project
warnings.filterwarnings('ignore')

In [None]:
annotations_path = '/store/Projects/wboohar/PhenoCycler/annotation_strategies/marker_combos_062525_updated_verified.json'   
base_dir = '/store/Projects/wboohar/PhenoCycler' 
project_name = 'QuantCellPaper'
project_path = f'{base_dir}/{project_name}'
data_path = f'{base_dir}/raw_data'

In [None]:

cell_order = [
    'KLS',
    'MEP',
    'CMP',
    'GMP',
    'CLP',
    'CFU-E',
    'MkP',
    'Erythroblasts',
    'Erythrocytes',
    'Megakaryocytes',
    'B Cells',
    'CD4 T Cells',
    'CD8 T Cells',
    'cDC',
    'Double Negative T Cells',
    'Monocytes',
    'Neutrophils',
    'Arterial ECs',
    'Capillary ECs',
    'Other BM ECs',
    'Lepr+ Perivascular Cells']


color_order = {
    'KLS' : 'yellowgreen',
    'MEP' : 'palegreen',
    'CMP' : 'orange',
    'GMP' : 'olive',
    'CLP' : 'darkorange',
    'CFU-E' : 'green',
    'MkP' : 'darkseagreen',
    'Erythroblasts' : 'gold',
    'Erythrocytes' : 'wheat',
    'Megakaryocytes' : 'cornflowerblue',
    'B Cells' : 'violet',
    'CD4 T Cells' : 'red',
    'CD8 T Cells' : 'orangered',
    'cDC' : 'chocolate',
    'Double Negative T Cells' : 'blue',
    'Monocytes' : 'chartreuse',
    'Neutrophils' : 'teal',
    'Arterial ECs' : 'rosybrown',
    'Capillary ECs' : 'brown',
    'Other BM ECs' : 'orange',
    'Lepr+ Perivascular Cells' : 'deepskyblue',
    'Other' : 'white'
}

if os.path.exists(f'{project_path}/label_encoder.joblib'):
    encoder = joblib.load(f'{project_path}/label_encoder.joblib')

In [None]:


def plot_confusion_matrix(conf_matrix, name,  normalization='row', **kwargs):
    if normalization == 'row':
        row_sums = conf_matrix.sum(axis=1) # row sums
        conf_matrix = conf_matrix.divide(row_sums, axis=0) # normalize by row sums
    elif normalization == 'col':
        col_sums = conf_matrix.sum(axis=0) # col sums
        conf_matrix = conf_matrix.divide(col_sums, axis=1) # normalize by col sums
    elif normalization == 'all':
        total_sum = conf_matrix.sum().sum()
        conf_matrix = conf_matrix / total_sum

    ax=sns.heatmap(conf_matrix, 
                linecolor='grey', 
                linewidth=0.01, 
                xticklabels=cell_order, 
                yticklabels=cell_order,
                cmap=sns.color_palette("flare", as_cmap=True),
                cbar_kws={'format' :  FuncFormatter(lambda x, _: f'{x:.0f}%'), 'boundaries': np.linspace(0, 1, 100)},
                **kwargs)
    ax.collections[0].colorbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    ax.collections[0].colorbar.set_ticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
    ax.tick_params(length=6)
    ax.set_xticklabels(cell_order, rotation=90)
    return ax

### Fig 1
Voronoi

In [None]:
original_codex = pd.read_csv(f'{project_path}/codex_conventional_QuantCellPaper.csv', index_col=0)
relabeled_codex = pd.read_csv(f'{project_path}/codex_quantcell_QuantCellPaper.csv', index_col=0)

In [None]:
X_LIM = (9150, 9257)
Y_LIM = (947, 878)

Marker Expression Annotation Example

In [None]:
codex = original_codex

marker='CD45'



fig, ax = plt.subplots(figsize = (6,4))

section_mask = (codex['x'] > X_LIM[0] - 50) & (codex['x'] < X_LIM[1] + 50) & (codex['y'] > Y_LIM[1] - 50) & (codex['y'] < Y_LIM[0] + 50)
section = codex.loc[section_mask, :].copy()
section.reset_index(inplace=True, drop=True)
vor = Voronoi(section.loc[:, ['x', 'y']])
fig = voronoi_plot_2d(vor, show_vertices=False, show_points=False, line_width=0.05, ax=ax)
axs = fig.axes

section.loc[:, 'color'] = section.loc[:, marker].apply(lambda x: 'grey' if x.count('+') else 'white')

for r in range(len(vor.point_region)):
    region = vor.regions[vor.point_region[r]]
    if not -1 in region:
        polygon = [vor.vertices[i] for i in region]
        plt.fill(*zip(*polygon), color = section.loc[r, 'color'], edgecolor='black')

plt.xlim(X_LIM)
plt.ylim(Y_LIM)

plt.axis('off')
plt.savefig(f'{project_path}/section_voronoi_{marker}_expression.png', dpi=300, bbox_inches='tight', pad_inches=0)

Conventional Annotation

In [None]:
codex = original_codex

fig, ax = plt.subplots(figsize = (6,4))

section_mask = (codex['x'] > X_LIM[0] - 50) & (codex['x'] < X_LIM[1] + 50) & (codex['y'] > Y_LIM[1] - 50) & (codex['y'] < Y_LIM[0] + 50)
section = codex.loc[section_mask, :].copy()
section.reset_index(inplace=True, drop=True)
vor = Voronoi(section.loc[:, ['x', 'y']])
fig = voronoi_plot_2d(vor, show_vertices=False, show_points=False, line_width=0.05, ax=ax)
axs = fig.axes


for r in range(len(vor.point_region)):
    region = vor.regions[vor.point_region[r]]
    if not -1 in region:
        polygon = [vor.vertices[i] for i in region]
        plt.fill(*zip(*polygon), color = color_order[section.loc[r, 'cell_type']], edgecolor='black')

plt.xlim(X_LIM)
plt.ylim(Y_LIM)

plt.axis('off')
plt.savefig(f'{project_path}/section_voronoi_original.png', dpi=300, bbox_inches='tight', pad_inches=0)

QuantCell Annotation

In [None]:
codex = relabeled_codex
fig, ax = plt.subplots(figsize = (6,4))

section_mask = (codex['x'] > X_LIM[0] - 50) & (codex['x'] < X_LIM[1] + 50) & (codex['y'] > Y_LIM[1] - 50) & (codex['y'] < Y_LIM[0] + 50)
section = codex.loc[section_mask, :].copy()
section.reset_index(inplace=True, drop=True)
vor = Voronoi(section.loc[:, ['x', 'y']])
fig = voronoi_plot_2d(vor, show_vertices=False, show_points=False, line_width=0.05, ax=ax)
axs = fig.axes


for r in range(len(vor.point_region)):
    region = vor.regions[vor.point_region[r]]
    if not -1 in region:
        polygon = [vor.vertices[i] for i in region]
        plt.fill(*zip(*polygon), color = color_order[section.loc[r, 'cell_type']], edgecolor='black')

plt.xlim(X_LIM)
plt.ylim(Y_LIM)
plt.axis('off')
plt.savefig(f'{project_path}/section_voronoi_relabeled.png', dpi=300, bbox_inches='tight', pad_inches=0)

### Fig 2


A)

In [None]:
base_clf_performance_dict=joblib.load(f'{project_path}/base_models/base_clf_performance_dict.joblib')
base_classifier_names=['Nearest Neighbors', 'Multilayer Perceptron', 'Random Forest', 'Extra Trees', 'Decision Tree', 'Gaussian Naive-Bayes', 'Ridge Classifier', 'Linear SVC']

macro_pr_rec = np.zeros((len(base_classifier_names), 1000, 2, 10))

for n, lab in enumerate(base_classifier_names):

    for fold in range(10):

        macro_pr_rec[n, :, 0, fold]=base_clf_performance_dict[lab][2][fold][0]
        macro_pr_rec[n, :, 1, fold]=base_clf_performance_dict[lab][2][fold][1]

num_classes = base_clf_performance_dict[base_classifier_names[0]][4]

B)

In [None]:
macro_AP_list = [np.mean([base_clf_performance_dict[lab][0][fold] for fold in range(10)]) for lab in base_classifier_names]

lab_order = np.argsort(macro_AP_list)
macro_AP_list = [macro_AP_list[n] for n in lab_order]
names_sorted = [base_classifier_names[n] for n in lab_order]

min_macro = 1/num_classes


C)

In [None]:
results = joblib.load(f'{project_path}/model_selection/top3models_results.joblib')

order = ['RandomForestClassifier', 'MLPClassifier', 'ExtraTreesClassifier']

standard_aps={}
optimized_aps={}

for clf_name in results.keys():
    acc = 0
    score = 0
    y_test_list, y_pred_list, y_proba_list = results[clf_name]
    for n in range(10):
        score += average_precision_score(y_test_list[n], y_proba_list[n], average='macro')
        acc += balanced_accuracy_score(y_test_list[n], y_pred_list[n])

    acc /= 10
    score /= 10
    print(f'{clf_name} Macro-averaged AUPRC: {score:.3}')
    print(f'{clf_name} Balanced Accuracy: {acc:.3}')
    if 'untuned' in clf_name:
        clf_name = clf_name.split('_untuned')[0]
        standard_aps[clf_name] = score
    else:
        clf_name = clf_name.split('_tuned')[0]
        optimized_aps[clf_name] = score


standard_aps = [standard_aps[name] for name in order]
optimized_aps = [optimized_aps[name] for name in order]


In [None]:
stats_loc_dict = joblib.load(f'{project_path}/model_selection/stats_loc_dict_included.joblib')
    

location_order = ['Membrane', 'Cytoplasm', 'Nucleus', 'Cell', 'All']
stat_order = ['Min', 'Max', 'Std.Dev.', 'Median', 'Mean', 'All']

included_average_precision_dict={}
included_average_precision_std_dict={}


for location in location_order:
    included_average_precision_dict[location] = {}
    included_average_precision_std_dict[location] = {}
    for stat in stat_order:

        y_true_dict = stats_loc_dict[location][stat][0]
        y_proba_dict = stats_loc_dict[location][stat][2]

        average_precision_list = []
        for n in y_true_dict.keys():
            average_precision_list.append(average_precision_score(y_true_dict[n], y_proba_dict[n], average='macro'))
        included_average_precision_dict[location][stat] = np.mean(average_precision_list)
        included_average_precision_std_dict[location][stat] = np.std(average_precision_list)

included_df=pd.DataFrame(included_average_precision_dict).T
included_std_df=pd.DataFrame(included_average_precision_std_dict).T

included_text_df=included_df.copy()

for row in included_text_df.index:
    for col in included_text_df.columns:
            included_text_df.loc[row, col] = f'{included_text_df.loc[row, col]:.3f}\n±{included_std_df.loc[row, col]:.3f}'

baseline = included_df.loc['All', 'All']

In [None]:
leave_one_out_dict = joblib.load(f'{project_path}/model_selection/stats_loc_dict_excluded.joblib')

excluded_average_precision_dict={}
excluded_average_precision_std_dict={}


for location in location_order:
    excluded_average_precision_dict[location] = {}
    excluded_average_precision_std_dict[location] = {}
    for stat in stat_order:

        y_true_dict = leave_one_out_dict[location][stat][0]
        y_proba_dict = leave_one_out_dict[location][stat][2]

        average_precision_list = []
        if location == 'All' and stat == 'All':
            excluded_average_precision_dict[location][stat] = None
            excluded_average_precision_std_dict[location][stat] = None
            continue

        for n in y_true_dict.keys():
            average_precision_list.append(average_precision_score(y_true_dict[n], y_proba_dict[n], average='macro'))
        excluded_average_precision_dict[location][stat] = np.mean(average_precision_list)
        excluded_average_precision_std_dict[location][stat] = np.std(average_precision_list)
    
excluded_df=baseline - pd.DataFrame(excluded_average_precision_dict).T
excluded_std_df=pd.DataFrame(included_average_precision_std_dict).T

excluded_text_df=excluded_df.copy()

for row in excluded_text_df.index:
    for col in excluded_text_df.columns:
            excluded_text_df.loc[row, col] = f'{excluded_text_df.loc[row, col]:.3f}\n±{excluded_std_df.loc[row, col]:.3f}'


Design

In [None]:
default_fontsize=24
small_fontsize=21
big_fontsize=30
plt.rcParams['font.size'] = default_fontsize

color_dict = {
    'Nearest Neighbors': '#1f77b4',
    'Multilayer Perceptron': 'mediumslateblue',
    'Random Forest': '#2ca02c',
    'Extra Trees': '#d62728',
    'Decision Tree': 'darkgoldenrod',
    'Gaussian Naive-Bayes': '#8c564b',
    'Ridge Classifier': 'darkcyan',
    "Linear SVC": '#e377c2'
}


optimized_classifier_names = ['Random\nforest', 'Multilayer\nperceptron', 'Extra trees']

max_val=np.nanmax(excluded_df.values)
min_val=np.nanmin(excluded_df.values)

high_rgb=(1, 0, 0)
low_rgb=(0, 0, 1)

strength=0.5
pivot=int(np.abs(min_val)/np.abs(max_val-min_val)*256)

colors=np.zeros((256, 3))


colors[:pivot, 0] = 1
colors[pivot:, 0] = np.linspace(1, 1-strength, 256-pivot)
colors[:pivot, 1] = np.linspace(1-strength, 1, pivot)
colors[pivot:, 1] = np.linspace(1, 1-strength, 256-pivot)
colors[:pivot, 2] = np.linspace(1-strength, 1, pivot)
colors[pivot:, 2] = 1


custom_coolwarm_cmap = LinearSegmentedColormap.from_list('Custom_Coolwarm', colors)

In [None]:
fig = plt.figure(figsize=(17, 22))
axd={}
axd['A'] = fig.add_subplot(3, 3, 1)
axd['B'] = fig.add_subplot(3, 3, 2)
axd['C'] = fig.add_subplot(3, 3, 4)
axd['D'] = fig.add_subplot(3, 3, 5)
axd['X'] = fig.add_subplot(3, 3, 6)
axd['E'] = fig.add_subplot(3, 3, 7)
axd['Y'] = fig.add_subplot(3, 3, 8)


axd['A'].set_position([0.05, 0.75, 0.3, 0.23])
axd['B'].set_position([0.6, 0.75, 0.32, 0.23])

axd['C'].set_position([0.05, 0.4, 0.37, 0.285])
axd['D'].set_position([0.52, 0.4, 0.4, 0.285])
axd['X'].set_position([0.94, 0.4, 0.02, 0.285])

axd['E'].set_position([0.05, 0.05, 0.4, 0.285])
axd['Y'].set_position([0.47, 0.05, 0.02, 0.285])




for n, lab in enumerate(base_classifier_names):
    axd['A'].plot(np.mean(macro_pr_rec[n, :, 1, :], axis=1), np.mean(macro_pr_rec[n, :, 0, :], axis=1), color=color_dict[lab], label=lab, linewidth=4, alpha=0.6)

axd['A'].set_xlim(0, 1)
axd['A'].spines[['top', 'right']].set_visible(False)
axd['A'].set_ylim(0.45, 1)
axd['A'].axvline(1, linestyle='--', color='black', linewidth=3)


axd['A'].set_xlabel('Recall', fontsize=default_fontsize)
axd['A'].set_ylabel('Precision', fontsize=default_fontsize)
axd['A'].tick_params(axis='both', which='major', labelsize=default_fontsize)

print('Minimum macro-averaged AUPRC:', min_macro)



axd['B'].barh(names_sorted, macro_AP_list, color='white', alpha=1, edgecolor='black')
axd['B'].barh(names_sorted, macro_AP_list, color=[color_dict[base_classifier_names[n]] for n in lab_order], alpha=0.6, edgecolor='black')

axd['B'].set_yticklabels(names_sorted)

axd['B'].axvline(min_macro, color='black', linestyle='--', linewidth=1.5)

axd['B'].spines[['top', 'right']].set_visible(False)
axd['B'].set_xlim(0, 1)

def sentence_case(label_list):
    word_list = [word.get_text().lower() for word in label_list]
    word_list = [x[0].upper() + x[1:] for x in word_list]
    if 'Linear svc' in word_list:
        word_list[word_list.index('Linear svc')] = 'Linear SVC'
    return word_list

axd['B'].set_xlabel('Average precision', fontsize=default_fontsize)
axd['B'].tick_params(axis='both', which='major', labelsize=default_fontsize)
axd['B'].set_yticklabels(sentence_case(axd['B'].get_yticklabels()), fontsize=default_fontsize)

axd['C'].bar([0.75, 2.75, 4.75], standard_aps, width=0.75, color='white',zorder=1, alpha=1, edgecolor='black')
axd['C'].bar([0.75, 2.75, 4.75], standard_aps, width=0.75, color=['#2ca02c', 'mediumslateblue', '#d62728'],zorder=2, alpha=0.6, edgecolor='black')

axd['C'].bar([1.5, 3.5, 5.5], optimized_aps,width=0.75, color='white',zorder=1, alpha=1, hatch='\\\\',edgecolor='black')
axd['C'].bar([1.5, 3.5, 5.5], optimized_aps,width=0.75, color=['#2ca02c', 'mediumslateblue', '#d62728'],zorder=2, alpha=1, edgecolor='black')
axd['C'].set_ylim([0.80, 1])

axd['C'].set_ylabel('Average precision', fontsize=default_fontsize)
axd['C'].spines[['right', 'top']].set_visible(False)
axd['C'].tick_params(axis='both', which='major', labelsize=default_fontsize)

circ1 = mpatches.Patch( facecolor='white',alpha=1,label='Default', edgecolor='black')
circ2= mpatches.Patch( facecolor='gray',alpha=1, label='Hyperparameter tuned', edgecolor='black')
axd['C'].set_xticks([1.125, 3.125, 5.125], optimized_classifier_names, fontsize=default_fontsize)
axd['C'].set_yticks([0.85, 0.9, 0.95, 1])
axd['C'].set_ylim([0.85, 1])
axd['C'].legend(handles=[circ1, circ2], fontsize=default_fontsize, edgecolor='white', bbox_to_anchor=[0.9, 1.05])

sns.heatmap(included_df.loc[location_order, stat_order], 
            annot=included_text_df.loc[location_order, stat_order], 
            fmt='',
            vmin=0,
            vmax=1, 
            cmap='Reds',
            annot_kws={"size": small_fontsize},
            ax=axd['D'],
            cbar_ax=axd['X'])

axd['D'].tick_params(labelsize=default_fontsize)
axd['D'].set_xticks(axd['D'].get_xticks(), axd['D'].get_xticklabels(), rotation=90)
axd['D'].set_yticks(axd['D'].get_yticks(), axd['D'].get_yticklabels(), rotation=0)
axd['X'].set_ylabel('Average precision', rotation=270, fontsize=default_fontsize, labelpad=30)
axd['X'].set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'], fontsize=default_fontsize)


sns.heatmap(excluded_df.loc[location_order, stat_order], 
            vmin=min_val,
            vmax=max_val,
            cmap=custom_coolwarm_cmap, 
            annot=excluded_text_df.loc[location_order, stat_order],
            fmt='', 
            annot_kws={"size": small_fontsize},
            ax=axd['E'],
            cbar_ax=axd['Y'],)


axd['E'].tick_params(labelsize=default_fontsize)
axd['E'].set_xticks(axd['E'].get_xticks(), axd['E'].get_xticklabels(),  rotation=90)
axd['E'].set_yticks(axd['E'].get_yticks(), axd['E'].get_yticklabels(), rotation=0)
axd['Y'].set_ylabel('Δ Average precision', rotation=270, fontsize=default_fontsize, labelpad=30)
axd['Y'].set_yticks([0, 0.005, 0.01, 0.015], ['0.000', '0.005', '0.010', '0.015'], fontsize=default_fontsize)


axd['A'].text(-0.2, 2.09, 'a', fontsize=big_fontsize, fontweight='bold', transform=axd['C'].transAxes)
axd['B'].text(-0.2, 2.09, 'b', fontsize=big_fontsize, fontweight='bold', transform=axd['D'].transAxes)
axd['C'].text(-0.2, 1.05, 'c', fontsize=big_fontsize, fontweight='bold', transform=axd['C'].transAxes)
axd['D'].text(-0.2, 1.05, 'd', fontsize=big_fontsize, fontweight='bold', transform=axd['D'].transAxes)
axd['E'].text(-0.2, 1.05, 'e', fontsize=big_fontsize, fontweight='bold', transform=axd['E'].transAxes)



### Fig 3

A)

In [None]:


encoder = joblib.load(f'{project_path}/label_encoder.joblib')

quantcell_labels = pd.read_csv(f'{project_path}/quantcell_labels.csv').iloc[:, 0]
quantcell_true = pd.read_csv(f'{project_path}/quantcell_true_labels.csv', index_col=0).loc[:, 'cell_type']
# rows are true labels, columns are predicted labels

confusion_matrix_quantcell = confusion_matrix(quantcell_true, quantcell_labels, labels=encoder.classes_)
confusion_matrix_quantcell = pd.DataFrame(confusion_matrix_quantcell, 
                              index=encoder.classes_, 
                              columns=encoder.classes_)

confusion_matrix_quantcell = confusion_matrix_quantcell.loc[cell_order, cell_order] # reorder rows and columns

other_mask = quantcell_labels != 'Other'

scores=precision_score(quantcell_true[other_mask], quantcell_labels[other_mask], average=None)
print('Minimum cell-type precision score for QuantCell:', min(scores))

B)

In [None]:
fraction_annotated = joblib.load(f'{project_path}/fraction_annotated.joblib')

percent_increase = (fraction_annotated['FDR < 5%']) / fraction_annotated['Conventional'] * 100
print(f'Percent increase in fraction annotated: {percent_increase:.2f}%')

In [None]:
fraction_annotated['Conventional']+fraction_annotated['FDR < 5%']

In [None]:
fraction_annotated['Conventional']

C)

In [None]:
original_codex = pd.read_csv(f'{project_path}/codex_conventional_QuantCellPaper.csv', index_col=0)
relabeled_codex = pd.read_csv(f'{project_path}/codex_quantcell_QuantCellPaper.csv', index_col=0)

original_frequencies = original_codex.loc[:, 'cell_type'].value_counts()/len(original_codex)
relabeled_frequencies = relabeled_codex.loc[:, 'cell_type'].value_counts()/len(relabeled_codex)

In [None]:
original_codex

D)

In [None]:
facs_data = pd.read_csv(f'{data_path}/OldYoungRound1FACS.csv', index_col=0)

facs_data.loc[:, 'Mouse ID'] = ['O1', 'O2', 'O3', 'Y1', 'Y2', 'Y3']
facs_data.set_index('Mouse ID', inplace=True, drop=True)

for x in facs_data.columns:
    facs_data.rename(columns={x: x.split(' ')[0]}, inplace=True)
    x= x.split(' ')[0]

    facs_data.rename(columns={x: x.split('+')[0]}, inplace=True)
    x= x.split('+')[0]
    if x in ['T', 'B']:
        facs_data.rename(columns={x: x + ' Cells'}, inplace=True)

    if x in ['CD4', 'CD8']:
        facs_data.rename(columns={x: x + ' T Cells'}, inplace=True)
    
facs_data.loc[:, 'MPP'] = facs_data.loc[:, 'KLS'] - facs_data.loc[:, 'HSC']
# CLP gating strategy is not the same as we use for cell annotation, so we will not use it
facs_data.drop(columns=['Granulocytes', 'FLK2', 'FLK2-', 'HSC', 'Progenitor', 'Lineage-', 'MPP', 'T Cells', 'CLP'], inplace=True)
facs_data /= 100

E)

In [None]:
annospat_labels = pd.read_csv(f'{project_path}/AnnoSpat/outputdir/trte_labels_ELM_IMC_T1D_AnnoSpat.csv', index_col=0).loc[:, 'label']
annospat_true = pd.read_csv(f'{project_path}/AnnoSpat/annospat_true_labels.csv', index_col=0).loc[:, 'cell_type']

maps_labels = pd.read_csv(f'{project_path}/MAPS/results/cell_phenotyping/pred_labels.csv', index_col=0).iloc[:, 0]
maps_true = pd.read_csv(f'{project_path}/MAPS/data/cell_phenotyping/test_features.csv', index_col=0).loc[:, 'cell_label']

maps_labels = encoder.inverse_transform(maps_labels)
maps_true = encoder.inverse_transform(maps_true)

quantcell_labels = pd.read_csv(f'{project_path}/quantcell_labels.csv').iloc[:, 0]
quantcell_true = pd.read_csv(f'{project_path}/quantcell_true_labels.csv', index_col=0).loc[:, 'cell_type']

astir_labels = pd.read_csv(f'{project_path}/Astir/astir_labels.csv', index_col=0).loc[:, 'cell_type']
astir_true = pd.read_csv(f'{project_path}/Astir/astir_true_labels.csv').loc[:, 'cell_type']

annospat = balanced_accuracy_score(annospat_true, annospat_labels)
maps = balanced_accuracy_score(maps_true, maps_labels)
quantcell_other_mask = quantcell_labels != 'Other'
quantcell = balanced_accuracy_score(quantcell_true[quantcell_other_mask], quantcell_labels[quantcell_other_mask])
astir = balanced_accuracy_score(astir_true, astir_labels)

In [None]:
default_fontsize=22
small_fontsize=14
big_fontsize=30
plt.rcParams['font.size'] = default_fontsize


In [None]:
fig = plt.figure(figsize=(17, 22))
axd={}
axd['A'] = fig.add_subplot(3, 3, 1)
axd['X'] = fig.add_subplot(3, 3, 2)
axd['B'] = fig.add_subplot(3, 3, 3)
axd['C'] = fig.add_subplot(3, 3, 4)
axd['D'] = fig.add_subplot(3, 3, 5)
axd['E'] = fig.add_subplot(3, 3, 7)


axd['A'].set_position([0.1, 0.725, 0.25, 0.205])
axd['X'].set_position([0.37, 0.725, 0.02, 0.205])
axd['B'].set_position([0.55, 0.7, 0.45, 0.23])

axd['C'].set_position([0.05, 0.35, 0.35, 0.193])
axd['D'].set_position([0.55, 0.35, 0.4, 0.193])

axd['E'].set_position([0.05, 0.1, 0.4, 0.15])



plot_confusion_matrix(confusion_matrix_quantcell, 'QuantCell', normalization='col', ax=axd['A'], cbar_ax=axd['X'])


axd['A'].set_xticklabels((axd['A'].get_xticklabels()), fontsize=small_fontsize, rotation=90)
axd['A'].set_yticklabels((axd['A'].get_yticklabels()), fontsize=small_fontsize, rotation=0)
axd['A'].set_ylabel('Conventional annotation', fontsize=default_fontsize)
axd['A'].set_xlabel('QuantCell', fontsize=default_fontsize)

base = fraction_annotated['Conventional']
labels = ['Conventional\nannotation', 'FDR < 1%', 'FDR < 5%', 'FDR < 10%']

axd['B'].bar(range(len(labels)), 
       height=base, 
       color='white', 
       alpha=1,
       edgecolor='black')

axd['B'].bar(range(len(labels)), 
       height=base, 
       color='blue', 
       alpha=0.5,
       edgecolor='black')

axd['B'].bar(range(len(labels)), 
       height=[0] + [fraction_annotated[x] for x in labels[1:]], 
       bottom=base, 
       color='white', 
       alpha=1,
       edgecolor='black')

axd['B'].bar(range(len(labels)), 
       height=[0] + [fraction_annotated[x] for x in labels[1:]], 
       bottom=base, 
       color='green', 
       alpha=0.5,
       edgecolor='black')

green_patch = mpatches.Patch(facecolor=(0,0.5,0,0.5), label='QuantCell', edgecolor=(0,0,0,1))
blue_patch = mpatches.Patch(facecolor=(0,0,1,0.5), label='Conventional annotation', edgecolor=(0,0,0,1))

axd['B'].legend(handles=[blue_patch, green_patch], fontsize=default_fontsize, frameon=False, bbox_transform=axd['B'].transAxes, bbox_to_anchor=(0.05, 1.15), loc='upper left')
axd['B'].set_ylim([0, 1])
axd['B'].set_xticks(range(len(labels)),  labels=labels)
axd['B'].set_ylabel('Fraction annotated', fontsize=default_fontsize)
axd['B'].spines[['right', 'top']].set_visible(False)
axd['B'].tick_params(labelsize=20)


index = relabeled_frequencies.index[relabeled_frequencies.index != 'Other']
xvals = original_frequencies[original_frequencies.index.isin(index)]
yvals = relabeled_frequencies[relabeled_frequencies.index.isin(index)]
axd['C'].scatter(xvals, yvals, s=20, color='black')
axd['C'].set_xlabel('Conventional annotation frequency', fontsize=default_fontsize)
axd['C'].set_ylabel('QuantCell frequency', fontsize=default_fontsize)
axd['C'].tick_params(labelsize=default_fontsize)
r, p = pearsonr(xvals, yvals)
axd['C'].set_title(f'r={str(r)[:5]}   $p$-value={p:.2}', fontsize=default_fontsize)
axd['C'].spines[['top', 'right']].set_visible(False)

a, b = np.polyfit(xvals.sort_values(), yvals[xvals.sort_values().index], 1)
axd['C'].plot(xvals.sort_values(), a * xvals.sort_values() + b, color='grey', linestyle='--')

axd['C'].set_xlim([-0.002, None])
axd['C'].set_ylim([-0.005, None])

cells_of_interest = [x for x in cell_order if x in facs_data.columns and x in original_frequencies.index and x in relabeled_frequencies.index]


xvals = facs_data.loc[:, cells_of_interest].mean()
yvals = relabeled_frequencies[cells_of_interest]
axd['D'].scatter(xvals, yvals, s=20, color='black')
axd['D'].set_xlabel('Flow cytometry frequency', fontsize=default_fontsize)
axd['D'].set_ylabel('QuantCell frequency', fontsize=default_fontsize)
axd['D'].tick_params(labelsize=default_fontsize)
r, p = pearsonr(xvals, yvals)
axd['D'].set_title(f'r={str(r)[:5]}   $p$-value={p:.4f}', fontsize=default_fontsize)
axd['D'].spines[['top', 'right']].set_visible(False)

a, b = np.polyfit(xvals.sort_values(), yvals[xvals.sort_values().index], 1)
axd['D'].plot(xvals.sort_values(), a * xvals.sort_values() + b, color='grey', linestyle='--')
axd['D'].set_xlim([-0.001, 0.08])
axd['D'].set_ylim([-0.003, None])


axd['E'].bar(['Astir', 'AnnoSpat', 'MAPS', 'QuantCell'], 
       [astir, annospat, maps, quantcell], 
       color='white', 
       edgecolor='black',
       alpha=1)    

axd['E'].bar(['Astir', 'AnnoSpat', 'MAPS', 'QuantCell'], 
       [astir, annospat, maps, quantcell], 
       color=['crimson', 'purple', 'orange', 'green'], 
       edgecolor='black',
       alpha=0.5)



axd['E'].set_ylabel('Balanced accuracy', fontsize=default_fontsize)
axd['E'].tick_params(labelsize=default_fontsize)
axd['E'].set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1], fontsize=default_fontsize)
axd['E'].spines[['top', 'right']].set_visible(False)

for i, score in enumerate([astir, annospat, maps, quantcell]):
    axd['E'].text(i-0.23, score + 0.01, f'{score:.3f}', fontsize=default_fontsize, color='black')

axd['A'].text(-0.2, 3.064, 'a', fontsize=big_fontsize, fontweight='bold', transform=axd['C'].transAxes)
axd['B'].text(-0.15, 1.05, 'b', fontsize=big_fontsize, fontweight='bold', transform=axd['B'].transAxes)
axd['C'].text(-0.2, 1.093, 'c', fontsize=big_fontsize, fontweight='bold', transform=axd['C'].transAxes)
axd['D'].text(-0.15, 1.093, 'd', fontsize=big_fontsize, fontweight='bold', transform=axd['D'].transAxes)
axd['E'].text(-0.15, 1.05, 'e', fontsize=big_fontsize, fontweight='bold', transform=axd['E'].transAxes)



### Fig 4

A

In [None]:

results = joblib.load(f'{project_path}/model_selection/top3models_results.joblib')
encoder = joblib.load(f'{project_path}/label_encoder.joblib')

In [None]:
y_test, y_pred, y_proba=results['RandomForestClassifier_tuned']
score_matrix = np.zeros((len(encoder.classes_), 10))
for n in range(10):
    score_matrix[:, n] = f1_score(y_test[n], y_pred[n], average=None)

f1_scores = {}
for i, label in enumerate(encoder.classes_):
    f1_scores[label] = np.mean(score_matrix[i, :])

In [None]:


base_dir = '/store/Projects/wboohar/PhenoCycler/' 
project_name = 'QuantCellPaper'

project = quantcell_project()
project.initialize(base_path=base_dir, project_name=project_name)

In [None]:
frequency_dict={}
size_dict={}
number_marker_dict={}
circularity_dict={}

marker_combos = read_marker_combos(annotations_path)

total_annotated = project.codex.loc[:, 'cell_type'].value_counts().sum() - project.codex.loc[:, 'cell_type'].value_counts()['Other']

for ct in cell_order:
    mask = project.codex.loc[:, 'cell_type'] == ct
    frequency_dict[ct] = project.codex.loc[mask, 'cell_type'].value_counts()[ct] / total_annotated
    size_dict[ct] = project.codex.loc[mask, 'Cell: Area µm^2'].mean()
    circularity_dict[ct] = project.codex.loc[mask, 'Cell: Circularity'].mean()
    number_marker_dict[ct] = len(marker_combos[ct])
    

In [None]:
gene_ko_dict = joblib.load(f'{project_path}/model_selection/gene_ko_dict.joblib')
tuned_predictions = joblib.load(f'{project_path}/model_selection/top3models_results.joblib')['RandomForestClassifier_tuned']
gene_order = sorted(gene_ko_dict.keys(), key=str.lower)
f1_baseline = np.mean([f1_score(tuned_predictions[0][n], tuned_predictions[1][n], average=None) for n in range(10)], axis=0)

f1_matrix = np.zeros((len(gene_order), 10, len(encoder.classes_)))
for n, gene in enumerate(gene_order):
    for fold in range(10):
        y_true = gene_ko_dict[gene][0][fold]
        y_pred = gene_ko_dict[gene][1][fold]
        f1_matrix[n, fold, :] = f1_score(y_true, y_pred, average=None)


long_df = pd.DataFrame(columns=['Gene', 'Fold', 'Cell Type', 'F1 Score'])
for n, gene in enumerate(gene_order):
    for fold in range(10):
        for cell_type in encoder.classes_:
            cell_index = encoder.transform([cell_type])[0]
            long_df = long_df.append({
                'Gene': gene,
                'Fold': fold,
                'Cell Type': cell_type,
                'F1 Score': f1_matrix[n, fold, cell_index] - f1_baseline[cell_index]
            }, ignore_index=True)

marker_combos = read_marker_combos(annotations_path)

In [None]:
fig, axs = plt.subplots(4, 2, figsize=(17, 22))
axs=axs.flatten()

axs[0].set_position([0.05, 0.75, 0.4, 0.13])
axs[1].set_position([0.55, 0.75, 0.4, 0.13])
axs[2].set_position([0.05, 0.53, 0.4, 0.13])
axs[3].set_position([0.55, 0.53, 0.4, 0.13])

axs[4].set_position([0.05, 0.28, 0.4, 0.18])
axs[5].set_position([0.55, 0.28, 0.4, 0.18])
axs[6].set_position([0.05, 0.05, 0.4, 0.18])
axs[7].set_position([0.55, 0.05, 0.4, 0.18])


axs[0].scatter([frequency_dict[ct] for ct in cell_order], [f1_scores[ct] for ct in cell_order], color='black')
axs[0].set_xlabel('Cell frequency', fontsize=default_fontsize)
axs[0].set_ylabel('F1 score', fontsize=default_fontsize)
axs[0].spines[['top', 'right']].set_visible(False)
axs[0].tick_params(labelsize=default_fontsize)
axs[0].set_xlim(0, 0.3)
axs[0].set_ylim(0.68, 1)
stat, pvalue = pearsonr([frequency_dict[ct] for ct in cell_order], [f1_scores[ct] for ct in cell_order])
axs[0].text(0.20, 1.05, f'r= {stat:.2f}       $p$-value={pvalue:.3f}', fontsize=default_fontsize, transform=axs[0].transAxes)

axs[1].scatter([size_dict[ct] for ct in cell_order], [f1_scores[ct] for ct in cell_order], color='black')
axs[1].set_xlabel('Cell size (µm²)', fontsize=default_fontsize)
axs[1].spines[['top', 'right']].set_visible(False)
axs[1].tick_params(labelsize=default_fontsize)
stat, pvalue = pearsonr([size_dict[ct] for ct in cell_order], [f1_scores[ct] for ct in cell_order])
axs[1].text(0.20, 1.05, f'r= {stat:.2f}       $p$-value={pvalue:.3f}', fontsize=default_fontsize, transform=axs[1].transAxes)

axs[2].scatter([circularity_dict[ct] for ct in cell_order], [f1_scores[ct] for ct in cell_order], color='black')
axs[2].set_xlabel('Cell circularity', fontsize=default_fontsize)
axs[2].set_ylabel('F1 score', fontsize=default_fontsize)
axs[2].spines[['top', 'right']].set_visible(False)
axs[2].tick_params(labelsize=default_fontsize)
stat, pvalue = pearsonr([circularity_dict[ct] for ct in cell_order], [f1_scores[ct] for ct in cell_order])
axs[2].text(0.20, 1.05, f'r= {stat:.2f}       $p$-value={pvalue:.3f}', fontsize=default_fontsize, transform=axs[2].transAxes)

axs[3].scatter([number_marker_dict[ct] for ct in cell_order], [f1_scores[ct] for ct in cell_order], color='black')
axs[3].set_xlabel('Number of markers', fontsize=default_fontsize)
axs[3].spines[['top', 'right']].set_visible(False)
axs[3].tick_params(labelsize=default_fontsize)
stat, pvalue = pearsonr([number_marker_dict[ct] for ct in cell_order], [f1_scores[ct] for ct in cell_order])
axs[3].text(0.20, 1.05, f'r= {stat:.2f}       $p$-value={pvalue:.3f}', fontsize=default_fontsize, transform=axs[3].transAxes)


for i, ct in enumerate(['KLS', 'Erythrocytes', 'Arterial ECs', 'CD8 T Cells']):
    sns.swarmplot(data=long_df[long_df['Cell Type'] == ct],
                   x='Gene',
                   y='F1 Score',
                   ax=axs[4+i],
                   color='black',
                   size=4)
    axs[4+i].spines[['top', 'right']].set_visible(False)
    axs[4+i].set_title(f'{ct}', fontsize=default_fontsize)
    axs[4+i].axhline(0, color='black', linestyle='--', linewidth=2)
    axs[4+i].set_ylabel('Δ F1 Score', fontsize=default_fontsize)
    
    axs[4+i].set_xticklabels([' ']*len(gene_order), fontsize=default_fontsize)
    axs[4+i].tick_params(labelsize=default_fontsize)
    mean_std = long_df.groupby(['Gene', 'Cell Type']).std().groupby('Cell Type').mean().loc[ct]
    offset=-1
    ylim=0
    for n, gene in enumerate(gene_order):
       
        mean_difference = long_df[(long_df['Gene'] == gene) & (long_df['Cell Type'] == ct)]['F1 Score'].mean()
        min_difference = long_df[(long_df['Gene'] == gene) & (long_df['Cell Type'] == ct)]['F1 Score'].min()

        if np.abs(mean_difference) > 3*mean_std.values[0]:
            if n < 5:
                offset = 1
            if n > len(gene_order) - 4:
                offset = -1
            if gene == 'Endomucin':
                offset = 1
            axs[4+i].text(gene_order.index(gene) + 1.6*offset + 0.2*len(gene)*offset, 1.1*min_difference, gene, fontsize=default_fontsize, ha='center', va='bottom')
            offset *= -1
        if min_difference < ylim:
            ylim = min_difference
    axs[4+i].set_ylim([1.2*ylim, None])

axs[4].set_xlabel('', fontsize=default_fontsize)
axs[5].set_xlabel('', fontsize=default_fontsize)

axs[6].set_xlabel('Gene excluded', fontsize=default_fontsize)
axs[7].set_xlabel('Gene excluded', fontsize=default_fontsize)
axs[6].set_title('Arterial ECs', fontsize=default_fontsize)


axs[0].text(-0.15, 1.1, 'a', fontsize=big_fontsize, fontweight='bold', transform=axs[0].transAxes)
axs[4].text(-0.15, 1.1, 'b', fontsize=big_fontsize, fontweight='bold', transform=axs[4].transAxes)

## S1
A

In [None]:


base_model_names = [
    'Nearest Neighbors',
    'Multilayer Perceptron',
    'Random Forest',
    'Extra Trees',
    'Decision Tree',
    'Gaussian Naive-Bayes',
    'Ridge Classifier',
    "Linear SVC"
]

color_dict = {
    'Nearest Neighbors': '#1f77b4',
    'Multilayer Perceptron': 'mediumslateblue',
    'Random Forest': '#2ca02c',
    'Extra Trees': '#d62728',
    'Decision Tree': 'darkgoldenrod',
    'Gaussian Naive-Bayes': '#8c564b',
    'Ridge Classifier': 'darkcyan',
    "Linear SVC": '#e377c2'
}


fig,axs = plt.subplots(2, 1, figsize=(12, 14))
axs=axs.flatten()


if not 'base_names_sorted' in globals():
    base_names_sorted = ['Random Forest', 'Multilayer Perceptron', 'Extra Trees', 'Linear SVC', 'Nearest Neighbors', 'Decision Tree', 'Ridge Classifier', 'Gaussian Naive-Bayes'][::-1]

for n, model in enumerate(base_names_sorted):
    time_elapsed = joblib.load(f'{project_path}/base_models/{model}_time_elapsed.joblib')

    time_elapsed = time_elapsed / 10 # average over 10 folds

    print(f'{model} time elapsed: {time_elapsed:.2f} seconds')
    axs[0].barh(n, time_elapsed, color=color_dict[model], label=model, alpha=0.6, edgecolor='black')
    axs[0].text(time_elapsed*1.05, n, f'{time_elapsed:.2f}', va='center', fontsize=default_fontsize)
axs[0].set_xscale('log')
axs[0].spines[['top', 'right']].set_visible(False)
axs[0].set_xlabel('Training time (seconds)', fontsize=default_fontsize)

def sentence_case(names):
    temp = [name.lower() for name in names]
    temp = [name[0].upper() + name[1:] for name in temp]
    temp = [name.replace('svc', 'SVC') for name in temp]  # Keep SVC as is
    return temp

axs[0].set_yticks(range(len(base_names_sorted)), sentence_case(base_names_sorted), fontsize=default_fontsize)
axs[0].tick_params(axis='both', which='major', labelsize=default_fontsize)
axs[0].set_xticks([1, 10, 100, 1000], [1, 10, 100, 1000], fontsize=default_fontsize)


names_sorted = ['Astir', 'AnnoSpat', 'MAPS', 'QuantCell'][::-1]

astir_time = pd.read_csv(f'{project_path}/Astir/astir_time.csv', index_col=0).iloc[0, 0]
with open(f'{project_path}/AnnoSpat/time_elapsed.txt', 'r') as f:
    annospat_time = float(f.read().strip())/10**9 # convert from nanoseconds to seconds
maps_time = pd.read_csv(f'{project_path}/MAPS/results/cell_phenotyping/time_elapsed_maps.csv', index_col=0).iloc[0, 0]
quantcell_time = pd.read_csv(f'{project_path}/time_elapsed_quantcell.csv').iloc[0, 0]


axs[1].barh(range(len(names_sorted)),
        [astir_time, annospat_time, maps_time, quantcell_time][::-1],
        color='white', label='Training time', alpha=1, edgecolor='black')
axs[1].barh(range(len(names_sorted)),
        [astir_time, annospat_time, maps_time, quantcell_time][::-1],
        color=['crimson', 'purple', 'orange', 'green'][::-1], label='Training time', alpha=0.5, edgecolor='black')
axs[1].set_xscale('log')
axs[1].spines[['top', 'right']].set_visible(False)
axs[1].set_xlabel('Annotation runtime (seconds)', fontsize=default_fontsize)
axs[1].set_yticks(range(len(names_sorted)), names_sorted, fontsize=default_fontsize)
axs[1].tick_params(axis='both', which='major', labelsize=default_fontsize)
axs[1].set_xticks([1, 10, 100, 1000, 10000], [1, 10, 100, 1000, 10000], fontsize=default_fontsize)
axs[1].text(astir_time*1.1, 3, f'{astir_time:.2f}', fontsize=default_fontsize, color='black', va='center')
axs[1].text(annospat_time*1.1, 2, f'{annospat_time:.2f}', fontsize=default_fontsize, color='black', va='center')
axs[1].text(maps_time*1.1, 1, f'{maps_time:.2f}', fontsize=default_fontsize, color='black', va='center')
axs[1].text(quantcell_time*1.1, 0, f'{quantcell_time:.2f}', fontsize=default_fontsize, color='black', va='center')


axs[0].text(-0.1, 1.05, 'a', fontsize=big_fontsize, fontweight='bold', transform=axs[0].transAxes)
axs[1].text(-0.1, 1.05, 'b', fontsize=big_fontsize, fontweight='bold', transform=axs[1].transAxes)

plt.subplots_adjust(hspace=0.3, wspace=0.3)

## S2

In [None]:
default_fontsize=22
small_fontsize=14
big_fontsize=30
plt.rcParams['font.size'] = default_fontsize


fig, axs = plt.subplots(4, 2, figsize=(18, 18))
axs=axs.flatten()

axs[0].set_position([0.05, 0.45, 0.23, 0.23])
axs[1].set_position([0.30, 0.45, 0.02, 0.23])
axs[2].set_position([0.55, 0.45, 0.23, 0.23])
axs[3].set_position([0.80, 0.45, 0.02, 0.23])

axs[4].set_position([0.05, 0.05, 0.23, 0.23])
axs[5].set_position([0.30, 0.05, 0.02, 0.23])
axs[6].set_position([0.55, 0.05, 0.23, 0.23])
axs[7].set_position([0.80, 0.05, 0.02, 0.23])


# rows are true labels, columns are predicted labels

annospat_labels = pd.read_csv(f'{project_path}/AnnoSpat/outputdir/trte_labels_ELM_IMC_T1D_AnnoSpat.csv', index_col=0).loc[:, 'label']
annospat_true = pd.read_csv(f'{project_path}/AnnoSpat/annospat_true_labels.csv', index_col=0).loc[:, 'cell_type']

confusion_matrix_annospat = confusion_matrix(annospat_true, annospat_labels, labels=encoder.classes_)
confusion_matrix_annospat = pd.DataFrame(confusion_matrix_annospat, 
                              index=encoder.classes_, 
                              columns=encoder.classes_)

confusion_matrix_annospat = confusion_matrix_annospat.loc[cell_order, cell_order] # reorder rows and columns

maps_labels = pd.read_csv(f'{project_path}/MAPS/results/cell_phenotyping/pred_labels.csv', index_col=0).iloc[:, 0]
maps_true = pd.read_csv(f'{project_path}/MAPS/data/cell_phenotyping/test_features.csv', index_col=0).loc[:, 'cell_label']

maps_labels = encoder.inverse_transform(maps_labels)
maps_true = encoder.inverse_transform(maps_true)

confusion_matrix_maps = confusion_matrix(maps_true, maps_labels, labels=encoder.classes_)
confusion_matrix_maps = pd.DataFrame(confusion_matrix_maps, 
                              index=encoder.classes_, 
                              columns=encoder.classes_)

confusion_matrix_maps = confusion_matrix_maps.loc[cell_order, cell_order] # reorder rows and columns


astir_labels = pd.read_csv(f'{project_path}/Astir/astir_labels.csv', index_col=0).loc[:, 'cell_type']
astir_true = pd.read_csv(f'{project_path}/Astir/astir_true_labels.csv').loc[:, 'cell_type']

confusion_matrix_astir = confusion_matrix(astir_true, astir_labels, labels=(encoder.classes_))
confusion_matrix_astir = pd.DataFrame(confusion_matrix_astir, 
                              index=encoder.classes_, 
                              columns=encoder.classes_)

confusion_matrix_astir = confusion_matrix_astir.loc[cell_order, cell_order] # reorder rows and columns



plot_confusion_matrix(confusion_matrix_astir, 'Astir', normalization='col', ax=axs[0], cbar_ax=axs[1])
plot_confusion_matrix(confusion_matrix_annospat, 'AnnoSpat', normalization='col', ax=axs[2], cbar_ax=axs[3])
plot_confusion_matrix(confusion_matrix_maps, 'MAPS', normalization='col', ax=axs[4], cbar_ax=axs[5])

axs[0].set_xticklabels((axs[0].get_xticklabels()), fontsize=small_fontsize, rotation=90)
axs[0].set_yticklabels((axs[0].get_yticklabels()), fontsize=small_fontsize, rotation=0)
axs[0].set_ylabel('Conventional annotation', fontsize=default_fontsize)
axs[0].set_xlabel('Astir', fontsize=default_fontsize)

axs[2].set_xticklabels((axs[2].get_xticklabels()), fontsize=small_fontsize, rotation=90)
axs[2].set_yticklabels((axs[2].get_yticklabels()), fontsize=small_fontsize, rotation=0)
axs[2].set_ylabel('Conventional annotation', fontsize=default_fontsize)
axs[2].set_xlabel('AnnoSpat', fontsize=default_fontsize)

axs[4].set_xticklabels((axs[4].get_xticklabels()), fontsize=small_fontsize, rotation=90)
axs[4].set_yticklabels((axs[4].get_yticklabels()), fontsize=small_fontsize, rotation=0)
axs[4].set_ylabel('Conventional annotation', fontsize=default_fontsize)
axs[4].set_xlabel('MAPS', fontsize=default_fontsize)

axs[6].set_visible(False)
axs[7].set_visible(False)
