In [None]:
import os
os.chdir('/home/orenmil/raid/mds/240801_standalone_blood_aging/code')
# print(os.getcwd())


import sys
import shutil
# caution: path[0] is reserved for script path (or '' in REPL)

import importlib
import itertools
import metacells as mc
import anndata as ad
import pandas as pd
import numpy as np
import scipy.cluster.hierarchy
import scipy.spatial.distance
import matplotlib
import seaborn as sb
import matplotlib.pyplot as plt
import pathlib
import logging
import re
import collections
import sklearn
import pickle
import sklearn.neighbors

# from generic im
# port gene_module_utils
from generic import generic_utils
from mds import mds_analysis_params
# from mds import ipssm_utils
from mds import mds_analysis
from generic import mc_utils
from generic import hg38_utils
from mds import arch_mutation_interface_and_utils
from mds import clinical_data_interface_and_utils
from sc_rna_seq_preprocessing import sc_rna_seq_preprocessing_params
from mds import lateral_and_noisy_genes

plt.rcParams["patch.force_edgecolor"] = False
plt.rcParams['patch.linewidth'] = 0
plt.rcParams['patch.edgecolor'] = 'none'
# plt.rcParams['scatter.edgecolors'] = 'black' # didnt affect sb.scatterplot

%matplotlib widget
%load_ext autoreload
%autoreload 2

def get_mds_params():
    return mds_analysis_params.MDS_ANALYSIS_PARAMS


def get_sc_rna_seq_preprocessing_params():
    return sc_rna_seq_preprocessing_params.SC_RNA_SEQ_PREPROCESSING_PARAMS

plt.subplots()
plt.close('all')

In [None]:
n2_mc_ad = mc_utils.get_atlas_ad(get_mds_params()['nimrod_oren_atlas'])
n2_c_ad = ad.read_h5ad(get_mds_params()['nimrod_oren_atlas']['c_ad_file_path'])

# the following call is needed to change the unassigned from white to light grey...
mc_utils.add_c_state_and_mc_c_state_stats(
    n2_mc_ad, n2_c_ad, 
    cell_state_and_info_list=get_mds_params()['pb_cd34_enriched_cell_state_and_c_info_list'], 
    mask_and_info_list=get_mds_params()['pb_cd34_enriched_mask_and_c_info_list'], 
    cell_type_colors_csv_file_path=get_mds_params()['cell_type_colors_csv_file_path'],
    only_add_c_state=True,
)

In [None]:
out_dir_path = 'temp/_fig_sup_4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

allow_skipping_plots = True
allow_skipping_plots = False

from mds import pb_cd34_c_score_threshs

c_ad = n2_c_ad.copy()
b_states = ['aged-B?', 'CCDC88A-high-B', 'naive-B', 'memory-B']
monocyte_states = ['cMonocyte', 'intermMonocyte', 'ncMonocyte']
dc_states = ['AS-DC', 'pDC', 'cDC?', 'high_cMonocyte_sig_cDC2-prog?', 'cDC2-prog?', 'DC-prog?', 'cDC1']
nkt_states = ['NK', 'T']
c_ad.obs['c_state'].replace({
    **{x: 'B' for x in b_states},
    **{x: 'Monocyte' for x in monocyte_states},
    **{x: 'DC' for x in dc_states},
    **{x: 'NKT' for x in nkt_states},
    'monocyte_B_doublet': 'Doublet',
    'plasmablast_ighm_ighg': 'plasmablast',
    'state_unassigned': 'unassigned',
}, inplace=True)


c_ad.obs = generic_utils.merge_preserving_df1_index_and_row_order(c_ad.obs.drop(columns='c_state_color'), pd.read_csv(get_mds_params()['cell_type_colors_csv_file_path']).rename(columns={'cell_type': 'c_state', 'color': 'c_state_color'}))
c_ad.obs['c_state'].value_counts()



def get_name_to_mask_func(c_ad):
    state_series = c_ad.obs['c_state'].astype(str)
    return {
        'HSPC': state_series.isin(mds_analysis_params.PB_HSPC_STATE_NAMES),
        'non_lymphoid_HSPC': state_series.isin(mds_analysis_params.PB_NON_LYMPHOID_HSPC_STATE_NAMES),
        **{x: state_series == x for x in state_series.unique()},
    }

scatter_name_to_info = {
    'separate_HSPC_from_DC_and_monocyte_and_B': dict(
        x='higher_in_dc_than_all_hspcs',
        y='higher_in_b_than_all_other_hspcs',
        x_thresh=pb_cd34_c_score_threshs.HIGHER_IN_DC_THAN_ALL_HSPCS_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_B_THAN_ALL_OTHER_HSPCS_THRESH,
        background_name='all',
        # gates_to_target_mask_name={('both_low',): 'HSPC'},
        skip_plot=allow_skipping_plots,
    ),
    'separate_HSPC_from_NKT_and_endothel': dict(
        x='nkt_somewhat_specific',
        y='endothel_somewhat_specific',
        x_thresh=pb_cd34_c_score_threshs.NKT_SOMEWHAT_SPECIFIC_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_ENDOTHEL_THAN_ALL_EXCEPT_UNKNOWN_THRESH,
        background_name='non-B/Monocyte/DC',
        get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: scatter_name_to_mask_dict['separate_HSPC_from_DC_and_monocyte_and_B']['both_low'],
        gates_to_target_mask_name={('both_low',): 'HSPC'},
        skip_plot=allow_skipping_plots,
    ),
    'identify_pro_B_and_pre_B': dict(
        x='higher_in_pre_b_than_all_hspcs',
        y='higher_in_pre_b_than_pro_b',
        x_thresh=pb_cd34_c_score_threshs.HIGHER_IN_PRE_B_THAN_ALL_HSPCS_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_PRE_B_THAN_PRO_B_THRESH,
        background_name='HSPC',
        # get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: scatter_name_to_mask_dict['separate_HSPC_from_NKT_and_endothel']['both_low'],
        get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: c_ad.obs['not_dc_nkt_monocyte_endothel_b_lamp3'],
        gates_to_target_mask_name={('both_high',): 'pre-B?', ('high_x_low_y',): 'pro-B?'},
        skip_plot=allow_skipping_plots,
    ),
    'separate_non_lymphoid_HSPC_from_CLP_and_NKTDP': dict(
        x='higher_in_clp_m_than_all_myeloid_hspcs',
        y='higher_in_nktdp_than_all_myeloid_hspcs',
        x_thresh=pb_cd34_c_score_threshs.HIGHER_IN_CLP_M_THAN_ALL_MYELOID_HSPCS_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_NKTDP_THAN_ALL_MYELOID_HSPCS_THRESH,
        background_name='non-B-prog HSPC',
        get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: (
            scatter_name_to_mask_dict['identify_pro_B_and_pre_B']['both_low'] |
            scatter_name_to_mask_dict['identify_pro_B_and_pre_B']['low_x_high_y']
        ),
        gates_to_target_mask_name={('both_low',): 'non_lymphoid_HSPC'},
        skip_plot=allow_skipping_plots,
    ),
    'indentify_NKTDP_and_CLP': dict(
        x='higher_in_clp_m_than_all_myeloid_hspcs',
        y='higher_in_nktdp_than_clp',
        x_thresh=pb_cd34_c_score_threshs.HIGHER_IN_CLP_M_THAN_ALL_MYELOID_HSPCS_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_NKTDP_THAN_CLP_THRESH,
        background_name='high-NKTDP-sig or high-CLP-sig non-B-prog HSPC',
        get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: (
            scatter_name_to_mask_dict['separate_non_lymphoid_HSPC_from_CLP_and_NKTDP']['low_x_high_y'] |
            scatter_name_to_mask_dict['separate_non_lymphoid_HSPC_from_CLP_and_NKTDP']['both_high'] |
            scatter_name_to_mask_dict['separate_non_lymphoid_HSPC_from_CLP_and_NKTDP']['high_x_low_y']
        ),
        gates_to_target_mask_name={('low_x_high_y','both_high'): 'NKTDP', ('high_x_low_y',): 'CLP'},
        skip_plot=allow_skipping_plots,
    ),
    'identify_GMP-L_and_BEMP': dict(
        x='higher_in_bemp_than_mebemp_l_and_eryp',
        y='higher_in_gmp_l_than_all_other_hspcs',
        x_thresh=pb_cd34_c_score_threshs.HIGHER_IN_BEMP_THAN_MEBEMP_L_AND_ERYP_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_GMP_L_THAN_ALL_OTHER_HSPCS_THRESH,
        background_name='non-lymphoid HSPC',
        # get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: c_ad.obs['not_dc_nkt_monocyte_endothel_b_lamp3'],
        get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: scatter_name_to_mask_dict['separate_non_lymphoid_HSPC_from_CLP_and_NKTDP']['both_low'],
        gates_to_target_mask_name={('high_x_low_y',): 'BEMP', ('low_x_high_y','both_high'): 'GMP-L'},
        skip_plot=allow_skipping_plots,
    ),
    'identify_MKP_and_separate_by_ERYP_sig': dict(
        x='higher_in_mkp_than_mebemp_l_and_eryp',
        y='higher_in_ep_than_mpp',
        x_thresh=pb_cd34_c_score_threshs.HIGHER_IN_MKP_THAN_MEBEMP_L_AND_ERYP_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_EP_THAN_MPP_THRESH,
        background_name='non-lymphoid/BEMP/GMP-L HSPC',
        # get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: c_ad.obs['not_dc_nkt_monocyte_endothel_b_lamp3'],
        get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: scatter_name_to_mask_dict['identify_GMP-L_and_BEMP']['both_low'],
        gates_to_target_mask_name={('high_x_low_y','both_high'): 'MKP'},
        skip_plot=allow_skipping_plots,
    ),
    'identify_MEBEMP-L_and_HSC_MPP': dict(
        x='higher_in_mebemp_l_than_mpp',
        y='higher_in_ep_than_mpp',
        x_thresh=pb_cd34_c_score_threshs.HIGHER_IN_MEBEMP_L_THAN_MPP_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_EP_THAN_MPP_THRESH,
        background_name='non-lymphoid/BEMP/GMP-L/MKP HSPC',
        # get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: c_ad.obs['not_dc_nkt_monocyte_endothel_b_lamp3'],
        get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: (
            scatter_name_to_mask_dict['identify_MKP_and_separate_by_ERYP_sig']['both_low'] |
            scatter_name_to_mask_dict['identify_MKP_and_separate_by_ERYP_sig']['low_x_high_y']
        ),
        gates_to_target_mask_name={('high_x_low_y',): 'MEBEMP-L', ('both_low',): 'HSC_MPP'},
        skip_plot=allow_skipping_plots,
    ),
    'identify_ERYP_and_high_MPP_sig_ERYP': dict(
        x='higher_in_mpp_than_ep',
        y='higher_in_mebemp_l_than_mpp',
        x_thresh=pb_cd34_c_score_threshs.HIGHER_IN_MPP_THAN_EP_THRESH,
        y_thresh=pb_cd34_c_score_threshs.HIGHER_IN_MEBEMP_L_THAN_MPP_THRESH,
        background_name='high-ERYP-sig non-lymphoid/BEMP/GMP-L/MKP HSPC',
        # get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: c_ad.obs['not_dc_nkt_monocyte_endothel_b_lamp3'],
        get_background_mask_func=lambda scatter_name_to_mask_dict, c_ad: (
            scatter_name_to_mask_dict['identify_MEBEMP-L_and_HSC_MPP']['both_high'] |
            scatter_name_to_mask_dict['identify_MEBEMP-L_and_HSC_MPP']['low_x_high_y']
        ),
        gates_to_target_mask_name={('low_x_high_y','both_low'): 'ERYP', ('both_high','high_x_low_y'): 'high_MPP_sig_ERYP'},
        # skip_plot=allow_skipping_plots,
    ),
}




palette = mc_utils.get_palette(c_ad, 'c_state')
assert 'white' not in palette.values()

name_to_mask = get_name_to_mask_func(c_ad)

mds_analysis.plot_gate_c_state_scatterplots(
    c_ad=c_ad,
    out_dir_path=out_dir_path,
    scatter_name_to_info=scatter_name_to_info,
    palette=palette,
    name_to_mask=name_to_mask,
)



In [None]:
out_dir_path = 'temp/_fig_sup_4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

# ordered_c_states = [x for x in mds_analysis_params.ORDERED_PB_HSPC_STATE_NAMES if x in c_ad.obs['c_state'].unique()]
ordered_hspc_c_states = [x for x in mds_analysis_params.ORDERED_PB_HSPC_STATE_NAMES if x in c_ad.obs['c_state'].unique()]
ordered_non_hspc_c_states = [x for x in mds_analysis_params.ORDERED_PB_NON_HSPC_STATE_NAMES if x in c_ad.obs['c_state'].unique()]
# raise
state_color_df = pd.DataFrame({'state': ordered_hspc_c_states[::-1] + ordered_non_hspc_c_states})
# raise
state_color_df = generic_utils.merge_preserving_df1_index_and_row_order(state_color_df, pd.read_csv(get_mds_params()['cell_type_colors_csv_file_path']).rename(columns={'cell_type': 'state'}))
fig, ax = plt.subplots(figsize=(10,10))
for x in state_color_df.to_records(index=False).tolist():
    ax.scatter([0.5], [0.5], s=40, color=x[1], label=x[0])

ax.legend(fontsize='large')
sb.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), fontsize='large')
fig.tight_layout()
fig.savefig(
    os.path.join(out_dir_path, f'EDF_9A_cell_state_circle_legend.png'), 
    dpi=300,
)

In [None]:
plt.close('all')
out_dir_path = 'temp/_fig_sup_4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

add_ax_titles = True
add_ax_titles = False


