In [None]:
import scanpy as sc 
import anndata as ad
import polars as pl

In [None]:
import os

In [None]:
import re
def is_meta_column(c):
    for ex in '''
        Metadata
        ^Count
        ImageNumber
        Object
        Parent
        Children
        Plate
        Well
        Location
        _[XYZ]_
        _[XYZ]$
        Phase
        Scale
        Scaling
        BoundingBox
        Width
        Height
        Group
        FileName
        PathName
        URL
        Execution
        ModuleError
        LargeBrightArtefact
        label
    '''.split():
        if re.search(ex, c):
            return True
    return False

In [None]:
meta_cols_specs_moa = [
            'Metadata_Plate',
            'Metadata_Well',
            'Metadata_Site',
            'Metadata_cmpdName',
            'compound_name',
            'Compound',
            'moa_broad',
            'target',
            'moa',
            'Synonyms',
            'CAS No.',
            'M.Wt',
            'Information',
            'Formula',
            'Smiles',
            'Solubility',
            'URL',
            'Pathway',
            'concentration_uM',
            'grit_score',
            'flag',
            'secondary_target',
            'SPECS_name',
            'BatchID',
            'SPECS_moa',
            'SPECS_target',
            'SPECS_name2',
            'grit',
            'count',
            'smiles',
            'inchi',
            'inkey',
            'compound_name_right',
            'label',
            'project']

In [None]:
meta_cols_specs_moa_df = ['Metadata_cmpdName',
 'moa',
 'Metadata_Plate',
 'Metadata_Well',
 'Metadata_Site',
 'compound_name',
 'Nuclei_Location_Center_X',
 'Nuclei_Location_Center_Y',
 'project',
 "label"]

In [None]:
PROJECT_DIR = "/home/jovyan/share/data/analyses/benjamin/Single_cell_supervised"
ROOT_DIR = os.getcwd()

In [None]:
import polars as pl 
import pandas as pd
sc_profiles = pl.read_parquet(os.path.join(PROJECT_DIR, 'BF_MOA/CellProfiler/datasets/specs5k_undersampled_significant_CP_BF.parquet'))

In [None]:
sc_profiles = sc_profiles.drop("AreaShape_FormFactor_nuclei")

In [None]:
features_fixed = [col for col in sc_profiles.columns if col not in meta_cols_specs_moa_df]

In [None]:
na_counts = []

# Iterate through each column, checking if it's numeric and counting NaN values if so
for col_name in sc_profiles[features_fixed].columns:
    if sc_profiles[features_fixed][col_name].dtype in [pl.Float32, pl.Float64]:
        na_count = sc_profiles[features_fixed][col_name].is_nan().sum()
        na_counts.append((col_name, na_count))

# Convert the list of tuples to a DataFrame
na_summary_df = pl.DataFrame(na_counts)
na_summary_df = na_summary_df.sort("column_1", descending=True)

print(na_summary_df)

## Run scanpy analysis

In [None]:
adata = ad.AnnData(X = sc_profiles[features_fixed].to_pandas().astype('float32'), obs = sc_profiles[meta_cols_specs_moa_df].to_pandas())

In [None]:
import numpy as np
nan_mask = np.isnan(adata.X)
nan_indices = np.where(nan_mask)

# Assuming column names are stored in adata.var_names
column_names = np.array(adata.var_names)[nan_indices[1]]

# Creating a Polars DataFrame
df_nans = pl.DataFrame({
        "Row_Index": nan_indices[0],
        "Column_Index": nan_indices[1],
        "Column_Name": column_names
    })

print(df_nans)

In [None]:
def run_scanpy(adata):
    sc.tl.pca(adata, svd_solver='arpack', n_comps= 50)
    sc.pp.neighbors(adata, n_neighbors=10, n_pcs=50)
    sc.tl.paga(adata, groups = "Metadata_cmpdName")
    sc.pl.paga(adata, plot=False)  # remove `plot=False` if you want to see the coarse-grained graph
    sc.tl.umap(adata, init_pos='random')
    #sc.tl.leiden(adata, key_added='clusters', resolution=0.2)

In [None]:
run_scanpy(adata)

In [None]:
sc.pl.pca_variance_ratio(adata)
sc.pl.pca_loadings(adata, components = '1,2,3,4,5')
sc.pl.pca(adata, color = "moa")
sc.pl.umap(adata, color = "moa")

