In [None]:
import scanpy as sc 
import anndata as ad
import polars as pl
import os
import sys
# Add the module_folder to the sys.path list
sys.path.append('/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/analysis_functions/')
from plotting import *

In [None]:
PROJECT_DIR = "/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids"

In [None]:
tests = ad.read_h5ad("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cell_cellcycle2.h5ad")

In [None]:
list(tests.obs["Metadata_cmpdName"])

## Read in old data

In [None]:
# Only use one plate for Beactica due to size of data
mad_norm_df = pl.read_parquet(os.path.join(PROJECT_DIR, 'Beactica/Results/sc_profiles_normalized_Beactica_PB000046.parquet'))

In [None]:
#features_fixed = [f for f in mad_norm_df.columns if "Feature" in f]
grit_filt_df = pl.read_parquet(os.path.join(PROJECT_DIR, "Beactica/deepprofiler/Results/sc_grit_merged_cellprofiler.parquet"))

In [None]:
grit_filter = grit_filt_df.filter(
    ((pl.col("grit") >= 0.8) & (pl.col("grit") < 4)) |
    (pl.col("Metadata_cmpdName") == "DIMETHYL SULFOXIDE")
)

In [None]:
grit_filter.write_parquet("sc_grit_filter_FILTERED.parquet")

In [None]:
grit_filter_df_sampled = grit_filt_df
#sample_compounds(mad_norm_df, grit_filt_df, sampling_rate= 1, mode = "normal")
grit_filter_df_sampled_pd = grit_filter_df_sampled.to_pandas()
features_fixed = [f for f in grit_filt_df.columns if "Feature" in f]
meta_features = [col for col in grit_filter_df_sampled_pd.columns if col not in features_fixed]

## Grit filtering

In [None]:
def fix_compound_names(adata):
    adata_temp = adata.copy()
    adata_temp.obs["Metadata_cmpdName"] = adata_temp.obs["Metadata_cmpdName"].astype(str)
    adata_temp.obs["compound_id"] = adata_temp.obs["compound_id"].astype(str)
    nan_string_indices = adata_temp.obs["Metadata_cmpdName"]== 'nan'
    # For these indices, assign values from 'backup_column'
    adata_temp.obs.loc[nan_string_indices, "Metadata_cmpdName"] = adata_temp.obs.loc[nan_string_indices, 'compound_id']
    adata_temp.obs["Metadata_cmpdNameConc"] = adata_temp.obs["Metadata_cmpdName"].astype(str) + "_" + adata_temp.obs["Metadata_cmpdConc"].astype(str)
    return adata_temp

In [None]:
adata_copy = adata.copy()
adata_copy.obs['grit'] = adata_copy.obs['grit'].fillna(0)
adata_copy = fix_compound_names(adata_copy)
#adata_copy.write("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cell_fix.h5ad")

In [None]:
adata_grit_filter = adata_copy.copy()
condition1 = (adata_grit_filter.obs['grit'] >= 0.8) & (adata_grit_filter.obs['grit'] <= 4)
values_to_keep = ['DIMETHYL SULFOXIDE']  # replace with your actual values
condition2 = adata_grit_filter.obs['Metadata_cmpdName'].isin(values_to_keep)

combined_condition = condition1 | condition2

# Filter the AnnData object
filtered_adata = adata_grit_filter[combined_condition]

filtered_adata = fix_compound_names(filtered_adata)

In [None]:
filtered_adata.write("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cellcylce_grit_filtered.h5ad")

In [None]:
grit_dist = adata[adata.obs["grit"].notna()]

In [None]:
grit_dist_filt = grit_dist[(grit_dist.obs['grit'] >= 0.8) & (grit_dist.obs['grit'] <= 4)| grit_dist.obs['Metadata_cmpdName'].isin(["DIMETHYL SULFOXIDE"])]

In [None]:
grit_dist_filt = fix_compound_names(grit_dist_filt)

## Filter low intensity

In [None]:
adata = ad.read_h5ad("sc_embedding_scanpy_Beactica_deep+cell_edist.h5ad")

In [None]:
adata.obs.columns

In [None]:
filtered_adata = adata[(adata.obs['Intensity_IntegratedIntensity_illumHOECHST_nuclei'] > 10) & 
                       (adata.obs['AreaShape_Area_nuclei'] > 50)].copy()

In [None]:
filtered_adata.obs["Intensity_IntegratedIntensity_illumHOECHST_nuclei"].min()

In [None]:
print("Starting scanpy!")
sc.tl.pca(filtered_adata, svd_solver='arpack')
sc.pp.neighbors(filtered_adata, n_neighbors=10, n_pcs=50)
sc.tl.paga(filtered_adata, groups="Metadata_cmpdName")
sc.pl.paga(filtered_adata, plot=False)  # remove `plot=False` if you want to see the coarse-grained graph
sc.tl.umap(filtered_adata, init_pos='paga')
print("Embedding complete. Saving file!")

## Scanpy plots

In [None]:
sc.pl.pca(filtered_adata, color = "Metadata_cmpdName")
sc.pl.pca_variance_ratio(filtered_adata)

In [None]:
sc.pl.umap(filtered_adata, color = ["Metadata_cmpdName"])

In [None]:
filtered_adata.write_h5ad("sc_embedding_scanpy_Beactica_deep+cell_edist_segmfix3.h5ad")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors
# Inital setting for plot
from matplotlib import rcParams
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.umap(filtered_adata, color ="moa")

## Cell cycle analysis

In [None]:
import numpy as np
import seaborn as sns
def show_meta_distribution(adata, column):
    data = adata.obs[column]

# Generate the histogram
    counts, bin_edges = np.histogram(data, bins=300)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    # Normalize bin_centers for color mapping
    norm_bin_centers = (bin_centers - bin_centers.min()) / (bin_centers.max() - bin_centers.min())

    # Create a colormap
    cmap = plt.cm.viridis

    # Plot each bar individually, coloring them based on the normalized bin center
    plt.figure(figsize=(10, 6))
    for count, x, norm_bin_center in zip(counts, bin_centers, norm_bin_centers):
        plt.bar(x, count, width=bin_edges[1] - bin_edges[0], color=cmap(norm_bin_center))

    # Overlay KDE
    sns.kdeplot(data, color='k', linewidth=2, alpha=0.7)

    plt.colorbar(plt.cm.ScalarMappable(cmap=cmap), label='Normalized Bin Center')
    plt.title(f'Histogram of {column}')
    plt.xlabel(column)
    plt.ylabel('Frequency')
    plt.show()

In [None]:
import numpy as np
import seaborn as sns
from scipy.signal import find_peaks
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import anndata as ad