mc_ad = n2_mc_ad.copy()
b_states = ['aged-B?', 'CCDC88A-high-B', 'naive-B', 'memory-B']
monocyte_states = ['cMonocyte', 'intermMonocyte', 'ncMonocyte']
dc_states = ['AS-DC', 'pDC', 'cDC?']
nkt_states = ['NK', 'T']
mc_ad.obs['state'].replace({
    **{x: 'B' for x in b_states},
    **{x: 'Monocyte' for x in monocyte_states},
    **{x: 'DC' for x in dc_states},
    **{x: 'NKT' for x in nkt_states},
    'CLP_NKTDP_intermediate?': 'Doublet', # unsure, but anyway only 3 of those currently...
}, inplace=True)
mc_ad.obs = generic_utils.merge_preserving_df1_index_and_row_order(mc_ad.obs.drop(columns='state_color'), pd.read_csv(get_mds_params()['cell_type_colors_csv_file_path']).rename(columns={'cell_type': 'state', 'color': 'state_color'}))

c_state_count_series = n2_c_ad.obs.loc[n2_c_ad.obs['c_state'].isin(mds_analysis_params.PB_HSPC_STATE_NAMES), 'c_state'].astype(str).value_counts(normalize=True)
common_c_states = list((c_state_count_series > 0.015).loc[lambda x: x].index)
fig, ax = mc_utils.plot_manifold_umap(mc_ad, legend=True, move_legend=True, legend_labels_ordered=mds_analysis_params.ORDERED_CELL_STATE_NAMES)
generic_utils.make_all_spines_and_x_and_y_axes_invisible(ax)
# ax.get_yaxis().set_visible(False)
# ax.get_xaxis().set_visible(False)
ordered_common_c_states = [
    'HSC_MPP',
    'MEBEMP-L',
    'ERYP',
    'BEMP',
    'CLP',
    'NKTDP',
]
assert set(ordered_common_c_states) == set(common_c_states)
fig.savefig(os.path.join(out_dir_path, 'EDF_9B_cHSPC_79_normal_atlas_umap.png'))
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(4, 7), gridspec_kw=dict(hspace=0.3))
fig.subplots_adjust(top=1.1, bottom=0.01, left=0.01, right=0.99) # seems like top is ignored.
for c_state, ax in zip(ordered_common_c_states, axes.flatten()):
    plot_manifold_umap_kwargs = dict(
        mc_ad=mc_ad, color_by=f'c_state_{c_state}_frac', get_palette_kwargs=dict(numeric_colormap_name='gist_heat_r'), 
        # plot_colorbar=True, 
        hue_norm=(0,1),
        fig_colorbar_kwargs=dict(label=f'{c_state} fraction'),
        s=6,
    )
    fig, ax = mc_utils.plot_manifold_umap(ax=ax, **plot_manifold_umap_kwargs)
    generic_utils.make_all_spines_and_x_and_y_axes_invisible(ax)
    fig.subplots_adjust(top=0.9)
    if add_ax_titles:
        ax.set_title(c_state)
ax_title_repr = '' if add_ax_titles else '_without_titles'
fig.savefig(os.path.join(out_dir_path, f'EDF_9B_cHSPC_79_normal_atlas_umap_colored_by_c_state_frac{ax_title_repr}.png'), dpi=300)

fig, ax = mc_utils.plot_manifold_umap(plot_colorbar=True, **plot_manifold_umap_kwargs)
fig.savefig(os.path.join(out_dir_path, f'EDF_9B_cHSPC_79_normal_atlas_umap_colored_by_c_state_with_colorbar.png'), dpi=300)
plt.close('all')

In [None]:
out_dir_path = 'temp/_fig_sup_4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

all_cna_attrs_df_csv_file_path = get_mds_params()['all_cna_attrs_df_csv_file_path']
all_cna_attrs_df = pd.read_csv(all_cna_attrs_df_csv_file_path)
all_cna_attrs_df = all_cna_attrs_df[all_cna_attrs_df['chrom'].notna()]
all_cna_attrs_df = all_cna_attrs_df[['donor_id', 'chrom', 'is_del', 'is_dup']].drop_duplicates()

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
all_cna_attrs_df = all_cna_attrs_df[all_cna_attrs_df['donor_id'].isin(all_ext_feature_df['donor_id'])]
# all_ext_feature_df[all_ext_feature_df['cna_count'] > 0]
# all_ext_feature_df.sort_values('cna_count', ascending=False).drop_duplicates(subset='donor_id')

ordered_chrom_names = get_mds_params()['karyotype_estimation']['ordered_chrom_names_for_karyotype_estimation']

flat_dicts = []
for donor_id, curr_df in all_cna_attrs_df.groupby('donor_id'):
    curr_dict = {x: 0 for x in ordered_chrom_names}
    for _, row in curr_df.iterrows():
        assert row[['is_dup', 'is_del']].any()
        curr_dict[row['chrom']] += 0.5 if row['is_del'] else 1
    flat_dicts.append({
        'donor_id': donor_id,
        **curr_dict,
    })
df = pd.DataFrame(flat_dicts)
df.set_index('donor_id', inplace=True)
df = df[ordered_chrom_names]
assert df.isin([0,0.5,1]).all(axis=None), 'sorry, handling both dup and del in same chrom is not implemented'
ordered_donor_ids = list((df > 0).sum(axis=1).sort_values().index)
orig_donor_count = len(df)
df = df.loc[ordered_donor_ids]
assert len(df) == orig_donor_count

# val_to_color = {0: 'white', 1: 'blue', 2: 'red'}
# for col in df.columns:
#     df[col] = df[col].map(val_to_color)

# color_arr = generic_utils.get_color_arr(
#     df, col_to_val_to_color={x: val_to_color for x in ordered_chrom_names},
# )

plt.close('all')
fig, ax = plt.subplots(figsize=(2.7,5))
# color_arr = df.values
# color_df = pd.DataFrame()
# ax.imshow(color_arr, aspect='auto')
ugly_hack_cna_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    '', [(0, 'white'), (0.5, 'blue'), (1, 'red')])
df = df.T
heatmap_obj = sb.heatmap(
    df, ax=ax, cmap=ugly_hack_cna_cmap, square=True, cbar=False, 
    xticklabels=True, 
    # yticklabels=True,
)
fig.subplots_adjust(top=0.95, bottom=0.1)
for _, spine in heatmap_obj.spines.items():
    spine.set_visible(True)
for x in range(len(df.index)):
    ax.axhline(x, color='grey', alpha=0.7, linewidth=0.5)
for x in range(len(df.columns)):
    ax.axvline(x, color='grey', alpha=0.7, linewidth=0.5)
ax.set_ylabel(None)
ax.set_xlabel(None)
if 0:
    ax.legend(
        handles=[matplotlib.patches.Patch(color=x[1], label=x[0]) for x in [('deletion', 'blue'), ('duplication', 'red')]],
        # fontsize='large',
    )
    sb.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), fontsize='large')
# fig.tight_layout()
fig.savefig(
    os.path.join(out_dir_path, f'EDF_9C_cna_heatmap.png'), 
    dpi=300,
)

total_cna_count = (df > 0).sum().sum()
print(f'total_cna_count: {total_cna_count}')
assert total_cna_count == 38 # 240729


In [None]:
out_dir_path = 'temp/_fig_sup_4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

final_feature_names_and_types_df = pd.read_csv(get_mds_params()['final_feature_names_and_types_df_csv_file_path'])

list_of_max_pval_and_star_count = [
    (0.001, 3),
    (0.01, 2),
    (0.05, 1),
]

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
curr_feature_df = all_ext_feature_df[all_ext_feature_df['in_mds_cyto_groups_analysis']].copy()
curr_feature_df['diag_class_and_sex'] = curr_feature_df['diagnosis_class'] + ' ' + curr_feature_df['donor_sex']

final_feature_name_to_type = generic_utils.get_dict_mapping_one_df_column_to_other(final_feature_names_and_types_df, 'name', 'type')
diag_class_and_sex_to_ref_diag_class_and_sex = {}
ref_class = 'normal'
for clas in ['cytopenia', 'MDS', 'MDS/MPN']:
    for sex in ('male', 'female'):
        clas_and_sex = f'{clas} {sex}'
        if clas != ref_class:
            diag_class_and_sex_to_ref_diag_class_and_sex[clas_and_sex] = f'{ref_class} {sex}'

ordered_class_and_sex = [
    'normal male',
    'normal female',
    'cytopenia male',
    'cytopenia female',
    'MDS male',
    'MDS female',
    'MDS/MPN male',
    'MDS/MPN female',
]
class_and_sex_to_pos_in_stripplot = {
    'normal male': 0,
    'normal female': 1,
    'cytopenia male': 2.5,
    'cytopenia female': 3.5,
    'MDS male': 5,
    'MDS female': 6,
    'MDS/MPN male': 7.5,
    'MDS/MPN female': 8.5,
}
male_female_vline_xs = [0.5, 3, 5.5, 8]
class_vline_xs = [1.75, 4.25, 6.75]
assert pd.Series([class_and_sex_to_pos_in_stripplot[x] for x in ordered_class_and_sex]).is_monotonic_increasing

flat_dicts = []
for feature in sorted(final_feature_names_and_types_df.loc[final_feature_names_and_types_df['type'] == 'log_c_state_freq', 'name']):
    # print(feature)
    curr_df = curr_feature_df[curr_feature_df[feature].notna()]
        
        
    for clas_and_sex, ref_clas_and_sex in diag_class_and_sex_to_ref_diag_class_and_sex.items():
        clas_vals = curr_df.loc[curr_df['diag_class_and_sex'] == clas_and_sex, feature]
        ref_vals = curr_df.loc[curr_df['diag_class_and_sex'] == ref_clas_and_sex, feature]
        if (len(clas_vals) >= 1) and (len(ref_vals) >= 1):
            mw_test_res = generic_utils.perform_mw_test(
                clas_vals,
                ref_vals,
                alternative='two-sided',
            )
            flat_dicts.append({'diag_class_and_sex': clas_and_sex, 'feature': feature, 'feature_type': final_feature_name_to_type[feature], 'test_type': 'mw', 'pval': mw_test_res['pvalue']})

diff_c_freq_across_diag_classes_df = pd.DataFrame(flat_dicts)
diff_c_freq_across_diag_classes_df.sort_values('pval', inplace=True)
mw_pval_mask = diff_c_freq_across_diag_classes_df['test_type'] == 'mw'
diff_c_freq_across_diag_classes_df.loc[mw_pval_mask, 'mw_bh_adjusted_pval'] = scipy.stats.false_discovery_control(
    diff_c_freq_across_diag_classes_df.loc[mw_pval_mask, 'pval'])

curr_feature_df['diag_class_and_sex_pos_in_stripplot'] = curr_feature_df['diag_class_and_sex'].map(class_and_sex_to_pos_in_stripplot)
np.random.seed(0)
curr_feature_df['jittered_diag_class_and_sex_pos_in_stripplot'] = curr_feature_df['diag_class_and_sex_pos_in_stripplot'] + np.random.uniform(-0.3, 0.3, len(curr_feature_df))


for attr in [
    'log_c_CLP',
    # 'log_c_pre-B?',
    # 'log_c_pro-B?',
    # 'log_c_MKP',
    # 'log_c_GMP-L',
]:
    plt.close('all')
    fig, ax = plt.subplots(figsize=(4, 2.2))
    
    assert curr_feature_df[attr].notna().all()

    curr_observed_class_and_sex = set(curr_feature_df['diag_class_and_sex'].tolist())
    curr_ordered_class_and_sex = [x for x in ordered_class_and_sex if x in curr_observed_class_and_sex]
    assert curr_ordered_class_and_sex == ordered_class_and_sex

    assert curr_feature_df['donor_id'].is_unique
    print(f'len(curr_feature_df): {len(curr_feature_df)}')
    
    sb.scatterplot(
        data=curr_feature_df, 
        x='jittered_diag_class_and_sex_pos_in_stripplot', 
        y=attr, 
        ax=ax,
        hue='diagnosis_class', 
        palette=clinical_data_interface_and_utils.DIAGNOSIS_CLASS_TO_COLOR,
        edgecolor='k',
        s=20,
    )
    ax.get_xaxis().set_visible(False)

    orig_y_lim = ax.get_ylim()
    y_lim_range = orig_y_lim[1] - orig_y_lim[0]
    ax.set_ylim(orig_y_lim[0], orig_y_lim[1] + 0.07 * y_lim_range)

    for clas_and_sex in curr_ordered_class_and_sex:
        if clas_and_sex in diag_class_and_sex_to_ref_diag_class_and_sex:
            mw_test_res = generic_utils.perform_mw_test(
                curr_feature_df.loc[curr_feature_df['diag_class_and_sex'] == clas_and_sex, attr],
                curr_feature_df.loc[curr_feature_df['diag_class_and_sex'] == diag_class_and_sex_to_ref_diag_class_and_sex[clas_and_sex], attr],
                alternative='two-sided',
            )
            curr_pval = mw_test_res['pvalue']

            row_in_diff_feature_df = diff_c_freq_across_diag_classes_df[
                (diff_c_freq_across_diag_classes_df['test_type'] == 'mw') &
                (diff_c_freq_across_diag_classes_df['feature'] == attr) & 
                (diff_c_freq_across_diag_classes_df['diag_class_and_sex'] == clas_and_sex) 
            ]
            assert len(row_in_diff_feature_df) == 1
            row_in_diff_feature_df = row_in_diff_feature_df.iloc[0]
            assert np.isclose(row_in_diff_feature_df['pval'], curr_pval)
            assert np.isclose(row_in_diff_feature_df['pval'], curr_pval)
            bh_adjusted_pval = row_in_diff_feature_df['mw_bh_adjusted_pval']

            # raise
            # assert  == final_feature_name_to_type[attr]

            for max_pval, star_count in list_of_max_pval_and_star_count:
                if bh_adjusted_pval < max_pval:
                    x_pos = class_and_sex_to_pos_in_stripplot[clas_and_sex]
                    x_pos = x_pos + 0.125 * (1 if clas_and_sex.endswith('female') else -1) 
                    ax.annotate(
                        # '$' + '*' * star_count + '$',
                        # '$*$' * star_count,
                        '*' * star_count,
                        # '$' + '\\star' * star_count + '$', # looks bad for some reason
                        xy=(x_pos, 0.944), xycoords=('data', 'axes fraction'), color='k', ha='center', va='center', fontsize='large',
                    )
                    
                    break



    ax.set_ylabel(None)
    ax.set_xlabel(None)
    ax.set_title(attr)
    if 0:
        ax.set_xticklabels(
            ax.get_xticklabels(),
            rotation=90, ha='center', 
            # rotation_mode='anchor', 
            # fontsize='small',
        )
    for x in male_female_vline_xs:
        ax.axvline(x, color='grey', alpha=0.5, linestyle='--')
    for x in class_vline_xs:
        ax.axvline(x, color='black', alpha=0.9, linestyle='-')

    # if plot_legend:
    #     legend = ax.legend(fontsize='x-large')
    #     legend.get_frame().set_alpha(None)
    ax.get_legend().remove()
    # fig.tight_layout()
    fig.savefig(
        os.path.join(out_dir_path, f'EDF_10B_{attr}_stripplot.png'), 
        dpi=300,
    )
    # break


