# How can frog grow its tail back?


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%capture
%pip install anndata
%pip install scanpy
%pip install igraph
%pip install umap-learn==0.5.8
%pip install louvain
%pip install leidenalg
%pip install magic-impute
%pip install pyALRA

In [None]:
import scanpy as sc
import igraph as ig
import umap
import magic
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.sparse import csr_matrix
from sklearn.mixture import GaussianMixture
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, adjusted_rand_score, davies_bouldin_score, calinski_harabasz_score

In [None]:
data = sc.read_h5ad('/content/drive/MyDrive/5243_projects/cleaned_processed_frogtail.h5ad')
data.X.todense()

## Starter step 1: replicate 1B

In [None]:
sc.pp.normalize_total(data, target_sum=1e4)

hvg_params = {
    'min_mean_quantile': 0.05,
    'max_mean_quantile': 0.80,
    'min_fano_quantile': 0.65,
}

gene_means = np.asarray(data.X.mean(axis=0)).flatten()
gene_vars = np.asarray(data.X.var(axis=0, ddof=0)).flatten()

gene_fano = np.full_like(gene_means, np.nan)
np.divide(gene_vars, gene_means, where=gene_means > 0, out=gene_fano)

mean_thresh_low = np.quantile(gene_means, hvg_params['min_mean_quantile'])
mean_thresh_high = np.quantile(gene_means, hvg_params['max_mean_quantile'])
fano_thresh = np.quantile(gene_fano[np.isfinite(gene_fano)], hvg_params['min_fano_quantile'])

is_in_mean_range = (gene_means >= mean_thresh_low) & (gene_means <= mean_thresh_high)
has_high_fano = gene_fano >= fano_thresh
data.var['is_hvg'] = is_in_mean_range & has_high_fano

clustering_params = {
    'n_neighbors': 10,
    'walktrap_steps': 10,
    'umap_seed': 3,
    'clustering_seed': 8,
}

data_hvg = data[:, data.var['is_hvg']].copy()
sc.pp.log1p(data_hvg, base=2)

knn_graph, _, _ = umap.umap_.fuzzy_simplicial_set(
    X=data_hvg.X,
    n_neighbors=clustering_params['n_neighbors'],
    random_state=clustering_params['umap_seed'],
    metric='cosine'
)

igraph_graph = ig.Graph.Weighted_Adjacency(knn_graph.toarray(), mode='undirected')
np.random.seed(clustering_params['clustering_seed'])
walktrap_partition = igraph_graph.community_walktrap(weights='weight', steps=clustering_params['walktrap_steps'])
cluster_labels = walktrap_partition.as_clustering().membership

data.obs['walktrap_labels'] = pd.Series(cluster_labels, index=data.obs_names, dtype="category")

umap_viz_params = {
    'n_neighbors': 20,
    'min_dist': 0.5,
    'random_state': 32,
    'metric': 'cosine'
}

sc.pp.neighbors(
    data_hvg,
    n_neighbors=umap_viz_params['n_neighbors'],
    random_state=umap_viz_params['random_state'],
    metric=umap_viz_params['metric'],
    use_rep='X'
)
sc.tl.umap(
    data_hvg,
    min_dist=umap_viz_params['min_dist'],
    random_state=umap_viz_params['random_state']
)

data.obsm['X_umap'] = data_hvg.obsm['X_umap']

colors_df = pd.read_csv('cluster_color_hex.csv')
color_map = colors_df.set_index('names')['cols'].to_dict()

observed_categories = data.obs['cluster'].cat.categories
custom_palette = [color_map.get(cat, '#808080') for cat in observed_categories]

plot_kwargs = {
    'color': 'cluster',
    'palette': custom_palette,
    'frameon': False,
    's': 12,
    'legend_loc': 'right margin',
    'title': 'UMAP replication of figure 1B'
}

sc.pl.umap(data, **plot_kwargs)