def find_histogram_groups(adata, obs_column, labels=None, threshold=None, bin_width=100, bw_adjust=0.5):
    """
    Find groups in the distribution of a feature in an AnnData object, shade the identified regions in different colors, and label them.
    
    Parameters:
    - adata (AnnData): The AnnData object containing the data.
    - obs_column (str): The name of the column in adata.obs of interest.
    - labels (list of str, optional): Labels for each identified region.
    - threshold (float, optional): A value above which all data points are considered in the same category.
    - bin_width (int): The number of bins for the histogram.
    - bw_adjust (float): Bandwidth adjustment for KDE.
    """
    # Extract the data from the specified obs column
    data = adata.obs[obs_column]

    # Calculate histogram data
    density, bins = np.histogram(data, bins=bin_width, density=True)
    peaks, _ = find_peaks(density)
    
    # Define boundaries around peaks
    boundaries = [(bins[peaks[i]] + bins[peaks[i+1]]) / 2 for i in range(len(peaks)-1)]
    
    # Incorporate threshold
    if threshold is not None:
        boundaries = [b for b in boundaries if b < threshold]
    
    # Plotting
    plt.figure(figsize=(10, 6))
    sns.kdeplot(data, bw_adjust=bw_adjust, fill=True)
    cmap = plt.get_cmap('tab10')  # Colormap with distinct colors
    
    # Ensure there are enough labels
    num_regions = len(boundaries) + 1 + (1 if threshold is not None else 0)
    if not labels or len(labels) < num_regions:
        labels = labels or []
        labels += [f'Region {i + 1}' for i in range(len(labels), num_regions)]
    
    label_ranges = {}


    def update_label_ranges(label, new_range):
        if label in label_ranges:
            current_range = label_ranges[label]
            # Update the range with the min and max of the current and new range
            label_ranges[label] = (min(current_range[0], new_range[0]), max(current_range[1], new_range[1]))
        else:
            label_ranges[label] = new_range

    # Shade regions below the lowest boundary
    if boundaries:
        plt.fill_betweenx([0, density.max()], bins[0], boundaries[0], color=cmap(0), alpha=0.3)
        plt.text(boundaries[0]/2, density.max()/2, labels[0], color=cmap(0), ha='center', fontweight = "bold")
        #label_ranges[labels[0]] = (-100, boundaries[0])
        update_label_ranges(labels[0], (-10, boundaries[0]))
    # Shade regions and plot boundaries, and label them
    for i, boundary in enumerate(boundaries):
        plt.axvline(x=boundary, color='green', linestyle='--')  # Plot boundaries
        if i < len(boundaries) - 1:
            middle_point = (boundary + boundaries[i+1]) / 2
        else:
            middle_point = (boundary + bins[-1]) / 2 if threshold is None else (boundary + threshold) / 2
        plt.fill_betweenx([0, density.max()], boundary, boundaries[i+1] if i < len(boundaries) - 1 else bins[-1], color=cmap(i+1), alpha=0.3)
        plt.text(middle_point, density.max()/2, labels[i+1], color=cmap(i+1), ha='center', fontweight = "bold")
        new_range = (boundaries[i], boundaries[i+1] if i < len(boundaries) - 1 else (threshold if threshold is not None else bins[-1]))
        update_label_ranges(labels[i+1], new_range)
        #label_ranges[labels[i+1]] = (boundaries[i], boundaries[i+1] if i < len(boundaries) - 1 else (threshold if threshold is not None else bins[-1]))
    # Shade region above threshold if it's defined and label it
    if threshold is not None:
        plt.fill_betweenx([0, density.max()], threshold, bins[-1], color=cmap(len(boundaries) + 1), alpha=0.3)
        plt.axvline(x=threshold, color='black', linestyle='--', linewidth=1)  # Plot threshold line
        plt.text((threshold + bins[-1])/2, density.max()/2, labels[-1], color=cmap(len(boundaries) + 1), ha='center', fontweight='bold')
        #label_ranges[labels[-1]] = (threshold, bins[-1])
        update_label_ranges(labels[-1], (threshold, 5000))
    plt.title(f'Histogram with Identified Groups of {obs_column}')
    plt.xlabel(obs_column)
    plt.ylabel('Density')
    plt.show()
    print(label_ranges)
    return label_ranges

In [None]:
cell_cylce_ranges = find_histogram_groups(adata = adata_copy[adata_copy.obs["Metadata_cmpdName"] == "DIMETHYL SULFOXIDE"], obs_column ="Intensity_IntegratedIntensity_illumHOECHST_nuclei", labels =  ["Sub G1", "G1", "S", "G2/M", "G2/M", "Undefined"], threshold = 1300, bin_width=100)

In [None]:
def plot_histogram_with_labels(adata, obs_column, label_ranges, bin_width=100, bw_adjust=0.5):
    """
    Plot a histogram of a feature in an AnnData object with regions shaded and labeled based on provided label ranges.
    
    Parameters:
    - adata (AnnData): The AnnData object containing the data.
    - obs_column (str): The name of the column in adata.obs of interest.
    - label_ranges (dict): Dictionary with labels as keys and value ranges as values.
    - bin_width (int): The number of bins for the histogram.
    - bw_adjust (float): Bandwidth adjustment for KDE.
    """
    # Extract the data from the specified obs column
    data = adata.obs[obs_column]

    # Calculate histogram data
    density, bins = np.histogram(data, bins=bin_width, density=True)

    # Plotting
    plt.figure(figsize=(10, 6), dpi = 300)
    sns.kdeplot(data, bw_adjust=bw_adjust, fill=True)
    cmap = plt.get_cmap('tab10')  # Colormap with distinct colors

    # Iterate through the label_ranges and shade & label each region
    for i, (label, (start, end)) in enumerate(label_ranges.items()):
        plt.fill_betweenx([0, density.max()], start, end, color=cmap(i), alpha=0.3)
        plt.text((start + end) / 2, density.max() * 0.97, label, color=cmap(i), ha='center', fontweight='bold')

    plt.title(f'Histogram with Identified Groups of {obs_column}')
    plt.xlabel(obs_column)
    plt.ylabel('Density')
    plt.savefig("dmso_cell_cyles_BEACTICA.png", dpi = 300)
    plt.show()

In [None]:
plot_histogram_with_labels(adata_copy[adata_copy.obs["Metadata_cmpdName"] == "DIMETHYL SULFOXIDE"], "Intensity_IntegratedIntensity_illumHOECHST_nuclei", cell_cylce_ranges)