In [None]:
out_dir_path = 'temp/_fig_sup_4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

show_donor_ids = True
show_donor_ids = False

plot_y_equals_x = True
plot_y_equals_x = False

curr_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])

curr_df = curr_df[curr_df['in_within_state_sig_bio_rep_analysis']]
print(f'len(curr_df): {len(curr_df)}')
print(curr_df.drop_duplicates(subset='donor_id')['latest_diagnosis_class'].value_counts())

final_feature_names_and_types_df = pd.read_csv(get_mds_params()['final_feature_names_and_types_df_csv_file_path'])

bio_rep_df = sc_rna_seq_preprocessing_params.get_earliest_and_latest_bio_rep_df(curr_df)
print(f'len(bio_rep_df): {len(bio_rep_df)}')

flat_dicts = []
for col in sorted([
    # *final_feature_names_and_types_df.loc[final_feature_names_and_types_df['type'].isin(['gene_prog_excluding_mhc_ii', 'mhc_ii_gene_prog', 'log_c_state_freq']), 'name'],
    'mebemp_l_hla_sig_MEBEMP-L_0.5',
    'bemp_cd74_sig_BEMP_0.5',
    's_phase_sig_MEBEMP-L_0.5',
    
    # 'prss2_sig_CLP_0.5',
    # 'gmp_l_elane_sig_GMP-L_0.5',
]):

    x_col = f'{col}_early'
    y_col = f'{col}_late'


    curr_df = bio_rep_df[bio_rep_df[[x_col, y_col]].notna().all(axis=1)]
    bio_rep_count = len(curr_df)
    print(len(bio_rep_df))
    print(f'{col} final bio_rep_count: {bio_rep_count}')

    fig, ax = plt.subplots(figsize=(3.5,3.5))
    sb.scatterplot(
        data=curr_df,
        x=x_col,
        y=y_col,
        hue='diagnosis_class_late',
        palette=clinical_data_interface_and_utils.DIAGNOSIS_CLASS_TO_COLOR,
        ax=ax,
        legend=False,
    )
    slope, intercept, r_value, p_value, std_err_of_estimated_gradient = scipy.stats.linregress(
        curr_df[x_col], curr_df[y_col],
        # alternative='two-sided',
    )
    generic_utils.plot_line_on_ax(ax, slope=slope, offset=intercept, color='grey', alpha=0.5, linestyle='--')
    if plot_y_equals_x:
        generic_utils.plot_y_equals_x_line_on_ax(ax, color='k', alpha=0.5, linestyle=':')
    linear_fit_repr = f'n={len(curr_df)}, r={r_value:.2f}, pval={p_value:.0E}'
    print(f'slope={slope:.3f}, intercept={intercept:.3f}')
    ax.set_title(linear_fit_repr)
    if show_donor_ids:
        for i, row in curr_df.iterrows():
            ax.text(row[x_col], row[y_col], row['donor_id'])
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir_path, f'EDF_10D_{generic_utils.strip_special_chars_on_edges_and_replace_others_with_underscores(col)}_bio_reps.png'), dpi=300)
    plt.close('all')
    flat_dicts.append({
        'feature': col,
        'r_value': r_value,
        'p_value': p_value,
        'bio_rep_count': bio_rep_count,
    })
feature_bio_rep_df = pd.DataFrame(flat_dicts)
feature_bio_rep_df.sort_values('p_value', inplace=True)
#

if 1:
    print(list(feature_bio_rep_df.loc[feature_bio_rep_df['r_value'] < 0.5, 'feature']))
    print(feature_bio_rep_df)

In [None]:
out_dir_path = 'temp/_fig_sup_4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)
from mds import pb_cd34_c_score_threshs
c_ad = n2_c_ad
# c_ad = ad.read_h5ad(mds_analysis_params.get_mc_model_paths('final_mds_cyto_normal_excluding_atlas', dataset_name='illu_mds',
# )['cells_with_metacell_attrs_ad'])
mc_utils.add_c_state_and_mc_c_state_stats(
    None, c_ad, 
    cell_state_and_info_list=get_mds_params()['pb_cd34_enriched_cell_state_and_c_info_list'], 
    mask_and_info_list=get_mds_params()['pb_cd34_enriched_mask_and_c_info_list'], 
    cell_type_colors_csv_file_path=get_mds_params()['cell_type_colors_csv_file_path'],
    only_add_c_state=True,
)

c_mask = None

palette = mc_utils.get_palette(c_ad, 'c_state')

if c_mask is None:
    c_mask = np.full(c_ad.n_obs, True)

# curr_df = c_ad.obs[generic_utils.get_equal_contrib_mask([c_ad.obs['c_state'] == x for x in ['CLP', 'HSC_MPP']])].copy()
curr_df = c_ad.obs[c_mask & c_ad.obs['c_state'].isin(['CLP', 'HSC_MPP'])].copy()
curr_df['c_state'] = curr_df['c_state'].astype(str)

plt.close('all')
fig, ax = plt.subplots(figsize=(4,4))
sb.histplot(
    curr_df, 
    x='higher_in_clp_m_than_all_myeloid_hspcs', 
    # hue='c_state',
    # palette=palette,
    # common_norm=False,
    # stat='proportion',
    bins=100,
    legend=False,
    ax=ax,
)
for x in (pb_cd34_c_score_threshs.HIGHER_IN_CLP_M_THAN_ALL_MYELOID_HSPCS_CLP_E_LOW_THRESH, pb_cd34_c_score_threshs.HIGHER_IN_CLP_M_THAN_ALL_MYELOID_HSPCS_CLP_E_HIGH_THRESH):
    ax.axvline(x=x, color='red', alpha=0.3)
ax.set_title('CLP-sig across HSC_MPP and CLP')
fig.tight_layout()
fig.savefig(os.path.join(out_dir_path, 'EDF_11D_clp_e_definition_hist.png'), dpi=300)



In [None]:
out_dir_path = 'temp/_fig4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

show_donor_ids = False
show_donor_ids = True

skip_plotting = True
skip_plotting = False

only_mpp_nd_ids_without_suffix = ['N405', 'N421']
high_gmp_nd_ids_without_suffix = ['G12', 'N235_2', 'N367', 'N281_1', 'N250', 'N413', 'N403']
high_bemp_nd_ids_without_suffix = ['N151', 'N192_1']
high_clp_nd_ids_without_suffix = ['N387', 'N204_2']
high_clp_low_mebemp_nd_ids_without_suffix = ['N203', 'N241', 'N283']
comp_class_to_nd_ids_without_suffix = {
    'only HSC/MPP': only_mpp_nd_ids_without_suffix,
    'high GMP': high_gmp_nd_ids_without_suffix,
    'high BEMP': high_bemp_nd_ids_without_suffix,
    'high CLP': high_clp_nd_ids_without_suffix,
    'high CLP low MEBEMP': high_clp_low_mebemp_nd_ids_without_suffix,
}

linkage_metric = 'correlation'
linkage_metric = 'euclidean'

linkage_method = 'average'
linkage_method = 'ward' # scipy says "Method 'ward' requires the distance metric to be Euclidean"

epsilon_for_cell_freq = get_mds_params()['epsilon_for_donor_log_c_state_freq']

dont_show_normal = False
dont_show_normal = True

dont_show_normal_atlas = True
dont_show_normal_atlas = False

assert not (dont_show_normal_atlas and dont_show_normal)

cluster_log_freqs = False
cluster_log_freqs = True

cluster_log_enrich = True
cluster_log_enrich = False

nn_dist_high_quantile = 0.98

hclust_clust_count = None
hclust_clust_count = 5


order_by_k_means_k = 5
order_by_k_means_k = None

lower_ax_cols = [
    'diagnosis_c',
    'DNMT3A_max_mean_VAF',
    'TET2_max_mean_VAF',
    'SF3B1_max_mean_VAF',
    'other max VAF',
    'any_CNA',
]

order_donors_by_id = True
order_donors_by_id = False

order_donors_by_bleeding_date_and_exp_name = True
order_donors_by_bleeding_date_and_exp_name = False

discard_numbered_donor_id_tech_rep_suffix = False
discard_numbered_donor_id_tech_rep_suffix = True

verify_states_in_mds_analysis_params = False
verify_states_in_mds_analysis_params = True

all_ordered_state_names = [f'c_{x}' for x in mds_analysis_params.ORDERED_CELL_STATE_NAMES]
ordered_pb_hspc_state_names = [f'c_{x}' for x in mds_analysis_params.ORDERED_PB_HSPC_STATE_NAMES]

mutation_df = arch_mutation_interface_and_utils.get_minimal_all_mutation_df()
mut_genes = sorted(mutation_df['gene'].drop_duplicates())

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
# all_ext_feature_df = pd.concat([pd.read_csv(x).assign(ext_donor_feature_df_file_path=x) for x in extended_df_csv_file_paths], ignore_index=True)
# all_ext_feature_df = sc_rna_seq_preprocessing_params.get_df_with_numbered_donor_id(all_ext_feature_df)
curr_feature_df = all_ext_feature_df.copy()

curr_feature_df['atlas'] = curr_feature_df['dataset_name'] == 'ref'

curr_feature_df['donor_age'] = curr_feature_df['donor_age'].astype(float)
curr_feature_df['diagnosis_class'] = curr_feature_df['diagnosis_class'].astype(str)

curr_feature_df['any_large_dels_or_dups'] = (curr_feature_df['cna_count'] > 0).astype(int)

top_3_mut_genes = ['DNMT3A', 'TET2', 'SF3B1']
mut_genes_except_top_3 = sorted(set(mut_genes) - set(top_3_mut_genes))
curr_feature_df['max_mean_VAF_excluding_top_3'] = curr_feature_df[[f'{x}_max_mean_VAF' for x in mut_genes_except_top_3]].max(axis=1)

observed_hspc_state_names = sorted(set(curr_feature_df.columns) & set(ordered_pb_hspc_state_names))
log_hspc_freq_cols = [f'log_{x}' for x in observed_hspc_state_names]
curr_feature_df[log_hspc_freq_cols] = np.log2(curr_feature_df[observed_hspc_state_names] + epsilon_for_cell_freq)

unexpected_diagnosis_mask = ~curr_feature_df['diagnosis_class'].isin(['MDS', 'MDS/MPN', 'cytopenia', 'normal'])
assert not unexpected_diagnosis_mask.any()

curr_feature_df = curr_feature_df[curr_feature_df['in_mds_cyto_groups_analysis']]
assert curr_feature_df['donor_id'].is_unique
print(f'len(curr_feature_df): {len(curr_feature_df)}')
print(f'(curr_feature_df[f"diagnosis_class"] != "normal").sum(): {(curr_feature_df[f"diagnosis_class"] != "normal").sum()}')

# print(curr_ext_donor_feature_df['donor_id'].nunique())
# print(curr_ext_donor_feature_df[curr_ext_donor_feature_df['donor_id'] == 'N198'])
all_mean_near_neighbor_dists_without_atlas = curr_feature_df.loc[~curr_feature_df['atlas'], 'mean_near_neighbor_dist']
normal_mean_nn_dists_without_atlas = curr_feature_df.loc[~curr_feature_df['atlas'] & (curr_feature_df['diagnosis_class'] == 'normal'), 'mean_near_neighbor_dist']
curr_feature_df['normal'] = curr_feature_df[f'diagnosis_class'] == 'normal'

normal_mean_nn_dist_without_atlas_high_quantile_val = normal_mean_nn_dists_without_atlas.quantile(nn_dist_high_quantile)

if not skip_plotting:
    plt.close('all')
    fig, ax = plt.subplots(figsize=(3.5,3.5))
    # sb.histplot(curr_feature_df['mean_near_neighbor_dist'], bins=30)
    # sb.histplot(all_mean_near_neighbor_dists_without_atlas, bins=30)
    # sb.histplot(normal_mean_nn_dists_without_atlas, bins=30)
    sb.histplot(curr_feature_df.loc[~curr_feature_df['atlas']], x='mean_near_neighbor_dist', hue='diagnosis_class', cumulative=True, element='step', fill=False, stat='proportion', common_norm=False, bins=3000, palette=clinical_data_interface_and_utils.DIAGNOSIS_CLASS_TO_COLOR, ax=ax)
    ax.axvline(normal_mean_nn_dist_without_atlas_high_quantile_val, color='grey', alpha=0.5)
    ax.legend().remove()
    supp_out_dir_path = 'temp/_fig_sup_4'
    pathlib.Path(supp_out_dir_path).mkdir(parents=True, exist_ok=True)
    fig.savefig(os.path.join(supp_out_dir_path, 'EDF_10A_mean_near_neighbor_dist_cum_hist.png'), dpi=300)
    # raise

feature_df_before_removing_normal = curr_feature_df.copy()

if dont_show_normal:
    curr_feature_df = curr_feature_df[curr_feature_df['diagnosis_class'] != 'normal']
if dont_show_normal_atlas:
    curr_feature_df = curr_feature_df[(curr_feature_df['diagnosis_class'] != 'normal') | (~curr_feature_df['atlas'])]


assert not curr_feature_df.empty, 'filtered out all donors'

observed_cell_state_names = [x for x in curr_feature_df.columns if x in all_ordered_state_names]
missing_cell_state_names = set(observed_cell_state_names) - set(all_ordered_state_names)
assert (not missing_cell_state_names), f'{missing_cell_state_names} are missing from all_ordered_state_names'