In [None]:
preprocessing_params = {
    'counts_layer_name': 'raw_counts',
    'min_cells_per_gene': 3,
    'min_genes_per_cell': 200,
}

data.layers[preprocessing_params['counts_layer_name']] = data.X.copy()

sc.pp.normalize_total(data, target_sum=1e4)
sc.pp.log1p(data)

sc.pp.filter_genes(data, min_cells=preprocessing_params['min_cells_per_gene'])
sc.pp.filter_cells(data, min_genes=preprocessing_params['min_genes_per_cell'])

feature_selection_params = {
    'n_top_genes': 2300
}
pca_params = {
    'n_components': 50,
    'scale_max_value': 10
}

data.raw = data

sc.pp.highly_variable_genes(data, n_top_genes=feature_selection_params['n_top_genes'])
sc.pl.highly_variable_genes(data)

data_hvg = data[:, data.var['highly_variable']].copy()

sc.pp.scale(data_hvg, max_value=pca_params['scale_max_value'])
sc.pp.pca(data_hvg, n_comps=pca_params['n_components'])

sc.pl.pca_variance_ratio(data_hvg, n_pcs=pca_params['n_components'], log=False)

neighbors_params = {
    'n_neighbors': 15,
    'n_pcs': pca_params['n_components']
}

sc.pp.neighbors(
    data_hvg, 
    n_neighbors=neighbors_params['n_neighbors'], 
    n_pcs=neighbors_params['n_pcs']
)

ground_truth_labels = data_hvg.obs['cluster']
target_cluster_count = data_hvg.obs['leiden_clusters'].nunique()

resolutions_to_test = np.arange(1.60, 2.51, 0.01)

pca_embedding = data_hvg.obsm['X_pca']

### PCA + Leiden

In [None]:
optimization_params = {
    'target_cluster_count': 46,
    'random_seed': 5243,
    'silhouette_metric': 'euclidean',
}

valid_resolutions = []

data_hvg_for_testing = data_hvg.copy()

for res in resolutions_to_test:
    current_leiden_key = f'leiden_res_{res:.2f}'
    
    sc.tl.leiden(
        data_hvg_for_testing, 
        resolution=res, 
        key_added=current_leiden_key, 
        random_state=optimization_params['random_seed']
    )

    cluster_labels = data_hvg_for_testing.obs[current_leiden_key]
    num_clusters_generated = cluster_labels.nunique()
    
    if num_clusters_generated == optimization_params['target_cluster_count']:
        silhouette_avg = silhouette_score(
            pca_embedding,
            cluster_labels, 
            metric=optimization_params['silhouette_metric'],
            random_state=optimization_params['random_seed']
        )
        valid_resolutions.append({'resolution': res, 'silhouette': silhouette_avg})

if valid_resolutions:
    results_df = pd.DataFrame(valid_resolutions)
    
    best_result = results_df.loc[results_df['silhouette'].idxmax()]
    best_resolution = best_result['resolution']
    best_silhouette = best_result['silhouette']
    
    print(f"Found best resolution: {best_resolution:.2f} with silhouette score: {best_silhouette:.4f}")

    sc.tl.leiden(
        data_hvg, 
        resolution=best_resolution, 
        key_added='leiden_clusters', 
        random_state=optimization_params['random_seed']
    )
else:
    print(f"Warning: No resolution found that produces exactly {optimization_params['target_cluster_count']} clusters.")
sc.tl.umap(data_hvg, random_state=optimization_params['random_seed'])

umap_plot_settings = {
    'color': 'leiden_clusters',
    'title': 'Leiden Clustering',
    'frameon': False,
    'legend_loc': 'on data',
    'legend_fontsize': 8,
}

sc.pl.umap(data_hvg, **umap_plot_settings)

### PCA + GMM

In [None]:
gmm_params = {
    'n_components': target_cluster_count,
    'covariance_type': 'full',
    'random_state': optimization_params['random_seed'],
    'reg_covar': 1e-5
}