In [None]:
import pandas as pd
def assing_cell_cycle_stage(adata, obs_column, label_ranges):
    """
    Assign labels to the 'cell_cycle' column in an AnnData object based on the value ranges from the label_ranges dictionary.
    
    Parameters:
    - adata (AnnData): The AnnData object containing the data.
    - obs_column (str): The name of the column in adata.obs to use for labeling.
    - label_ranges (dict): Dictionary with labels as keys and value ranges as values.
    """
    # Ensure cell_cycle column exists in adata.obs
    if 'cell_cycle' not in adata.obs:
        adata.obs['cell_cycle'] = pd.NA
    
    # Assign labels based on ranges
    for label, (start, end) in label_ranges.items():
        # Find rows where the value falls within the current range and assign the label
        adata.obs.loc[(adata.obs[obs_column] >= start) & (adata.obs[obs_column] <= end), 'cell_cycle'] = label


In [None]:
assing_cell_cycle_stage(adata_copy,"Intensity_IntegratedIntensity_illumHOECHST_nuclei", cell_cylce_ranges)

In [None]:
adata_copy.write("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cell_cellcycle.h5ad")

In [None]:
adata = ad.read_h5ad("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cell_cellcycle2.h5ad")

In [None]:
def kde_plot_by_category(adata, compound, value_column, category_column):
    """
    Create KDE plots for values in one column of anndata.obs, colored by categories in another column.
    
    Parameters:
    - adata (AnnData): The AnnData object containing the data.
    - value_column (str): The name of the column in adata.obs whose values are to be plotted.
    - category_column (str): The name of the column in adata.obs to use for coloring the KDE plots.
    """
    # Ensure the columns exist
    if value_column not in adata.obs.columns or category_column not in adata.obs.columns:
        raise ValueError(f"Columns {value_column} and/or {category_column} not found in adata.obs")
    
    data = adata[adata.obs["Metadata_cmpdName"] == compound]
    # Create the plo
    plt.figure(figsize=(15, 10), dpi = 300)
    
    # Get unique categories
    categories = data.obs[category_column].unique()
    
    # Plot KDE for each category
    for category in categories:
        subset = data.obs[data.obs[category_column] == category]
        sns.kdeplot(subset[value_column], label=category, fill=False, bw_adjust=0.5, alpha = 1)
    
    plt.title(f'KDE of {value_column} colored by {category_column}')
    plt.xlabel(value_column)
    plt.ylabel('Density')
    plt.legend(title=category_column)
    plt.show()

In [None]:
kde_plot_by_category(adata, "CISPLATIN", "Intensity_IntegratedIntensity_illumHOECHST_nuclei", "Metadata_cmpdConc")

In [None]:
def violin_plot_by_category(adata, compound, value_column, category_column, scale):
    """
    Create violin plots for values in one column of anndata.obs, split by categories in another column, with an additional violin plot for "DIMETHYL SULFOXIDE".
    
    Parameters:
    - adata (AnnData): The AnnData object containing the data.
    - compound (str): The compound name to filter the data for plotting.
    - value_column (str): The name of the column in adata.obs whose values are to be plotted.
    - category_column (str): The name of the column in adata.obs to use for splitting the violin plots.
    """
    # Ensure the columns exist
    if value_column not in adata.obs.columns or category_column not in adata.obs.columns:
        raise ValueError(f"Columns {value_column} and/or {category_column} not found in adata.obs")
    
    # Filter data based on the selected compound and DMSO
    data_compound = adata.obs[adata.obs["Metadata_cmpdName"] == compound].copy()
    data_dmso = adata.obs[adata.obs["Metadata_cmpdName"] == "DIMETHYL SULFOXIDE"].copy()
    
    # Adjust the category for DMSO to display "DMSO" instead of the actual concentration value
    data_dmso[category_column] = "DMSO"
    
    # Concatenate DMSO data with the compound data
    data = pd.concat([data_dmso, data_compound], axis=0)
    
    # Define the order for the y-axis categories
    category_order = ["DMSO", 0.1, 0.3, 1, 3, 5, 10]
    
    # Create the violin plot with the specified category order
    plt.figure(figsize=(15, 10), dpi=300)
    ax = sns.violinplot(x=value_column, y=category_column, orient="h", scale= scale, data=data, cut=0, inner='box', palette='viridis', split=True, order=category_order)

    # Set plot title and labels
    ax.set_title(f'Distribution of DNA integrated intensity for {compound} with DMSO', fontsize=20, fontweight='bold')
    ax.set_xlabel(value_column, fontsize=14, fontweight='bold')
    ax.set_ylabel(category_column, fontsize=14, fontweight='bold')
    plt.tight_layout()  # Adjust layout to fit everything nicely
    plt.savefig(f"/home/jovyan/share/data/analyses/benjamin/cellxgene/Beactica/deepprofiler/figures/cell_cycle/dna_intensity_kde_{compound}_dmso.png", dpi = 300)
    #plt.show()

In [None]:
compounds = ["CLADRIBINE"]
#["ETOPOSIDE", "CISPLATIN", "FK 866", "NIGERICIN", 'HS-173', "BEA-0005443-AQ-003", "DACTINOMYCIN (ACTINOMYCIN D)"]
for c in compounds:
    if c in ["NIGERICIN", "DACTINOMYCIN (ACTINOMYCIN D)"]:
        violin_plot_by_category(adata, c, "Intensity_IntegratedIntensity_illumHOECHST_nuclei", "Metadata_cmpdConc", scale = "width")
    else:
        violin_plot_by_category(adata, c, "Intensity_IntegratedIntensity_illumHOECHST_nuclei", "Metadata_cmpdConc", scale = "area")

In [None]:
dmso_anndata = ad.AnnData(X=adata[adata.obs["Metadata_cmpdName"] == "DIMETHYL SULFOXIDE"].X, obs=adata[adata.obs["Metadata_cmpdName"] == "DIMETHYL SULFOXIDE"].obs)

In [None]:
def run_scanpy(adata, compound_col):
    sc.tl.pca(adata, svd_solver='arpack')
    sc.pp.neighbors(adata, n_neighbors=10, n_pcs=50)
    sc.tl.paga(adata, groups=compound_col)
    sc.pl.paga(adata, plot=False)  # remove `plot=False` if you want to see the coarse-grained graph
    sc.tl.umap(adata, init_pos='paga')
    return adata

In [None]:
dmso_anndata = run_scanpy(dmso_anndata, "Metadata_cmpdName")

## Load in cellxgene embeddings

In [None]:
CXG_DIR = "/home/jovyan/share/data/analyses/benjamin/cellxgene/Beactica/deepprofiler/embeddings"