assert set(observed_hspc_state_names) == {x for x in observed_cell_state_names if x in ordered_pb_hspc_state_names}
df = curr_feature_df[observed_hspc_state_names + ['numbered_donor_id']].copy()
df[log_hspc_freq_cols] = np.log2(df[observed_hspc_state_names] + epsilon_for_cell_freq)
numbered_donor_ids_in_orig_order = df['numbered_donor_id'].to_numpy()
numbered_donor_ids_in_orig_order_as_list = list(numbered_donor_ids_in_orig_order)
df.set_index('numbered_donor_id', inplace=True)


cols_for_clustering = log_hspc_freq_cols if cluster_log_freqs else observed_hspc_state_names
df_for_plot = df[observed_hspc_state_names].copy()


plt.close('all')
type_to_color = mds_analysis_params.get_cell_state_to_color(add_outlier_black_color=True)

figsize_x = 25
figsize_y = 15

if not skip_plotting:
    fig = plt.figure(
        figsize=(figsize_x, figsize_y),
    )

    num_of_rows = 8

    fig_gridspec = fig.add_gridspec(
        num_of_rows, 1, 
        height_ratios=[
            0.2, # age, sex, diagnosis
            
            0.35, # space

            
            4, # HSPC composition
            0.4 if show_donor_ids else 0.13, # space (for donor_id)
            1, # mean_near_neighbor_dist
            0.13, # space (for donor_id)
            
            0.5, # vafs, s_phase, bemp ms4a2, hla
            
            0.1, # space
        ],
    )

    if 1:
        fig.subplots_adjust(
            top=0.97,
            # left=0.065,
            left=0.1,
            # left=0.12, # fit MPN on 230820?
            hspace=0.0000001,
            
            
            bottom=0.15,
            # wspace=0.2,
            # right=0.78, # fit MPN on 230820?
            right=0.88,
        )
    # print(fig_gridspec)
    # print(fig_gridspec[0,0])
    composition_ax = fig.add_subplot(fig_gridspec[2,0])
    mean_nn_dist_ax = fig.add_subplot(fig_gridspec[4,0])
    lower_extra_data_ax = fig.add_subplot(fig_gridspec[6,0])

    # generic_utils.make_all_spines_and_x_and_y_axes_invisible(mean_nn_dist_ax)
    # ax.get_yaxis().set_visible(True)
    mean_nn_dist_ax.get_xaxis().set_visible(False)

    lower_extra_data_ax.get_xaxis().set_visible(False)
    lower_extra_data_ax.get_yaxis().set_visible(True)
    lower_extra_data_ax.set_facecolor("pink") # TODO: should remove?


if order_donors_by_id:
    assert not order_donors_by_bleeding_date_and_exp_name
    nd_ids_sorted_by_id = sorted(numbered_donor_ids_in_orig_order)
    assert set(numbered_donor_ids_in_orig_order) == set(nd_ids_sorted_by_id)
    ordered_row_indices = [numbered_donor_ids_in_orig_order_as_list.index(x) for x in nd_ids_sorted_by_id]
elif order_donors_by_bleeding_date_and_exp_name:
    ordered_exp_names = sc_rna_seq_preprocessing_params.get_ordered_exp_names_from_df_with_bleeding_dates(curr_feature_df)
    numbered_donor_id_to_exp_name = curr_feature_df.set_index('numbered_donor_id')['exp_name'].to_dict()
    assert set(numbered_donor_id_to_exp_name) == set(numbered_donor_ids_in_orig_order)
    nd_ids_sorted_by_bleeding_date_and_exp_name = sorted(
        numbered_donor_ids_in_orig_order, key=lambda x: (ordered_exp_names.index(numbered_donor_id_to_exp_name[x]), x))
    assert set(numbered_donor_ids_in_orig_order) == set(nd_ids_sorted_by_bleeding_date_and_exp_name)
    ordered_row_indices = [numbered_donor_ids_in_orig_order_as_list.index(x) for x in nd_ids_sorted_by_bleeding_date_and_exp_name]
else:
    low_nn_dist_mask = curr_feature_df['mean_near_neighbor_dist'] <= normal_mean_nn_dist_without_atlas_high_quantile_val
    high_nn_dist_mask = ~low_nn_dist_mask

    print(f'low_nn_dist_mask.sum(): {low_nn_dist_mask.sum()}')
    print(f'high_nn_dist_mask.sum(): {high_nn_dist_mask.sum()}')

    low_nn_dist_and_any_cna_mask = curr_feature_df['any_large_dels_or_dups'] & low_nn_dist_mask
    low_nn_dist_and_any_cna_count = low_nn_dist_and_any_cna_mask.sum()
    low_nn_dist_and_no_cna_count = low_nn_dist_mask.sum() - low_nn_dist_and_any_cna_count
    
    high_nn_dist_and_any_cna_mask = curr_feature_df['any_large_dels_or_dups'] & high_nn_dist_mask
    high_nn_dist_and_any_cna_count = high_nn_dist_and_any_cna_mask.sum()
    high_nn_dist_and_no_cna_count = high_nn_dist_mask.sum() - high_nn_dist_and_any_cna_count
    
    print(f'low_nn_dist_and_any_cna_count: {low_nn_dist_and_any_cna_count}')
    print(f'low_nn_dist_and_no_cna_count: {low_nn_dist_and_no_cna_count}')
    print(f'high_nn_dist_and_any_cna_count: {high_nn_dist_and_any_cna_count}')
    print(f'high_nn_dist_and_no_cna_count: {high_nn_dist_and_no_cna_count}')
    
    cna_fisher_result = scipy.stats.fisher_exact([
        [low_nn_dist_and_any_cna_count, low_nn_dist_and_no_cna_count],
        [high_nn_dist_and_any_cna_count, high_nn_dist_and_no_cna_count],
    ], alternative='two-sided')
    print(f'fisher result: {cna_fisher_result}')

    df_for_clustering = curr_feature_df.loc[high_nn_dist_mask, cols_for_clustering]
    df_for_clustering.index = curr_feature_df.loc[high_nn_dist_mask, 'numbered_donor_id']
    np.random.seed(0) # added this because i seem to have got a slightly different ordering upon in two different runs, though not important. anyway, i tried 5 different seeds and didn't get different results, even though https://stackoverflow.com/questions/16016959/scipy-stats-seed/16018958#16018958 says scipy uses numpy.random to get random numbers... ah. i think i understand. when moving the code of adding CBC to feature table to after i merged the feature tables, the order of the donors changed for some reason, so now here i started from a different order, and this is why i get very slightly different clustering order...
    linkage_mat = scipy.cluster.hierarchy.linkage(
        df_for_clustering.to_numpy(),
        method=linkage_method, 
        metric=linkage_metric,
        # optimal_ordering=True,
    )
    hclust_ordered_row_indices = scipy.cluster.hierarchy.leaves_list(linkage_mat)
    nd_ids_sorted_by_hclust = list(df_for_clustering.iloc[hclust_ordered_row_indices].index)

    df_for_clustering['hclust_clust'] = scipy.cluster.hierarchy.fcluster(
        Z=linkage_mat,
        t=hclust_clust_count,
        criterion='maxclust',
    )
    df_for_clustering['hclust_clust'] = df_for_clustering['hclust_clust'].astype(str)
    df_for_clustering.reset_index(inplace=True)
    curr_feature_df = generic_utils.merge_preserving_df1_index_and_row_order(curr_feature_df, df_for_clustering[['numbered_donor_id', 'hclust_clust']], how='left')
    curr_feature_df['hclust_clust'].fillna('-1', inplace=True)


    curr_lymph_hspc_freq_cols = sorted(set(curr_feature_df.columns) & set([f'c_{x}' for x in mds_analysis_params.PB_LYMPHOID_HSPC_STATE_NAMES]))
    curr_non_lymph_hspc_freq_cols = sorted(set(curr_feature_df.columns) & set([f'c_{x}' for x in mds_analysis_params.PB_NON_LYMPHOID_HSPC_STATE_NAMES]))
    curr_feature_df['lymph_hspc_freq'] = curr_feature_df[curr_lymph_hspc_freq_cols].sum(axis=1)
    curr_feature_df['non_lymph_hspc_freq'] = curr_feature_df[curr_non_lymph_hspc_freq_cols].sum(axis=1)
    assert np.allclose(curr_feature_df[['lymph_hspc_freq', 'non_lymph_hspc_freq']].sum(axis=1), 1)
    curr_feature_df['lymph_non_lymph_hspc_freq_log_ratio'] = np.log2(curr_feature_df['lymph_hspc_freq'] + epsilon_for_cell_freq) - np.log2(
        curr_feature_df['non_lymph_hspc_freq'] + epsilon_for_cell_freq)
    nd_ids_sorted_by_lymph_myelo_log_ratio = list(curr_feature_df.loc[low_nn_dist_mask].sort_values('lymph_non_lymph_hspc_freq_log_ratio', ascending=False)['numbered_donor_id'])
    ordered_row_indices = [numbered_donor_ids_in_orig_order_as_list.index(x) for x in nd_ids_sorted_by_lymph_myelo_log_ratio + nd_ids_sorted_by_hclust]


    if order_by_k_means_k is not None:
        kmeans = sklearn.cluster.KMeans(n_clusters=order_by_k_means_k, random_state=0, init='k-means++').fit(df_for_clustering.to_numpy())
        df_for_clustering['kmeans_cluster'] = kmeans.labels_.astype(str)
        df_for_clustering['hclust_i'] = [nd_ids_sorted_by_hclust.index(x) for x in numbered_donor_ids_in_orig_order_as_list]
        
        df_for_clustering = generic_utils.merge_preserving_df1_index_and_row_order(df_for_clustering, df_for_clustering.groupby('kmeans_cluster')['c_CLP'].mean().reset_index(name='kmeans_cluster_mean_c_CLP'))
        # assert not cluster_log_freqs
        # df_for_clustering['CLP_and_NKTDP'] = df_for_clustering['c_CLP'] + df_for_clustering['c_NKTDP']
        df_for_clustering = df_for_clustering.reset_index().rename(columns={'index': 'numbered_donor_id'})
        df_for_clustering = generic_utils.merge_preserving_df1_index_and_row_order(df_for_clustering, curr_feature_df[[
            'numbered_donor_id', 
            'DNMT3A_max_mean_VAF',
            'TET2_max_mean_VAF',
            'SF3B1_max_mean_VAF',
            'max_mean_VAF_excluding_top_3',
            'any_large_dels_or_dups',
            'donor_age',
        ]])
        # nd_ids_sorted_by_kmeans = list(df_for_clustering.sort_values(['kmeans_cluster_mean_c_CLP', 'kmeans_cluster', 'CLP_and_NKTDP'])['numbered_donor_id'])
        nd_ids_sorted_by_kmeans = list(df_for_clustering.sort_values(['kmeans_cluster_mean_c_CLP', 'kmeans_cluster', 'hclust_i',
                                                                       'DNMT3A_max_mean_VAF',
            'TET2_max_mean_VAF',
            'SF3B1_max_mean_VAF',
            'max_mean_VAF_excluding_top_3',
            'any_large_dels_or_dups', 'donor_age'])['numbered_donor_id'])
        
        curr_feature_df = generic_utils.merge_preserving_df1_index_and_row_order(curr_feature_df, df_for_clustering[['numbered_donor_id', 'kmeans_cluster']], how='left')
        assert set(numbered_donor_ids_in_orig_order) == set(nd_ids_sorted_by_kmeans)
        ordered_row_indices = [numbered_donor_ids_in_orig_order_as_list.index(x) for x in nd_ids_sorted_by_kmeans]

df_for_plot = df_for_plot.iloc[ordered_row_indices]

if discard_numbered_donor_id_tech_rep_suffix:
    df_for_plot.index = sc_rna_seq_preprocessing_params.get_numbered_donor_id_series_without_tech_rep_suffix(pd.Series(df_for_plot.index)).to_numpy() # 

final_donor_ids_ordered = list(numbered_donor_ids_in_orig_order[ordered_row_indices])

if order_donors_by_id:
    assert final_donor_ids_ordered == nd_ids_sorted_by_id
elif order_donors_by_bleeding_date_and_exp_name:
    assert final_donor_ids_ordered == nd_ids_sorted_by_bleeding_date_and_exp_name

ordered_observed_hspc_states = [x for x in ordered_pb_hspc_state_names if x in df_for_plot.columns]

df_for_hspc_plot = df_for_plot[ordered_observed_hspc_states]
# assert np.isclose(df_for_hspc_plot.sum(axis=1), 1, atol=0.02).all()

unexpected_sum_mask = ~np.isclose(df_for_hspc_plot.sum(axis=1), 1)
assert not unexpected_sum_mask.any(), f'df_for_hspc_plot[unexpected_sum_mask]: {df_for_hspc_plot[unexpected_sum_mask]}'

curr_feature_df_ordered = generic_utils.get_df_sorted_by_column_vals(
    curr_feature_df, 'numbered_donor_id', final_donor_ids_ordered, left_join=True)
# curr_ext_donor_feature_df = generic_utils.merge_preserving_df1_index_and_row_order(
#     curr_ext_donor_feature_df, curr_ext_donor_feature_df[['numbered_donor_id', 'donor_age']].drop_duplicates(subset='numbered_donor_id'))

curr_feature_df_ordered.loc[curr_feature_df_ordered['numbered_donor_id'].isin(nd_ids_sorted_by_hclust), 'composition_class'] = 'other abnorm comp'
curr_feature_df_ordered.loc[~curr_feature_df_ordered['numbered_donor_id'].isin(nd_ids_sorted_by_hclust), 'composition_class'] = 'norm-like comp'

ordered_nd_ids_without_suffix = list(curr_feature_df_ordered['numbered_donor_id_without_tech_rep_suffix'])
for comp_class, class_nd_ids_without_suffix in comp_class_to_nd_ids_without_suffix.items():
    indices = [ordered_nd_ids_without_suffix.index(x) for x in class_nd_ids_without_suffix]
    min_i = min(indices)
    max_i = max(indices)
    assert set(indices) == set(range(min_i, max_i + 1)), f'{ordered_nd_ids_without_suffix} are not consecutive'

    curr_feature_df_ordered.loc[curr_feature_df_ordered['numbered_donor_id_without_tech_rep_suffix'].isin(class_nd_ids_without_suffix), 'composition_class'] = comp_class

    if not skip_plotting:
        composition_ax.annotate('', xy=(min_i - 0.33, 1.01), xycoords=('data', 'axes fraction'), xytext=(max_i + 0.33, 1.01), 
                arrowprops=dict(arrowstyle="-", linewidth=2.5, color='k'))
        