gmm_model = GaussianMixture(**gmm_params)

gmm_cluster_labels = gmm_model.fit_predict(pca_embedding)

data_hvg.obs['gmm_clusters'] = pd.Series(
    gmm_cluster_labels,
    index=data_hvg.obs_names,
    dtype='str'
).astype('category')

gmm_plot_settings = {
    'color': 'gmm_clusters',
    'title': 'GMM Clustering',
    'frameon': False,
    'legend_loc': 'on data',
    'legend_fontsize': 8,
}

sc.pl.umap(data_hvg, **gmm_plot_settings)

### PCA + K-Means

In [None]:
kmeans_params = {
    'n_clusters': target_cluster_count,
    'random_state': optimization_params['random_seed'],
    'n_init': 'auto'
}

kmeans_model = KMeans(**kmeans_params)

kmeans_cluster_labels = kmeans_model.fit_predict(pca_embedding)

data_hvg.obs['kmeans_clusters'] = pd.Series(
    kmeans_cluster_labels,
    index=data_hvg.obs_names,
    dtype='str'
).astype('category')

kmeans_plot_settings = {
    'color': 'kmeans_clusters',
    'title': 'K-Means Clustering',
    'frameon': False,
    'legend_loc': 'on data',
    'legend_fontsize': 8,
}

sc.pl.umap(data_hvg, **kmeans_plot_settings)

### Comparing Metrics

In [None]:
def calculate_clustering_scores(adata, embedding, true_labels, method_name):
    predicted_labels_col = f'{method_name}_clusters'
    
    if predicted_labels_col not in adata.obs:
        print(f"Warning: Column '{predicted_labels_col}' not found in adata.obs. Skipping '{method_name}'.")
        return None
        
    predicted_labels = adata.obs[predicted_labels_col]
    
    scores = {
        'ARI': adjusted_rand_score(true_labels, predicted_labels),
        'Silhouette': silhouette_score(embedding, predicted_labels),
        'Davies-Bouldin': davies_bouldin_score(embedding, predicted_labels),
        'Calinski-Harabasz': calinski_harabasz_score(embedding, predicted_labels)
    }
    
    return scores

clustering_methods_to_evaluate = ["leiden", "gmm", "kmeans", "louvain"]
evaluation_results = {}

for method in clustering_methods_to_evaluate:
    metric_scores = calculate_clustering_scores(
        adata=data_hvg,
        embedding=pca_embedding,
        true_labels=ground_truth_labels,
        method_name=method
    )
    
    if metric_scores:
        evaluation_results[method] = metric_scores

if evaluation_results:
    metrics_df = pd.DataFrame(evaluation_results).T
    display(metrics_df)
else:
    print("No clustering results were evaluated.")

## Marker Identification

In [None]:
epidermis_analysis_params = {
    'gene_prefix': 'krt',
    'score_name': 'keratin_score',
    'quantile_threshold': 0.85,
    'random_seed': 0,
    'epidermis_col': 'is_epidermis',
    'condition_group_col': 'condition_group',
    'ignore_label': 'ignore'
}

def assign_condition_groups(obs_df, epidermis_col, ignore_label):

    conditions = [
        (obs_df["Condition"] == "ST46_0") & (obs_df[epidermis_col] == "True"),
        (obs_df["Condition"] == "ST40_0") & (obs_df[epidermis_col] == "True"),
        (obs_df["Condition"] == "ST46_1") & (obs_df[epidermis_col] == "True"),
        (obs_df["Condition"] == "ST40_1") & (obs_df[epidermis_col] == "True"),
    ]
    choices = ["inc_int", "com_int", "inc_amp", "com_amp"]
    
    return np.select(conditions, choices, default=ignore_label)