In [None]:
def fix_keys(adata):
    def find_key_with_substring(obsm, substring):
        for key in obsm.keys():
            if substring in key:
                return key
        return None

    # Find the keys
    pca_key = find_key_with_substring(adata.obsm, 'pca')
    umap_key = find_key_with_substring(adata.obsm, 'dmso')
    if umap_key == None:
        umap_key = find_key_with_substring(adata.obsm, 'emb')

    # Rename the keys if they are found
    if pca_key:
        adata.obsm['X_pca'] = adata.obsm[pca_key]
        #del adata.obsm[pca_key]

    if umap_key:
        adata.obsm['X_umap'] = adata.obsm[umap_key]
        #del adata.obsm[umap_key]

    return adata

In [None]:
import tqdm
import os
import anndata as ad
emb_dict = {}
ref_comp = ["fk866", "hs-173", "nigericin", "aq-003", "actinomycin-d", "dmso_only"]
h5ad_files = [file for file in os.listdir(CXG_DIR) if file.endswith(".h5ad")]

for comp in tqdm.tqdm(ref_comp):
   for filename in h5ad_files:
        # Check if the current string is in the filename
        if comp in filename and filename.endswith(".h5ad"):
            # Construct the full file path
            file_path = os.path.join(CXG_DIR, filename)
            # Load the .h5ad file
            temp = ad.read_h5ad(file_path)
            temp_fix = fix_keys(temp)
            emb_dict[comp] = temp_fix
            # Optional: Print a message
            print(f"Loaded {filename}")

In [None]:
temp = ad.read_h5ad(os.path.join(CXG_DIR, "umap.h5ad"))
temp_fix = fix_keys(temp)
emb_dict["all"] = temp_fix

In [None]:
emb_dict["all"].obs["Metadata_cmpdName"] = emb_dict["all"].obs["Metadata_cmpdName"].astype(str)
emb_dict["all"].obs["compound_id"] = emb_dict["all"].obs["compound_id"].astype(str)
nan_string_indices = emb_dict["all"].obs["Metadata_cmpdName"]== 'nan'
# For these indices, assign values from 'backup_column'
emb_dict["all"].obs.loc[nan_string_indices, "Metadata_cmpdName"] = emb_dict["all"].obs.loc[nan_string_indices, 'compound_id']

emb_dict["all"].obs['subpopulations'] = emb_dict["all"].obs.apply(
    lambda row: row['Metadata_cmpdName'] + "_" + row['leiden_v4_r0.4'] if row['leiden_v4_r0.4'] != 'unassigned' else row['Metadata_cmpdName'],
    axis=1
)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors
# Inital setting for plot
from matplotlib import rcParams
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    for key, item in emb_dict.items():
        print(key)
        if key == "all":
            sc.pl.umap(item, color = "subpopulations")
        elif key == "dmso_only":
            sc.pl.umap(item, color = "leiden_v4_r0.4")
        else:
            sc.pl.umap(item, color = "Metadata_cmpdNameConc")
        

In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np

def plot_umap_grid_colored(anndata_dict, color_by, n_cols=3):
    """
    Create a grid of UMAP plots from a dictionary of AnnData objects, colored by a specified column.
    
    Parameters:
    anndata_dict (dict): A dictionary of AnnData objects.
    color_by (str): Column name to color by.
    n_cols (int): Number of columns in the grid.
    """
    # Determine all unique categories across all AnnData objects
    anndata_dict = {k: v for k, v in anndata_dict.items() if k != 'all'}
    all_categories = set()
    for adata in anndata_dict.values():
        all_categories.update(adata.obs[color_by].astype(str))

    # Sort categories for consistent ordering and create color palette
    sorted_categories = sorted(list(all_categories))
    color_palette = sc.pl.palettes.default_20 # Use any large enough palette or define your own
    color_map = {cat: color_palette[i % len(color_palette)] for i, cat in enumerate(sorted_categories)}
    #color_map = {'big_dmso': '#1f77b4', 'small_dmso': '#ff7f0e', 'small_FLUP': '#279e68', 'big_FLUP': '#d62728', 'big_ETOP': '#aa40fc', 'small_ETOP': '#8c564b', 'big_TETR': '#e377c2', 'small_TETR': '#b5bd61', 'small_CA-O': '#17becf', 'big_CA-O': '#aec7e8', 'unassigned': '#ffbb78', 'BERB': '#98df8a', 'FEB': '#ff9896'}
    print(color_map)
    # Set up the figure for subplots
    n_rows = int(np.ceil(len(anndata_dict) / n_cols))
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
    axs = axs.flatten()  # Flatten to make indexing axs easier

    all_handles = []
    all_labels = set()
    
    # Plot UMAP for each AnnData object
    for ax, (key, adata) in zip(axs, anndata_dict.items()):
        sc.pl.umap(adata, color=color_by, ax=ax, show=False, 
                   title=key, frameon=False,
                   palette=color_map,
                   legend_loc = "none")  # Apply the consistent color map

        handles, labels = ax.get_legend_handles_labels()
        all_handles.extend(handles)
        all_labels.update(labels)
        # Remove axis titles (optional, for cleaner look)
        ax.set_xlabel('')
        ax.set_ylabel('')

    # Hide any extra axes
    for i in range(len(anndata_dict), len(axs)):
        axs[i].axis('off')

    # Create an overall title
    plt.suptitle('UMAP Grid', fontsize=16)

    # Add a single legend outside the plots
    # Get handles and labels for legend from the last plot
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=3, bbox_to_anchor=(0.5, 0.01))

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# Example usage:
with plt.rc_context({"figure.figsize": (15, 15), "figure.dpi": (300)}):
    plot_umap_grid_colored(emb_dict, "Metadata_cmpdNameConc")


In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