assert curr_feature_df_ordered['composition_class'].notna().all()
curr_feature_df_ordered[['donor_id', 'numbered_donor_id', 'exp_name', 'composition_class']].to_csv(get_mds_params()['fig4_donor_composition_class_df_csv_file_path'], index=False)

if not skip_plotting:

    binary_cmap = plt.get_cmap('binary')
    binary_cmap.set_bad(color='red', alpha=0.2)
        

    raw_extra_data_df = curr_feature_df_ordered[[
            'DNMT3A_max_mean_VAF',
            'TET2_max_mean_VAF',
            'SF3B1_max_mean_VAF',
            # 'max_mean_VAF',
            'max_mean_VAF_excluding_top_3',
            'any_large_dels_or_dups',
            'diagnosis_class',
        ]].rename(columns={
            'log_HSPC_count': 'log(#HSPCs)', 
            'any_large_dels_or_dups': 'any_CNA',
            'max_mean_VAF_excluding_top_3': 'other max VAF',
            'diagnosis_class': 'diagnosis_c',
        })

    col_to_cmap_and_min_and_max = {
        'DNMT3A_max_mean_VAF': (binary_cmap, 0, 0.5),
        'TET2_max_mean_VAF': (binary_cmap, 0, 0.5),
        'SF3B1_max_mean_VAF': (binary_cmap, 0, 0.5),
        'other max VAF': (binary_cmap, 0, 0.5),
        'any_CNA': (binary_cmap, 0, 1),
    }
    col_to_val_to_color = {
        'diagnosis_c': clinical_data_interface_and_utils.DIAGNOSIS_CLASS_TO_COLOR,
    }

    def plot_col_color(ax, cols):
        curr_raw_extra_data_df = raw_extra_data_df[cols].copy()
        color_arr = generic_utils.get_color_arr(
            curr_raw_extra_data_df, allow_nans=True, 
            col_to_cmap_and_min_and_max={k: v for k, v in col_to_cmap_and_min_and_max.items() if k in cols},
            col_to_val_to_color={k: v for k, v in col_to_val_to_color.items() if k in cols},
        ).transpose((1, 0, 2))
        ax.imshow(color_arr, aspect='auto')
        ax.set_yticks(range(len(curr_raw_extra_data_df.columns)))
        ax.set_yticklabels(curr_raw_extra_data_df.columns)

    assert set(lower_ax_cols) == set(raw_extra_data_df.columns)
    plot_col_color(lower_extra_data_ax, lower_ax_cols)

    # df_for_plot.index = [(f'norm_{x}' if x in healthy_donor_ids else x) for x in df_for_plot.index]

    assert all(x.startswith('c_') for x in ordered_pb_hspc_state_names)
    df_for_hspc_plot = df_for_hspc_plot.copy()
    df_for_hspc_plot.rename(columns={x: x[2:] for x in ordered_pb_hspc_state_names}, inplace=True)

    # make sure the stacked bar composition shows at most a single sample per donor
    df_for_sanity_check = generic_utils.merge_preserving_df1_index_and_row_order(
        pd.DataFrame({'numbered_donor_id_without_tech_rep_suffix': df_for_hspc_plot.index}), 
        all_ext_feature_df[['numbered_donor_id_without_tech_rep_suffix', 'donor_id']].drop_duplicates(),
    )
    assert df_for_sanity_check['donor_id'].is_unique

    print(f'len(df_for_hspc_plot): {len(df_for_hspc_plot)}')

    df_for_hspc_plot.plot(
        kind="bar",
        ax=composition_ax,
        stacked=True,
        color=[
            type_to_color[x]
            for x in list(df_for_hspc_plot)
        ],
        # legend=False,
    )
    composition_ax.set_xlabel(None)
    if not show_donor_ids:
        composition_ax.get_xaxis().set_visible(False)
    composition_ax.set_ylim(0,1.001)

    handles, labels = composition_ax.get_legend_handles_labels()
    legend_handles = handles
    legend_labels = labels

    legend_ax = composition_ax
    legend_ax.legend(legend_handles[::-1], legend_labels[::-1], fontsize='x-large')
    sb.move_legend(legend_ax, "upper left", bbox_to_anchor=(1, 1), fontsize='x-large')


    min_nn_dist = curr_feature_df_ordered['mean_near_neighbor_dist'].min()
    max_nn_dist = curr_feature_df_ordered['mean_near_neighbor_dist'].max()
    norm_normal_mean_nn_dist_without_atlas_high_quantile_val = (normal_mean_nn_dist_without_atlas_high_quantile_val - min_nn_dist) / (max_nn_dist - min_nn_dist)
    print(f'norm_normal_mean_nn_dist_without_atlas_high_quantile_val: {norm_normal_mean_nn_dist_without_atlas_high_quantile_val}')

    mean_nn_dist_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    '', [(0, 'grey'), (norm_normal_mean_nn_dist_without_atlas_high_quantile_val, 'grey'), (1, 'red')])
    # curr_feature_df_ordered['norm_mean_nn_dist'] = generic_utils.normalize_df_columns(
    #     curr_feature_df_ordered, column_names=['mean_near_neighbor_dist'], 
    #     col_to_min_and_max={'mean_near_neighbor_dist': (min_nn_dist, max_nn_dist)},
    # )['mean_near_neighbor_dist']

    curr_feature_df_ordered['mean_nn_dist_color'] = generic_utils.get_color_df(
        curr_feature_df_ordered[['mean_near_neighbor_dist']], 
        col_to_cmap_and_min_and_max={'mean_near_neighbor_dist': (mean_nn_dist_cmap, min_nn_dist, max_nn_dist)})['mean_near_neighbor_dist']
    mean_nn_dist_palette = generic_utils.get_dict_mapping_one_df_column_to_other(curr_feature_df_ordered, 'numbered_donor_id', 'mean_nn_dist_color')
    # raise
    sb.barplot(data=curr_feature_df_ordered, x='numbered_donor_id', y='mean_near_neighbor_dist', hue='numbered_donor_id', palette=mean_nn_dist_palette, ax=mean_nn_dist_ax, dodge=False)
    mean_nn_dist_ax.get_legend().remove()

    ordered_by_k_means_repr = '' if (order_by_k_means_k is None) else f'_kmeans{order_by_k_means_k}'
    dont_show_normal_atlas_repr = f'_no_normal_atlas' if dont_show_normal_atlas else ''        
    dont_show_normal_repr = f'_no_normal' if dont_show_normal else ''        
    nn_dist_high_quantile_repr = f'_{nn_dist_high_quantile}'
    fig_file_name = f'fig_4BCD_{ordered_by_k_means_repr}{dont_show_normal_atlas_repr}{dont_show_normal_repr}{nn_dist_high_quantile_repr}.png'
    fig.savefig(
        os.path.join(out_dir_path, fig_file_name), 
        dpi=300,
    )
    plt.close('all')


In [None]:
out_dir_path = 'temp/_fig4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

horizontal = False
horizontal = True
orientation = 'horizontal' if horizontal else 'vertical'

mean_nn_dist_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
    '', [(0, 'grey'), (norm_normal_mean_nn_dist_without_atlas_high_quantile_val, 'grey'), (1, 'red')])

print(normal_mean_nn_dist_without_atlas_high_quantile_val)
plt.close('all')
fig, ax = plt.subplots(figsize=(3, 1) if horizontal else (1, 3))
if horizontal:
    fig.subplots_adjust(bottom=0.5)
else:
    fig.subplots_adjust(right=0.5)

# cmap = matplotlib.cm.cool
norm = matplotlib.colors.Normalize(vmin=0, vmax=max_nn_dist)

cb1 = matplotlib.colorbar.ColorbarBase(ax, cmap=mean_nn_dist_cmap,
                                norm=norm,
                                orientation=orientation)
cb1.set_label('composition abnormality score')

fig.savefig(
    os.path.join(out_dir_path, f'FIG_4C_comp_abnorm_score_colorbar_{orientation}.png'), 
    dpi=300,
)

plt.close('all')
fig, ax = plt.subplots(figsize=(3, 1) if horizontal else (1, 3))
if horizontal:
    fig.subplots_adjust(bottom=0.5)
else:
    fig.subplots_adjust(right=0.5)

# cmap = matplotlib.cm.cool
norm = matplotlib.colors.Normalize(vmin=0, vmax=50)

cb1 = matplotlib.colorbar.ColorbarBase(ax, cmap=binary_cmap,
                                norm=norm,
                                orientation=orientation)
cb1.set_label('VAF')

fig.savefig(
    os.path.join(out_dir_path, f'FIG_4D_vaf_colorbar_{orientation}.png'), 
    dpi=300,
)

In [None]:
out_dir_path = 'temp/_fig_sup_4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

for ordered_numbered_donor_ids_without_suffix in [
    [
        'N192_1',
        'N192_2',
    ],
    [
        'N235_1',
        'N235_2',
    ],
    [
        'N281_1',
        'N281_2',
    ],
    [
        'N204_1',
        'N204_2',
    ],
    [
        'N226_1',
        'N226_2',
    ],
    [
        'N224_1',
        'N224_2',
    ],
    # 'N74',
]:
    epsilon_for_cell_freq = get_mds_params()['epsilon_for_donor_log_c_state_freq']

    all_ordered_state_names = [f'c_{x}' for x in mds_analysis_params.ORDERED_CELL_STATE_NAMES]
    ordered_pb_hspc_state_names = [f'c_{x}' for x in mds_analysis_params.ORDERED_PB_HSPC_STATE_NAMES]

    all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
    curr_feature_df = all_ext_feature_df.copy()
    curr_feature_df = curr_feature_df[
        (curr_feature_df['final_c_hspc_count'] >= 450) &
        ~curr_feature_df['biased_composition_due_to_discarded_low_umi_count_barcodes']
    ].copy()
    
    curr_feature_df['diagnosis_class'] = curr_feature_df['diagnosis_class'].astype(str)

    observed_hspc_state_names = sorted(set(curr_feature_df.columns) & set(ordered_pb_hspc_state_names))
    log_hspc_freq_cols = [f'log_{x}' for x in observed_hspc_state_names]
    curr_feature_df[log_hspc_freq_cols] = np.log2(curr_feature_df[observed_hspc_state_names] + epsilon_for_cell_freq)

    unexpected_diagnosis_mask = ~curr_feature_df['diagnosis_class'].isin(['MDS', 'MDS/MPN', 'cytopenia', 'normal'])
    assert not unexpected_diagnosis_mask.any()

    # curr_feature_df = curr_feature_df[curr_feature_df['in_mds_cyto_groups_analysis']]
    curr_feature_df = curr_feature_df[curr_feature_df['numbered_donor_id'].isin(
        sc_rna_seq_preprocessing_params.get_best_or_random_numbered_donor_id_per_donor_or_donor_sample(curr_feature_df))]

    # curr_feature_df = curr_feature_df[curr_feature_df['diagnosis_class'] == 'normal']

    observed_cell_state_names = [x for x in curr_feature_df.columns if x in all_ordered_state_names]
    missing_cell_state_names = set(observed_cell_state_names) - set(all_ordered_state_names)
    assert (not missing_cell_state_names), f'{missing_cell_state_names} are missing from all_ordered_state_names'

    assert set(observed_hspc_state_names) == {x for x in observed_cell_state_names if x in ordered_pb_hspc_state_names}
    df = curr_feature_df[observed_hspc_state_names + ['numbered_donor_id']].copy()
    df[log_hspc_freq_cols] = np.log2(df[observed_hspc_state_names] + epsilon_for_cell_freq)
    df.set_index('numbered_donor_id', inplace=True)
    df_for_plot = df[observed_hspc_state_names].copy()


    plt.close('all')
    type_to_color = mds_analysis_params.get_cell_state_to_color(add_outlier_black_color=True)

    fig, ax = plt.subplots(figsize=(2, 4))

    df_for_plot.index = sc_rna_seq_preprocessing_params.get_numbered_donor_id_series_without_tech_rep_suffix(pd.Series(df_for_plot.index)).to_numpy() # 

    ordered_observed_hspc_states = [x for x in ordered_pb_hspc_state_names if x in df_for_plot.columns]
    df_for_hspc_plot = df_for_plot[ordered_observed_hspc_states]
    print(df_for_hspc_plot.index)
    df_for_hspc_plot = df_for_hspc_plot.loc[ordered_numbered_donor_ids_without_suffix]
    # df_for_hspc_plot = df_for_hspc_plot.loc[sorted(df_for_hspc_plot.index)]

    unexpected_sum_mask = ~np.isclose(df_for_hspc_plot.sum(axis=1), 1)
    assert not unexpected_sum_mask.any(), f'df_for_hspc_plot[unexpected_sum_mask]: {df_for_hspc_plot[unexpected_sum_mask]}'

    df_for_hspc_plot = df_for_hspc_plot.copy()
    df_for_hspc_plot.rename(columns={x: x[2:] for x in ordered_pb_hspc_state_names}, inplace=True)

    df_for_hspc_plot.plot(
        kind="bar",
        ax=ax,
        stacked=True,
        color=[
            type_to_color[x]
            for x in list(df_for_hspc_plot)
        ],
        legend=False,
    )
    ax.set_xlabel(None)
    ax.set_ylim(0,1.001)
    generic_utils.make_all_spines_and_x_and_y_axes_invisible(ax)
    ax.get_xaxis().set_visible(True)
    fig.tight_layout()

    fig.savefig(
        os.path.join(out_dir_path, f'EDF_12C_bio_rep_compos_stacked_{ordered_numbered_donor_ids_without_suffix[0]}.png'), 
        dpi=300,
    )
    plt.close('all')


In [None]:
out_dir_path = 'temp/_fig4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

sorted_diagnosis_and_color = [(x, clinical_data_interface_and_utils.DIAGNOSIS_CLASS_TO_COLOR[x]) for x in ['normal', 'cytopenia', 'MDS', 'MDS/MPN']]

plt.close('all')

fig, ax = plt.subplots()

