In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import squidpy as sq
import sccellfie
import scanpy as sc
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import glasbey
import h5py

import textwrap

from pathlib import Path

## Load Datasets

In [None]:
with h5py.File("merged_TMA_processed.h5ad", "r+") as f:
    if "uns" in f and "log1p" in f["uns"] and "base" in f["uns/log1p"]:
        del f["uns/log1p/base"]

In [None]:
# Load the dataset
adata = sc.read(filename='merged_TMA_processed.h5ad')
adata

In [None]:
adata.var_names = adata.var["gene"].astype(str)
adata.var_names_make_unique()

In [None]:
subject_treatment_df = pd.DataFrame(adata.obs[["Subject_ID","Treatment_Status"]])
subject_treatment_df

In [None]:
adata.obs[["Subject_ID","Treatment_Status"]]

In [None]:
aa = adata.obs[["Treatment_Status","Subject_ID"]]
aa[aa["Subject_ID"]=='89397']

In [None]:
dot_color_df = pd.DataFrame({
    "Treated":   (adata.obs["Treatment_Status"].astype(str) == "Treated").astype(int),
    "Untreated": (adata.obs["Treatment_Status"].astype(str) == "Untreated").astype(int)
}, index=adata.obs_names)

In [None]:
adata.obs["Treated"]   = (adata.obs["Treatment_Status"].astype(str) == "Treated").astype(int)
adata.obs["Untreated"] = (adata.obs["Treatment_Status"].astype(str) == "Untreated").astype(int)


In [None]:
sc.pl.dotplot(
    adata,
    var_names=["Treated","Untreated"],   # now valid
    groupby="Subject_ID",
    mean_only_expressed=False,
    expression_cutoff=0,
    swap_axes=False,
    vmin=0, vmax=1,
    colorbar_title="Fraction of cells"
)


In [None]:
sc.pl.dotplot(adata, var_names=results['adata'].var.index, groupby=['Subject_ID','Treatment_Status'], swap_axes=True, standard_scale='var')

In [None]:
adata.var_names

In [None]:
# Make the dotplot: subjects × treatment groups
sc.pl.dotplot(
    adata,
    var_names=adata.var_names,
    groupby=['Subject_ID','Treatment_Status'],
    swap_axes=False,
    standard_scale=None,
    mean_only_expressed=True,
    expression_cutoff=0,
    return_fig=True   # <- IMPORTANT so we can edit borders
)


## Apply scCellFie Pipeline

In [None]:
results = sccellfie.run_sccellfie_pipeline(adata,
                                           organism='human',
                                           sccellfie_data_folder=None,
                                           n_counts_col='nCount_Nanostring',
                                           process_by_group=False,
                                           groupby=None,
                                           neighbors_key='neighbors',
                                           n_neighbors=10,
                                           batch_key='sample',
                                           threshold_key='sccellfie_threshold',
                                           smooth_cells=True,
                                           alpha=0.33,
                                           chunk_size=5000,
                                           disable_pbar=False,
                                           save_folder=None,
                                           save_filename=None
                                          )

## scCellFie Result Data Understanding

In [None]:
results

In [None]:
metabolic_data = results['adata'].metabolic_tasks
metabolic_data

In [None]:
reaction_data = results['adata'].reactions
reaction_data

## Save Gene, Reactions and Metabolic Tasks as CSV

In [None]:
# df = pd.DataFrame(
#     results['adata'].metabolic_tasks.X,
#     index=results['adata'].metabolic_tasks.obs.index,
#     columns=results['adata'].metabolic_tasks.var.index
#     )
# df

In [None]:
# df.to_csv('scCellFie_metabolic_tasks_with_name.csv')

In [None]:
# Save adata objects containing single-cell/spatial predictions in H5AD format
ccellfie.io.save_adata(adata=results['adata'],
                        output_directory='output/',
                        filename='sccellfie_results'
                        )

## Visualization on scCellFie Data

In [None]:
group_by = 'Subject_ID'

### Cell group level for the Metabolic Task Visualizer

In [None]:
# Summarize results in a cell-group level for the Metabolic Task Visualizer
report = sccellfie.reports.generate_report_from_adata(results['adata'].metabolic_tasks,
                                                      group_by=gp,
                                                      feature_name='metabolic_task'
                                                      )

In [None]:
# Export files to a specific folder.
sccellfie.io.save_result_summary(results_dict=report, output_directory='report/')