def plot_single_umap_colored(adata, color_by):
    # Calculate the count of each category in the color_by column
    category_counts = adata.obs[color_by].value_counts()

    # Create a color palette
    comps = adata.obs[color_by].unique()

    # Sort categories for consistent ordering and create color palette
    sorted_categories = sorted(list(comps))
    color_palette = sc.pl.palettes.default_102 # Use any large enough palette or define your own
    color_map = {cat: color_palette[i % len(color_palette)] for i, cat in enumerate(sorted_categories)}
    #color_map = {'big_dmso': '#1f77b4', 'small_dmso': '#ff7f0e', 'small_FLUP': '#279e68', 'big_FLUP': '#d62728', 'big_ETOP': '#aa40fc', 'small_ETOP': '#8c564b', 'big_TETR': '#e377c2', 'small_TETR': '#b5bd61', 'small_CA-O': '#17becf', 'big_CA-O': '#aec7e8', 'unassigned': '#ffbb78', 'BERB': '#98df8a', 'FEB': '#ff9896'}
    #color_map = {'berb': '#1f77b4', 'cao': '#ff7f0e', 'dmso_big': '#279e68', 'dmso_small': '#d62728', 'etop_big': '#aa40fc', 'etop_nocluster': '#8c564b', 'etop_small': '#e377c2', 'fenb': '#b5bd61', 'flup': '#17becf', 'tetr_big': '#aec7e8', 'tetr_nocluster': '#ffbb78'}
    # Create figure and axis for UMAP plot
    fig, ax = plt.subplots(figsize=(8, 6))  # Adjust figure size as needed

    # Create UMAP plot
    sc.pl.umap(adata, color=color_by, ax=ax, show=False,
               title=f'UMAP colored by {color_by}', 
               frameon=False, legend_loc='none', 
               palette=color_map, s = 2)

    # Create a custom legend for all categories with counts
    legend_elements = [Line2D([0], [0], marker='o', color='w',
                              label=f"{cat} (n={category_counts[cat]})",
                              markerfacecolor=color_map[cat], markersize=10)
                       for cat in category_counts.index]

    # Place legend outside the plot to the right
    ax.legend(handles=legend_elements, title=color_by, loc='center left',
              bbox_to_anchor=(1, 0.5), ncol=1, fontsize='x-small')

    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Adjust the rect parameter to make space for the legend
    plt.show()

with plt.rc_context({"figure.figsize": (15, 15), "figure.dpi": (300)}):
    plot_single_umap_colored(emb_dict["dmso_only"], "leiden_v4_r0.4")

## Custom plots

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import datetime

def make_jointplot_anndata(adata, dmso_name, colouring, cmpd, save_path=None):
    # Extract UMAP data from .obsm
    umap_data = pd.DataFrame(adata.obsm['X_umap'], columns=['UMAP1', 'UMAP2']).reset_index()

    # Join with metadata from .obs
    embedding = pd.concat([umap_data,pd.DataFrame(adata.obs).reset_index()], axis = 1)
    embedding = embedding.reset_index()
    embedding["Metadata_cmpdNameConc2"] = embedding["Metadata_cmpdName"].astype(str) + "_" + embedding["Metadata_cmpdConc"].astype(str)
    # Generate a color palette based on unique values in the colouring column
    unique_treatments = embedding[colouring].unique()
    palette = sns.color_palette("Set2", len(unique_treatments))
    color_map = dict(zip(unique_treatments, palette))
    # Adjust colors and transparency if colouring is 'Metadata_cmpdName'

    embedding[colouring] = embedding[colouring].astype(str)

# Adjust colors and transparency if colouring is 'Metadata_cmpdName'
    if colouring == colouring:
        if dmso_name in color_map:
            color_map[dmso_name] = 'lightgrey'

    def get_size(val):
        return 10 if val != dmso_name else 3
    
    embedding['color'] = embedding[colouring].apply(lambda x: color_map.get(x, "default_color"))
    #embedding['color'] = embedding[colouring].map(color_map)
    embedding['size'] = embedding[colouring].apply(get_size)
    
    # Increase the DPI for displaying
    plt.rcParams['figure.dpi'] = 300
    
    # Create the base joint plot
    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    unique_treatments = list(unique_treatments)
    index = unique_treatments.index(dmso_name)

# Remove the element from its current position and insert it at index 0
    if index != 0:
        unique_treatments.insert(0, unique_treatments.pop(index))
    # Plot KDE plots for each category
    for treatment in unique_treatments:
        subset = embedding[embedding[colouring] == treatment]
        
        sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=color_map[treatment], legend=False)
        sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=color_map[treatment], legend=False)

    # Plot the scatter plots
    for treatment in unique_treatments:
        subset = embedding[embedding[colouring] == treatment]
        alpha_val = 0.3 if treatment == dmso_name and colouring == 'Metadata_cmpdName' else 0.8
        g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], label=treatment, alpha=alpha_val, edgecolor='white', linewidth=0.5)
    
    g.ax_joint.set_title(cmpd)
    #legend = g.ax_joint.legend(fontsize=10)
    #legend.get_frame().set_facecolor('white')
    legend_elements = [Line2D([0], [0], marker='o', linestyle= "None", color=color_map[treatment], label=treatment, markersize=5, markerfacecolor=color_map[treatment], alpha=1) for treatment in unique_treatments]
    #legend = g.ax_joint.legend(handles=legend_elements, fontsize=10, title=colouring)
    legend.get_frame().set_facecolor('white')
    
    if save_path != None:
        current_time = datetime.datetime.now()
        timestamp = current_time.strftime("%Y%m%d_%H%M%S")
        g.savefig(f"{save_path}.png", dpi=300)

    plt.show()

# Usage

In [None]:
make_jointplot_anndata(adata_filtered, 'DIMETHYL SULFOXIDE', 'Metadata_cmpdName', "")

In [None]:
for key, value in emb_dict.items():
    if key != "all":
        emb_dict[key].obs["Metadata_cmpdName"] = emb_dict[key].obs["Metadata_cmpdName"].astype(str)
        emb_dict[key].obs["compound_id"] = emb_dict[key].obs["compound_id"].astype(str)
        nan_string_indices = emb_dict[key].obs["Metadata_cmpdName"]== 'nan'
        emb_dict[key].obs.loc[nan_string_indices, "Metadata_cmpdName"] = emb_dict[key].obs.loc[nan_string_indices, 'compound_id']
        make_jointplot_anndata(emb_dict[key], 'DIMETHYL SULFOXIDE_0.1', 'Metadata_cmpdNameConc2', key, save_path= f"/home/jovyan/share/data/analyses/benjamin/cellxgene/deepprofiler/BEACTICA/figures/sc_dmso_{key}_umap.png")