ax.legend(handles=[matplotlib.patches.Patch(color=x[1], label=x[0]) for x in sorted_diagnosis_and_color], fontsize='x-large')

fig.savefig(
    os.path.join(out_dir_path, f'patch_diagnosis_legend.png'), 
    dpi=300,
)

plt.close('all')

fig, ax = plt.subplots()

for x in sorted_diagnosis_and_color:
    ax.scatter([0.5], [0.5], s=60, color=x[1], label=x[0])

ax.legend(fontsize='large')
sb.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), fontsize='large')
fig.tight_layout()
fig.savefig(
    os.path.join(out_dir_path, f'circle_diagnosis_legend.png'), 
    dpi=300,
)

plt.close('all')

fig, ax = plt.subplots()

for x in sorted_diagnosis_and_color:
    ax.plot([0.5], [0.5], color=x[1], label=x[0])

ax.legend(fontsize='large')
sb.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), fontsize='large')
fig.tight_layout()
fig.savefig(
    os.path.join(out_dir_path, f'line_diagnosis_legend.png'), 
    dpi=300,
)

plt.close('all')

fig, ax = plt.subplots(figsize=(10,10))

if 'n2_c_ad' not in locals():
    raise RuntimeError('run code above to get n2_c_ad with updated c_states (e.g., without CFD_tryptase_GMP-L) - need to call mc_utils.add_c_state_and_mc_c_state_stats()')


assert set(n2_c_ad.obs['c_state'].astype(str)) <= set(mds_analysis_params.ORDERED_CELL_STATE_NAMES)

ordered_c_states = [x for x in mds_analysis_params.ORDERED_PB_HSPC_STATE_NAMES if x in n2_c_ad.obs['c_state'].unique()]
ordered_c_states = [x for x in mds_analysis_params.ORDERED_CELL_STATE_NAMES if x in n2_c_ad.obs['c_state'].unique()]
state_color_df = pd.DataFrame({'state': ordered_c_states})
state_color_df = generic_utils.merge_preserving_df1_index_and_row_order(state_color_df, pd.read_csv(get_mds_params()['cell_type_colors_csv_file_path']).rename(columns={'cell_type': 'state'}))
for x in state_color_df.to_records(index=False).tolist()[::-1]:
    ax.scatter([0.5], [0.5], s=40, color=x[1], label=x[0])

ax.legend(fontsize='large')
sb.move_legend(ax, "upper left", bbox_to_anchor=(1, 1), fontsize='large')
fig.tight_layout()
fig.savefig(
    os.path.join(out_dir_path, f'cell_state_circle_legend.png'), 
    dpi=300,
)

In [None]:
out_dir_path = 'temp/_fig4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

krusk_with_normal = True
krusk_with_normal = False

final_feature_names_and_types_df = pd.read_csv(get_mds_params()['final_feature_names_and_types_df_csv_file_path'])

list_of_max_pval_and_star_count = [
    (0.001, 3),
    (0.01, 2),
    (0.05, 1),
]

list_of_attrs_to_plot = [
    [
        'mebemp_l_hla_sig_MEBEMP-L_0.5',
        # 'prss2_sig_CLP_0.5', # nah, though maybe maybe interesting? anyway very nice in bio reps.
        'bemp_cd74_sig_BEMP_0.5',
        's_phase_sig_MEBEMP-L_0.5',
        # 'gmp_l_elane_sig_GMP-L_0.5', # nah, though maybe maybe interesting? anyway looks ok in bio reps if considering only normal. N280 is an outlier there (and she was treated was EPO)
    ],
    [
        'donor_age',
        'Hemoglobin (g/dl)',
        'RBC (10^6/microliter)',
    ],
    [
        'RDW (%)',
        'MCV (fL)',
        'Hematocrit (%)',
    ],
    [
        'Neutro#',
        'Mono%',
        'WBC (10^3/microliter)',
        # 'mean_near_neighbor_dist', # TODO: maybe not show this? it is biased for atlas, because one neighbor is at distance=0... so if showing this, should not show atlas donors
    ],
    [
        ###'Lympho%',
        'Lympho#', # looks better than Lympho% here.
        'Platelets (10^3/microliter)',
        'log_c_clp_e',
    ],
]
all_attrs_to_plot = list(itertools.chain(*list_of_attrs_to_plot))


all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
curr_feature_df = all_ext_feature_df[all_ext_feature_df['in_mds_cyto_groups_analysis']]

curr_feature_df = generic_utils.merge_preserving_df1_index_and_row_order(curr_feature_df, pd.read_csv(get_mds_params()['fig4_donor_composition_class_df_csv_file_path'])[['numbered_donor_id', 'composition_class']], how='left')
assert (curr_feature_df.loc[curr_feature_df['composition_class'].isna(), 'diagnosis_class'] == 'normal').all()
curr_feature_df['composition_class'].fillna('normal diagnosis', inplace=True)


curr_feature_df['rough_composition_class'] = curr_feature_df['composition_class'].replace({
    'high CLP low MEBEMP': 'high CLP',
    'high BEMP': 'other abnorm comp',
    'only HSC/MPP': 'other abnorm comp',
})
curr_feature_df['rough_class_and_sex'] = curr_feature_df['rough_composition_class'].astype(str) + ' ' + curr_feature_df['donor_sex']
ordered_classes = [
    'normal diagnosis',
    'norm-like comp',
    'high CLP',
    # 'high CLP low MEBEMP',
    # 'high BEMP',
    # 'only HSC/MPP',
    'high GMP',
    'other abnorm comp',
]
ref_class = 'normal diagnosis'
class_and_sex_to_ref_class_and_sex = {}
for clas in ordered_classes:
    for sex in ('male', 'female'):
        clas_and_sex = f'{clas} {sex}'
        if clas != ref_class:
            class_and_sex_to_ref_class_and_sex[clas_and_sex] = f'{ref_class} {sex}'

ordered_class_and_sex = [
    'normal diagnosis male',
    'normal diagnosis female',
    'norm-like comp male',
    'norm-like comp female',
    'high CLP male',
    'high CLP female',
    'high GMP male',
    'high GMP female',
    'other abnorm comp male',
    'other abnorm comp female',
]
class_and_sex_to_pos_in_stripplot = {
    'normal diagnosis male': 0,
    'normal diagnosis female': 1,
    'norm-like comp male': 2.5,
    'norm-like comp female': 3.5,
    'high CLP male': 5,
    'high CLP female': 6,
    'high GMP male': 7.5,
    'high GMP female': 8.5,
    'other abnorm comp male': 10,
    'other abnorm comp female': 11,
}
male_female_vline_xs = [0.5, 3, 5.5, 8, 10.5]
class_vline_xs = [1.75, 4.25, 6.75, 9.25]
assert pd.Series([class_and_sex_to_pos_in_stripplot[x] for x in ordered_class_and_sex]).is_monotonic_increasing



final_feature_name_to_type = generic_utils.get_dict_mapping_one_df_column_to_other(final_feature_names_and_types_df, 'name', 'type')

flat_dicts = []
for class_col in ['composition_class', 'rough_composition_class']:
    print(class_col)
    for feature in sorted(set(final_feature_names_and_types_df['name']) - {'donor_sex'}):
        # print(feature)
        curr_df = curr_feature_df[curr_feature_df[feature].notna()]
        krusk_df = curr_df
        if not krusk_with_normal:
            krusk_df = krusk_df[krusk_df['diagnosis_class'] != 'normal']
        krusk_stat, krusk_pval = scipy.stats.kruskal(*krusk_df.groupby(class_col)[feature].apply(list))
        flat_dicts.append({'class_col': class_col, 'feature': feature, 'feature_type': final_feature_name_to_type[feature], 'test_type': 'krusk', 'krusk_stat': krusk_stat, 'pval': krusk_pval})
        if class_col == 'rough_composition_class':
            for clas_and_sex, ref_clas_and_sex in class_and_sex_to_ref_class_and_sex.items():
                clas_vals = curr_df.loc[curr_df['rough_class_and_sex'] == clas_and_sex, feature]
                ref_vals = curr_df.loc[curr_df['rough_class_and_sex'] == ref_clas_and_sex, feature]
                if (len(clas_vals) >= 1) and (len(ref_vals) >= 1):
                    mw_test_res = generic_utils.perform_mw_test(
                        clas_vals,
                        ref_vals,
                        alternative='two-sided',
                    )
                    flat_dicts.append({'class_col': class_col, 'clas_and_sex': clas_and_sex, 'feature': feature, 'feature_type': final_feature_name_to_type[feature], 'test_type': 'mw', 'pval': mw_test_res['pvalue']})

diff_feature_across_our_groups_df = pd.DataFrame(flat_dicts)
diff_feature_across_our_groups_df.sort_values('pval', inplace=True)
mw_pval_mask = diff_feature_across_our_groups_df['test_type'] == 'mw'
diff_feature_across_our_groups_df.loc[mw_pval_mask, 'mw_bh_adjusted_pval'] = scipy.stats.false_discovery_control(
    diff_feature_across_our_groups_df.loc[mw_pval_mask, 'pval'])


gene_prog_kruskal_rough_comp_df = diff_feature_across_our_groups_df[
    diff_feature_across_our_groups_df['feature_type'].isin(['gene_prog_excluding_mhc_ii', 'mhc_ii_gene_prog']) & 
    diff_feature_across_our_groups_df['class_col'].isin(['rough_composition_class']) & 
    diff_feature_across_our_groups_df['test_type'].isin(['krusk'])
]

within_state_sig_df = mds_analysis_params.get_within_state_sig_df()
minimal_gene_prog_kruskal_rough_comp_df = generic_utils.merge_preserving_df1_index_and_row_order(gene_prog_kruskal_rough_comp_df.drop(columns=['mw_bh_adjusted_pval', 'clas_and_sex', 'krusk_stat', 'feature_type', 'test_type', 'class_col']), within_state_sig_df)
minimal_gene_prog_kruskal_rough_comp_df['signature name'] = minimal_gene_prog_kruskal_rough_comp_df['sig_name'].replace(
    get_mds_params()['sig_name_to_name_in_paper'])
minimal_gene_prog_kruskal_rough_comp_df.rename(columns={'state': 'state within which signature was evaluated', 'pval': 'kruskal pvalue'}, inplace=True)


supp_table_out_dir_path = 'temp/__rejected_sup_tables'
pathlib.Path(supp_table_out_dir_path).mkdir(parents=True, exist_ok=True)
minimal_gene_prog_kruskal_rough_comp_df[['signature name', 'state within which signature was evaluated', 'kruskal pvalue']].to_csv(os.path.join(supp_table_out_dir_path, 'within_state_gene_signature_kruskal_across_fig4_groups.csv'), index=False)

if 0:
    diff_feature_across_our_groups_df[~diff_feature_across_our_groups_df['plotted_in_fig4_or_sup'] & ~diff_feature_across_our_groups_df['feature_type'].isin(['log_c_state_freq'])].head(40)
    diff_feature_across_our_groups_df[diff_feature_across_our_groups_df['plotted_in_fig4_or_sup']].head(40)




curr_feature_df['rough_class_and_sex_pos_in_stripplot'] = curr_feature_df['rough_class_and_sex'].map(class_and_sex_to_pos_in_stripplot)
np.random.seed(0)
curr_feature_df['jittered_rough_class_and_sex_pos_in_stripplot'] = curr_feature_df['rough_class_and_sex_pos_in_stripplot'] + np.random.uniform(-0.3, 0.3, len(curr_feature_df))