In [None]:
# metabolic_tasks = ['ATP generation from glucose (hypoxic conditions) - glycolysis', 
#                    'ATP regeneration from glucose (normoxic conditions) - glycolysis + krebs cycle',
#                    'Gluconeogenesis from Lactate',
#                    'Glutaminolysis (glutamine to lactate)',
#                    'Glucose to lactate conversion'
#                   ]
metabolic_tasks = results['adata'].metabolic_tasks.var.index.tolist()
metabolic_tasks

In [None]:
palette = glasbey.extend_palette('Set2', palette_size=max([10, results['adata'].metabolic_tasks.obs[gp].unique().shape[0]]))

plt.rcParams['figure.figsize'] = (3,3)
plt.rcParams['font.size'] = 10

sc.pl.embedding(results['adata'].metabolic_tasks,
                color=[gp] + metabolic_tasks,
                ncols=1,
                palette=palette,
                frameon=False,
                basis='X_umap',
                wspace=0.7,
                title=["\n".join(textwrap.wrap(t, width=60)) for t in [gp] + metabolic_tasks],
                cmap='OrRd'
               )

In [None]:
# Violin Plot
fig, axes = sccellfie.plotting.create_multi_violin_plots(results['adata'].metabolic_tasks,
                                                         features=metabolic_tasks,
                                                         groupby=gp,
                                                         stripplot=False,
                                                         n_cols=2,
                                                         ylabel='Metabolic Score'
                                                        )

In [None]:
ax = sc.pl.stacked_violin(results['adata'].metabolic_tasks, metabolic_tasks, groupby=gp, swap_axes=True, dendrogram=False, standard_scale='var')

### Dot Plot

In [None]:
# Genes
sc.pl.dotplot(results['adata'], var_names=results['adata'].var.index, groupby=gp, swap_axes=True, standard_scale='var')

In [None]:
# Reactions
sc.pl.dotplot(results['adata'].reactions, var_names=results['adata'].reactions.var.index, groupby=gp, swap_axes=True, standard_scale='var')

In [None]:
# Metabolic Tasks
sc.pl.dotplot(results['adata'].metabolic_tasks, var_names=metabolic_tasks, groupby=gp, swap_axes=True, standard_scale='var')

### Heat Map 

In [None]:
ax = sc.pl.heatmap(results['adata'].metabolic_tasks, var_names=metabolic_tasks, groupby=gp, cmap="YlGnBu", swap_axes=True, dendrogram=True,
                   figsize=(16, 4)
                  )

In [None]:
agg = sccellfie.expression.aggregation.agg_expression_cells(results['adata'].metabolic_tasks, groupby=gp, agg_func='trimean')

In [None]:
input_df = sccellfie.preprocessing.matrix_utils.min_max_normalization(agg.T, axis=1)

In [None]:
plt.figure(figsize=(16, 4))
g = sns.heatmap(input_df.loc[metabolic_tasks,:], cmap='YlGnBu', linewidths=0.5, xticklabels=1, yticklabels=1)

cbar = g.collections[0].colorbar
cbar.set_label('Scaled metabolic activity', size=14, rotation=270, labelpad=25)  # Change colorbar label size and rotation

# Uncomment code below to save figure
# plt.savefig('./figures/Heatmap-Seaborn.pdf', dpi=300, bbox_inches='tight')

In [None]:
ax = sc.pl.tracksplot(results['adata'].metabolic_tasks, var_names=metabolic_tasks, groupby=gp, dendrogram=True, figsize=(16, 4))

### Radial Plot

In [None]:
df_melted = pd.melt(input_df.reset_index(), id_vars='Task', var_name='cell_type', value_name='scaled_trimean')
df_melted = df_melted.rename(columns={'Task': 'metabolic_task'})
df_melted.head()

In [None]:
ct = df_melted.cell_type.unique()[0:4]
ct

In [None]:
fig = plt.figure(figsize=(16, 16))
ax1 = fig.add_subplot(221, projection='polar')
ax2 = fig.add_subplot(222, projection='polar')
ax3 = fig.add_subplot(223, projection='polar')
ax4 = fig.add_subplot(224, projection='polar')

for i, (cell, ax) in enumerate(zip(gp, [ax1, ax2, ax3, ax4])):
    sccellfie.plotting.create_radial_plot(df_melted, 
                                          results['task_info'],
                                          cell_type=str(ct[i]),
                                          ax=ax,
                                          show_legend=i == 1,
                                          ylim=1.0)