In [1]:
import argparse
import os
import torch
import pyro
import json
import math
from tqdm import tqdm
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from pyro.optim import PyroOptim
from torch.optim import Adam
import pyro.distributions as dist
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.cm import get_cmap
from scipy.sparse import csr_matrix
from scipy.spatial import KDTree
import seaborn as sns

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import lightning as L
import torch.nn.functional as F
from torch.utils.data import DataLoader

import subprocess
import warnings
warnings.filterwarnings("ignore")
from importlib import reload

# this ensures that I can update the class without losing my variables in my notebook
import xenium_cluster
reload(xenium_cluster)
from xenium_cluster import XeniumCluster
from utils.metrics import *

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans

In [2]:
def prepare_Xenium_data(
        dataset="hBreast", 
        spots=True, 
        spot_size=100, 
        third_dim=False, 
        log_normalize=True,
        likelihood_mode="PCA",
        num_pcs=5,
        hvg_var_prop=0.5,
        min_expressions_per_spot=10
    ):

    data_filepath = f"data/spot_data/{dataset}/hBreast_SPOTSIZE={spot_size}um_z={third_dim}.h5ad"
    
    if spots:

        if os.path.exists(data_filepath):

            clustering = XeniumCluster(data=None, dataset_name="hBreast")
            clustering.set_spot_size(spot_size)
            print("Loading data.")
            clustering.xenium_spot_data = ad.read_h5ad(data_filepath)

        else:

            # Path to your .gz file
            file_path = f'data/{dataset}/transcripts.csv.gz'

            # Read the gzipped CSV file into a DataFrame
            df_transcripts = pd.read_csv(file_path, compression='gzip')
            df_transcripts["error_prob"] = 10 ** (-df_transcripts["qv"]/10)
            df_transcripts.head(), df_transcripts.shape

            # drop cells without ids
            df_transcripts = df_transcripts[df_transcripts["cell_id"] != -1]

            # drop blanks and controls
            df_transcripts = df_transcripts[~df_transcripts["feature_name"].str.startswith('BLANK_') & ~df_transcripts["feature_name"].str.startswith('NegControl')]

            clustering = XeniumCluster(data=df_transcripts, dataset_name="hBreast")
            clustering.set_spot_size(spot_size)

            if not os.path.exists(data_filepath):
                print("Generating and saving data")
                clustering.create_spot_data(third_dim=third_dim, save_data=True)
                clustering.xenium_spot_data.write_h5ad(data_filepath)

        print("Number of spots: ", clustering.xenium_spot_data.shape[0])
        clustering.xenium_spot_data = clustering.xenium_spot_data[clustering.xenium_spot_data.X.sum(axis=1) > min_expressions_per_spot]
        print("Number of spots after filtering: ", clustering.xenium_spot_data.shape[0])

        if log_normalize:
            clustering.normalize_counts(clustering.xenium_spot_data)

        if likelihood_mode == "PCA":
            sc.tl.pca(clustering.xenium_spot_data, svd_solver='arpack', n_comps=num_pcs)
            data = clustering.xenium_spot_data.obsm["X_pca"]
        elif likelihood_mode == "HVG":
            min_dispersion = torch.distributions.normal.Normal(0.0, 1.0).icdf(hvg_var_prop)
            clustering.filter_only_high_variable_genes(clustering.xenium_spot_data, flavor="seurat", min_mean=0.0125, max_mean=1000, min_disp=min_dispersion)
            clustering.xenium_spot_data = clustering.xenium_spot_data[:,clustering.xenium_spot_data.var.highly_variable==True]
            data = clustering.xenium_spot_data.X
        elif likelihood_mode == "ALL":
            data = clustering.xenium_spot_data.X

        spatial_locations = clustering.xenium_spot_data.obs[["row", "col"]]
    
    # prepare cells data
    else:

        cells = df_transcripts.groupby(['cell_id', 'feature_name']).size().reset_index(name='count')
        cells_pivot = cells.pivot_table(index='cell_id', 
                                        columns='feature_name', 
                                        values='count', 
                                        fill_value=0)
        
        location_means = df_transcripts.groupby('cell_id').agg({
            'x_location': 'mean',
            'y_location': 'mean',
            'z_location': 'mean'
        }).reset_index()

        cells_pivot = location_means.join(cells_pivot, on='cell_id')

        if log_normalize:
            # log normalization
            cells_pivot.iloc[:, 4:] = np.log1p(cells_pivot.iloc[:, 4:])

        if likelihood_mode == "PCA":
            pca = PCA(n_components=num_pcs)
            data = pca.fit_transform(cells_pivot.iloc[:, 4:])

        elif likelihood_mode == "HVG":
            genes = cells_pivot.iloc[:, 4:]
            gene_variances = genes.var(axis=0)
            gene_variances = gene_variances.sort_values(ascending=False)
            gene_var_proportions = (gene_variances / sum(gene_variances))
            relevant_genes = list(gene_var_proportions[(gene_var_proportions.cumsum() < hvg_var_prop)].index)
            cells_pivot.iloc[:, 4:] = cells_pivot.iloc[:, 4:][[relevant_genes]]
            data = cells_pivot.iloc[:, 4:]

        elif likelihood_mode == "ALL":
            data = cells_pivot.iloc[:, 4:]

        spatial_locations = cells_pivot[["x_location", "y_location"]]


    return data, spatial_locations, clustering # the last one is to regain var/obs access from original data