print(f'len(curr_feature_df): {len(curr_feature_df)}')
for fig_i, attrs_to_plot in enumerate(list_of_attrs_to_plot):

    plt.close('all')
    fig, axes = plt.subplots(ncols=len(attrs_to_plot), figsize=(13, 2.2))
    # fig_gridspec = fig.add_gridspec(
    #     num_of_rows, 1, 
    #     height_ratios=[
    #         0.2, # age, sex, diagnosis
            
    #         0.35, # space

            
    #         4, # HSPC composition
    #         0.4 if show_donor_ids else 0.13, # space (for donor_id)
    #         1, # mean_near_neighbor_dist
    #         0.13, # space (for donor_id)
            
    #         0.5, # vafs, s_phase, bemp ms4a2, hla
            
    #         0.1, # space
    #     ],
    # )

    if 1:
        fig.subplots_adjust(
            top=0.9,
            # left=0.065,
            left=0.05,
            # left=0.12, # fit MPN on 230820?
            # hspace=0.0000001,
            
            
            bottom=0.05,
            # wspace=0.2,
            # right=0.78, # fit MPN on 230820?
            right=0.99,
        )

    for i, (ax, attr) in enumerate(zip(axes, attrs_to_plot)):
        curr_curr_feature_df = curr_feature_df[curr_feature_df[attr].notna()]
        curr_observed_class_and_sex = set(curr_curr_feature_df['rough_class_and_sex'].tolist())
        curr_ordered_class_and_sex = [x for x in ordered_class_and_sex if x in curr_observed_class_and_sex]
        assert curr_ordered_class_and_sex == ordered_class_and_sex
                
                # raise
        assert curr_curr_feature_df['donor_id'].is_unique
        
        if 0:
            sb.stripplot(
                data=curr_curr_feature_df, 
                # x='rough_class_and_sex', 
                x='rough_class_and_sex_pos_in_stripplot', 
                dodge=False,
                jitter=0.3,
                linewidth=0.5,
                # order=curr_ordered_class_and_sex,
                y=attr, 
                ax=ax,
                hue='diagnosis_class', 
                palette=clinical_data_interface_and_utils.DIAGNOSIS_CLASS_TO_COLOR,
                s=4,
            )
        else:
            
            sb.scatterplot(
                data=curr_curr_feature_df, 
                x='jittered_rough_class_and_sex_pos_in_stripplot', 
                y=attr, 
                ax=ax,
                hue='diagnosis_class', 
                palette=clinical_data_interface_and_utils.DIAGNOSIS_CLASS_TO_COLOR,
                edgecolor='k',
                s=20,
            )
            ax.get_xaxis().set_visible(False)

        orig_y_lim = ax.get_ylim()
        y_lim_range = orig_y_lim[1] - orig_y_lim[0]
        ax.set_ylim(orig_y_lim[0], orig_y_lim[1] + 0.07 * y_lim_range)
        
        for clas_and_sex in curr_ordered_class_and_sex:
            if clas_and_sex in class_and_sex_to_ref_class_and_sex:
                mw_test_res = generic_utils.perform_mw_test(
                    curr_curr_feature_df.loc[curr_curr_feature_df['rough_class_and_sex'] == clas_and_sex, attr],
                    curr_curr_feature_df.loc[curr_curr_feature_df['rough_class_and_sex'] == class_and_sex_to_ref_class_and_sex[clas_and_sex], attr],
                    alternative='two-sided',
                )
                curr_pval = mw_test_res['pvalue']

                row_in_diff_feature_df = diff_feature_across_our_groups_df[
                    (diff_feature_across_our_groups_df['test_type'] == 'mw') &
                    (diff_feature_across_our_groups_df['feature'] == attr) & 
                    (diff_feature_across_our_groups_df['clas_and_sex'] == clas_and_sex) 
                ]
                assert len(row_in_diff_feature_df) == 1
                row_in_diff_feature_df = row_in_diff_feature_df.iloc[0]
                assert np.isclose(row_in_diff_feature_df['pval'], curr_pval)
                assert np.isclose(row_in_diff_feature_df['pval'], curr_pval)
                bh_adjusted_pval = row_in_diff_feature_df['mw_bh_adjusted_pval']

                # raise
                # assert  == final_feature_name_to_type[attr]

                for max_pval, star_count in list_of_max_pval_and_star_count:
                    if bh_adjusted_pval < max_pval:
                        x_pos = class_and_sex_to_pos_in_stripplot[clas_and_sex]
                        x_pos = x_pos + 0.125 * (1 if clas_and_sex.endswith('female') else -1) 
                        ax.annotate(
                            # '$' + '*' * star_count + '$',
                            # '$*$' * star_count,
                            '*' * star_count,
                            # '$' + '\\star' * star_count + '$', # looks bad for some reason
                            xy=(x_pos, 0.944), xycoords=('data', 'axes fraction'), color='k', ha='center', va='center', fontsize='large',
                        )
                        
                        break

        
        
        ax.set_ylabel(None)
        ax.set_xlabel(None)
        ax.set_title(attr)
        if 0:
            ax.set_xticklabels(
                ax.get_xticklabels(),
                rotation=90, ha='center', 
                # rotation_mode='anchor', 
                # fontsize='small',
            )
        for x in male_female_vline_xs:
            ax.axvline(x, color='grey', alpha=0.5, linestyle='--')
        for x in class_vline_xs:
            ax.axvline(x, color='black', alpha=0.9, linestyle='-')

        # if plot_legend:
        #     legend = ax.legend(fontsize='x-large')
        #     legend.get_frame().set_alpha(None)
        ax.get_legend().remove()
    # fig.tight_layout()
    fig.savefig(
        os.path.join(out_dir_path, f'fig_4EG_EDF_10C_sig_stripplots{fig_i}.png'), 
        dpi=300,
    )
    # break

print('top mw gene sigs:')
print(diff_feature_across_our_groups_df[mw_pval_mask & diff_feature_across_our_groups_df['feature_type'].isin(['gene_prog_excluding_mhc_ii', 'mhc_ii_gene_prog'])].head(15))

In [None]:
# total donor counts

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])

latest_diagnosis_df = all_ext_feature_df[['donor_id', 'latest_diagnosis_class']].drop_duplicates()
print('all latest diagnosis counts')
print(latest_diagnosis_df['latest_diagnosis_class'].value_counts())
# latest_diagnosis_df.to_csv('donor_id_latest_diagnosis_class.csv', index=False)

print('fig2b latest diagnosis counts')
print(all_ext_feature_df[all_ext_feature_df['in_mds_cyto_groups_analysis']]['latest_diagnosis_class'].value_counts())

In [None]:
# bio reps

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
curr_feature_df = all_ext_feature_df[all_ext_feature_df['in_mds_cyto_groups_analysis']]
print(curr_feature_df['diagnosis_class'].value_counts())

all_ext_feature_df['bleeding_date_as_date'] = pd.to_datetime(all_ext_feature_df['bleeding_date'], format='%d.%m.%y', errors='raise')
uniq_donor_bleeding_date_df = all_ext_feature_df.loc[all_ext_feature_df['diagnosis_class'].isin(['cytopenia', 'MDS', 'MDS/MPN']), ['donor_id', 'bleeding_date_as_date']].drop_duplicates()
unhealthy_with_bio_rep_donor_ids = list((uniq_donor_bleeding_date_df['donor_id'].value_counts() > 1).loc[lambda x: x].index)
print(f'len(unhealthy_with_bio_rep_donor_ids): {len(unhealthy_with_bio_rep_donor_ids)}')
uniq_patient_bleeding_date_df = uniq_donor_bleeding_date_df[uniq_donor_bleeding_date_df['donor_id'].isin(unhealthy_with_bio_rep_donor_ids)]
print(uniq_patient_bleeding_date_df['donor_id'].value_counts())

In [None]:
out_dir_path = 'temp/_fig4'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])

comp_class_to_num = {
    'norm-like comp': 1,
    'only HSC/MPP': 5,
    'high GMP': 6,
    'high BEMP': 4,
    'high CLP': 2,
    'high CLP low MEBEMP': 3,
    'other abnormal composition': 7,
}

SHOW_COMP_CLASS = True
SHOW_COMP_CLASS = False

SHOW_DONOR_INFO = True
SHOW_DONOR_INFO = False

PLOT_LINEAR_FIT = False
PLOT_LINEAR_FIT = True

PLOT_Y_EQUALS_X = True
PLOT_Y_EQUALS_X = False

FORCE_LOWER_LIM_ZERO = True
FORCE_LOWER_LIM_ZERO = False

HORIZ_LINE_YS = [
    # np.log2(0.05 + 0.005), # 0.05 is the BM FACS blast clinical threshold. 0.005 is the epsilon for log we used.
]

# NOTE: N413 (?) in demux_n_18_12_23_2 is discarded due to biased_composition_due_to_discarded_low_umi_count_barcodes
    
plt.close('all')
curr_df = all_ext_feature_df.copy()

curr_df = curr_df[curr_df['in_blast_vs_clp_e_scatter']]

curr_df['c_clp_e_count'] = curr_df['c_clp_e'] * curr_df['final_c_hspc_count']

curr_df = generic_utils.merge_preserving_df1_index_and_row_order(curr_df, pd.read_csv(get_mds_params()['fig4_donor_composition_class_df_csv_file_path']), how='left', allow_multiple_common_column_names=True)
curr_df['c_clp_e_count'] = curr_df['c_clp_e'] * curr_df['final_c_hspc_count']

# y_column_name = 'final_c_hspc_count'
# y_column_name = 'c_clp_e_count'
# x_column_name = 'c_clp_e_count'
# x_column_name = 'c_clp_e'

y_column_name = 'log_bm_FACS_blast_frac'
x_column_name = 'log_c_clp_e'

palette = clinical_data_interface_and_utils.DIAGNOSIS_CLASS_TO_COLOR.copy()

x_or_y_nan_mask = curr_df[x_column_name].isna() | curr_df[y_column_name].isna()
assert not x_or_y_nan_mask.any()

fig, ax = plt.subplots(figsize=(5,4))
sb.scatterplot(
    # data=curr_composition_and_clinical_df,
    data=curr_df,
    x=x_column_name,
    y=y_column_name,
    
    palette=palette,
    hue='diagnosis_class',
    # hue='in_mds_cyto_groups_analysis',

    # legend=False,
    ax=ax,
)

if SHOW_COMP_CLASS:
    for _, row in curr_df[curr_df['composition_class'].notna()].iterrows():
        if (row['log_c_clp_e'] > -3) or (row['log_bm_FACS_blast_frac'] > -5):
            ax.annotate(' ' + str(comp_class_to_num[row['composition_class']]), row[[x_column_name, y_column_name]])

if SHOW_DONOR_INFO:
    for _, row in curr_df.iterrows():
        
        numbered_donor_id = row['numbered_donor_id']
        numbered_donor_id_without_tech_rep_suffix = row['numbered_donor_id_without_tech_rep_suffix']
        
        curr_x_and_y = curr_df.loc[curr_df['numbered_donor_id'] == numbered_donor_id, [x_column_name, y_column_name]].iloc[0].to_numpy()
        extended_numbered_donor_id = numbered_donor_id
        extended_numbered_donor_id = numbered_donor_id_without_tech_rep_suffix
        # extended_numbered_donor_id += '*'
        ax.annotate(extended_numbered_donor_id, curr_x_and_y)
    
    # generic_utils.plot_y_equals_x_line_on_ax(ax)


if FORCE_LOWER_LIM_ZERO:
    x_lims = ax.get_xlim()
    y_lims = ax.get_ylim()
    if (curr_df[x_column_name] >= 0).all():
        ax.set_xlim(0, x_lims[1])
    if (curr_df[y_column_name] >= 0).all():
        ax.set_ylim(0, y_lims[1])

if PLOT_Y_EQUALS_X:
    generic_utils.plot_y_equals_x_line_on_ax(ax=ax)

if PLOT_LINEAR_FIT:
    slope, intercept, r_value, p_value, std_err_of_estimated_gradient = scipy.stats.linregress(
        curr_df[x_column_name], curr_df[y_column_name],
        # alternative='two-sided',
    )
    generic_utils.plot_line_on_ax(ax, slope=slope, offset=intercept, color='grey', alpha=0.5, linestyle='--')
    linear_fit_repr = f'r={r_value:.2f}, pval={p_value:.0E}'
    print(f'slope={slope}, intercept={intercept}')

sb.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

for y in HORIZ_LINE_YS:
    ax.axhline(y, color='red', alpha=0.2)
    
ax_title = f'n={len(curr_df)}'
if PLOT_LINEAR_FIT:
    ax_title += f', {linear_fit_repr}'
ax.set_title(ax_title)

fig.tight_layout()
fig.savefig(os.path.join(out_dir_path, 'fig_4H_blast_vs_clp_e.png'), dpi=300)

In [None]:
out_dir_path = 'temp/__sup_tables'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
mut_df = arch_mutation_interface_and_utils.get_minimal_all_mutation_df()
mut_df['nimrod_compatible_donor_id'] = mut_df['donor_id'].replace(get_sc_rna_seq_preprocessing_params()['240723_donor_id_i_use_to_nimrod_donor_id'])
# raise
assert (all_ext_feature_df['exp_date'] == all_ext_feature_df['bleeding_date']).all()
mut_df['bleeding_date'] = mut_df['exp_date']
mut_df = generic_utils.merge_preserving_df1_index_and_row_order(
    mut_df, pd.read_csv(get_sc_rna_seq_preprocessing_params()['blood_sample_id_df_csv_file_path']), how='left', on=['donor_id', 'bleeding_date'])
mut_df = mut_df[mut_df['blood_sample_id'].notna()]
mut_df['blood_sample_id'] = mut_df['blood_sample_id'].astype(int)
mut_df['CHR'] = mut_df['CHR'].astype(str)
mut_df['POS'] = mut_df['POS'].astype(int)

mut_df = mut_df[[
    'blood_sample_id',
    'nimrod_compatible_donor_id',
    'gene', 
    'CHR', 'POS', 'REF', 'ALT', 
    'mean_VAF', 
]].drop_duplicates().sort_values('blood_sample_id')

print(len(mut_df))
mut_df = mut_df[~mut_df['blood_sample_id'].isin(mds_analysis.get_unreliable_arch_blood_sample_ids(all_ext_feature_df))]
print(len(mut_df))
mut_df.rename(columns=get_mds_params()['col_name_to_name_in_paper']).to_csv(os.path.join(out_dir_path, 's11_fig4_CH_mutations.csv'), index=False)

manually_copied_ones_df = mut_df[mut_df[['CHR', 'POS', 'mean_VAF']].duplicated(keep=False)].sort_values('nimrod_compatible_donor_id')
if not manually_copied_ones_df.empty:
    print(f'\nmanually copied ones, i guess:\n{manually_copied_ones_df}')

In [None]:
out_dir_path = 'temp/__sup_tables'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

curr_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
curr_df = curr_df[curr_df['cna_count'] >= 1]

donor_with_cnas_ids = set(curr_df['donor_id'])

curr_df = curr_df[['nimrod_compatible_donor_id', 'donor_id', 'exp_name', 'blood_sample_id']].drop_duplicates()

cna_df = pd.read_csv(get_mds_params()['all_cna_attrs_df_csv_file_path'])
cna_df = cna_df[cna_df['donor_general_manual_comment'].isna()][['donor_id', 'exp_name', 'chrom', 'whole_chrom', 'first_gene_end_pos_in_chrom', 'last_gene_end_pos_in_chrom', 'is_dup', 'is_del', 'manual_comment_for_blood_aging_paper']]
non_exp_specific_cna_df = cna_df[cna_df['exp_name'] == 'all'].drop(columns='exp_name')
non_exp_specific_cna_donor_ids = set(non_exp_specific_cna_df['donor_id'])

exp_specific_cna_df = cna_df[cna_df['exp_name'] != 'all']
exp_specific_cna_donor_ids = set(exp_specific_cna_df['donor_id'])

assert not (non_exp_specific_cna_donor_ids & exp_specific_cna_donor_ids)
assert (non_exp_specific_cna_donor_ids | exp_specific_cna_donor_ids) >= set(curr_df['donor_id'])

curr_df = pd.concat([
    curr_df[curr_df['donor_id'].isin(list(non_exp_specific_cna_donor_ids))].merge(non_exp_specific_cna_df),
    curr_df[curr_df['donor_id'].isin(list(exp_specific_cna_donor_ids))].merge(exp_specific_cna_df),
], ignore_index=True)

assert set(curr_df['donor_id']) == donor_with_cnas_ids

assert (curr_df['is_dup'] | curr_df['is_del']).all()

curr_df['CNA type'] = curr_df['is_dup'].replace({True: 'duplication', False: 'deletion'})
curr_df.loc[curr_df['whole_chrom'], 'CNA type'] = 'whole chromosome ' + curr_df.loc[curr_df['whole_chrom'], 'CNA type']