In [None]:
from scipy import stats
import numpy as np
def make_jointplot_seaborn_density_anndata(adata, colouring, dmso_name, cmpd, save_path = None, overlay=False, overlay_df=None):
    
    umap_data = pd.DataFrame(adata.obsm['X_umap'], columns=['UMAP1', 'UMAP2']).reset_index()
    embedding = pd.concat([umap_data,pd.DataFrame(adata.obs).reset_index()], axis = 1)
    embedding = embedding.reset_index()
    embedding["Metadata_cmpdNameConc2"] = embedding["Metadata_cmpdName"].astype(str) + "_" + embedding["Metadata_cmpdConc"].astype(str)


    def get_color(val):
        if dmso_name in val:
            return "lightgrey"
        else:
            return "#e96565"  # This color will be overridden for non-[DMSO] treatments
    
    def get_size(val):
        return 10 if val != dmso_name else 3
    
    embedding['color'] = embedding[colouring].apply(get_color)
    embedding['size'] = embedding[colouring].apply(get_size)

    all_treatments = list(embedding[colouring].unique())
    sorted_treatments = all_treatments.copy()
    specific_value = dmso_name
    if specific_value in sorted_treatments:
        sorted_treatments.remove(specific_value)
    sorted_treatments.insert(0, specific_value)

    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    #cmap = plt.cm.viridis
    cmap = plt.cm.jet
    norm = plt.Normalize(vmin=0, vmax=1)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    is_dmso_only = len(all_treatments) == 1 and all_treatments[0] == dmso_name

    for treatment in sorted_treatments:
        subset = embedding[embedding[colouring] == treatment]
        
        if is_dmso_only:
            values = np.vstack([subset["UMAP1"], subset["UMAP2"]])
            kernel = stats.gaussian_kde(values)(values)
            colors = cmap(kernel)

            # Plot KDE for x and y axes
            sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=colors.mean(axis=0), legend=False)
            sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=colors.mean(axis=0), legend=False)

            # Scatter plot with density color
            g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=kernel, s=subset['size'], cmap=cmap, label=f"{treatment} - {len(subset)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)
        
        else:
            if treatment != dmso_name:
                # Calculate density for non-[DMSO] treatments
                values = np.vstack([subset["UMAP1"], subset["UMAP2"]])
                kernel = stats.gaussian_kde(values)(values)
                colors = cmap(kernel)

                # Plot KDE for x and y axes
                sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=colors.mean(axis=0), legend=False)
                sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=colors.mean(axis=0), legend=False)

                # Scatter plot with density color
                g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=kernel, s=subset['size'], cmap=cmap, label=f"{treatment} - {len(subset)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)
            else:
                # Plot for [DMSO] treatment
                sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color='lightgrey', legend=False)
                sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color='lightgrey', legend=False)
                g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c='lightgrey', s=subset['size'], label=f"{treatment} - {len(subset)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)
        # Overlay additional points if the option is active
    
    if overlay and overlay_df is not None:
        overlay_df['color'] = overlay_df[colouring].apply(get_color)
        overlay_df['size'] = overlay_df[colouring].apply(lambda val: get_size(val) * 2)  
        
        for treatment in sorted_treatments:
            subset = overlay_df[overlay_df[colouring] == treatment]
            g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], alpha=0.9, edgecolor='grey', linewidth=0.5)


    fig = g.fig  # Get the figure of the JointGrid
    cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.7])  # Add axes for the colorbar

    # Add colorbar to the figure, not the joint plot axes
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Relative density', rotation=270, labelpad=15) 

    # Adjust the figure to make space for the colorbar
    fig.subplots_adjust(right=0.9)

    g.ax_joint.set_title(cmpd)
    g.ax_joint.legend()
    if save_path is not None:
        plt.savefig(save_path, dpi = 300)
    plt.show()

In [None]:
make_jointplot_seaborn_density_anndata(emb_dict["dmso_only"], "Metadata_cmpdName", "DIMETHYL SULFOXIDE", "", save_path= "/home/jovyan/share/data/analyses/benjamin/cellxgene/deepprofiler/BEACTICA/figures/dmso_density_umap.png")

In [None]:
adata = ad.read("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cellcylce_grit_filtered.h5ad")

## Feature dist

In [None]:
def show_summary_stats(df):
    features = df.columns

# Plotting
    plt.figure(figsize=(12,6))

    # Mean line
    plt.plot(features, df.loc['mean'], label='Mean', color='blue')

    # 5th percentile line
    plt.plot(features, df.loc['5%'], label='5th Percentile', color='green')

    # 95th percentile line
    plt.plot(features, df.loc['95%'], label='95th Percentile', color='red')

    # Max values as dots
    plt.scatter(features, df.loc['max'], color='black', label='Max', s=5)  # s is the size of points
    plt.scatter(features, df.loc['min'], color='grey', label='Min', s=5)

    # Labels and title
    plt.xlabel('Features')
    plt.ylabel('Values')
    plt.title('Feature distributions')
    plt.xticks([])  # Rotate feature names for readability

    # Legend
    plt.legend()

    plt.tight_layout()  # Adjust layout
    plt.show()

In [None]:
show_summary_stats(summary_features)