## Gene Set Enrichment Analysis (GSEA)

Intro: https://www.youtube.com/watch?v=egO7Lt92gDY

Results Explanation: https://www.youtube.com/watch?v=Yi4d7JIlAsM

Differential Expression: https://www.youtube.com/watch?v=wIvxFEMQVwg 

Assume n total cells
* Calculate the total number of UMIs in each cell
counts_per_cell: n values
* Calculate a size factor for each cell by dividing the cell's total UMI count by the median of those n counts_per_cell
counts_per_cell / median(counts_per_cell): n values
* Calculate a size factor for each cluster by summing the size factors of each cell in that cluster.
* Normalize the UMI counts for each gene in each cluster by dividing by the size factor for that cluster
* Calculate fold change per gene by dividing the normalized total UMI counts for that gene in cluster1 by cluster2

#### Introduce a pseudocount into log2(fold_change)
'log2_fold_change': np.log2((1+gene_sums_a)/(1+size_factor_a)) -  np.log2((1+gene_sums_b)/(1+size_factor_b))

In [3]:
SPOT_SIZE = 50
LIKELIHOOD_MODE="PCA"
NUM_PCS=5
HVG_VAR_PROP=0.9
MARKER_GENES = ["BANK1", "CEACAM6", "FASN", "FGL2", "IL7R", "KRT6B", "POSTN", "TCIM"]

gene_data, spatial_locations, original_adata = prepare_Xenium_data(
    dataset="hBreast", 
    spots=True, 
    spot_size=SPOT_SIZE, 
    third_dim=False, 
    log_normalize=False, 
    likelihood_mode=LIKELIHOOD_MODE, 
    num_pcs=NUM_PCS,
    hvg_var_prop=HVG_VAR_PROP,
    min_expressions_per_spot=0
)

Loading data.
Number of spots:  23444
Number of spots after filtering:  23444


In [5]:
sample_clustering = pd.read_csv("results/hBreast/BayXenSmooth/clusters/PCA/25/INIT=K-Means/NEIGHBORSIZE=2/NUMCLUSTERS=17/SPATIALINIT=False/SAMPLEFORASSIGNMENT=False/SPATIALNORM=0.0/SPATIALPRIORMULT=1.0/SPOTSIZE=50/AGG=mean/clusters_K=17.csv", index_col=0)

original_adata.xenium_spot_data.obs["cluster"] = sample_clustering.values
original_adata.xenium_spot_data.obs

Unnamed: 0,spot_number,x_location,y_location,z_location,row,col,cluster
0,43,548.426080,2218.907850,12.599341,43.0,0.0,9
1,45,550.663970,2337.095433,19.034785,45.0,0.0,9
2,46,552.062194,2358.384500,16.009149,46.0,0.0,9
3,47,551.965956,2423.005319,16.599941,47.0,0.0,3
4,48,546.802622,2483.570972,15.898581,48.0,0.0,1
...,...,...,...,...,...,...,...
23439,33366,10015.713547,5868.944360,27.402354,116.0,190.0,11
23440,33367,10018.078500,5910.929274,27.193971,117.0,190.0,9
23441,33368,10017.667292,5977.355553,27.780879,118.0,190.0,9
23442,33369,10015.040000,6039.822000,30.913712,119.0,190.0,9


In [6]:
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats

from gseapy.plot import gseaplot
import gseapy as gp

In [48]:
dds = DeseqDataSet(adata=original_adata.xenium_spot_data, design_factors='cluster')
dds.deseq2()

Fitting size factors...
Fitting dispersions...
... done in 0.58 seconds.

Fitting MAP dispersions...
... done in 0.70 seconds.



In [None]:
stat_res = DeseqStats(original_adata, contrast = ('cluster', '1', '2'))

AttributeError: 'XeniumCluster' object has no attribute 'varm'

In [None]:
pre_res = gp.prerank()

In [None]:
out = []

for term in list(pre_res.results):
    out.append([term,
               pre_res.results[term]['fdr'],
               pre_res.results[term]['es'],
               pre_res.results[term]['nes']])

out_df = pd.DataFrame(out, columns = ['Term','fdr', 'es', 'nes']).sort_values('fdr').reset_index(drop = True)
out_df

In [None]:
term_to_graph = out_df.iloc[0].Term
term_to_graph

NameError: name 'out_df' is not defined

In [None]:
gseaplot(pre_res.ranking, term = term_to_graph, **pre_res.results[term_to_graph])

NameError: name 'pre_res' is not defined