In [50]:
import scanpy as sc
#from .autonotebook import tqdm as notebook_tqdm
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [53]:
import os

# Define the directory and filename suffix
folder = "saturn_results"
suffix = "_saturn_seed_0.h5ad"

# Search for the file with the matching suffix
for filename in os.listdir(folder):
    if filename.endswith(suffix):
        h5ad_file = os.path.join(folder, filename)
        break
else:
    raise FileNotFoundError(f"No file ending with '{suffix}' found in {folder}")


In [54]:
atlas_ad = sc.read_h5ad(h5ad_file)

In [None]:
atlas_ad.obs.head()

# Data preprocessing

In [None]:
sc.pp.pca(atlas_ad)
sc.pp.neighbors(atlas_ad)

In [None]:
sc.tl.umap(atlas_ad, n_components=2)

In [None]:
## Clustering
sc.tl.leiden(atlas_ad, resolution=0.1)

In [None]:
# Save result
atlas_ad.write_h5ad("output/atlas.h5ad")

# Visualize data distribution

In [None]:
sc.set_figure_params(dpi=72, color_map = 'viridis_r',figsize=[8,8] )
sc.settings.verbosity = 1
sc.logging.print_header()

## By species

In [None]:
sc.pl.umap(atlas_ad, color="species", projection="2d", palette='Set1')

## Visualize each species individually

In [None]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import math

# Get all unique species
species = atlas_ad.obs['species'].unique()

# Calculate number of rows and columns
num_species = len(species)
num_cols = 4
num_rows = math.ceil(num_species / num_cols)

# Open a PDF file
with PdfPages('output/species_plots.pdf') as pdf:
    # Plot each species one by one
    for i in range(0, num_species, num_cols * num_rows):
        # Create a new figure and set up the subplot layout
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 3))
        axes = axes.flatten()  # Flatten the 2D axes array into 1D

        # Set color mapping and default colormap
        sc.set_figure_params(color_map='Set3')
        for j, highlight_species in enumerate(species[i:i + num_cols * num_rows]):
            color_map = {}
            for sp in species:
                if sp == highlight_species:
                    color_map[sp] = "red"
                else:
                    color_map[sp] = "gray"
            # Plot UMAP for the species
            sc.pl.umap(atlas_ad, color='species', title=f'{highlight_species}', palette=color_map, legend_loc=None, show=False, ax=axes[j])

        # Hide unused subplots
        for k in range(j + 1, num_rows * num_cols):
            axes[k].axis('off')

        # Adjust layout and save the page
        plt.tight_layout()
        pdf.savefig(fig)  # Save current page
        plt.close(fig)  # Close the figure to free memory


## By original labels

In [None]:
sc.pl.umap(atlas_ad, color="labels2")

## Visualize each cluster individually

In [None]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import math

# Get all clusters and species
clusters = atlas_ad.obs['labels'].unique()
species = atlas_ad.obs['species'].unique()

# Get the list of clusters to be plotted
all_clusters = clusters.to_list()

# Calculate number of rows and columns
num_clusters = len(all_clusters)
num_cols = 4
num_rows = math.ceil(num_clusters / num_cols)

# Open a PDF file
with PdfPages('output/cluster_plots.pdf') as pdf:
    # Plot each cluster one by one
    for i in range(0, num_clusters, num_cols * num_rows):
        # Create a new figure and set up the subplot layout
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 3))
        axes = axes.flatten()  # Flatten the 2D axes array into 1D

        # Set color mapping and default colormap
        sc.set_figure_params(color_map='Set3')
        for j, highlight_cluster in enumerate(all_clusters[i:i + num_cols * num_rows]):
            color_map = {}
            for cluster in all_clusters:
                if cluster == highlight_cluster:
                    color_map[cluster] = "red"
                else:
                    color_map[cluster] = "gray"
            # Plot UMAP for the cluster
            sc.pl.umap(atlas_ad, color='labels', title=f'{highlight_cluster}', palette=color_map, legend_loc=None, show=False, ax=axes[j])

        # Hide unused subplots
        for k in range(j + 1, num_rows * num_cols):
            axes[k].axis('off')

        # Adjust layout and save the page
        plt.tight_layout()
        pdf.savefig(fig)  # Save current page
        plt.close(fig)  # Close the figure to free memory


## Species composition in different clusters

In [None]:
sc.pl.umap(adata=atlas_ad, color='leiden', legend_loc='on data')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Convert data to a DataFrame
data = pd.DataFrame(atlas_ad.obs)

# Count the number of each species in each cluster
count_table = data.groupby(['leiden', 'species']).size().unstack(fill_value=0)

# Convert counts to proportions
proportion_table = count_table.div(count_table.sum(axis=1), axis=0)

# Plot stacked bar chart
ax = proportion_table.plot(kind='bar', stacked=True, figsize=(10, 6), width=.8)

# Set title and labels
ax.set_title('Proportion of Species in Each Cluster')
ax.set_xlabel('Cluster')
ax.set_ylabel('Proportion')
ax.grid(False)

# Show legend
ax.legend(title='Species', bbox_to_anchor=(1.05, 1), loc='upper left')

# Display the plot
plt.tight_layout()
plt.show()


# Macrogene differential expression

In [126]:
import pickle

In [127]:
import os
import pickle

# Define the directory and suffix pattern
folder = "saturn_results"
suffix = "_saturn_seed_0_genes_to_macrogenes.pkl"

# Search for the file with the matching suffix
for filename in os.listdir(folder):
    if filename.endswith(suffix):
        filepath = os.path.join(folder, filename)
        with open(filepath, "rb") as f:
            macrogene_weights = pickle.load(f)
        break
else:
    raise FileNotFoundError(f"No file ending with '{suffix}' found in {folder}")


In [None]:
# macrogene weights is a dictionary of (species_{gene name}) : [gene to macrogen weight](1x2000)
len(macrogene_weights)

In [104]:
# Create a copy of the adata with macrogenes as the X values
macrogene_adata = sc.AnnData(atlas_ad.obsm["macrogenes"])
macrogene_adata.obs = atlas_ad.obs


In [None]:
# Rows are cells, columns are macrogenes, each value corresponds to the gene weight
macrogene_adata.shape

In [None]:
macrogene_adata.obs

In [134]:
# Differential analysis based on specified group, eg. 11
sc.tl.rank_genes_groups(macrogene_adata, groupby="leiden", groups=["11"], method="wilcoxon")

  return reduction(axis=axis, out=out, **passkwargs)


In [None]:
sc.pl.rank_genes_groups(macrogene_adata)

In [None]:
sc.pl.rank_genes_groups_dotplot(macrogene_adata,swap_axes=True)

In [None]:
de_df = sc.get.rank_genes_groups_df(macrogene_adata, group="11").head(20)
de_df

In [137]:
def get_scores(macrogene):
    '''
    Given the index of a macrogene, return the scores by gene for that centroid
    '''
    scores = {}
    for (gene), score in macrogene_weights.items():
        scores[gene] = score[int(macrogene)]
    return scores

In [138]:

# get macrogene , e.g 891
macrogene = 891
df = pd.DataFrame(get_scores(macrogene).items(), columns=["gene", "weight"])\
        .sort_values("weight", ascending=False)

In [None]:
for macrogene in de_df["names"]:
    print(f"Macrogene {macrogene}")
    df = pd.DataFrame(get_scores(macrogene).items(), columns=["gene", "weight"])\
            .sort_values("weight", ascending=False)
    #df.reset_index(inplace=True)
    # get the rank of the gene in df and print it
    #print(df[df["gene"] == gene].index[0])
    display(df.head(20))