In [None]:
adata.write_h5ad("data/sc_embedding_BF_undersampled_sign_CP.h5ad")

In [None]:
# Assuming 'adata' is your AnnData object
# Extract unique categories excluding 'dmso'
categories = adata.obs['moa'].unique().tolist()
categories.remove('dmso')  # Remove 'dmso' to handle it separately

# Divide categories into two groups (example based on alphabetical order or any other criterion)
half = len(categories) // 2
group1 = categories[:half]
group2 = categories[half:]
group1.append('dmso')
group2.append('dmso')

In [None]:
def generate_density_plots(adata, basis, group_categories, plot_key_prefix):
    # Create a temporary column for grouping
    temp_group_col = 'temp_group'
    adata.obs[temp_group_col] = adata.obs['moa'].apply(lambda x: x if x in group_categories else None)
    
    # Generate and plot density
    sc.tl.embedding_density(adata, basis=basis, groupby=temp_group_col)
    sc.pl.embedding_density(adata, basis=basis, key=f'{basis}_density_{temp_group_col}', 
                            save=f"moa/sc_BF_sign_{plot_key_prefix}_density_{basis}.png")
    
    # Clean up temporary column
    del adata.obs[temp_group_col]

In [None]:
generate_density_plots(adata, 'umap', group2, 'group2')

In [None]:
mean_per_col = adata_copy.X.mean(axis=0)
col_remove = np.where(mean_per_col > 1)[0]
new_adata = adata_copy[:, ~adata.var.index.isin(col_remove)]
new_adata = new_adata[~new_adata.obs.index.isin(col_remove)]

In [None]:
mean_per_col = adata_copy.X.mean(axis=0)
bigger = mean_per_col > 1
smaller = mean_per_col < -1
col_remove = np.where(bigger|smaller)[0]
X = np.delete(adata_copy.X, col_remove, axis = 1)
var_names = np.delete(adata_copy.var_names, col_remove)


In [None]:
col_remove

In [None]:
run_scanpy_debug(testing_adata)

In [None]:
sc.pl.pca_variance_ratio(adata)
sc.pl.pca_loadings(adata, components = '1,2,3,4,5')

In [None]:
sc.pl.pca(adata, color = "moa_broad")
sc.pl.umap(adata, color = "Metadata_cmpdName")

In [None]:
adata.write('grit_reference_locations_cellprofiler_test.h5ad')


In [None]:
adata = ad.read_h5ad("data/sc_features_SPECS3k_ref_cellprofiler.h5ad")

In [None]:
adata.obs

## Load in cellxgene embeddings

In [None]:
CXG_DIR = "/home/jovyan/share/data/analyses/benjamin/cellxgene/embeddings"

In [None]:
def fix_keys(adata):
    def find_key_with_substring(obsm, substring):
        for key in obsm.keys():
            if substring in key:
                return key
        return None

    # Find the keys
    pca_key = find_key_with_substring(adata.obsm, 'pca')
    umap_key = find_key_with_substring(adata.obsm, 'dmso')
    if umap_key == None:
        umap_key = find_key_with_substring(adata.obsm, 'emb')

    # Rename the keys if they are found
    if pca_key:
        adata.obsm['X_pca'] = adata.obsm[pca_key]
        #del adata.obsm[pca_key]

    if umap_key:
        adata.obsm['X_umap'] = adata.obsm[umap_key]
        #del adata.obsm[umap_key]

    return adata

In [None]:
import tqdm
import os
import anndata as ad
emb_dict = {}
ref_comp = ["berb", "cao", "etop", "fenb", "flup", "tetr", "dmso_only"]
h5ad_files = [file for file in os.listdir(CXG_DIR) if file.endswith(".h5ad")]

for comp in tqdm.tqdm(ref_comp):
   for filename in h5ad_files:
        # Check if the current string is in the filename
        if comp in filename and filename.endswith(".h5ad"):
            # Construct the full file path
            file_path = os.path.join(CXG_DIR, filename)
            # Load the .h5ad file
            temp = ad.read_h5ad(file_path)
            temp_fix = fix_keys(temp)
            emb_dict[comp] = temp_fix
            # Optional: Print a message
            print(f"Loaded {filename}")

