In [None]:
import os
os.chdir(path='../..')
import numpy as np
import scanpy as sc
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt
import STForte.helper as stfhelper
sc.set_figure_params(dpi=120, transparent=True, dpi_save=400, frameon=False, vector_friendly=False, format="pdf", fontsize=16)
trial_name = "trial-OSCC/sample_1"
palette = px.colors.qualitative.Plotly
plot_dir = f"./{trial_name}/plots"
sc.settings.figdir = plot_dir
plt.rcParams['font.sans-serif'] = [
    'Helvetica',
    'Arial',
    'sans-serif',]
palette_seq = stfhelper.pl.hex2rgb(px.colors.sequential.Viridis)
palette_seq = stfhelper.pl.create_refined_colormap(palette_seq)

In [None]:
adata = sc.read_h5ad(f"./{trial_name}/outputs/stforte.h5ad")
adata

In [None]:
STForte_pred = adata.obs['louvain'].to_numpy()
STForte_pred[STForte_pred!='2'] = 'Normal tissue (STForte)'
STForte_pred[STForte_pred=='2'] = 'SCC (STForte)'
pathologist_anno = adata.obs['pathologist_anno.x'].to_numpy()
pathologist_anno[pathologist_anno!='SCC'] = 'Normal tissue'
adata.obs['STForte_anno'] = STForte_pred
adata.obs['pathologist_anno'] = pathologist_anno
adata.obs['STForte_anno'] = adata.obs['STForte_anno'].astype('category').cat.reorder_categories(['Normal tissue (STForte)','SCC (STForte)'])
adata.obs['pathologist_anno'] = adata.obs['pathologist_anno'].astype('category').cat.reorder_categories(['Normal tissue','SCC'])

In [None]:
adata_sp = sc.read_h5ad(f"./{trial_name}/outputs/sp.h5ad")
adata_sp.uns = adata.uns
adata_sp

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)

In [None]:
gene_names = np.array(adata.var.index)[adata.var['highly_variable_rank']<4]
# gene_names = np.array(['S100A7','S100A8','S100A9','IGKC'])
sc.pl.violin(
    adata, keys=gene_names, 
    groupby="pathologist_anno",
    palette=palette, 
    xlabel="",
    
    )

In [None]:
import plotly.graph_objects as go
idxs = [adata.var.index==gene for gene in gene_names]
idx = sum(idxs)
x = adata.X.A[:,idx.astype('bool')].T.reshape(-1)
y = gene_names.reshape(-1,1).repeat(len(adata),axis=1).reshape(-1)
z = adata.obs['pathologist_anno'].to_numpy().reshape(1,-1).repeat(4,axis=0).reshape(-1)

group = adata.obs['STForte_anno']
fig = go.Figure()
fig.add_trace(go.Box(
    x=x[z=='Normal tissue'],
    y=y[z=='Normal tissue'],
    name='Normal tissue',
    marker_color=palette[0]
))
fig.add_trace(go.Box(
    x=x[z=='SCC'],
    y=y[z=='SCC'],
    name='SCC',
    marker_color=palette[1]
))
# fig.add_trace(go.Box(
#     x=adata.X.A[:,adata.var.index=='S100A7'].reshape(-1),
#     y=group,
#     name='S100A7',
#     marker_color=palette[2]
# ))
# fig.add_trace(go.Box(
#     x=adata.X.A[:,adata.var.index=='IGKC'].reshape(-1),
#     y=group,
#     name='IGKC',
#     marker_color=palette[3]
# ))
fig.update_layout(
    xaxis=dict(title='', zeroline=False),
    boxmode='group'
)

fig.update_traces(orientation='h') # horizontal box plots
fig.write_image(f"./{trial_name}/plots/select_gene.pdf")
fig.show()

In [None]:
marker_palette = dict(map(lambda k, v: (k, v), adata.obs["pathologist_anno.x"].cat.categories.astype(str), adata.uns['pathologist_anno.x_colors']))
# adata.obs["pathologist_anno_temp"] = adata.obs['pathologist_anno.x'].cat.reorder_categories([1,2,3,4]).cat.rename_categories(lambda i : str(i))
fig = stfhelper.pl.plot_trend_genes(adata, gene_names, 
                                    group="pathologist_anno.x",
                                    marker_palette=marker_palette,
                                    line_color="#333333")
fig.update_layout(
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor="#F8F8FF",
    title=None,
    template='plotly_white', width=220, height=500, 
    showlegend=False,)
fig.update_layout(margin=dict(l=16, r=2, b=2, t=16, pad=2),
                  font_family="Arial", font_size=16)
fig.update_yaxes(title_font_size=14)
for ii in range(len(gene_names)):
    if ii != len(gene_names) - 1:
        fig.update_xaxes(tickson="boundaries",row=ii+1, col=1)
    else:    
        fig.update_xaxes(ticks="outside", tickson="boundaries", ticklen=10, tickcolor="#aaaaaa",row=8, col=1)
fig.write_image(f"./{trial_name}/plots/trend_markergene.pdf")
fig.show()

In [None]:
sp_exp = stfhelper.complete_unseen_expression(
    adata_sp, gene_names, adata, 'SP_TOPO'
)

In [None]:
sc.pl.spatial(
    adata,color=gene_names,
    color_map=palette_seq,s=13, vmin=0,vmax=7,
    save="Gene_unpaded.pdf",bw=True)
sc.pl.spatial(
    adata_sp,color=[i + '_with_padding' for i in gene_names],
    color_map=palette_seq,s=6, vmin=0,vmax=7,
    save="Gene_STForte_padding.pdf",bw=True)