In [None]:
def plot_grouped_feature_statistics(df, group_column, feature_columns):
    """
    Plot statistical summaries (mean, 5th, 95th percentiles, and max) of features for each group in the DataFrame.
    
    Parameters:
    df (DataFrame): The original pandas DataFrame with data.
    group_column (str): The name of the column to group by.
    feature_columns (list): List of columns to calculate statistics on.
    """
    # Grouping the DataFrame by the specified column
    grouped = df.groupby(group_column)

    # Determine the number of subplots needed
    n_groups = len(grouped)
    n_cols = 1  # You can adjust the number of columns per row
    n_rows = int(np.ceil(n_groups / n_cols))

    # Create a figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15 * n_cols, 10 * n_rows), squeeze=False)
    axes = axes.flatten()  # Flatten to 1D array for easy iteration

    for i, (group_name, group_data) in enumerate(grouped):
        # Calculating statistics for the group
        mean = group_data[feature_columns].mean()
        std = group_data[feature_columns].std()
        min_val = group_data[feature_columns].min()
        max_val = group_data[feature_columns].max()
        percentile_5 = group_data[feature_columns].quantile(0.05)
        percentile_95 = group_data[feature_columns].quantile(0.95)

        # Plotting on the ith subplot
        ax = axes[i]
        ax.plot(feature_columns, mean, label='Mean', color='blue')
        ax.plot(feature_columns, percentile_5, label='5th Percentile', color='green')
        ax.plot(feature_columns, percentile_95, label='95th Percentile', color='red')


        ax.set_title(f'Group: {group_name}')
        ax.set_xticks([])  # Remove x-axis labels

        if i == 0:  # Add legend to the first subplot as an example
            ax.legend()

    # Hide any unused subplots
    for j in range(i+1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

plot_grouped_feature_statistics(grit_filter_df_sampled_pd, "Metadata_cmpdName", features_fixed)

## Heatmap analysis

In [None]:
adata = ad.read("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cell_cellcycle2.h5ad")
adata_filtered = ad.read("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cellcylce_grit_filtered.h5ad")

In [None]:
# Get the sorted indices

sorted_indices = adata.obs["Metadata_cmpdName"].sort_values().index

# Reorder .X and .obs
adata = adata[sorted_indices]

In [None]:
import anndata
import numpy as np
import scanpy as sc
import scipy
import seaborn as sns
import matplotlib.colors as mcolors
# Inital setting for plot
import matplotlib.pyplot as plt
from matplotlib import rcParams


def plot_clipped_heatmap(adata, max_val=10, min_val=-10, genes=None, groupby=None):
    """
    Plot a heatmap from clipped data of an AnnData object.

    Parameters:
    adata (AnnData): The original AnnData object.
    max_val (float): Maximum value to clip data to.
    min_val (float): Minimum value to clip data to.
    genes (list): List of gene names to be plotted. They should match the var_names in adata.
    groupby (str): Name of the observation annotation to group by (usually categorical).

    Returns:
    None: Displays a heatmap.
    """

    # Step 1: Make a copy of the AnnData object to avoid overwriting original data
    adata_copy = adata.copy()

    # Step 2: Clip the data in the X matrix of the copied AnnData object
    # Check if 'X' is dense or sparse and clip accordingly
    if isinstance(adata_copy.X, np.ndarray):
        adata_copy.X = np.clip(adata_copy.X, a_min=min_val, a_max=max_val)
    elif isinstance(adata_copy.X, (scipy.sparse.csr_matrix, scipy.sparse.csc_matrix)):
        adata_copy.X.data = np.clip(adata_copy.X.data, a_min=min_val, a_max=max_val)
    else:
        raise TypeError("adata.X must be a numpy array or a scipy sparse matrix.")

    rcParams["figure.figsize"]  =(10,10)
    # Step 3: Use scanpy's pl.heatmap function to visualize the clipped data
    sc.pl.heatmap(adata_copy, var_names=genes, groupby=groupby, swap_axes= True, standard_scale = "obs")

# Example usage
# plot_clipped_heatmap(your_adata_object, max_val=10, min_val=-10, genes=your_genes_list, groupby='your_groupby_column')


In [None]:
with plt.rc_context({"figure.figsize": (8, 8), "figure.dpi": (300)}):
    sc.pl.heatmap(emb_dict["all"][emb_dict["all"].obs["Metadata_cmpdName"].isin(compounds + ["DIMETHYL SULFOXIDE"])], var_names=features_fixed, groupby="Metadata_cmpdName", dendrogram= False, swap_axes= False, vmin = -3, vmax = 3, cmap='RdBu_r')

In [None]:
def aggregate_by_group(adata, group_by, additional_columns=None):
    """
    Aggregate the expression data in an AnnData object by a specified group, retaining additional columns.

    Parameters:
    adata (AnnData): The original AnnData object.
    group_by (str): The column in adata.obs to group by.
    additional_columns (list of str, optional): List of additional columns to include in the output.

    Returns:
    AnnData: A new AnnData object with aggregated data and selected additional columns.
    """
    # Ensure the group_by column is categorical for efficiency
    adata.obs[group_by] = adata.obs[group_by].astype('category')
    
    if isinstance(adata.X, (np.ndarray, np.generic)):  # If .X is already a dense matrix
        adata_df = pd.DataFrame(adata.X, columns=adata.var_names)
    else:  # If .X is a sparse matrix
        adata_df = pd.DataFrame(adata.X.toarray(), columns=adata.var_names)

    # Group and aggregate data
    adata_df[group_by] = adata.obs[group_by].values

    # Aggregate data by taking the median for each group
    aggregated_data = adata_df.groupby(group_by).median()

    # Create a new AnnData object with the aggregated data
    aggregated_adata = anndata.AnnData(X=aggregated_data.values, var=adata.var.copy())
    aggregated_adata.obs[group_by] = aggregated_data.index.values

    # Include additional columns from adata.obs
    if additional_columns:
        for column in additional_columns:
            aggregated_adata.obs[column] = adata.obs.groupby(group_by)[column].first().values

    return aggregated_adata

In [None]:
aggregated = aggregate_by_group(adata, "Metadata_cmpdName", ["moa"])

In [None]:
aggregated_filtered = aggregate_by_group(adata_filtered, "Metadata_cmpdName", ["moa"])

In [None]:
with plt.rc_context({"figure.figsize": (15, 20), "figure.dpi": (300), 'lines.linewidth': 1}):
    sc.pl.heatmap(aggregated, dendrogram=True, var_names=features_fixed, groupby="Metadata_cmpdName", swap_axes=True, vmin=-2, vmax=2, cmap='RdBu_r', show = True)

In [None]:
with plt.rc_context({"figure.figsize": (15, 20), "figure.dpi": (300), 'lines.linewidth': 1}):
    sc.pl.correlation_matrix(adata, 'moa')

In [None]:
def create_heatmap_with_colorbar(adata, groupby_column, color_column, title="", cmap='viridis', figsize=(10, 8), vmin=None, vmax=None):
    """
    Create a heatmap from an aggregated AnnData object with specified labels on the y-axis and a color bar based on another column.

    Parameters:
    adata (AnnData): The aggregated AnnData object.
    groupby_column (str): Column in adata.obs to use for y-axis labels.
    color_column (str): Column in adata.obs to use for color mapping in the color bar.
    title (str, optional): Title of the heatmap.
    cmap (str, optional): Colormap for the heatmap.
    figsize (tuple, optional): Size of the figure.
    vmin, vmax (float, optional): Min and max values for colormap scaling.

    Returns:
    Heatmap plot with color bar
    """
    # Ensure the label and color columns are present
    if groupby_column not in adata.obs:
        raise ValueError(f"{groupby_column} not found in adata.obs")
    if color_column not in adata.obs:
        raise ValueError(f"{color_column} not found in adata.obs")

    adata.obs[color_column] = adata.obs[color_column].astype(str)
    # Extract group labels and assign colors
    color_labels = adata.obs[color_column].unique()
    colors = sns.color_palette('Set1', len(color_labels))
    # Create a color dictionary for the groups
    color_dict = dict(zip(color_labels, colors))

    # Convert the .X matrix to a DataFrame
    data_df = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)
    # Add the group column for color bar creation
    data_df[groupby_column] = adata.obs[groupby_column]
    # Add the color column for color mapping
    data_df[color_column] = adata.obs[color_column]

    # Creating the heatmap
    plt.figure(figsize=figsize, dpi=300)
    ax = sns.heatmap(data_df.drop(columns=[groupby_column, color_column]), cmap=cmap, annot=False, vmin=vmin, vmax=vmax)
    plt.title(title)
    plt.ylabel('')
    plt.xlabel('')
    plt.xticks([])  # Remove x-axis tick labels
    
    data_df = data_df.sort_values(by=groupby_column)

    # Extract the unique labels for the groups in the order they appear in the sorted data
    unique_labels = data_df[groupby_column].unique()

    # Ensure the y-ticks positions are set to match the number of unique labels
    ax.set_yticks(np.arange(len(unique_labels)) + 0.5)  # +0.5 to center the labels

    # Set the y-tick labels with the unique labels
    ax.set_yticklabels(unique_labels, rotation=0)

    for i in range(data_df.shape[0] - 1):
        ax.axhline(i + 1, color='grey', lw=0.5)
    
    color_mapping = adata.obs[color_column].map(color_dict)
    for i, row in enumerate(data_df.index):
        ax.add_patch(mpatches.Rectangle((-0.5, i), 5, 1, color=color_mapping[i]))
    # Add lines around the plot
    ax.axhline(0, color='black', lw=2)  # Top horizontal line
    ax.axhline(data_df.shape[0], color='black', lw=2)  # Bottom horizontal line
    ax.axvline(data_df.shape[1], color='black', lw=2) 
    xlim = ax.get_xlim()  # Get the current x-axis limits
    ax.axvline(x=xlim[1], color='black', lw=2) 

    color_values = data_df[color_column].values
    color_labels = data_df[color_column].values
    unique_colors = [color_dict[label] for label in color_labels]

    legend_labels = data_df[color_column].unique()
    color_bar = [mpatches.Patch(color=color_dict[label], label=label) for label in legend_labels]

    # Plot the color bar as a legend on the right side of the heatmap
    ax_legend = ax.figure.add_axes([0.95, 0.15, 0.02, 0.7])
    ax_legend.legend(handles=color_bar, title=color_column)
    ax_legend.axis('off')
    plt.savefig(f"heatmap_aggregated_{groupby_column}.png", dpi = 300)
    plt.show()