curr_df.loc[curr_df['whole_chrom'], ['first_gene_end_pos_in_chrom', 'last_gene_end_pos_in_chrom']] = np.nan

curr_df = curr_df[[
    'blood_sample_id',
    'nimrod_compatible_donor_id',
    
    'chrom',
    # 'whole_chrom',
    'first_gene_end_pos_in_chrom',
    'last_gene_end_pos_in_chrom',
    'CNA type',
    'manual_comment_for_blood_aging_paper',
]].drop_duplicates().sort_values(['blood_sample_id', 'chrom', 'first_gene_end_pos_in_chrom'])
curr_df.rename(columns=get_mds_params()['col_name_to_name_in_paper']).to_csv(os.path.join(out_dir_path, 's12_fig4_CNAs.csv'), index=False)

assert len(curr_df.drop(columns=['blood_sample_id', 'first_gene_end_pos_in_chrom', 'last_gene_end_pos_in_chrom']).drop_duplicates()) == 38 # # 240729: that's the number in EDF 9C


In [None]:
out_dir_path = 'temp/__sup_tables'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

skip_cell_filtering_count_cols = True
skip_cell_filtering_count_cols = False

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])

all_ext_feature_df['treatment'].replace({x: np.nan for x in clinical_data_interface_and_utils.UNKNOWN_TREATMENT_VALUES}, inplace=True)

all_ext_feature_df['BM FACS blasts (%)'] = all_ext_feature_df['bm_FACS_blast_frac'] * 100
all_ext_feature_df['FACS_BM_date'] = pd.to_datetime(all_ext_feature_df['FACS_BM_date'])
all_ext_feature_df['facs_bm_date_repr'] = all_ext_feature_df.apply(
    lambda row: str(row["FACS_BM_date"].year) + ('' if row['fake_FACS_BM_month'] else ('-' + str(row["FACS_BM_date"].month))), axis=1)

with_vaf_classifier_final_feature_names_and_types_df = pd.read_csv(get_mds_params()['with_vaf_classifier_final_feature_names_and_types_df_csv_file_path'])
with_vaf_classifier_final_feature_names_without_cbc = with_vaf_classifier_final_feature_names_and_types_df.loc[
    with_vaf_classifier_final_feature_names_and_types_df['type'] != 'CBC', 'name'].tolist()


if skip_cell_filtering_count_cols:
    count_cols = []
    curr_df = all_ext_feature_df.copy()
else:
    cell_filtering_count_dfs = []
    prev_count_cols = None
    for dataset_name in ['illu_mds', 'ult_mds']:
        mds_path_dict = mds_analysis_params.get_mds_path_dict(dataset_name=dataset_name)
        if not os.path.isfile(mds_path_dict['assigned_and_no_doublet_enirhced_mc_c_ad_file_path']):
            raise RuntimeError('you can generate the table excluding cell filtering count cols by setting skip_cell_filtering_count_cols=True')
        count_df, count_cols = mds_analysis.write_and_get_cell_filtering_count_df(dataset_name, use_existing=True)
        if prev_count_cols is not None:
            assert count_cols == prev_count_cols
        else:
            prev_count_cols = count_cols
        cell_filtering_count_dfs.append(count_df)
    cell_filtering_count_df = pd.concat(cell_filtering_count_dfs, ignore_index=True)

    curr_df = generic_utils.merge_preserving_df1_index_and_row_order(all_ext_feature_df, cell_filtering_count_df)

curr_df['seq_platform'] = curr_df['is_ultima'].replace({True: 'Ultima', False: 'Illumina'})
assert curr_df['seq_platform'].isin(['Ultima', 'Illumina']).all()

scrna_sample_df = curr_df[[
    'blood_sample_id',
    'nimrod_compatible_donor_id',
    'donor_sex',
    'donor_age',
    'bleeding_date_as_date',
    'diagnosis_class',
    'treatment',
    'scrna_exp_id',
    'seq_platform',
    'in_fig4_ref_model',
    'in_mds_cyto_groups_analysis',
    'in_train_set',
    'in_test_set',
    'in_blast_vs_clp_e_scatter',
    'used_to_fit_hspc_compos_knn',
    'in_orig_ref_model',

    'facs_bm_date_repr',
    'BM FACS blasts (%)',
    'cbc_date',
    *all_ext_feature_df.columns[all_ext_feature_df.columns.isin(list(clinical_data_interface_and_utils.CBC_COL_TO_PUBLISH_COL))],
    
    *with_vaf_classifier_final_feature_names_without_cbc,

    'final_c_hspc_count',
    *[x for x in count_cols if x not in {'mouse_human_doublet_count', 'not_whitelist_barcode_count', 'dc_nkt_monocyte_endothel_b_lamp3_count', 'presumably_hspc_count'}],

]].sort_values(['blood_sample_id', 'scrna_exp_id'])

assert len(scrna_sample_df.columns) == len(set(scrna_sample_df.columns))

# curr_name = c_ad_file_path_to_name[c_ad_file_path]

assert set(with_vaf_classifier_final_feature_names_and_types_df['name']) <= set(scrna_sample_df.columns)

scrna_sample_df.rename(columns={
    **mds_analysis_params.get_gene_sig_score_to_name_in_paper(), 
    **get_mds_params()['col_name_to_name_in_paper'],
    **clinical_data_interface_and_utils.CBC_COL_TO_PUBLISH_COL,
}).to_csv(os.path.join(out_dir_path, 's9_fig4_scRNA_sample_metadata.csv'), index=False)

# df.to_csv(os.path.join(out_dir_path, 'asdf.csv'), index=False)

In [None]:
out_dir_path = 'temp/__sup_tables'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

sig_name_to_name_in_paper = get_mds_params()['sig_name_to_name_in_paper']
sig_names = list(sig_name_to_name_in_paper)
all_genes = sorted(set(itertools.chain.from_iterable(mds_analysis_params.PB_SIG_NAME_TO_GENES[x] for x in sig_names)))
within_state_sig_df = mds_analysis_params.get_within_state_sig_df()
within_state_sig_names = set(within_state_sig_df['sig_name'])

df = pd.DataFrame({'gene': all_genes})

for sig_name in sig_names:
    sig_genes = mds_analysis_params.PB_SIG_NAME_TO_GENES[sig_name]
    in_silico_sig = not (sig_name in within_state_sig_names)
    col = sig_name_to_name_in_paper[sig_name]
    if in_silico_sig:
        col += ' (used for in-silico sorting)'
    
    df[col] = df['gene'].isin(sig_genes)

df.rename(columns=get_mds_params()['col_name_to_name_in_paper']).to_csv(os.path.join(out_dir_path, 's13_fig4_gene_signatures.csv'), index=False)


In [None]:
out_dir_path = 'temp/__sup_tables'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

unified_mc_model_name = 'final_mds_cyto_normal_excluding_atlas'

all_genes = {
    *lateral_and_noisy_genes.ULT_ILL_PROBLEMATIC_GENES,
    *get_mds_params()['noisy_gene_names']['sex_diff_expressed'],
}

ult_ill_problematic_or_sex_diff_genes = set(lateral_and_noisy_genes.ULT_ILL_PROBLEMATIC_GENES) | set(get_mds_params()['noisy_gene_names']['sex_diff_expressed'])
print(f'len(ult_ill_problematic_or_sex_diff_genes): {len(ult_ill_problematic_or_sex_diff_genes)}')

print(f'len(lateral_and_noisy_genes.ULT_ILL_PROBLEMATIC_GENES): {len(set(lateral_and_noisy_genes.ULT_ILL_PROBLEMATIC_GENES))}')
print(f"len(get_mds_params()['noisy_gene_names']['sex_diff_expressed']): {len(set(get_mds_params()['noisy_gene_names']['sex_diff_expressed']))}")

col_to_genes = {}
desc_to_illu_and_ult_genes = collections.defaultdict(set)
for gene_desc in [
    'genes_excluded_from_cna_due_to_consistent_obs_exp_diff',
    'genes_excluded_from_cna_due_to_low_expr',
    'genes_excluded_from_cna_due_to_missing_position_or_on_y',
]:
    for dataset_name in ('ult_mds', 'illu_mds'):
        unified_mc_model_paths = mds_analysis_params.get_mc_model_paths(unified_mc_model_name, dataset_name)
        curr_genes = generic_utils.read_text_file(
            unified_mc_model_paths[gene_desc]).splitlines()
        col_to_genes[f'{dataset_name}_{gene_desc}'] = curr_genes
        
        all_genes |= set(curr_genes)
        desc_to_illu_and_ult_genes[gene_desc] |= set(curr_genes)
desc_to_illu_and_ult_genes = dict(desc_to_illu_and_ult_genes) # I don't want a defaultdict moving around.
for gene_desc, genes in desc_to_illu_and_ult_genes.items():
    print(f'{gene_desc}: {len(set(genes))}')

df = pd.DataFrame({'gene': sorted(all_genes)})
df['sex-specific'] = df['gene'].isin(get_mds_params()['noisy_gene_names']['sex_diff_expressed'])
df['sequencing platform-specific'] = df['gene'].isin(lateral_and_noisy_genes.ULT_ILL_PROBLEMATIC_GENES)
for col, genes in col_to_genes.items():
    df[col] = df['gene'].isin(genes)



df.rename(columns={
    **get_mds_params()['col_name_to_name_in_paper'],
    'ult_mds_genes_excluded_from_cna_due_to_consistent_obs_exp_diff': 'consistent difference between observed and projected (Ultima)',
    'ult_mds_genes_excluded_from_cna_due_to_low_expr': 'lowly expressed (Ultima)',
    'ult_mds_genes_excluded_from_cna_due_to_missing_position_or_on_y': 'missing chromosomal coordinates or on Y chromosome (Ultima)',
    'illu_mds_genes_excluded_from_cna_due_to_consistent_obs_exp_diff': 'consistent difference between observed and projected (Illumina)',
    'illu_mds_genes_excluded_from_cna_due_to_low_expr': 'lowly expressed (Illumina)',
    'illu_mds_genes_excluded_from_cna_due_to_missing_position_or_on_y': 'missing chromosomal coordinates or on Y chromosome (Illumina)',
}).to_csv(os.path.join(out_dir_path, 's14_fig4_genes_excluded_from_CNA_analysis.csv'), index=False)


In [None]:
out_dir_path = 'temp/_fig_sup_4/N200'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

karyo_out_dir_path = mds_analysis_params.get_mds_path_dict(dataset_name='ult_mds')['karyotype_estimation_out_dir_path']

for exp_name, N200_del_5q_thresh, N200_dup_8_thresh in [
    ('demux_26_03_23_1', -0.2, 0.24),
    ('demux_11_04_21_2', -0.2, 0.18),
    ('demux_10_01_22_1', -0.25, 0.21),
]:
    mc_ad_file_path = mds_analysis_params.get_mc_model_paths(f'final_only_N200_only_{exp_name}', dataset_name='ult_mds')['metacells_with_projection_ad']
    mc_ad = ad.read_h5ad(mc_ad_file_path)
    df_csv_file_path = f'{karyo_out_dir_path}/N200/{exp_name}/corrected_by_gc/mc_hists/cna_median_normalized_projected_fold_df.csv'
    plt.close('all')
    df = pd.read_csv(df_csv_file_path)
    assert len(df) == mc_ad.n_obs

    hspc_mask = mc_ad.obs['state'].isin(mds_analysis_params.PB_HSPC_STATE_NAMES).to_numpy()
    df = df[hspc_mask]
    # print(hspc_mask.mean())

    # raise
    fig, ax = plt.subplots()
    sb.scatterplot(df, x='N200_del_5q', y='N200_dup_8', hue='N200_del_3p_end_and_3q_start', s=15, ax=ax)
    ax.axvline(N200_del_5q_thresh, color='red', alpha=0.3)
    ax.axhline(N200_dup_8_thresh, color='red', alpha=0.3)

    mutated_mc_frac = ((df['N200_del_5q'] < N200_del_5q_thresh) & (df['N200_dup_8'] > N200_dup_8_thresh)).mean()
    ax.set_title(f'{exp_name}, mutated_mc_frac: {mutated_mc_frac:.3f}')
    sb.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
    fig.tight_layout()
    fig.savefig(os.path.join(out_dir_path, f'for_EDF_12A_{exp_name}_mutated_mc_by_del_5q_and_dup_8_scatter.png'), dpi=300)
    

In [None]:
out_dir_path = f'temp/_fig_sup_4/CBC_history_case_studies'
pathlib.Path(out_dir_path).mkdir(parents=True, exist_ok=True)

all_ext_feature_df = pd.read_csv(get_mds_params()['all_ext_donor_feature_df_csv_file_path'])
all_ext_feature_df['bleeding_date_as_date'] = pd.to_datetime(all_ext_feature_df['bleeding_date'], format='%d.%m.%y', errors='raise')
cbc_df = pd.read_csv(get_sc_rna_seq_preprocessing_params()['donor_table_paths']['minimal_cbc_df_csv']).drop_duplicates()
cbc_df['cbc_date'] = pd.to_datetime(cbc_df['cbc_date'], errors='raise')

for donor_id in ('N200', 'N211'):
    curr_df = cbc_df[cbc_df['donor_id'] == donor_id]
    tenx_dates = all_ext_feature_df.loc[all_ext_feature_df['donor_id'] == donor_id, 'bleeding_date_as_date']

    for y_col in (
        'Hemoglobin (g/dl)', 
        # 'RDW (%)',
    ):
        plt.close('all')
        fig, ax = plt.subplots(figsize=(15,3))
        sb.scatterplot(curr_df, x='cbc_date', y=y_col, ax=ax)
        # ax.set_xticklabels(
        #     ax.get_xticklabels(),
        #     rotation=45, ha='right', 
        #     rotation_mode='anchor', 
        #     # fontsize='small',
        # )
        ax.set_title(donor_id)
        ax.set_xlabel(None)
        for x in tenx_dates:
            print(x)
            ax.axvline(x, color='grey', alpha=0.5)
        fig.tight_layout()
        fig.savefig(os.path.join(out_dir_path, f'EDF_12AB_{donor_id}_{generic_utils.strip_special_chars_on_edges_and_replace_others_with_underscores(y_col)}_history.png'), dpi=300)