In [None]:
temp = ad.read_h5ad(os.path.join(CXG_DIR, "umap.h5ad"))
temp_fix = fix_keys(temp)
emb_dict["all"] = temp_fix

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors
# Inital setting for plot
from matplotlib import rcParams
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    for key, item in emb_dict.items():
        print(key)
        sc.pl.umap(item, color = "subpopulations",)

In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np

def plot_umap_grid_colored(anndata_dict, color_by, n_cols=3):
    """
    Create a grid of UMAP plots from a dictionary of AnnData objects, colored by a specified column.
    
    Parameters:
    anndata_dict (dict): A dictionary of AnnData objects.
    color_by (str): Column name to color by.
    n_cols (int): Number of columns in the grid.
    """
    # Determine all unique categories across all AnnData objects
    anndata_dict = {k: v for k, v in anndata_dict.items() if k != 'all'}
    all_categories = set()
    for adata in anndata_dict.values():
        all_categories.update(adata.obs[color_by].astype(str))

    # Sort categories for consistent ordering and create color palette
    sorted_categories = sorted(list(all_categories))
    color_palette = sc.pl.palettes.default_20 # Use any large enough palette or define your own
    color_map = {cat: color_palette[i % len(color_palette)] for i, cat in enumerate(sorted_categories)}
    #color_map = {'big_dmso': '#1f77b4', 'small_dmso': '#ff7f0e', 'small_FLUP': '#279e68', 'big_FLUP': '#d62728', 'big_ETOP': '#aa40fc', 'small_ETOP': '#8c564b', 'big_TETR': '#e377c2', 'small_TETR': '#b5bd61', 'small_CA-O': '#17becf', 'big_CA-O': '#aec7e8', 'unassigned': '#ffbb78', 'BERB': '#98df8a', 'FEB': '#ff9896'}
    print(color_map)
    # Set up the figure for subplots
    n_rows = int(np.ceil(len(anndata_dict) / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
    axs = axs.flatten()  # Flatten to make indexing axs easier

    all_handles = []
    all_labels = set()
    
    # Plot UMAP for each AnnData object
    for ax, (key, adata) in zip(axs, anndata_dict.items()):
        sc.pl.umap(adata, color=color_by, ax=ax, show=False, 
                   title=key, frameon=False,
                   palette=color_map,
                   legend_loc = "none")  # Apply the consistent color map

        handles, labels = ax.get_legend_handles_labels()
        all_handles.extend(handles)
        all_labels.update(labels)
        # Remove axis titles (optional, for cleaner look)
        ax.set_xlabel('')
        ax.set_ylabel('')

    # Hide any extra axes
    for i in range(len(anndata_dict), len(axs)):
        axs[i].axis('off')

    # Create an overall title
    plt.suptitle('UMAP Grid', fontsize=16)

    # Add a single legend outside the plots
    # Get handles and labels for legend from the last plot
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=3, bbox_to_anchor=(0.5, 0.01))

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# Example usage:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    plot_umap_grid_colored(emb_dict, "subpopulations")


In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

def plot_single_umap_colored(adata, color_by):
    # Calculate the count of each category in the color_by column
    category_counts = adata.obs[color_by].value_counts()

    # Create a color palette
    #color_map = {'big_dmso': '#1f77b4', 'small_dmso': '#ff7f0e', 'small_FLUP': '#279e68', 'big_FLUP': '#d62728', 'big_ETOP': '#aa40fc', 'small_ETOP': '#8c564b', 'big_TETR': '#e377c2', 'small_TETR': '#b5bd61', 'small_CA-O': '#17becf', 'big_CA-O': '#aec7e8', 'unassigned': '#ffbb78', 'BERB': '#98df8a', 'FEB': '#ff9896'}
    color_map = {'berb': '#1f77b4', 'cao': '#ff7f0e', 'dmso_big': '#279e68', 'dmso_small': '#d62728', 'etop_big': '#aa40fc', 'etop_nocluster': '#8c564b', 'etop_small': '#e377c2', 'fenb': '#b5bd61', 'flup': '#17becf', 'tetr_big': '#aec7e8', 'tetr_nocluster': '#ffbb78'}
    # Create figure and axis for UMAP plot
    fig, ax = plt.subplots(figsize=(8, 6))  # Adjust figure size as needed

    # Create UMAP plot
    sc.pl.umap(adata, color=color_by, ax=ax, show=False,
               title=f'UMAP colored by {color_by}', 
               frameon=False, legend_loc='none', 
               palette=color_map, s = 2)

    # Create a custom legend for all categories with counts
    legend_elements = [Line2D([0], [0], marker='o', color='w',
                              label=f"{cat} (n={category_counts[cat]})",
                              markerfacecolor=color_map[cat], markersize=10)
                       for cat in category_counts.index]

    # Place legend outside the plot to the right
    ax.legend(handles=legend_elements, title=color_by, loc='center left',
              bbox_to_anchor=(1, 0.5), ncol=1, fontsize='x-small')

    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust the rect parameter to make space for the legend
    plt.show()

with plt.rc_context({"figure.figsize": (12, 12), "figure.dpi": (300)}):
    plot_single_umap_colored(emb_dict["all"], "subpopulations")

In [None]:
summary_features = grit_filter_df_sampled_pd[features_fixed].describe(percentiles= [0.05, 0.95, 0.5])
min_of_min = summary_features.loc['min'].min()  # Minimum of the 'min' values
max_of_max = summary_features.loc['max'].max()  # Maximum of the 'max' values
max_of_95th = summary_features.loc['95%'].max()  # Maximum of the '95th percentile' values
min_of_5th = summary_features.loc['5%'].min()  
print("Minimum of 'min' values:", min_of_min)
print("Maximum of 'max' values:", max_of_max)
print("Maximum of '95th percentile' values:", max_of_95th)
print("Minimum of '5th percentile' values:", min_of_5th)

In [None]:
def show_summary_stats(df):
    features = df.columns

# Plotting
    plt.figure(figsize=(12,6))

    # Mean line
    plt.plot(features, df.loc['mean'], label='Mean', color='blue')

    # 5th percentile line
    plt.plot(features, df.loc['5%'], label='5th Percentile', color='green')

    # 95th percentile line
    plt.plot(features, df.loc['95%'], label='95th Percentile', color='red')

    # Max values as dots
    plt.scatter(features, df.loc['max'], color='black', label='Max', s=5)  # s is the size of points
    plt.scatter(features, df.loc['min'], color='grey', label='Min', s=5)

    # Labels and title
    plt.xlabel('Features')
    plt.ylabel('Values')
    plt.title('Feature distributions')
    plt.xticks([])  # Rotate feature names for readability

    # Legend
    plt.legend()

    plt.tight_layout()  # Adjust layout
    plt.show()

In [None]:
show_summary_stats(summary_features)

In [None]:
def plot_grouped_feature_statistics(df, group_column, feature_columns):
    """
    Plot statistical summaries (mean, 5th, 95th percentiles, and max) of features for each group in the DataFrame.
    
    Parameters:
    df (DataFrame): The original pandas DataFrame with data.
    group_column (str): The name of the column to group by.
    feature_columns (list): List of columns to calculate statistics on.
    """
    # Grouping the DataFrame by the specified column
    grouped = df.groupby(group_column)

    # Determine the number of subplots needed
    n_groups = len(grouped)
    n_cols = 1  # You can adjust the number of columns per row
    n_rows = int(np.ceil(n_groups / n_cols))

    # Create a figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15 * n_cols, 10 * n_rows), squeeze=False)
    axes = axes.flatten()  # Flatten to 1D array for easy iteration

    for i, (group_name, group_data) in enumerate(grouped):
        # Calculating statistics for the group
        mean = group_data[feature_columns].mean()
        std = group_data[feature_columns].std()
        min_val = group_data[feature_columns].min()
        max_val = group_data[feature_columns].max()
        percentile_5 = group_data[feature_columns].quantile(0.05)
        percentile_95 = group_data[feature_columns].quantile(0.95)

        # Plotting on the ith subplot
        ax = axes[i]
        ax.plot(feature_columns, mean, label='Mean', color='blue')
        ax.plot(feature_columns, percentile_5, label='5th Percentile', color='green')
        ax.plot(feature_columns, percentile_95, label='95th Percentile', color='red')


        ax.set_title(f'Group: {group_name}')
        ax.set_xticks([])  # Remove x-axis labels

        if i == 0:  # Add legend to the first subplot as an example
            ax.legend()

    # Hide any unused subplots
    for j in range(i+1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

plot_grouped_feature_statistics(grit_filter_df_sampled_pd, "Metadata_cmpdName", features_fixed)

In [None]:
# Get the sorted indices
sorted_indices = emb_dict["all"].obs["subpopulations"].sort_values().index

# Reorder .X and .obs
emb_dict["all"] = emb_dict["all"][sorted_indices]

In [None]:
import anndata
import numpy as np
import scanpy as sc
import scipy
import seaborn as sns
import matplotlib.colors as mcolors
# Inital setting for plot
import matplotlib.pyplot as plt
from matplotlib import rcParams


def plot_clipped_heatmap(adata, max_val=10, min_val=-10, genes=None, groupby=None):
    """
    Plot a heatmap from clipped data of an AnnData object.

    Parameters:
    adata (AnnData): The original AnnData object.
    max_val (float): Maximum value to clip data to.
    min_val (float): Minimum value to clip data to.
    genes (list): List of gene names to be plotted. They should match the var_names in adata.
    groupby (str): Name of the observation annotation to group by (usually categorical).

    Returns:
    None: Displays a heatmap.
    """

    # Step 1: Make a copy of the AnnData object to avoid overwriting original data
    adata_copy = adata.copy()

    # Step 2: Clip the data in the X matrix of the copied AnnData object
    # Check if 'X' is dense or sparse and clip accordingly
    if isinstance(adata_copy.X, np.ndarray):
        adata_copy.X = np.clip(adata_copy.X, a_min=min_val, a_max=max_val)
    elif isinstance(adata_copy.X, (scipy.sparse.csr_matrix, scipy.sparse.csc_matrix)):
        adata_copy.X.data = np.clip(adata_copy.X.data, a_min=min_val, a_max=max_val)
    else:
        raise TypeError("adata.X must be a numpy array or a scipy sparse matrix.")

    rcParams["figure.figsize"]  =(10,10)
    # Step 3: Use scanpy's pl.heatmap function to visualize the clipped data
    sc.pl.heatmap(adata_copy, var_names=genes, groupby=groupby, swap_axes= True, standard_scale = "obs")

# Example usage
# plot_clipped_heatmap(your_adata_object, max_val=10, min_val=-10, genes=your_genes_list, groupby='your_groupby_column')


In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.heatmap(emb_dict["all"], var_names=features_fixed, groupby="subpopulations", dendrogram= False, swap_axes= False, vmin = -3, vmax = 3, cmap='RdBu_r')

In [None]:
import anndata
import numpy as np
import pandas as pd

def aggregate_by_group(adata, group_by):
    """
    Aggregate the expression data in an AnnData object by a specified group.
    
    Parameters:
    adata (AnnData): The original AnnData object.
    group_by (str): The column in adata.obs to group by.
    
    Returns:
    AnnData: A new AnnData object with aggregated data.
    """
    # Ensure the group_by column is categorical for efficiency
    adata.obs[group_by] = adata.obs[group_by].astype('category')
    if isinstance(adata.X, (np.ndarray, np.generic)):  # If .X is already a dense matrix
         adata_df = pd.DataFrame(adata.X, columns=adata.var_names)
    else:  # If .X is a sparse matrix
        adata_df  = pd.DataFrame(adata.X.toarray(), columns=adata.var_names)

    # Group and aggregate data

    adata_df[group_by] = adata.obs[group_by].values
    
    # Aggregate data by taking the mean for each group
    aggregated_data = adata_df.groupby(group_by).median()
    # Create a new AnnData object with the aggregated data
    # Note: Here we're assuming that the .var information remains the same
    # If there are .obs specific fields you'd like to retain or calculate, adjust as needed
    aggregated_adata = anndata.AnnData(X=aggregated_data.values, var=adata.var.copy())
    aggregated_adata.obs[group_by] = aggregated_data.index.values
    
    return aggregated_adata

# Example usage:
# aggregated_adata = aggregate_by_group(your_adata, 'cell_type')


In [None]:
aggregated = aggregate_by_group(emb_dict["all"], "subpopulations")

In [None]:
sc.tl.dendrogram(aggregated, var_names=features_fixed, groupby="subpopulations")

In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.heatmap(aggregated, dendrogram=True, var_names=features_fixed, groupby="subpopulations", swap_axes=False, vmin=-3, vmax=3, cmap='RdBu_r')


In [None]:
sc.tl.dendrogram(emb_dict["all"], var_names=features_fixed, groupby="subpopulations")

In [None]:
sc.pl.correlation_matrix(emb_dict["all"], 'subpopulations')

In [None]:
import matplotlib.patches as mpatches
def create_heatmap_from_aggregated_adata(adata, groupby_column, title="", cmap='viridis', figsize=(10, 8), vmin=None, vmax=None):
    """
    Create a heatmap from an aggregated AnnData object with specified labels on the y-axis.

    Parameters:
    adata (AnnData): The aggregated AnnData object.
    label_column (str): Column in adata.obs to use for y-axis labels.
    title (str, optional): Title of the heatmap.
    cmap (str, optional): Colormap for the heatmap.
    figsize (tuple, optional): Size of the figure.
    vmin, vmax (float, optional): Min and max values for colormap scaling.

    Returns:
    Heatmap plot
    """
    # Ensure the label column is present
    if groupby_column not in adata.obs:
        raise ValueError(f"{groupby_column} not found in adata.obs")

    # Extract group labels and assign colors
    group_labels = adata.obs[groupby_column].unique()
    colors = sns.color_palette('hsv', len(group_labels))

    # Create a color dictionary for the groups
    color_dict = dict(zip(group_labels, colors))

    # Convert the .X matrix to a DataFrame
    data_df = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)
    # Add the group column for color bar creation
    data_df[groupby_column] = adata.obs[groupby_column]

    # Creating the heatmap
    plt.figure(figsize=figsize)
    ax = sns.heatmap(data_df.drop(columns=[groupby_column]), cmap=cmap, annot=False, vmin=vmin, vmax=vmax)
    plt.title(title)
    plt.ylabel('')
    plt.xlabel('')
    plt.xticks([])  # Remove x-axis tick labels
    ax.set_yticklabels(data_df[groupby_column].unique(), rotation=0)

    for i in range(data_df.shape[0] - 1):
        ax.axhline(i + 1, color='black', lw=1)

    # Add lines around the plot
    ax.axhline(0, color='black', lw=2)  # Top horizontal line
    ax.axhline(data_df.shape[0], color='black', lw=2)  # Bottom horizontal line
    #ax.axvline(0, color='black', lw=2)  # Left vertical line
    ax.axvline(data_df.shape[1], color='black', lw=2) 
    xlim = ax.get_xlim()  # Get the current x-axis limits
    ax.axvline(x=xlim[1], color='black', lw=2) 
    # Add color bars
    for i, group in enumerate(data_df[groupby_column].unique()):
        ax.add_patch(mpatches.Rectangle((0, i), 5, 1, color=color_dict[group]))

    plt.show()

In [None]:
create_heatmap_from_aggregated_adata(aggregated, groupby_column= "subpopulations", vmin = -3, vmax = 3, cmap= "RdBu_r")

In [None]:
import scipy.cluster.hierarchy as sch
def create_heatmap_with_dendrogram(adata, groupby_column, title="", cmap='viridis', figsize=(12, 10), vmin=None, vmax=None):
    """
    Create a heatmap from an aggregated AnnData object with a dendrogram based on groupings.

    Parameters:
    adata (AnnData): The aggregated AnnData object.
    groupby_column (str): Column in adata.obs to use for groupings.
    title (str, optional): Title of the heatmap.
    cmap (str, optional): Colormap for the heatmap.
    figsize (tuple, optional): Size of the figure.
    vmin, vmax (float, optional): Min and max values for colormap scaling.

    Returns:
    Heatmap plot with a dendrogram
    """
    if groupby_column not in adata.obs:
        raise ValueError(f"{groupby_column} not found in adata.obs")

    # Convert the .X matrix to a DataFrame and add group labels
    data_df = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)
    data_df[groupby_column] = adata.obs[groupby_column]

    # Perform hierarchical clustering
    grouped = data_df.groupby(groupby_column).mean()
    Z = sch.linkage(grouped, method='average')

    # Create a dendrogram
    fig, ax = plt.subplots(figsize=figsize)
    dendro = sch.dendrogram(Z, labels=grouped.index, ax=ax, above_threshold_color='black')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
    ax.set_title(title)

    # Create the heatmap
    ax_heatmap = fig.add_axes([0.3, 0.1, 0.6, 0.6])  # Adjust these values as needed for layout
    sns.heatmap(grouped.reindex(dendro['ivl']), cmap=cmap, ax=ax_heatmap, vmin=vmin, vmax=vmax)

    plt.show()


In [None]:
create_heatmap_with_dendrogram(aggregated, groupby_column= "subpopulations", vmin = -3, vmax = 3, cmap= "RdBu_r")