In [None]:
create_heatmap_with_colorbar(aggregated, groupby_column= "Metadata_cmpdName", color_column= "moa", vmin = -2, vmax = 2, figsize=(20, 20), cmap= "RdBu_r")

In [None]:
create_heatmap_with_colorbar(aggregated_filtered, groupby_column= "Metadata_cmpdName", color_column= "moa", vmin = -2, vmax = 2, figsize=(20, 20), cmap= "RdBu_r")

In [None]:
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage
def create_clustermap_adata(adata, groupby_column, title="", cmap='viridis', figsize=(15, 10), vmin=None, vmax=None, method='average', metric='euclidean'):
    """
    Create a clustermap from an aggregated AnnData object with specified labels on the y-axis and a dendrogram for the rows.

    Parameters:
    adata (AnnData): The aggregated AnnData object.
    groupby_column (str): Column in adata.obs to use for y-axis labels.
    title (str, optional): Title of the heatmap.
    cmap (str, optional): Colormap for the heatmap.
    figsize (tuple, optional): Size of the figure.
    vmin, vmax (float, optional): Min and max values for colormap scaling.
    method (str, optional): Linkage method to use for clustering (e.g., 'average', 'single', 'complete').
    metric (str, optional): Distance metric to use for clustering (e.g., 'euclidean', 'cityblock').

    Returns:
    Clustermap plot
    """
    # Ensure the label column is present
    if groupby_column not in adata.obs:
        raise ValueError(f"{groupby_column} not found in adata.obs")

    # Convert the .X matrix to a DataFrame
    data_df = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)
    # Add the group column for color bar creation
    data_df[groupby_column] = adata.obs[groupby_column]
    # Sort the DataFrame based on the groupby_column
    data_df = data_df.sort_values(by=groupby_column)

    # Use seaborn's clustermap function to draw the heatmap with dendrograms
    # Note: The `row_cluster` parameter is set to True to cluster rows, `col_cluster` is set to False to avoid clustering columns
    # The `row_linkage` parameter allows specifying custom linkage for rows, feel free to modify or remove it
    row_linkage = linkage(data_df.drop(columns=[groupby_column]), method=method, metric=metric)
    g = sns.clustermap(data_df.drop(columns=[groupby_column]), cmap=cmap, figsize=figsize, vmin=vmin, vmax=vmax, 
                       row_cluster=True, col_cluster=False, row_linkage=row_linkage,
                       yticklabels=True, xticklabels=[])

    # Rotate the yticklabels to horizontal
    plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0)
    g.ax_heatmap.set_yticklabels(data_df[groupby_column], rotation=0)
    # Set the title
    plt.suptitle(title)

    plt.show()

In [None]:
create_clustermap_adata(aggregated, "Metadata_cmpdName", figsize= (20,15), cmap = "RdBu_r", vmin = -2, vmax = 2)

In [None]:
from scipy.cluster import hierarchy
def calc_correlation_matrix(adata, col, layer, plot = False):
    # Choose the data representation ('X' for the main matrix, 'X_pca' for PCA)

    # Aggregate the data by the grouping variable
    if layer == 'X':
        data_matrix = adata.X
    elif layer == 'X_pca':
        data_matrix = adata.obsm['X_pca']

    # Convert the data to a DataFrame for easier manipulation
    df = pd.DataFrame(data_matrix, index=adata.obs_names)
    df[col] = adata.obs[col]

    # Group by the grouping variable and calculate the mean for each group
    grouped = df.groupby(col).mean()
    grouped_trans = grouped.T
    # Calculate the correlation matrix from the grouped data
    correlation_matrix = grouped_trans.corr()

    if plot:
        cmap = sns.diverging_palette(240, 10, sep=20, n=256, center="light")        
        plt.figure(figsize=(10, 8), dpi = 300)
        g = sns.clustermap(correlation_matrix, annot=False, fmt=".2f", col_cluster=True, cmap='bwr', center = 0, square=True, cbar_pos=(1, .4, .03, .4))
        plt.title('')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.grid(which="minor", color="black", linestyle='-', linewidth=1)
        plt.show()

    else:
        return correlation_matrix

In [None]:
def find_strongest_comps(corr_mat, N):
    reference_group = 'DIMETHYL SULFOXIDE_0.1'

# Get the correlations with the reference group
    correlations_with_reference = corr_mat[reference_group]

    # Sort the correlations to find the most dissimilar (i.e., those with the lowest correlation values)
    most_dissimilar_groups = correlations_with_reference.sort_values(ascending=True)

    top_n_dissimilar_groups = most_dissimilar_groups.head(N)
    return top_n_dissimilar_groups


In [None]:
compound_correlations = calc_correlation_matrix(adata_filtered, "Metadata_cmpdName", "X_pca", plot = False)
least_correlated_comps = find_strongest_comps(compound_correlations, 50)