# Integrated stAge pipeline

#### stAge steps:
1. Load H5AD dataset and set parameters
2. Optimal Resolution Search (ORS) or custom resolution
3. Apply stAge at optimal resolution
4. Display/Save results

Requirements:
- Scaled and YuGene EN pkl files
- Mus_musculus.gene_info
- st_utils.py & st_resol.py

Instructions: 
- Give directory with one or more H5AD files (1 sample = 1 file)
- Make sure gene names (var_names) are SYMBOL
- Make sure main gene expression matrix is raw counts

Notes: 
- Must run spatial plots to get spot-level predictions
- Must run box plots to get metaspot-level predictions

In [12]:
def stAge(rawdata_dir='', # directory with the H5AD files
          control_file_pattern = '', # string in the file names that identifies the contrl group for relative predictions
          ORS=True, # ORS active or not 
          alt_res=1, # custom resolution if ORS=False
          clocks_dir='', # directory of folder with EN tAge models
          spatial_plot=True, # show spatial plots of spot-level predictions
          spot_size = 10, # spatial plots spot size
          box_plot=True, # show box plots of metaspot-level predictions
          group_patterns=[], # strings in the file names that identifies/separates the different groups to be plotted in the box plot
          save_at_spot=False, # save spot-level predictions as H5AD
          save_at_metaspot=False, # save metaspot-level predictions as H5AD
          save_dir=''): # directory to save results 

    import warnings
    warnings.filterwarnings("ignore", category=FutureWarning)     # Suppress specific warning types
    warnings.filterwarnings("ignore", category=RuntimeWarning)
    
    import math
    import joblib
    import os
    from pathlib import Path
    
    import pandas as pd
    import scanpy as sc
    
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from scipy.stats import mannwhitneyu
    import squidpy as sq
    import numpy as np
    from scipy import stats
    import scanpy as sc, squidpy as sq, anndata as ad
    from statannotations.Annotator import Annotator
    
    # 1. Load H5AD dataset and set parameters
    ipynb_dir = os.getcwd()
    h5ad_files = [file for file in os.listdir(f'{rawdata_dir}') if ".h5ad" in file]
    assembled_adatas = {file:sc.read(f'{rawdata_dir}/{file}') for file in h5ad_files}

    # 2. Maximum Optimal Resolution Search
    if ORS==True:
        resol_df = optimal_resolution_search(
            assembled_adatas,
            ipynb_dir=ipynb_dir,
            pred_pipeline=full_nonoverlap_mp_pipeline,
            control_file_pattern = control_file_pattern,
            cohen_weight=0.6,             # weight for Cohen's d in composite score
            tstat_weight=0.4,             # weight for t-statistic
            tolerance=0.1               # tie-breaking tolerance for score 0.1 = 10%
            )
    
        optimal_resolutions = {row['Clock']: row['Resolution'] for _, row in resol_df.iterrows()}
        orig_resol = optimal_resolutions['orig'] # optimal resolution for the original tAge, use 'tms' or 'tmsh' for other
    else: 
        orig_resol = alt_res
        
    # 3. MAIN: Apply stAge pipeline with the resolution
    cleaned = {name: ad for name, ad in assembled_adatas.items()
               if ad.n_obs >= 20}            # keep only well-sized slices

    if spatial_plot == True: 
        # 3.1. Spot-level predictions for SPATIAL PLOTTING
        is_lowres = False  
        
        # now run once on the filtered dict
        preds_per_file = full_nonoverlap_mp_pipeline(
                        cleaned,
                        res=orig_resol,
                        lower_res=is_lowres,
                        control_file_pattern=control_file_pattern,
                        mp_coverage_threshold=1_000,
                        save_plot=False,
                        save_result=False,
                        clock_folder=clocks_dir, #tAge_clocks/tms_clocks #EN differential models 4.6 #tAge_clocks/EN differential models 5.4 # 5.4 is for HUMANS, 4.6 for MOUSE
                        # save_dir=f'{save_dir}',
                        )
        # Save results 
        if save_at_spot == True:
            for file, adata in preds_per_file.items():
                adata.write_h5ad(f'{save_dir}/{file}' #.h5ad
                                )
        
        ## 4.1. SPATIAL plotting
        # Determine consistent color scale across all samples and both clocks
        vmax_sm = max(adata.obs['tAge_SM'].max() for adata in preds_per_file.values())
        vmin_sm = min(adata.obs['tAge_SM'].min() for adata in preds_per_file.values())
        
        # Make it symmetric around 0 (if needed)
        vfinal = max(abs(vmax_sm), abs(vmin_sm))
        for tis, adata_pred in preds_per_file.items():
            sc.pl.spatial(
                        adata_pred,
                        color='tAge_SM',
                        spot_size=spot_size,
                        cmap='coolwarm',
                        vmax=vfinal,
                        vmin=-vfinal,
                        title=f'{tis.replace('.h5ad', '')}',
                        )
        
        # # Number of samples
        # n = len(preds_per_file)
        # fig, axes = plt.subplots(nrows=2, ncols=n, figsize=(n * 24, 24), dpi=150)
        
        # # If only one sample, axes might not be 2D
        # if n == 1:
        #     axes = np.array([[axes[0]], [axes[1]]])
        
        # # Plot each sample's clocks in two rows
        # for i, (tis, adata_pred) in enumerate(preds_per_file.items()):
        #     for j, clock in enumerate(['tAge_SM']):
        #         ax = axes[j][i]
        #         sc.pl.spatial(
        #             adata_pred,
        #             color=clock,
        #             spot_size=spot_size,
        #             cmap='coolwarm',
        #             vmax=vfinal,
        #             vmin=-vfinal,
        #             ax=ax,
        #             show=False,
        #             title=f'{tis.replace('.h5ad', '')} | {clock}',
        #             )
        # plt.tight_layout()
        # plt.show()
        ##
        
    if box_plot == True: 
    
        # 3.2. MetaSpot-level predictions for BOX PLOTTING
        is_lowres = True  
        
        # now run once on the filtered dict
        preds_per_file = full_nonoverlap_mp_pipeline(
                        cleaned,
                        res=orig_resol,
                        lower_res=is_lowres,
                        control_file_pattern=control_file_pattern,
                        mp_coverage_threshold=1_000,
                        save_plot=False,
                        save_result=False,
                        clock_folder=clocks_dir, #tAge_clocks/tms_clocks #EN differential models 4.6 #tAge_clocks/EN differential models 5.4 # 5.4 is for HUMANS, 4.6 for MOUSE
                        # save_dir=f'{save_dir}/{file}',
                        )
        # Optional save
        if save_at_metaspot == True:
            for file, adata in preds_per_file.items():
                adata.write_h5ad(f'{save_dir}/{file}' #.h5ad
                                )
        # 4.2. BOX plotting  
        plot_clock_distributions(preds_per_file, group_patterns, norm_cols=['tAge_SM'], test='Mann-Whitney')

In [None]:
# Example usage 
import os
from st_utils import *
from st_resol import *

stAge(rawdata_dir='/home/vvicente/spatial_aging/data/immunoglobulin/Hippocampus',
          control_file_pattern = 'Y',
          ORS=False,
          alt_res=1,
          clocks_dir='/home/vvicente/spatial_aging/tAge_clocks/EN differential models 4.6',
          spatial_plot=True,
          spot_size = 1,
          box_plot=True,
          group_patterns=['Y', 'O'],
          save_dir='/home/vvicente/spatial_aging/vvicente/results')