def analyze_cluster_distribution(adata, cluster_col_name, params):

    adata_processed = adata.copy()

    all_genes = adata.raw.var_names
    gene_list = [
        gene for gene in all_genes 
        if gene.lower().startswith(params['gene_prefix'])
    ]
    sc.tl.score_genes(
        adata_processed,
        gene_list=gene_list,
        score_name=params['score_name'],
        use_raw=True,
        random_state=params['random_seed']
    )

    cluster_mean_scores = adata_processed.obs.groupby(cluster_col_name, observed=True)[params['score_name']].mean()
    score_threshold = cluster_mean_scores.quantile(params['quantile_threshold'])
    epidermis_clusters = cluster_mean_scores[cluster_mean_scores > score_threshold].index.tolist()

    adata_processed.obs[params['epidermis_col']] = "False"
    adata_processed.obs.loc[adata_processed.obs[cluster_col_name].isin(epidermis_clusters), params['epidermis_col']] = "True"
    
    adata_processed.obs[params['condition_group_col']] = assign_condition_groups(
        adata_processed.obs,
        params['epidermis_col'],
        params['ignore_label']
    )

    adata_filtered = adata_processed[adata_processed.obs[params['condition_group_col']] != params['ignore_label'], :].copy()
    
    crosstab_counts = pd.crosstab(
        adata_filtered.obs[cluster_col_name],
        adata_filtered.obs[params['condition_group_col']]
    )
    crosstab_proportions = crosstab_counts.div(crosstab_counts.sum(axis=0), axis=1)
    
    return crosstab_proportions

clustering_methods_to_analyze = ["louvain_clusters", "leiden_clusters", "gmm_clusters", "kmeans_clusters"]

heatmap_settings = {
    "annot": True,
    "fmt": ".2f",
    "cmap": "viridis",
    "linewidths": 0.5
}

for method in clustering_methods_to_analyze:
    print(f"--- Analyzing: {method} ---")
    if method in data_hvg.obs.columns:
        proportions_df = analyze_cluster_distribution(
            adata=data_hvg,
            cluster_col_name=method,
            params=epidermis_analysis_params
        )

        plt.figure(figsize=(8, 10))
        plt.title(f"Proportional Distribution for {method}")
        sns.heatmap(proportions_df, **heatmap_settings)
        plt.show()
    else:
        print(f"Warning: Column '{method}' not found in data_hvg.obs. Skipping.")

### T-test

In [None]:
roc_cluster_ids = {
    "louvain": "19",
    "leiden": "15",
    "gmm": "16",
    "kmeans": "16"
}

marker_analysis_params = {
    'statistical_tests': ['t-test', 'wilcoxon'],
    'n_top_genes': 50
}

heatmap_plot_params = {
    'dendrogram': False,
    'vmax': 5,
    'n_genes': 25
}

def find_marker_genes(adata, cluster_col, target_group, methods_to_run, n_genes):

    marker_gene_dfs = {}
    
    for stat_method in methods_to_run:
        results_key = f"{cluster_col.replace('_clusters', '')}_{stat_method}_ranks"
        
        sc.tl.rank_genes_groups(
            adata, 
            groupby=cluster_col, 
            groups=[target_group], 
            method=stat_method,
            key_added=results_key,
            n_genes=n_genes
        )
        
        df = sc.get.rank_genes_groups_df(
            adata, 
            group=target_group,
            key=results_key
        )
        marker_gene_dfs[stat_method] = df
        
    return marker_gene_dfs

adata_for_markers = data_hvg.copy()

all_marker_results = {}

