## Contents:
* [Loading packages](#Loading_packages)
* [Reading cell2location model output](#read)


* [Fig 4C](#Fig4C)

* [Fig 4F - B IFN & co-located cells](#Fig4F)
* [Fig 4G - binned pseudospace dotplot](#Fig4G) 


* [Read saved sklearn colocated cell group model](#read_nmf)
* [Fig 4D - co-located cell groups - dotplot](#Fig4D_dotplot)
* [Fig 4D - Supplementary - matching clustering - dotplot](#Fig4D_suppl_clust)
* [Fig 4D - factor spatial](#Fig4D_spatial)

### Loading packages <a class="anchor" id="Loading_packages"></a>

In [None]:
# Import modules and packages:
%pylab inline
import pandas as pd
import sys, ast, os
rcParams['axes.spines.right'] = False
rcParams['axes.spines.top'] = False
import pickle as pickle
import numpy as np
import time
import itertools
data_type = 'float32'

os.environ["THEANO_FLAGS"] = 'device=cuda,floatX=' + data_type + ',force_device=True'
sys.path.insert(1, '/nfs/team205/vk7/sanger_projects/BayraktarLab/cell2location/')
sys.path.insert(1, '/nfs/team205/vk7/sanger_projects/cell2location_dev/')

%matplotlib inline
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
import os
import cell2location.models as c2l
import anndata
import scanpy as sc

from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text

### Reading cell2location model output <a class="anchor" id="read"></a>

In [None]:
# read cell2location model results
results_folder = '/nfs/team205/vk7/sanger_projects/cell2location_paper/notebooks/results/visium_human_ln/'
sc_data_folder = '/nfs/team205/vk7/sanger_projects/cell2location_proj/notebooks/data/b_t_cells_tonsils_hk/'

run_name = 'CoLocationModelNB4V2_34clusters_4039locations_10241genes_input_inferred_V4_batch1024_l2_0001_n_comb50_5_cps5_fpc3_alpha001'

# path for saving figures
fig_path = results_folder + 'std_model/'+run_name+'/plots/figures/'
sc_figpath = sc.settings.figdir
sc.settings.figdir = f'{fig_path}'


sp_data_file = results_folder + 'std_model/'+run_name+'/sp_with_clusters.h5ad'
adata_vis = anndata.read(sp_data_file)
adata_vis.obs['sample'] = list(adata_vis.uns['spatial'].keys())[0]

In [None]:
from cell2location.plt.plot_heatmap import clustermap
comb2fact_df = pd.DataFrame(adata_vis.uns['mod']['post_sample_means']['comb2fact'],
                            columns=adata_vis.uns['mod']['fact_names'])

rcParams["figure.figsize"] = [6, 5]
rcParams["axes.facecolor"] = "white"
plt.hist(np.log10(comb2fact_df.max(1)), bins=50);
thresh = 0.15
plt.vlines(np.log10(thresh), 0, 50);
fact_filt = comb2fact_df.max(1) > thresh

comb2fact_df = comb2fact_df.loc[fact_filt.values, :]

clustermap(comb2fact_df, figure_size=(17, 7))

comb2fact_df_prop = (comb2fact_df / comb2fact_df.sum(0))
clustermap(comb2fact_df_prop, figure_size=(17, 7))

n_combs = adata_vis.uns['mod']['post_sample_means']['combs_factors'].shape[1]
combs_factors = pd.DataFrame(adata_vis.uns['mod']['post_sample_means']['combs_factors'],
                             columns=[f'combs_{i}' for i in range(n_combs)],
                             index=adata_vis.uns['mod']['obs_names'])
combs_factors = combs_factors.loc[:, fact_filt.values]
adata_vis.obs[combs_factors.columns] = combs_factors


s = 'V1_Human_Lymph_Node'
rcParams["axes.facecolor"] = "black"
sc.pl.spatial(adata_vis[adata_vis.obs['sample'].isin([s]),:], cmap='magma',
              color=combs_factors.columns, # limit size in this notebook
              library_id=f'{s}',
              ncols=3, 
              size=1, img_key='hires', 
              alpha_img=0,
              vmin=0, vmax='p99.0'
             )

In [None]:
import matplotlib.patches as patches
from mpl_toolkits.axes_grid.inset_locator import inset_axes as inset_axes_func


def add_rectange_to_axis(ax, crop_coord, **kwargs):
    width = crop_coord[1] - crop_coord[0]
    height = crop_coord[2] - crop_coord[3]
    bottom_left = (crop_coord[0], crop_coord[3])
    
    kwargs = {'linewidth': 3,
             'edgecolor': 'r',
             'facecolor': 'none',
             'zorder': 10,
             **kwargs}
    
    rect = patches.Rectangle(bottom_left, width, height, **kwargs)
    ax.add_patch(rect)
    

def add_rectangle_to_fig(fig, crop_coord, **kwargs):
    for ax in fig.axes:
        if type(ax.get_aspect()) is str or ax.get_aspect() < 5:
            add_rectange_to_axis(ax, crop_coord, **kwargs)
            
def plot_1D_posterior(adata, x, ct_list, xlabel='', 
                      ylabel='Inferred cell density',
                      param='spot_factors', shaded_alpha=0.5,
                      show_points=False, point_size=0.5,
                      reorder_cmap=None, vlines=None):
    r""" Plot posterior of cell types and factors along some 1D gradient (e.g. diffusion pseudospace).
    This is done by computing LOESS function of the location-specific cell type density 
    and factor expression parameters (up to 7). Solid line shows the LOESS-smoothed function of the posterior mean, 
    shaded areas highlight 5% and 95% posterior quantiles highlighting cell type mapping uncertainty.
    """
    
    colors=[( 240/256, 228/256, 66/256),
            ( 213/256, 94/256, 0/256),
            ( 86/256, 180/256, 233/256),
            ( 0/256,158/256, 115/256),
            'purple',
            ( 200/256, 200/256, 200/256),
            ( 50/256, 50/256, 50/256)
           ]
    if reorder_cmap is not None:
        colors = [colors[i] for i in reorder_cmap]

    from statsmodels.nonparametric.smoothers_lowess import lowess

    for i, ct in enumerate(ct_list):
    
        y1_mean=adata.obs['mean_' + param + ct]
        y1_mean_s=lowess(y1_mean, x, return_sorted=False)
        y1_q05=adata.obs['q05_' + param + ct]
        y1_q05_s=lowess(y1_q05, x, return_sorted=False)
        y1_q95=adata.obs['q95_' + param + ct]
        y1_q95_s=lowess(y1_q95, x, return_sorted=False)

        if show_points:
            plt.scatter(x, y1_mean, color=colors[i], s=point_size)
        plt.plot(x, y1_mean_s, color=colors[i])
        plt.fill_between(x, y1_q05_s, y1_q95_s, color=colors[i], alpha=shaded_alpha)
    if vlines is not None:
        plt.vlines(vlines, 0, np.max(y1_q95_s) + 0.05 * np.max(y1_q95_s))
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)


### Fig 4C <a class="anchor" id="Fig4C"></a>

In [None]:
from cell2location.plt.mapping_video import plot_spatial
# select up to 6 clusters 
sel_clust = ['FDC', 'T_CD4+_naive']
sel_clust_col = ['q05_spot_factors' + str(i) for i in sel_clust]

crop_x = [230, 1640]
crop_y = [1760, 150]
crop_x_small = [810, 1060]
crop_y_small = list(adata_vis.uns['spatial'].values())[0]['images']['hires'].shape[0] - np.array([1500, 1850])

rcParams["figure.figsize"] = [25,25]
fig = plot_spatial(adata_vis.obs[sel_clust_col], labels=sel_clust,
              coords=adata_vis.obsm['spatial'] * list(adata_vis.uns['spatial'].values())[0]['scalefactors']['tissue_hires_scalef'], 
              img=list(adata_vis.uns['spatial'].values())[0]['images']['hires'], show_img=True, img_alpha=1,
              circle_diameter=11, alpha_scaling=1, reorder_cmap=[1, 3],
              crop_x=crop_x, crop_y=crop_y, axis_y_flipped=False,
                 colorbar_label_kw={'size': 45, 'weight': 'bold', 'y': 1.29},
                 colorbar_shape={'vertical_gaps': 0.12, 'horizontal_gaps': 1, 
                                 'width': 3.5, 'height': 0.045},
                 colorbar_tick_size=20)
add_rectangle_to_fig(fig, crop_coord=list(crop_x_small) + list(crop_y_small),
                     edgecolor='white', linestyle='--', linewidth=7)
fig.savefig(f'{fig_path}Fig4C_maps.pdf',
                bbox_inches='tight', facecolor='white')

fig = plot_spatial(adata_vis.obs[sel_clust_col], labels=sel_clust,
              coords=adata_vis.obsm['spatial'] * list(adata_vis.uns['spatial'].values())[0]['scalefactors']['tissue_hires_scalef'], 
              img=list(adata_vis.uns['spatial'].values())[0]['images']['hires'], show_img=True, img_alpha=1,
              circle_diameter=11, alpha_scaling=0, reorder_cmap=[1, 3],
              crop_x=crop_x, crop_y=crop_y, axis_y_flipped=False,
                 colorbar_label_kw={'size': 45, 'weight': 'bold', 'y': 1.29},
                 colorbar_shape={'vertical_gaps': 0.12, 'horizontal_gaps': 1, 
                                 'width': 3.5, 'height': 0.045},
                 colorbar_tick_size=20)
add_rectangle_to_fig(fig, crop_coord=list(crop_x_small) + list(crop_y_small),
                     edgecolor='white', linestyle='--', linewidth=7)
fig.savefig(f'{fig_path}Fig4C_histology.pdf',
                bbox_inches='tight', facecolor='white')

In [None]:
crop_x = [810, 1060]
crop_y = list(adata_vis.uns['spatial'].values())[0]['images']['hires'].shape[0] - np.array([1500, 1850])

rcParams["figure.figsize"] = [7, 7]
fig = plot_spatial(adata_vis.obs[sel_clust_col], labels=sel_clust,
              coords=adata_vis.obsm['spatial'] * list(adata_vis.uns['spatial'].values())[0]['scalefactors']['tissue_hires_scalef'], 
              img=list(adata_vis.uns['spatial'].values())[0]['images']['hires'], show_img=True, img_alpha=1,
              circle_diameter=16, alpha_scaling=1, reorder_cmap=[1, 3],
              crop_x=crop_x, crop_y=crop_y, axis_y_flipped=False)
fig.savefig(f'{fig_path}Fig4C_maps_zoom_in.pdf',
            bbox_inches='tight', facecolor='white')

fig = plot_spatial(adata_vis.obs[sel_clust_col], labels=sel_clust,
              coords=adata_vis.obsm['spatial'] * list(adata_vis.uns['spatial'].values())[0]['scalefactors']['tissue_hires_scalef'], 
              img=list(adata_vis.uns['spatial'].values())[0]['images']['hires'], show_img=True, img_alpha=1,
              circle_diameter=16, alpha_scaling=0, reorder_cmap=[1, 3],
              crop_x=crop_x, crop_y=crop_y, axis_y_flipped=False)
fig.savefig(f'{fig_path}Fig4C_maps_zoom_in_histology.pdf',
            bbox_inches='tight', facecolor='white')

### Fig 4F - B IFN+ & co-located cells <a class="anchor" id="Fig4F"></a>

In [None]:
# Visualize cell type locations
# making copy to transform to log & assign nice names
adata_vis_pl = adata_vis.copy()
clust_names_orig = ['q05_spot_factors' + i for i in adata_vis.uns['mod']['fact_names']]
clust_names = adata_vis.uns['mod']['fact_names']
adata_vis_pl.obs[clust_names] = (adata_vis_pl.obs[clust_names_orig])

rcParams["axes.facecolor"] = "black"
rcParams["savefig.facecolor"] = "white"
rcParams["figure.figsize"] = [10,10]
rcParams["font.size"] = 24

crop_x = [230, 1640]
crop_y = [150, 1760]

max_y = list(adata_vis_pl.uns['spatial'].values())[0]['images']['hires'].shape[1]
crop_y = [max_y - i + 80 for i in crop_y]
img_shape = list(adata_vis.uns['spatial'].values())[0]['images']['hires'].shape

ct_list=['B_IFN', 'T_TIM3+', 'NK', 'T_CD8+_cytotoxic']

fig = sc.pl.spatial(adata_vis_pl, cmap='magma',
              color=ct_list, ncols=4, 
              size=1.3, img_key='hires', alpha_img=0,
              frameon=True, legend_fontsize=50,
              crop_coord=crop_x + [crop_y[0]] + [crop_y[1]],
              vmin=0, vmax='p99.5', 
              return_fig=True, save='Fig_4F.pdf'
             )
add_rectangle_to_fig(fig, [960, 1430, 
                           img_shape[1] - 1670 + (img_shape[1] - crop_y[0]), 
                           img_shape[1] - 1370 + (img_shape[1] - crop_y[0])],
                     edgecolor='white', linestyle='--', linewidth=5)
fig.savefig(f'{fig_path}Fig_4F.pdf',
            bbox_inches='tight', facecolor='white')

rcParams["figure.figsize"] = [13.5,6]
sc.pl.spatial(adata_vis_pl, cmap='magma',
              color=ct_list, ncols=4, 
              size=1.3, img_key='hires', alpha_img=0,
              frameon=True, legend_fontsize=50,
              #crop_coord=crop_x + [crop_y[0]] + [crop_y[1]],
              vmin=0, vmax='p99.5', save=f'Fig_4F_cropped.pdf',
              crop_coord=[960, 1430, 1670, 1370]
             )

### Fig 4G - binned pseudospace dotplot <a class="anchor" id="Fig4G"></a>

In [None]:
adata_vis_dpt = adata_vis.copy()
adata_vis_dpt.uns['iroot'] = np.argmax(adata_vis.obs['mean_spot_factorsB_IFN'])
sc.tl.dpt(adata_vis_dpt)
# order adata object by pseudotime
adata_vis_dpt = adata_vis_dpt[adata_vis_dpt.obs['dpt_pseudotime'].sort_values(ascending=False).index, :]

adata_vis_dpt.obs['dpt_pseudotime_bin'] = pd.cut(adata_vis_dpt.obs['dpt_pseudotime'], bins=8).astype('str')

# Compute average abundance of each region cluster
from cell2location.cluster_averages.cluster_averages import get_cluster_averages_df
from cell2location.plt.plot_heatmap import clustermap

ct_aver = get_cluster_averages_df(X=adata_vis_dpt.obs[['mean_spot_factors' + i 
                                                      for i in adata_vis_dpt.uns['mod']['fact_names']]],
                                  cluster_col=adata_vis_dpt.obs["dpt_pseudotime_bin"])
ct_aver.index = adata_vis_dpt.uns['mod']['fact_names']
ct_aver = ct_aver.loc[ct_list, :]
ct_aver = ct_aver[ct_aver.columns.sort_values(ascending=False)]

# normalise to get 10% of each cell type in each location
ct_aver = (ct_aver.T / ct_aver.max(1)).T
with matplotlib.rc_context({'axes.facecolor':  'white'}):
    clustermap(ct_aver, fun_type='dotplot',
               cluster_rows=False, cluster_cols=False, figure_size=(7, 2))
    plt.savefig(f'{str(sc.settings.figdir)}/Fig_4G_diffmap_dotplot.pdf',
                bbox_inches='tight', facecolor='white')

In [None]:
adata_vis_dpt = adata_vis.copy()
adata_vis_dpt.uns['iroot'] = np.argmax(adata_vis.obs['mean_spot_factorsB_IFN'])
sc.tl.dpt(adata_vis_dpt)
# order adata object by pseudotime
adata_vis_dpt = adata_vis_dpt[adata_vis_dpt.obs['dpt_pseudotime'].sort_values(ascending=False).index, :]

rcParams["axes.facecolor"] = "white"
rcParams["figure.figsize"] = [7, 4*len(ct_list)]
x=adata_vis_dpt.obs['dpt_pseudotime']

n_col = len(ct_list)
i=0
plt.subplot(n_col, 1, i+1)
plot_1D_posterior(adata_vis_dpt, x, [ct_list[i]], reorder_cmap=[1],
                  show_points=False,
                  xlabel='Diffusion pseudospace', 
                  ylabel=f'{ct_list[i]} cell density')
#plt.tight_layout()
#plt.savefig(f'{fig_path}/Fig6_F{ct_list[i]}.pdf', bbox_inches='tight')
#plt.show()

i=1
plt.subplot(n_col, 1, i+1)
plot_1D_posterior(adata_vis_dpt, x, [ct_list[i]], reorder_cmap=[3],
                  show_points=False,
                  xlabel='Diffusion pseudospace', 
                  ylabel=f'{ct_list[i]} cell density')
#plt.tight_layout()
#plt.savefig(f'{fig_path}/Fig6_F{ct_list[i]}.pdf', bbox_inches='tight')
#plt.show()

i=2
plt.subplot(n_col, 1, i+1)
plot_1D_posterior(adata_vis_dpt, x, [ct_list[i]], reorder_cmap=[4],
                  show_points=False,
                  xlabel='Diffusion pseudospace', 
                  ylabel=f'{ct_list[i]} cell density')

i=3
plt.subplot(n_col, 1, i+1)
plot_1D_posterior(adata_vis_dpt, x, [ct_list[i]], reorder_cmap=[2],
                  show_points=False,
                  xlabel='Diffusion pseudospace', 
                  ylabel=f'{ct_list[i]} cell density')
#plt.tight_layout()
#plt.savefig(f'{fig_path}/Fig6_F{ct_list[i]}.pdf', bbox_inches='tight')
#plt.show()

plt.tight_layout()
plt.savefig(f'{fig_path}/Fig4G_suppl_full_diffmap.pdf', bbox_inches='tight')
plt.show()

```python
# extract umap coordinates
umap_coord = adata_snrna_raw.obsm['X_umap'].copy()
    
# make positive and rescale to fill the image
umap_coord[:, 0] = umap_coord[:, 0] # + abs(umap_coord[:, 0].min()) + abs(umap_coord[:, 0].max())*0.01
umap_coord[:, 1] = umap_coord[:, 1] # + abs(umap_coord[:, 1].min()) + abs(umap_coord[:, 1].max())*0.01

adata_cluster_col = 'Subset'
cell_fact_df = pd.get_dummies(adata_snrna_raw.obs[adata_cluster_col], columns=[adata_cluster_col])
cell_fact_df = cell_fact_df[ct_list]
cell_fact_df.columns = cell_fact_df.columns.tolist()
cell_fact_df['other'] = (cell_fact_df.sum(1) == 0).astype(np.int64)

sc.settings.set_figure_params(dpi = 150, color_map = 'RdPu', dpi_save = 260, vector_friendly = True,
                              format = 'svg')
plt.rcParams["axes.grid"] = False
rcParams["figure.figsize"] = [8, 10]
for i in cell_fact_df.columns:
    cell_fact_df[i] = cell_fact_df[i].astype('float32')

plot_spatial(cell_fact_df, 
              coords=umap_coord, 
              img=None, reorder_cmap=[0, 1, 2, 3, 6],
              show_img=False, img_alpha=0.8, lim='no_limit',
              circle_diameter=2.5, alpha_scaling=1, axis_y_flipped=False,
              max_color_quantile=1,
              save_path=fig_path + '/', save_name='Fig6_ext_G_UMAP', 
              save_facecolor='white', save_extension='png'
             )
```

In [None]:
adata_vis_dpt.obs['inv_dpt_pseudotime'] = 1 - adata_vis_dpt.obs['dpt_pseudotime']

rcParams["axes.facecolor"] = "black"
rcParams["savefig.facecolor"] = "white"
rcParams["figure.figsize"] = [10,10]
rcParams["font.size"] = 24

crop_x = [230, 1640]
crop_y = [150, 1760]

max_y = list(adata_vis_pl.uns['spatial'].values())[0]['images']['hires'].shape[1]
crop_y = [max_y - i + 80 for i in crop_y]


sc.pl.spatial(adata_vis_dpt, cmap='magma',
              color='inv_dpt_pseudotime', ncols=1, 
              size=1.3, img_key='hires', alpha_img=0,
              frameon=True, legend_fontsize=50,
              crop_coord=crop_x + [crop_y[0]] + [crop_y[1]],
              vmin=0, vmax='p99.0', save='Fig4G_suppl_pseudospace_B_IFN.pdf'
             )

rcParams["axes.facecolor"] = "white"

### Read saved sklearn colocated cell group model <a class="anchor" id="read_nmf"></a>

In [None]:
# import models
def unpickle_model(path, mod_name):
    r""" Unpickle model
    """
    file = path + 'model_' + mod_name + ".p"
    
    mod1_ann = pickle.load(file = open(file, "rb"))
    return mod1_ann['mod']

n_fact = 14
mod_path = f'{results_folder}std_model/{run_name}/CoLocatedComb/CoLocatedCombination_sklearnNMF_4039locations_34factors/models/' 
adata_file = f'{results_folder}std_model/{run_name}/CoLocatedComb/CoLocatedCombination_sklearnNMF_4039locations_34factors/anndata/sp.h5ad' 


mod_sk = unpickle_model(mod_path, f'n_fact{n_fact}')

adata_vis_sk = anndata.read(adata_file)

### Fig 4D - co-located cell groups - dotplot <a class="anchor" id="Fig4D_dotplot"></a>

In [None]:
b_dev_sel = ['B_plasma', 'B_naive',
             
             'B_GC_LZ', 'T_CD4+_TfH_GC', 'B_GC_prePB', 'FDC', 
             'B_Cycling', 'B_GC_DZ',
             
             'B_preGC',
             'Endo', 'VSMC', 'Mast', 'Monocytes', 'DC_cDC2',
             
             'T_CD4+', 'B_mem',
             'T_CD4+_naive', 'T_CD8+_naive',
             
             'DC_CCR7+', 'T_TfR', 'T_Treg',
             'T_CD4+_TfH', 'T_CD8+_cytotoxic', 'T_CD8+_CD161+', 'NK', 'ILC', 'NKT',  
             'B_activated', 'Macrophages_M2', 'DC_pDC',
             
             'B_IFN', 'T_TIM3+',
             'DC_cDC1', 'Macrophages_M1',  
            ]

fact_filt = ['fact_8', 'fact_4', 
             'fact_3', 'fact_6', 
             'fact_12', 'fact_11', 
             'fact_9', 'fact_0', 'fact_7',
             'fact_5', 
             'fact_2', 'fact_10', 'fact_13', 'fact_1']

mod_sk.cell_type_fractions = (mod_sk.cell_type_fractions.T / mod_sk.cell_type_fractions.max(1)).T

matplotlib.rc_file_defaults()
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text
clustermap(mod_sk.cell_type_fractions.loc[b_dev_sel, fact_filt[::-1]],
           cluster_rows=False, cluster_cols=False, 
           figure_size=[5.9 + 0.12 * mod_sk.n_fact, 5.9 + 0.1 * mod_sk.n_genes],
           fun_type='dotplot', array_size=None)

plt.savefig(f'{fig_path}/Fig4D_fact_dotplot.pdf', bbox_inches='tight')
plt.show()

### Fig 4D - Supplementary - matching clustering - dotplot <a class="anchor" id="Fig4D_suppl_clust"></a>

In [None]:
# Repeat clustering of regions to get matching number
# Cluster spots into regions using scanpy
sc.tl.leiden(adata_vis_sk, resolution=0.95)
adata_vis_sk.obs["region_cluster"] = adata_vis_sk.obs["leiden"]
adata_vis_sk.obs["region_cluster"] = adata_vis_sk.obs["region_cluster"].astype("category")

rcParams["figure.figsize"] = [10, 10]
rcParams["axes.facecolor"] = "black"
crop_x = [230, 1640]
crop_y = [150, 1760]

max_y = list(adata_vis_pl.uns['spatial'].values())[0]['images']['hires'].shape[1]
crop_y = [max_y - i + 80 for i in crop_y]

sc.pl.spatial(adata_vis_pl, cmap='magma',
              color=['region_cluster'], ncols=4, 
              size=1.3, img_key='hires', alpha_img=0,
              frameon=True, legend_fontsize=20,
              crop_coord=crop_x + [crop_y[0]] + [crop_y[1]],
              vmin=0, vmax='p99.5', save='Fig4D_suppl_spatial_clusters.pdf',
              palette=sc.pl.palettes.default_102
             )
rcParams["axes.facecolor"] = "white"

In [None]:
# Compute average abundance of each region cluster
from cell2location.cluster_averages.cluster_averages import get_cluster_averages_df
from cell2location.plt.plot_heatmap import clustermap
#plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = True
#plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = False

ct_aver = get_cluster_averages_df(X=adata_vis_sk.obs[['mean_spot_factors' + i 
                                                      for i in adata_vis_sk.uns['mod']['fact_names']]],
                        cluster_col=adata_vis_sk.obs["region_cluster"])
ct_aver.index = adata_vis_sk.uns['mod']['fact_names']
ct_aver.columns = ['region_' + c for c in ct_aver.columns]

# normalise to get 10% of each cell type in each location
ct_aver = (ct_aver.T / ct_aver.max(1)).T

matplotlib.rc_file_defaults()
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text
clustermap(ct_aver.loc[b_dev_sel, :],
           cluster_rows=False, cluster_cols=True, 
           figure_size=[5.9 + 0.12 * mod_sk.n_fact, 5.9 + 0.1 * mod_sk.n_genes],
           fun_type='dotplot', array_size=None)

plt.savefig(f'{fig_path}/Fig4D_suppl_cluster_dotplot.pdf', bbox_inches='tight')
plt.show()

###  Fig 4D - factor spatial <a class="anchor" id="Fig4D_spatial"></a>

In [None]:
# Visualize cell type locations
# making copy to transform to log & assign nice names
adata_vis_pl = adata_vis_sk.copy()
clust_names_orig = ['q05_spot_factors' + i for i in adata_vis_sk.uns['mod']['fact_names']]
clust_names = adata_vis_sk.uns['mod']['fact_names']
adata_vis_pl.obs[clust_names] = (adata_vis_pl.obs[clust_names_orig])

fact_names_orig = ['mean_nUMI_factors' + i for i in adata_vis_sk.uns[f'mod_coloc_n_fact{n_fact}']['fact_names']]
fact_names = adata_vis_sk.uns[f'mod_coloc_n_fact{n_fact}']['fact_names']
adata_vis_pl.obs[fact_names] = pd.DataFrame(adata_vis_sk.uns[f'mod_coloc_n_fact{n_fact}']['post_sample_means']['nUMI_factors'],
                                            index=adata_vis_sk.uns[f'mod_coloc_n_fact{n_fact}']['obs_names'],
                                            columns=fact_names)

rcParams["axes.facecolor"] = "black"
rcParams["savefig.facecolor"] = "white"
rcParams["font.size"] = 24

crop_x = [230, 1640]
crop_y = [150, 1760]

max_y = list(adata_vis_pl.uns['spatial'].values())[0]['images']['hires'].shape[1]
crop_y = [max_y - i + 80 for i in crop_y]

def plot_spatial_factors(adata, ct, fig_name='', crop_coord_small=coord_small):
    
    with matplotlib.rc_context({'axes.facecolor':  'black',
                                'savefig.facecolor': 'white',
                                'font.size': 24,
                                'figure.figsize': [10, 10]}):
        
        fig = sc.pl.spatial(adata_vis_pl, cmap='magma',
                      color=ct, ncols=5, 
                      size=1.3, img_key='hires', alpha_img=0,
                      frameon=True, legend_fontsize=50,
                      crop_coord=crop_x + [crop_y[0]] + [crop_y[1]],
                      vmin=0, vmax='p99.5',
                      return_fig=True
                     )
        add_rectangle_to_fig(fig, [crop_coord_small[0], crop_coord_small[1], 
                                   img_shape[1] - crop_coord_small[2] + (img_shape[1] - crop_y[0]), 
                                   img_shape[1] - crop_coord_small[3] + (img_shape[1] - crop_y[0])],
                             edgecolor='white', linestyle='--', linewidth=5)
        fig.savefig(f'{fig_path}{fig_name}{ct[0]}.pdf',
                    bbox_inches='tight', facecolor='white')

    with matplotlib.rc_context({'axes.facecolor':  'black',
                                'savefig.facecolor': 'white',
                                'font.size': 24,
                                'figure.figsize': [12.5, 10]}):
        
        sc.pl.spatial(adata_vis_pl, cmap='magma',
                      color=ct, ncols=5, 
                      size=1.3, img_key='hires', alpha_img=0,
                      frameon=True, legend_fontsize=50,
                      vmin=0, vmax='p99.5', save=f'{fig_name}{ct[0]}_cropped.pdf',
                      crop_coord=crop_coord_small
                     )


ct_list=['fact_3', 'B_GC_LZ', 'T_CD4+_TfH_GC', 'B_GC_prePB', 'FDC']
coord_small = [810, 1200, 1850, 1450]
plot_spatial_factors(adata_vis_pl, ct_list, 'Fig4D_', coord_small)

In [None]:
ct_list=['fact_6', 'B_Cycling', 'B_GC_DZ']

plot_spatial_factors(adata_vis_pl, ct_list, 'Fig4D_', coord_small)

In [None]:
ct_list=['fact_9', 'B_mem', 'T_CD4+']

plot_spatial_factors(adata_vis_pl, ct_list, 'Fig4D_', coord_small)

In [None]:
ct_list=['fact_11', 'Endo', 'VSMC', 'Monocytes']

plot_spatial_factors(adata_vis_pl, ct_list, 'Fig4D_', coord_small)

In [None]:
ct_list=['B_preGC']

plot_spatial_factors(adata_vis_pl, ct_list, 'Fig4D_', coord_small)

In [None]:
ct_list=['fact_0', 'T_CD4+_naive', 'T_CD8+_naive']

plot_spatial_factors(adata_vis_pl, ct_list, 'Fig4D_', coord_small)

In [None]:
ct_list=['fact_7', 'T_Treg', 'T_TfR', 'DC_CCR7+', 'T_CD8+_cytotoxic']

plot_spatial_factors(adata_vis_pl, ct_list, 'Fig4D_', coord_small)

In [None]:
ct_list=['fact_5', 'B_activated', 'Macrophages_M2', 'DC_pDC']

plot_spatial_factors(adata_vis_pl, ct_list, 'Fig4D_', coord_small)