for method_name, cluster_id in roc_cluster_ids.items():
    print(f"--- Finding Markers for: {method_name.capitalize()} (Cluster {cluster_id}) ---")
    
    cluster_col_name = f"{method_name}_clusters"
    
    marker_dfs = find_marker_genes(
        adata=adata_for_markers,
        cluster_col=cluster_col_name,
        target_group=cluster_id,
        methods_to_run=marker_analysis_params['statistical_tests'],
        n_genes=marker_analysis_params['n_top_genes']
    )
    
    all_marker_results[method_name] = marker_dfs
    
    for stat_test in marker_analysis_params['statistical_tests']:
        plot_key = f"{method_name}_{stat_test}_ranks"
        plot_title = f"{method_name.capitalize()} Markers (Cluster {cluster_id}, {stat_test})"
        
        sc.pl.rank_genes_groups_heatmap(
            adata_for_markers, 
            key=plot_key, 
            show=False,
            title=plot_title,
            **heatmap_plot_params
        )
        plt.show()

In [None]:
ground_truth_marker_params = {
    'cluster_col': 'cluster',
    'target_group': 'ROCs', 
}

print(f"--- Finding Markers for Ground Truth: '{ground_truth_marker_params['target_group']}' cluster ---")

ground_truth_marker_dfs = find_marker_genes(
    adata=adata_for_markers,
    cluster_col=ground_truth_marker_params['cluster_col'],
    target_group=ground_truth_marker_params['target_group'],
    methods_to_run=marker_analysis_params['statistical_tests'],
    n_genes=marker_analysis_params['n_top_genes']
)

for stat_test in marker_analysis_params['statistical_tests']:
    plot_key = f"{ground_truth_marker_params['cluster_col']}_{stat_test}_ranks"
    plot_title = f"Ground Truth 'ROCs' Markers ({stat_test.capitalize()})"

    sc.pl.rank_genes_groups_heatmap(
        adata_for_markers,
        key=plot_key,
        title=plot_title,
        show=False,
        **heatmap_plot_params
    )
    plt.show()

print("Top 5 markers from t-test:")
display(ground_truth_marker_dfs['t-test'].head())

print("\nTop 5 markers from Wilcoxon test:")
display(ground_truth_marker_dfs['wilcoxon'].head())

In [None]:
def calculate_marker_match_stats(marker_df, reference_genes):
    if 'names' not in marker_df.columns:
        raise ValueError("Input DataFrame must contain a 'names' column.")
    is_match = marker_df['names'].isin(reference_genes)
    
    match_count = is_match.sum()
    match_percentage = match_count / len(marker_df) if len(marker_df) > 0 else 0
    
    stats = {
        'match_count': match_count,
        'match_percentage': match_percentage
    }
    return stats

reference_marker_genes_series = pd.read_csv("reference_marker_genes.csv", header=None).squeeze("columns")
reference_gene_set = set(reference_marker_genes_series)
marker_overlap_stats = {}

for method, df in ground_truth_marker_dfs.items():
    stats = calculate_marker_match_stats(df, reference_gene_set)
    marker_overlap_stats[method] = stats
    
    print(f"--- Results for {method.capitalize()} ---")
    print(f"Matching Genes Found: {stats['match_count']}")
    print(f"Match Percentage: {stats['match_percentage']:.2%}")
    print("-" * 25)

overlap_stats_df = pd.DataFrame.from_dict(marker_overlap_stats, orient='index')
display(overlap_stats_df)

In [None]:
final_overlap_percentages = {}
for method_name, marker_data in all_marker_results.items():
    method_stats = {}
    for stat_test, marker_df in marker_data.items():
        stats = calculate_marker_match_stats(marker_df, reference_gene_set)
        method_stats[stat_test] = stats['match_percentage']
    
    final_overlap_percentages[method_name] = method_stats

reference_stats = {}
for stat_test, marker_df in ground_truth_marker_dfs.items():
    stats = calculate_marker_match_stats(marker_df, reference_gene_set)
    reference_stats[stat_test] = stats['match_percentage']

final_overlap_percentages['reference'] = reference_stats

comparison_df = pd.DataFrame.from_dict(final_overlap_percentages, orient='index')

comparison_df.rename(columns=lambda c: f"{c.capitalize()} Overlap", inplace=True)

print("Comparison of Marker Gene Overlap Across All Methods")
display(comparison_df.style.format("{:.2%}"))