# Analysis of DeepProfiler experiment SPECS3K

In [None]:
%matplotlib inline

import numpy as np
import pandas as pd
from tqdm import tqdm

from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import sklearn.metrics

import scipy.linalg
import scipy.spatial.distance
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc('ytick', labelsize=7)
import seaborn as sns
import os
import psutil
import polars as pl
import cuml
import pandas as pd
from typing import List
from sklearn.preprocessing import StandardScaler
import numpy as np 
from collections import OrderedDict
#from umap import UMAP

In [None]:
import psutil

# Get memory information
memory = psutil.virtual_memory()

# Total memory
total_memory_gb = memory.total / (1024 ** 3)  # Convert bytes to gigabytes


# Available memory
available_memory_gb = memory.available / (1024 ** 3)  # Convert bytes to gigabytes


num_cores = os.cpu_count()
print(f"Number of CPU cores: {num_cores}")

pid = os.getpid()
print(f"PID of this notebook: {pid}")
def notebook_memory_usage(pid):
    process = psutil.Process(pid)
    memory_info = process.memory_info()
    return memory_info.rss / (1024 ** 2)  # Convert bytes to MB

print(f"Total memory: {total_memory_gb:.2f} GB")
print(f"Available memory: {available_memory_gb:.2f} GB")
# Usage
print(f"Memory Usage of Notebook: {notebook_memory_usage(pid)} MB")

## Normal grit

In [None]:
from typing import List, Union
import pandas.api.types as ptypes
def evaluate(
    profiles: pd.DataFrame,
    features: List[str],
    meta_features: List[str],
    replicate_groups: Union[List[str], dict],
    operation: str = "replicate_reproducibility",
    similarity_metric: str = "pearson",
    grit_control_perts: List[str] = ["None"],
    grit_replicate_summary_method: str = "mean"
):
    r"""Evaluate profile quality and strength.

    For a given profile dataframe containing both metadata and feature measurement
    columns, use this function to calculate profile quality metrics. The function
    contains all the necessary arguments for specific evaluation operations.

    Parameters
    ----------
    profiles : pandas.DataFrame
        profiles must be a pandas DataFrame with profile samples as rows and profile
        features as columns. The columns should contain both metadata and feature
        measurements.
    features : list
        A list of strings corresponding to feature measurement column names in the
        `profiles` DataFrame. All features listed must be found in `profiles`.
    meta_features : list
        A list of strings corresponding to metadata column names in the `profiles`
        DataFrame. All features listed must be found in `profiles`.
    replicate_groups : {str, list, dict}
        An important variable indicating which metadata columns denote replicate
        information. All metric operations require replicate profiles.
        `replicate_groups` indicates a str or list of columns to use. For
        `operation="grit"`, `replicate_groups` is a dict with two keys: "profile_col"
        and "replicate_group_col". "profile_col" is the column name that stores
        identifiers for each profile (can be unique), while "replicate_group_col" is the
        column name indicating a higher order replicate information. E.g.
        "replicate_group_col" can be a gene column in a CRISPR experiment with multiple
        guides targeting the same genes. See also
        :py:func:`cytominer_eval.operations.grit` and
        :py:func:`cytominer_eval.transform.util.check_replicate_groups`.
    operation : {'replicate_reproducibility', 'precision_recall', 'grit', 'mp_value'}, optional
        The specific evaluation metric to calculate. The default is
        "replicate_reproducibility".
    groupby_columns : List of str
        Only used for operation = 'precision_recall' and 'hitk'
        Column by which the similarity matrix is grouped and by which the operation is calculated.
        For example, if groupby_column = "Metadata_broad_sample" then precision/recall is calculated for each sample.
        Note that it makes sense for these columns to be unique or to span a unique space
        since precision and hitk may otherwise stop making sense.
    similarity_metric: {'pearson', 'spearman', 'kendall'}, optional
        How to calculate pairwise similarity. Defaults to "pearson". We use the input
        in pandas.DataFrame.cor(). The default is "pearson".

    Returns
    -------
    float, pd.DataFrame
        The resulting evaluation metric. The return is either a single value or a pandas
        DataFrame summarizing the metric as specified in `operation`.

    Other Parameters
    -----------------------------
    replicate_reproducibility_quantile : {0.95, ...}, optional
        Only used when `operation='replicate_reproducibility'`. This indicates the
        percentile of the non-replicate pairwise similarity to consider a reproducible
        phenotype. Defaults to 0.95.
    replicate_reproducibility_return_median_cor : bool, optional
        Only used when `operation='replicate_reproducibility'`. If True, then also
        return pairwise correlations as defined by replicate_groups and
        similarity metric
    precision_recall_k : int or list of ints {10, ...}, optional
        Only used when `operation='precision_recall'`. Used to calculate precision and
        recall considering the top k profiles according to pairwise similarity.
    grit_control_perts : {None, ...}, optional
        Only used when `operation='grit'`. Specific profile identifiers used as a
        reference when calculating grit. The list entries must be found in the
        `replicate_groups[replicate_id]` column.
    grit_replicate_summary_method : {"mean", "median"}, optional
        Only used when `operation='grit'`. Defines how the replicate z scores are
        summarized. see
        :py:func:`cytominer_eval.operations.util.calculate_grit`
    mp_value_params : {{}, ...}, optional
        Only used when `operation='mp_value'`. A key, item pair of optional parameters
        for calculating mp value. See also
        :py:func:`cytominer_eval.operations.util.default_mp_value_parameters`
    enrichment_percentile : float or list of floats, optional
        Only used when `operation='enrichment'`. Determines the percentage of top connections
        used for the enrichment calculation.
    hitk_percent_list : list or "all"
        Only used when operation='hitk'. Default : [2,5,10]
        A list of percentages at which to calculate the percent scores, ie the amount of indexes below this percentage.
        If percent_list == "all" a full dict with the length of classes will be created.
        Percentages are given as integers, ie 50 means 50 %.
    """
    if operation != "mp_value":
        # Melt the input profiles to long format
        similarity_melted_df = metric_melt(
            df=profiles,
            features=features,
            metadata_features=meta_features,
            similarity_metric=similarity_metric,
            eval_metric=operation,
        )

    # Perform the input operation
    if operation == "grit":
        metric_result = grit(
            similarity_melted_df=similarity_melted_df,
            control_perts=grit_control_perts,
            profile_col=replicate_groups["profile_col"],
            replicate_group_col=replicate_groups["replicate_group_col"],
            replicate_summary_method=grit_replicate_summary_method,
        )

    return metric_result


def grit(
    similarity_melted_df: pd.DataFrame,
    control_perts: List[str],
    profile_col: str,
    replicate_group_col: str,
    replicate_summary_method: str = "mean",
) -> pd.DataFrame:
    r"""Calculate grit

    Parameters
    ----------
    similarity_melted_df : pandas.DataFrame
        a long pandas dataframe output from cytominer_eval.transform.metric_melt
    control_perts : list
        a list of control perturbations to calculate a null distribution
    profile_col : str
        the metadata column storing profile ids. The column can have unique or replicate
        identifiers.
    replicate_group_col : str
        the metadata column indicating a higher order structure (group) than the
        profile column. E.g. target gene vs. guide in a CRISPR experiment.
    replicate_summary_method : {'mean', 'median'}, optional
        how replicate z-scores to control perts are summarized. Defaults to "mean".

    Returns
    -------
    pandas.DataFrame
        A dataframe of grit measurements per perturbation
    """
    # Check if we support the provided summary method
    # Determine pairwise replicates
    similarity_melted_df = assign_replicates(
        similarity_melted_df=similarity_melted_df,
        replicate_groups=[profile_col, replicate_group_col],
    )

    # Check to make sure that the melted dataframe is full
    assert_melt(similarity_melted_df, eval_metric="grit")

    # Extract out specific columns
    pair_ids = set_pair_ids()
    profile_col_name = "{x}{suf}".format(
        x=profile_col, suf=pair_ids[list(pair_ids)[0]]["suffix"]
    )

    # Define the columns to use in the calculation
    column_id_info = set_grit_column_info(
        profile_col=profile_col, replicate_group_col=replicate_group_col
    )

    # Calculate grit for each perturbation
    grit_df = (
        similarity_melted_df.groupby(profile_col_name)
        .apply(
            lambda x: calculate_grit(
                replicate_group_df=x,
                control_perts=control_perts,
                column_id_info=column_id_info,
                replicate_summary_method=replicate_summary_method,
            )
        )
        .reset_index(drop=True)
    )

    return grit_df


def calculate_grit(
    replicate_group_df: pd.DataFrame,
    control_perts: List[str],
    column_id_info: dict,
    distribution_compare_method: str = "zscore",
    replicate_summary_method: str = "mean",
) -> pd.Series:
    """Given an elongated pairwise correlation dataframe of replicate groups,
    calculate grit.

    Usage: Designed to be called within a pandas.DataFrame().groupby().apply(). See
    :py:func:`cytominer_eval.operations.grit.grit`.

    Parameters
    ----------
    replicate_group_df : pandas.DataFrame
        An elongated dataframe storing pairwise correlations of all profiles to a single
        replicate group.
    control_perts : list
        The profile_ids that should be considered controls (the reference)
    column_id_info: dict
        A dictionary of column identifiers noting profile and replicate group ids. This
        variable is autogenerated in
        :py:func:`cytominer_eval.transform.util.set_grit_column_info`.
    distribution_compare_method : {'zscore'}, optional
        How to compare the replicate and reference distributions of pairwise similarity
    replicate_summary_method : {'mean', 'median'}, optional
        How to summarize replicate z-scores. Defaults to "mean".

    Returns
    -------
    dict
        A return bundle of identifiers (perturbation, group) and results (grit score).
        The dictionary has keys ("perturbation", "group", "grit_score"). "grit_score"
        will be NaN if no other profiles exist in the defined group.
    """
    # Confirm that we support the user provided methods
    group_entry = get_grit_entry(replicate_group_df, column_id_info["group"]["id"])
    pert = get_grit_entry(replicate_group_df, column_id_info["profile"]["id"])

    # Define distributions for control perturbations
    control_distrib = replicate_group_df.loc[
        replicate_group_df.loc[:, column_id_info["profile"]["comparison"]].isin(
            control_perts
        ),
        "similarity_metric",
    ].values.reshape(-1, 1)

    assert len(control_distrib) > 1, "Error! No control perturbations found."

    # Define distributions for same group (but not same perturbation)
    same_group_distrib = replicate_group_df.loc[
        (
            replicate_group_df.loc[:, column_id_info["group"]["comparison"]]
            == group_entry
        )
        & (replicate_group_df.loc[:, column_id_info["profile"]["comparison"]] != pert),
        "similarity_metric",
    ].values.reshape(-1, 1)

    return_bundle = {"perturbation": pert, "group": group_entry}
    if len(same_group_distrib) == 0:
        return_bundle["grit"] = np.nan

    else:
        grit_score = compare_distributions(
            target_distrib=same_group_distrib,
            control_distrib=control_distrib,
            method=distribution_compare_method,
            replicate_summary_method=replicate_summary_method,
        )

        return_bundle["grit"] = grit_score

    return pd.Series(return_bundle)

def convert_pandas_dtypes(df: pd.DataFrame, col_fix: type = float) -> pd.DataFrame:
    r"""Helper funtion to convert pandas column dtypes

    Parameters
    ----------
    df : pandas.DataFrame
        A pandas dataframe to convert columns
    col_fix : {float, str}, optional
        A column type to convert the input dataframe.

    Returns
    -------
    pd.DataFrame
        A dataframe with converted columns
    """
    try:
        df = df.astype(col_fix)
    except ValueError:
        raise ValueError(
            "Columns cannot be converted to {col}; check input features".format(
                col=col_fix
            )
        )

    return df


def get_grit_entry(df: pd.DataFrame, col: str) -> str:
    """Helper function to define the perturbation identifier of interest

    Grit must be calculated using unique perturbations. This may or may not mean unique
    perturbations.
    """
    entries = df.loc[:, col]
    assert (
        len(entries.unique()) == 1
    ), "grit is calculated for each perturbation independently"
    return str(list(entries)[0])




def set_pair_ids():
    r"""Helper function to ensure consistent melted pairiwise column names

    Returns
    -------
    collections.OrderedDict
        A length two dictionary of suffixes and indeces of two pairs.
    """
    pair_a = "pair_a"
    pair_b = "pair_b"

    return_dict = OrderedDict()
    return_dict[pair_a] = {
        "index": "{pair_a}_index".format(pair_a=pair_a),
        "suffix": "_{pair_a}".format(pair_a=pair_a),
    }
    return_dict[pair_b] = {
        "index": "{pair_b}_index".format(pair_b=pair_b),
        "suffix": "_{pair_b}".format(pair_b=pair_b),
    }

    return return_dict

def assign_replicates(
    similarity_melted_df: pl.DataFrame,
    replicate_groups: List[str],
) -> pl.DataFrame:
    """Determine which profiles should be considered replicates.

    Given an elongated pairwise correlation matrix with metadata annotations, determine
    how to assign replicate information.

    Parameters
    ----------
    similarity_melted_df : pandas.DataFrame
        Long pandas DataFrame of annotated pairwise correlations output from
        :py:func:`cytominer_eval.transform.transform.metric_melt`.
    replicate_groups : list
        a list of metadata column names in the original profile dataframe used to
        indicate replicate profiles.

    Returns
    -------
    pd.DataFrame
        A similarity_melted_df but with added columns indicating whether or not the
        pairwise similarity metric is comparing replicates or not. Used in most eval
        operations.
    """
    pair_ids = set_pair_ids()
    replicate_col_names = {x: "{x}_replicate".format(x=x) for x in replicate_groups}
    similarity_melted_df = similarity_melted_df.to_pandas()
    compare_dfs = []
    for replicate_col in replicate_groups:
        replicate_cols_with_suffix = [
            "{col}{suf}".format(col=replicate_col, suf=pair_ids[x]["suffix"])
            for x in pair_ids
        ]

        assert all(
            [x in similarity_melted_df.columns for x in replicate_cols_with_suffix]
        ), "replicate_group not found in melted dataframe columns"

        replicate_col_name = replicate_col_names[replicate_col]

        compare_df = similarity_melted_df.loc[:, replicate_cols_with_suffix]
        compare_df.loc[:, replicate_col_name] = False

        compare_df.loc[
            np.where(compare_df.iloc[:, 0] == compare_df.iloc[:, 1])[0],
            replicate_col_name,
        ] = True
        compare_dfs.append(compare_df)

    compare_df = pd.concat(compare_dfs, axis="columns").reset_index(drop=True)
    compare_df = compare_df.assign(
        group_replicate=compare_df.loc[:, replicate_col_names.values()].min(
            axis="columns"
        )
    ).loc[:, list(replicate_col_names.values()) + ["group_replicate"]]

    similarity_melted_df = similarity_melted_df.merge(
        compare_df, left_index=True, right_index=True
    )
    output = pl.DataFrame(similarity_melted_df)
    return output


def assert_melt(
    df: pd.DataFrame, eval_metric: str = "replicate_reproducibility"
) -> None:
    r"""Helper function to ensure that we properly melted the pairwise correlation
    matrix

    Downstream functions depend on how we process the pairwise correlation matrix. The
    processing is different depending on the evaluation metric.

    Parameters
    ----------
    df : pandas.DataFrame
        A melted pairwise correlation matrix
    eval_metric : str
        The user input eval metric

    Returns
    -------
    None
        Assertion will fail if we incorrectly melted the matrix
    """

    pair_ids = set_pair_ids()
    df = df.loc[:, [pair_ids[x]["index"] for x in pair_ids]]
    index_sums = df.sum().tolist()

    assert_error = "Stop! The eval_metric provided in 'metric_melt()' is incorrect!"
    assert_error = "{err} This is a fatal error providing incorrect results".format(
        err=assert_error
    )
    if eval_metric == "replicate_reproducibility":
        assert index_sums[0] != index_sums[1], assert_error
    elif eval_metric == "precision_recall":
        assert index_sums[0] == index_sums[1], assert_error
    elif eval_metric == "grit":
        assert index_sums[0] == index_sums[1], assert_error
    elif eval_metric == "hitk":
        assert index_sums[0] == index_sums[1], assert_error



def set_grit_column_info(profile_col: str, replicate_group_col: str) -> dict:
    """Transform column names to be used in calculating grit

    In calculating grit, the data must have a metadata feature describing the core
    replicate perturbation (profile_col) and a separate metadata feature(s) describing
    the larger group (replicate_group_col) that the perturbation belongs to (e.g. gene,
    MOA).

    Parameters
    ----------
    profile_col : str
        the metadata column storing profile ids. The column can have unique or replicate
        identifiers.
    replicate_group_col : str
        the metadata column indicating a higher order structure (group) than the
        profile column. E.g. target gene vs. guide in a CRISPR experiment.

    Returns
    -------
    dict
        A nested dictionary of renamed columns indicating how to determine replicates
    """
    # Identify column transform names
    pair_ids = set_pair_ids()

    profile_id_with_suffix = [
        "{col}{suf}".format(col=profile_col, suf=pair_ids[x]["suffix"])
        for x in pair_ids
    ]

    group_id_with_suffix = [
        "{col}{suf}".format(col=replicate_group_col, suf=pair_ids[x]["suffix"])
        for x in pair_ids
    ]

    col_info = ["id", "comparison"]
    profile_id_info = dict(zip(col_info, profile_id_with_suffix))
    group_id_info = dict(zip(col_info, group_id_with_suffix))

    column_id_info = {"profile": profile_id_info, "group": group_id_info}
    return column_id_info



def compare_distributions(
    target_distrib: List[float],
    control_distrib: List[float],
    method: str = "zscore",
    replicate_summary_method: str = "mean",
) -> float:
    """Compare two distributions and output a single score indicating the difference.

    Given two different vectors of distributions and a comparison method, determine how
    the two distributions are different.

    Parameters
    ----------
    target_distrib : np.array
        A list-like (e.g. numpy.array) of floats representing the first distribution.
        Must be of shape (n_samples, 1).
    control_distrib : np.array
        A list-like (e.g. numpy.array) of floats representing the second distribution.
        Must be of shape (n_samples, 1).
    method : str, optional
        A string indicating how to compare the two distributions. Defaults to "zscore".
    replicate_summary_method : str, optional
        A string indicating how to summarize the resulting scores, if applicable. Only
        in use when method="zscore".

    Returns
    -------
    float
        A single value comparing the two distributions
    """
    # Confirm that we support the provided methods

    if method == "zscore":
        scaler = StandardScaler()
        scaler.fit(control_distrib)
        scores = scaler.transform(target_distrib)

        if replicate_summary_method == "mean":
            scores = np.mean(scores)
        elif replicate_summary_method == "median":
            scores = np.median(scores)

    return scores

def assert_pandas_dtypes(df: pd.DataFrame, col_fix: type = float) -> pd.DataFrame:
    r"""Helper funtion to ensure pandas columns have compatible columns

    Parameters
    ----------
    df : pandas.DataFrame
        A pandas dataframe to convert columns
    col_fix : {float, str}, optional
        A column type to convert the input dataframe.

    Returns
    -------
    pd.DataFrame
        A dataframe with converted columns
    """
    assert col_fix in [str, float], "Only str and float are supported"

    df = convert_pandas_dtypes(df=df, col_fix=col_fix)

    assert_error = "Columns not successfully updated, is the dataframe consistent?"
    if col_fix == str:
        assert all([ptypes.is_string_dtype(df[x]) for x in df.columns]), assert_error

    if col_fix == float:
        assert all([ptypes.is_numeric_dtype(df[x]) for x in df.columns]), assert_error

    return df

def process_melt(
    df: pd.DataFrame,
    meta_df: pd.DataFrame,
    eval_metric: str = "replicate_reproducibility",
) -> pd.DataFrame:
    """Helper function to annotate and process an input similarity matrix

    Parameters
    ----------
    df : pandas.DataFrame
        A similarity matrix output from
        :py:func:`cytominer_eval.transform.transform.get_pairwise_metric`
    meta_df : pandas.DataFrame
        A wide matrix of metadata information where the index aligns to the similarity
        matrix index
    eval_metric : str, optional
        Which metric to ultimately calculate. Determines whether or not to keep the full
        similarity matrix or only one diagonal. Defaults to "replicate_reproducibility".

    Returns
    -------
    pandas.DataFrame
        A pairwise similarity matrix
    """
    # Confirm that the user formed the input arguments properly
    assert df.shape[0] == df.shape[1], "Matrix must be symmetrical"

    # Get identifiers for pairing metadata
    pair_ids = set_pair_ids()
    
    # Subset the pairwise similarity metric depending on the eval metric given:
    #   "replicate_reproducibility" - requires only the upper triangle of a symmetric matrix
    #   "precision_recall" - requires the full symmetric matrix (no diagonal)
    # Remove pairwise matrix diagonal and redundant pairwise comparisons
    if eval_metric == "replicate_reproducibility":
        upper_tri = get_upper_matrix(df)
        df = df.where(upper_tri)
    else:
        np.fill_diagonal(df.values, np.nan)

    # Convert pairwise matrix to melted (long) version based on index value
    metric_unlabeled_df = (
        pd.melt(
            df.reset_index(),
            id_vars="index",
            value_vars=df.columns,
            var_name=pair_ids["pair_b"]["index"],
            value_name="similarity_metric",
        )
        .dropna()
        .reset_index(drop=True)
        .rename({"index": pair_ids["pair_a"]["index"]}, axis="columns")
    )

    # Merge metadata on index for both comparison pairs
    output_df = meta_df.merge(
        meta_df.merge(
            metric_unlabeled_df,
            left_index=True,
            right_on=pair_ids["pair_b"]["index"],
        ),
        left_index=True,
        right_on=pair_ids["pair_a"]["index"],
        suffixes=[pair_ids["pair_a"]["suffix"], pair_ids["pair_b"]["suffix"]],
    ).reset_index(drop=True)

    return output_df


def metric_melt(
    df: pd.DataFrame,
    features: List[str],
    metadata_features: List[str],
    eval_metric: str = "replicate_reproducibility",
    similarity_metric: str = "pearson",
) -> pd.DataFrame:
    """Helper function to fully transform an input dataframe of metadata and feature
    columns into a long, melted dataframe of pairwise metric comparisons between
    profiles.

    Parameters
    ----------
    df : pandas.DataFrame
        A profiling dataset with a mixture of metadata and feature columns
    features : list
        Which features make up the profile; included in the pairwise calculations
    metadata_features : list
        Which features are considered metadata features; annotate melted dataframe and
        do not use in pairwise calculations.
    eval_metric : str, optional
        Which metric to ultimately calculate. Determines whether or not to keep the full
        similarity matrix or only one diagonal. Defaults to "replicate_reproducibility".
    similarity_metric : str, optional
        The pairwise comparison to calculate

    Returns
    -------
    pandas.DataFrame
        A fully melted dataframe of pairwise correlations and associated metadata
    """
    # Subset dataframes to specific features
    df = df.reset_index(drop=True)

    assert all(
        [x in df.columns for x in metadata_features]
    ), "Metadata feature not found"
    assert all([x in df.columns for x in features]), "Profile feature not found"

    meta_df = df.loc[:, metadata_features]
    df = df.loc[:, features]

    # Convert pandas column types and assert conversion success
    meta_df = assert_pandas_dtypes(df=meta_df, col_fix=str)
    df = assert_pandas_dtypes(df=df, col_fix=float)

    # Get pairwise metric matrix
    pair_df = get_pairwise_metric(df=df, similarity_metric=similarity_metric)

    # Convert pairwise matrix into metadata-labeled melted matrix
    output_df = process_melt(df=pair_df, meta_df=meta_df, eval_metric=eval_metric)

    return output_df


def get_pairwise_metric(df: pd.DataFrame, similarity_metric: str) -> pd.DataFrame:
    """Helper function to output the pairwise similarity metric for a feature-only
    dataframe.

    Parameters
    ----------
    df : pandas.DataFrame
        Samples x features, where all columns can be coerced to floats
    similarity_metric : str
        The pairwise comparison to calculate

    Returns
    -------
    pandas.DataFrame
        A pairwise similarity matrix
    """
    df = assert_pandas_dtypes(df=df, col_fix=float)

    pair_df = df.transpose().corr(method=similarity_metric)

    # Check if the metric calculation went wrong
    # (Current pandas version makes this check redundant)
    if pair_df.shape == (0, 0):
        raise TypeError(
            "Something went wrong - check that 'features' are profile measurements"
        )

    return pair_df



def get_upper_matrix(df: pl.DataFrame, batch_size = 10000) -> np.ndarray:
    """
    Generate an upper triangle mask for a large DataFrame in batches.

    Parameters
    ----------
    df : pandas.DataFrame
        A DataFrame for which the upper triangle mask is to be created.
    batch_size : int
        The size of each batch.

    Returns
    -------
    np.ndarray
        An upper triangle matrix the same shape as the input DataFrame.
    """
    nrows, ncols = df.shape
    upper_matrix = np.zeros((nrows, ncols), dtype=bool)
    
    for start_row in range(0, nrows, batch_size):
        end_row = min(start_row + batch_size, nrows)
        for start_col in range(0, ncols, batch_size):
            end_col = min(start_col + batch_size, ncols)
            
            # Create a mask for the current batch
            batch_mask = np.triu(np.ones((end_row - start_row, end_col - start_col)), k=1).astype(bool)
            
            # Place the batch mask in the corresponding position of the full matrix
            upper_matrix[start_row:end_row, start_col:end_col] = batch_mask

    return upper_matrix


In [None]:
grit_plate1 = pd.read_parquet("sc_grit_Plate1.parquet")

## Polars grit

In [None]:
def evaluate_grit(
    profiles: pl.DataFrame,
    features: List[str],
    meta_features: List[str],
    replicate_groups: dict,
    operation: str = "replicate_reproducibility",
    similarity_metric: str = "pearson",
    grit_control_perts: List[str] = ["None"],
    grit_replicate_summary_method: str = "mean",
):
    """Evaluate profile quality using the 'grit' metric in a Polars DataFrame.

    Parameters
    ----------
    profiles : pl.DataFrame
        Profiles must be a Polars DataFrame with profile samples as rows and profile
        features as columns. The columns should contain both metadata and feature
        measurements.
    features : list
        A list of strings corresponding to feature measurement column names in the
        `profiles` DataFrame. All features listed must be found in `profiles`.
    meta_features : list
        A list of strings corresponding to metadata column names in the `profiles`
        DataFrame. All features listed must be found in `profiles`.
    replicate_groups : dict
        A dict with keys "profile_col" and "replicate_group_col" indicating columns
        that store identifiers for each profile and higher order replicate information,
        respectively.
    similarity_metric: str, optional
        How to calculate pairwise similarity. Defaults to "pearson".
    grit_control_perts : list, optional
        Specific profile identifiers used as a reference when calculating grit.
    grit_replicate_summary_method : str, optional
        Defines how the replicate z scores are summarized, either "mean" or "median".

    Returns
    -------
    pl.DataFrame
        The resulting 'grit' metric as a Polars DataFrame.
    """
    if operation != "mp_value":
        # Melt the input profiles to long format
        similarity_melted_df = metric_melt(
            df=profiles,
            features=features,
            metadata_features=meta_features,
            similarity_metric=similarity_metric,
            eval_metric=operation,
        )
    metric_result = grit(
            similarity_melted_df=similarity_melted_df,
            control_perts=grit_control_perts,
            profile_col=replicate_groups["profile_col"],
            replicate_group_col=replicate_groups["replicate_group_col"],
            replicate_summary_method=grit_replicate_summary_method,
        )
    return metric_result

def grit_pl(
    similarity_melted_df: pl.DataFrame,
    control_perts: List[str],
    profile_col: str,
    replicate_group_col: str,
    replicate_summary_method: str = "mean",
) -> pl.DataFrame:

    # Determine pairwise replicates
    similarity_melted_df = assign_replicates(
        similarity_melted_df=similarity_melted_df,
        replicate_groups=[profile_col, replicate_group_col],
    )

    # Check to make sure that the melted dataframe is full
    assert_melt(similarity_melted_df, eval_metric="grit")

    # Extract out specific columns
    pair_ids = set_pair_ids()
    profile_col_name = "{x}{suf}".format(
        x=profile_col, suf=pair_ids[list(pair_ids)[0]]["suffix"]
    )
    # Define the columns to use in the calculation
    column_id_info = set_grit_column_info(
        profile_col=profile_col, replicate_group_col=replicate_group_col
    )
    results = []
    # Calculate grit for each perturbation
    for value in similarity_melted_df[profile_col_name].unique():
        group_df = similarity_melted_df.filter(pl.col(profile_col_name) == value)
        print(group_df)
        result_df = calculate_grit(
            replicate_group_df=group_df,
            control_perts=control_perts,
            column_id_info=column_id_info,
            replicate_summary_method=replicate_summary_method
        )
        results.append(result_df)

# Concatenate all results into a single DataFrame
    grit_df = pl.concat(results)

    return grit_df

def assert_melt(
    df: pl.DataFrame, eval_metric: str = "replicate_reproducibility"
) -> None:
    r"""Helper function to ensure that we properly melted the pairwise correlation
    matrix

    Downstream functions depend on how we process the pairwise correlation matrix. The
    processing is different depending on the evaluation metric.

    Parameters
    ----------
    df : pandas.DataFrame
        A melted pairwise correlation matrix
    eval_metric : str
        The user input eval metric

    Returns
    -------
    None
        Assertion will fail if we incorrectly melted the matrix
    """
    df = df.to_pandas()
    pair_ids = set_pair_ids()
    df = df.loc[:, [pair_ids[x]["index"] for x in pair_ids]]
    index_sums = df.sum().tolist()

    assert_error = "Stop! The eval_metric provided in 'metric_melt()' is incorrect!"
    assert_error = "{err} This is a fatal error providing incorrect results".format(
        err=assert_error
    )
    if eval_metric == "replicate_reproducibility":
        assert index_sums[0] != index_sums[1], assert_error
    elif eval_metric == "precision_recall":
        assert index_sums[0] == index_sums[1], assert_error
    elif eval_metric == "grit":
        assert index_sums[0] == index_sums[1], assert_error
    elif eval_metric == "hitk":
        assert index_sums[0] == index_sums[1], assert_error



def calculate_grit_pl(
    replicate_group_df: pd.DataFrame,
    control_perts: List[str],
    column_id_info: dict,
    distribution_compare_method: str = "zscore",
    replicate_summary_method: str = "mean",
):
    """
    Calculate grit using Polars.

    Parameters
    ----------
    replicate_group_df : pl.DataFrame
        A DataFrame storing pairwise correlations of all profiles to a single replicate group.
    control_perts : list
        The profile_ids that should be considered controls (the reference).
    column_id_info : dict
        A dictionary of column identifiers noting profile and replicate group ids.
    distribution_compare_method : str, optional
        How to compare the replicate and reference distributions of pairwise similarity.
    replicate_summary_method : str, optional
        How to summarize replicate z-scores. Defaults to "mean".

    Returns
    -------
    pl.DataFrame
        A DataFrame with grit score per perturbation.
    """

    # We must have helper functions like check_compare_distribution_method and 
    # check_replicate_summary_method implemented for Polars or validated beforehand.
    replicate_group_df = pl.DataFrame(replicate_group_df)
    group_entry = get_grit_entry(replicate_group_df, column_id_info["group"]["id"])
    pert = get_grit_entry(replicate_group_df, column_id_info["profile"]["id"])
    # Define distributions for control perturbations
    control_distrib = replicate_group_df.filter(
        pl.col(column_id_info["profile"]["comparison"]).is_in(control_perts)
    )["similarity_metric"].to_numpy().reshape(-1, 1)
    assert control_distrib.shape[0] > 1, "Error! No control perturbations found."

    # Define distributions for the same group (but not the same perturbation)
    same_group_distrib = replicate_group_df.filter(
        (pl.col(column_id_info["group"]["comparison"]) == group_entry) &
        (pl.col(column_id_info["profile"]["comparison"]) != pert)
    )["similarity_metric"].to_numpy().reshape(-1, 1)

    # Compute the grit score
    # Assuming compare_distributions is a function that can operate on numpy arrays
    # and return a single grit score.
    if same_group_distrib.size == 0:
        grit_score = None
    else:
        grit_score = compare_distributions(
            target_distrib=same_group_distrib,
            control_distrib=control_distrib,
            method=distribution_compare_method,
            replicate_summary_method=replicate_summary_method,
        )

    return_bundle = {
        "perturbation": [pert],
        "group": [group_entry],
        "grit": [grit_score if grit_score is not None else float('nan')]
    }
    print(return_bundle)
    return pd.series(return_bundle)

def assign_replicates_pl(
    similarity_melted_df: pl.DataFrame,
    replicate_groups: List[str]
) -> pl.DataFrame:
    """Determine which profiles should be considered replicates in a Polars DataFrame.

    Parameters
    ----------
    similarity_melted_df : pl.DataFrame
        Long Polars DataFrame of annotated pairwise correlations.
    replicate_groups : list
        a list of metadata column names in the original profile dataframe used to
        indicate replicate profiles.

    Returns
    -------
    pl.DataFrame
        Updated similarity_melted_df with additional columns indicating replicate comparisons.
    """
    print(similarity_melted_df.columns)
    pair_ids = set_pair_ids()
    replicate_col_names = {x: "{x}_replicate".format(x=x) for x in replicate_groups}
    #print(pair_ids)
    compare_dfs_exprs = []
    for replicate_col in replicate_groups:
        replicate_cols_with_suffix = [
            pl.col(f"{replicate_col}{pair_ids[x]['suffix']}")
            for x in pair_ids
        ]
        print(f"{replicate_col}{pair_ids['pair_b']['suffix']}")
        # Check if all replicate columns are present
        assert all(
            [f"{replicate_col}{pair_ids[x]['suffix']}" in similarity_melted_df.columns for x in pair_ids]
        ), "replicate_group not found in melted dataframe columns"

        replicate_col_name = replicate_col_names[replicate_col]
        
        # Create expressions for each comparison
        compare_dfs_exprs.extend([
            pl.when(replicate_cols_with_suffix[0] == replicate_cols_with_suffix[1])
             .then(True)
             .otherwise(False)
             .alias(replicate_col_name)
        ])

    # Apply all expressions and create a group_replicate column
    group_replicate_expr = pl.fold(
        acc=True,
        f=lambda acc, x: acc & x,
        exprs=[pl.col(name) for name in replicate_col_names.values()]
    ).alias('group_replicate')

    def min_rowwise(*columns):
        return reduce(lambda a, b: pl.when(a < b).then(a).otherwise(b), columns)

    # Calculate row-wise minimum by dynamically unpacking the replicate columns
    group_replicate_expr = min_rowwise(*[pl.col(name) for name in replicate_col_names.values()])

    # Add the new group_replicate column
    compare_df = compare_df.with_column(group_replicate_expr.alias("group_replicate"))

    # Select the columns of interest (replicate_col_names and group_replicate)
    compare_df = compare_df.select(
        list(replicate_col_names.values()) + ["group_replicate"]
    )
    compare_dfs_exprs.append(group_replicate_expr)

    # Add the comparison columns to the original DataFrame
    similarity_melted_df = similarity_melted_df.with_columns(compare_dfs_exprs)

    return similarity_melted_df


def metric_melt(
    df: pl.DataFrame,
    features: List[str],
    metadata_features: List[str],
    eval_metric: str = "replicate_reproducibility",
    similarity_metric: str = "pearson",
) -> pl.DataFrame:
    # Make sure all features and metadata features are present
    assert all([x in df.columns for x in metadata_features]), "Metadata feature not found"
    assert all([x in df.columns for x in features]), "Profile feature not found"

    # Subset DataFrame to specific features and metadata features
    meta_df = df.select(metadata_features)
    feature_df = df.select(features)

    # Get pairwise metric matrix
    # Assuming get_pairwise_metric_pl is a rewritten version of get_pairwise_metric for Polars
    pair_df = get_pairwise_metric(df=feature_df.to_pandas(), similarity_metric=similarity_metric)
    
    # Convert pairwise matrix into metadata-labeled melted matrix
    # Assuming process_melt_pl is a rewritten version of process_melt for Polars
    output_df = process_melt(df=pair_df, meta_df=meta_df, eval_metric=eval_metric)

    return output_df

def get_pairwise_metric(df: pd.DataFrame, similarity_metric: str = 'pearson') -> pd.DataFrame:
    """Calculate the pairwise similarity metric for a feature-only DataFrame using numpy.

    Parameters
    ----------
    df : pd.DataFrame
        Samples x features, where all columns can be coerced to floats
    similarity_metric : str, optional
        The pairwise comparison to calculate, defaults to 'pearson'.

    Returns
    -------
    pd.DataFrame
        A pairwise similarity matrix
    """
    assert similarity_metric in ["pearson"], "Unsupported similarity metric"

    # Convert the DataFrame to numpy array and transpose it
    data = df.to_numpy().T

    # Calculate correlation using numpy
    if similarity_metric == 'pearson':
        corr_matrix = np.corrcoef(data)
    else:
        NotImplementedError("Only pearson correlation valid")

    corr_df = pd.DataFrame(corr_matrix)
    # Convert the numpy array back to pandas DataFrame
    return pl.DataFrame(corr_df)



def get_grit_entry_pl(df: pl.DataFrame, col: str) -> str:
    """
    Helper function to define the perturbation identifier of interest for Polars DataFrame.

    Grit must be calculated using unique perturbations. This may or may not mean unique
    perturbations in the DataFrame.

    Parameters
    ----------
    df : pl.DataFrame
        The Polars DataFrame containing the data.
    col : str
        The name of the column from which to extract the perturbation identifier.

    Returns
    -------
    str
        The unique entry from the specified column.

    Raises
    ------
    AssertionError
        If there is not exactly one unique value in the column.
    """
    entries = df.select(col).unique().to_series()
    assert entries.len() == 1, "grit is calculated for each perturbation independently"
    return entries[0]


def process_melt_pl(
    df: pl.DataFrame,
    meta_df: pl.DataFrame,
    eval_metric: str = "replicate_reproducibility",
) -> pl.DataFrame:
    assert df.height == df.width, "Matrix must be symmetrical"

    pair_ids = set_pair_ids()  # Assuming this function is properly adapted for Polars

    if eval_metric == "replicate_reproducibility":
        upper_tri = get_upper_matrix(df)
        df = df.mask(upper_tri)
    else:
        df2 = df.to_pandas()
        np.fill_diagonal(df2.values, np.nan)
        df3 = pl.DataFrame(df2)
    df_with_index = df3.with_row_count("index")
    # Melt the DataFrame
    metric_unlabeled_df = df_with_index.melt(id_vars="index", value_vars=df.columns, variable_name=pair_ids["pair_b"]["index"], value_name="similarity_metric")

# Dropping NA values
    metric_unlabeled_df = metric_unlabeled_df.filter(pl.col("similarity_metric").is_not_null())

    # Renaming the column 'index' to pair_ids["pair_a"]["index"]
    metric_unlabeled_df = metric_unlabeled_df.rename({"index": pair_ids["pair_a"]["index"]})

    # Merge metadata
    if "index" not in meta_df.columns:
        meta_df = meta_df.with_row_count("index")
        meta_df = meta_df.with_columns(meta_df["index"].cast(pl.Utf8))

    metric_unlabeled_df = metric_unlabeled_df.with_columns(metric_unlabeled_df[pair_ids["pair_b"]["index"]].cast(pl.Utf8))

    # First Merge
    # Joining metric_unlabeled_df with meta_df

    for col in metric_unlabeled_df.columns:
        if col in meta_df.columns and col not in [pair_ids["pair_b"]["index"], "index"]:
            metric_unlabeled_df = metric_unlabeled_df.rename({col: col + pair_ids["pair_b"]["suffix"]})

    first_merge = metric_unlabeled_df.join(
        meta_df,
        left_on=pair_ids["pair_b"]["index"],
        right_on="index"
    )

    first_merge = first_merge.with_columns(first_merge[pair_ids["pair_a"]["index"]].cast(pl.Utf8))
    # Second Merge
    # Joining the result of the first merge with meta_df again
    for col in first_merge.columns:
        if col in meta_df.columns and col not in [pair_ids["pair_a"]["index"], "index"]:
            first_merge = first_merge.rename({col: col + pair_ids["pair_a"]["suffix"]})

    output_df = first_merge.join(
        meta_df,
        left_on=pair_ids["pair_a"]["index"],
        right_on="index")
    
    return output_df

def process_melt(
    df: pl.DataFrame,
    meta_df: pl.DataFrame,
    eval_metric: str = "replicate_reproducibility",
) -> pl.DataFrame:
    """Helper function to annotate and process an input similarity matrix

    Parameters
    ----------
    df : pandas.DataFrame
        A similarity matrix output from
        :py:func:`cytominer_eval.transform.transform.get_pairwise_metric`
    meta_df : pandas.DataFrame
        A wide matrix of metadata information where the index aligns to the similarity
        matrix index
    eval_metric : str, optional
        Which metric to ultimately calculate. Determines whether or not to keep the full
        similarity matrix or only one diagonal. Defaults to "replicate_reproducibility".

    Returns
    -------
    pandas.DataFrame
        A pairwise similarity matrix
    """
    # Confirm that the user formed the input arguments properly

    df = df.to_pandas()
    meta_df = meta_df.to_pandas()
    assert df.shape[0] == df.shape[1], "Matrix must be symmetrical"


    # Get identifiers for pairing metadata
    pair_ids = set_pair_ids()

    # Subset the pairwise similarity metric depending on the eval metric given:
    #   "replicate_reproducibility" - requires only the upper triangle of a symmetric matrix
    #   "precision_recall" - requires the full symmetric matrix (no diagonal)
    # Remove pairwise matrix diagonal and redundant pairwise comparisons
    if eval_metric == "replicate_reproducibility":
        upper_tri = get_upper_matrix(df)
        df = df.where(upper_tri)
    else:
        np.fill_diagonal(df.values, np.nan)
    # Convert pairwise matrix to melted (long) version based on index value
    metric_unlabeled_df = (
        pd.melt(
            df.reset_index(),
            id_vars="index",
            value_vars=df.columns,
            var_name=pair_ids["pair_b"]["index"],
            value_name="similarity_metric",
        )
        .dropna()
        .reset_index(drop=True)
        .rename({"index": pair_ids["pair_a"]["index"]}, axis="columns")
    )

    # Merge metadata on index for both comparison pairs
    metric_unlabeled_df['pair_b_index'] = metric_unlabeled_df['pair_b_index'].astype('int64')
    
    output_df = meta_df.merge(
        meta_df.merge(
            metric_unlabeled_df,
            left_index=True,
            right_on=pair_ids["pair_b"]["index"],
        ),
        left_index=True,
        right_on=pair_ids["pair_a"]["index"],
        suffixes=[pair_ids["pair_a"]["suffix"], pair_ids["pair_b"]["suffix"]],
    ).reset_index(drop=True)

    output_df2= pl.DataFrame(output_df)
    return output_df2


def get_upper_matrix(df: pl.DataFrame, batch_size = 10000) -> np.ndarray:
    """
    Generate an upper triangle mask for a large DataFrame in batches.

    Parameters
    ----------
    df : pandas.DataFrame
        A DataFrame for which the upper triangle mask is to be created.
    batch_size : int
        The size of each batch.

    Returns
    -------
    np.ndarray
        An upper triangle matrix the same shape as the input DataFrame.
    """
    nrows, ncols = df.shape
    upper_matrix = np.zeros((nrows, ncols), dtype=bool)
    
    for start_row in range(0, nrows, batch_size):
        end_row = min(start_row + batch_size, nrows)
        for start_col in range(0, ncols, batch_size):
            end_col = min(start_col + batch_size, ncols)
            
            # Create a mask for the current batch
            batch_mask = np.triu(np.ones((end_row - start_row, end_col - start_col)), k=1).astype(bool)
            
            # Place the batch mask in the corresponding position of the full matrix
            upper_matrix[start_row:end_row, start_col:end_col] = batch_mask

    return upper_matrix


## Set paths and load functions

In [None]:
featdir = "outputs/results/features"
PROJECT_ROOT = "/share/data/analyses/benjamin/Single_cell_project/DP_specs3k/"
#EXP = "cp_dataset"
#OUTPUT_FILE = "well_level_data_efn128combinedplatesout_conv6a_1e-2_e30.csv"
#MATRIX_FILE = "cos_efn128combinedplatesout_conv6a_1e-2_e30.csv"
REG_PARAM = 1e-2


In [None]:
def make_plot_custom(embedding, colouring, save_dir=False, file_name="file_name", name="Emb type", description="details"):
    # Set the background to white
    sns.set(style="whitegrid", rc={"figure.figsize": (18, 12),'figure.dpi': 300, "axes.facecolor": "white", "grid.color": "white"})
    
    # Create a custom palette for the treatments of interest
    unique_treatments = set(embedding[colouring])
    custom_palette = sns.color_palette("hls", len(unique_treatments))
    color_dict = {treatment: color for treatment, color in zip(unique_treatments, custom_palette)}
    
    # Make the "Control" group grey
    if "DMSO_0.1%" in color_dict:
        color_dict["DMSO_0.1%"] = "lightgrey"
    
    # Create a size mapping
    size_dict = {treatment: 20 if treatment != "DMSO_0.1%" else 8 for treatment in unique_treatments}
    embedding['size'] = embedding[colouring].map(size_dict)
    
    # Create the scatter plot
    sns_plot = sns.scatterplot(data=embedding, x="UMAP 1", y="UMAP 2", hue=colouring, size='size', palette=color_dict, sizes=(8, 25), linewidth=0.1, alpha=0.9)
    
    plt.suptitle(f"{name}_{file_name}", fontsize=16)
    sns_plot.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    sns_plot.set_title("CLS Token embedding of "+str(len(embedding))+" cells" + " \n"+description, fontsize=12)
    sns.move_legend(sns_plot, "lower left", title='Treatments', prop={'size': 10}, title_fontsize=12, markerscale=0.5)
    
    # Remove grid lines
    sns.despine(bottom=True, left=True)
    
    if save_dir == True:
        # Save the figure with the specified DPI
        sns_plot.figure.savefig(f"{save_dir}{file_name}{name}.png", dpi=600)  # Changed DPI to 600
        sns_plot.figure.savefig(f"{save_dir}pdf_format/{file_name}{name}.pdf", dpi=600)  # Changed DPI to 600
    
    plt.show()

import plotly.express as px
import plotly.graph_objects as go

def make_plot_custom_plotly(embedding, colouring, save_dir=False, file_name="file_name", name="Emb type", description="details"):
    if isinstance(embedding, pl.DataFrame):
        embedding = embedding.to_pandas()
    # Create a custom palette for the treatments of interest
    unique_treatments = set(embedding[colouring])
    custom_palette = px.colors.qualitative.Set1[:len(unique_treatments)]
    color_dict = {treatment: color for treatment, color in zip(unique_treatments, custom_palette)}
    
    # Make the "Control" group grey
    if "DMSO_0.1%" in color_dict:
        color_dict["DMSO_0.1%"] = "lightgrey"
    
    # Create a size mapping
    size_dict = {treatment: 6 if treatment != "DMSO_0.1%" else 3 for treatment in unique_treatments}
    embedding['size'] = embedding[colouring].map(size_dict)
    
    # Create a Plotly scatter plot
    fig = px.scatter(embedding, x="UMAP1", y="UMAP2", color=colouring, size='size', color_discrete_map=color_dict, size_max=6, opacity=0.9, width=1000, height=700)

    # Customize the plot
    fig.update_layout(
        title={
            'text': f"{name} {file_name} - UMAP embedding of {len(embedding)} points \n{description}",
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'},
        plot_bgcolor='white',
        showlegend=True,
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        xaxis_showgrid=False,
        yaxis_showgrid=False
    )
    
    # Remove axis ticks and labels
    fig.update_xaxes(showticklabels=False, zeroline=False)
    fig.update_yaxes(showticklabels=False, zeroline=False)

    # Save the figure if required
    if save_dir:
        fig.write_image(f"{save_dir}{file_name}{name}.png")
        fig.write_image(f"{save_dir}pdf_format/{file_name}{name}.pdf")
    
    # Show the plot
    fig.show()


import plotly.express as px
import plotly.io as pio

def make_plot_custom_plotly_print(embedding, colouring, save_dir=False, file_name="file_name", name="Emb type", description="details"):
    
    # Define your custom color mapping based on conditions
    def get_color(val):
        if "_1" in val:
            return "#28B6D2"
        elif "_5" in val:
            return "#e96565"
        elif val == "DMSO_0.1%":
            return "lightgrey"
        else:
            return "grey"  # A default color for treatments that don't fit any of the conditions
    
    unique_treatments = set(embedding[colouring])
    color_dict = {treatment: get_color(treatment) for treatment in unique_treatments}
    low_conc = embedding[embedding[colouring].str.contains("_1")][colouring].unique()[0]
    high_conc = embedding[embedding[colouring].str.contains("_5")][colouring].unique()[0]
    # Create a size mapping
    size_dict = {treatment: 8 if treatment != "DMSO_0.1%" else 5 for treatment in unique_treatments}
    embedding['size'] = embedding[colouring].map(size_dict)
    
    fig = go.Figure()
    df_dmso = embedding[embedding[colouring] == "DMSO_0.1%"]
    fig.add_trace(
        go.Scatter(
            x=df_dmso["UMAP 1"],
            y=df_dmso["UMAP 2"],
            mode="markers",
            marker=dict(
                color=color_dict["DMSO_0.1%"],
                size=df_dmso['size']
            ),
            name="DMSO_0.1%"
        )
    )

    # Add "_1" dots
    df_1 = embedding[embedding[colouring].str.contains("_1")]
    fig.add_trace(
        go.Scatter(
            x=df_1["UMAP 1"],
            y=df_1["UMAP 2"],
            mode="markers",
            marker=dict(
                color=color_dict[low_conc],
                size=df_1['size']
            ),
            opacity=1,
            name=str(low_conc))
        )
    

    # Add "_5" dots
    df_5 = embedding[embedding[colouring].str.contains("_5")]
    fig.add_trace(
        go.Scatter(
            x=df_5["UMAP 1"],
            y=df_5["UMAP 2"],
            mode="markers",
            marker=dict(
                color=color_dict[high_conc],
                size=df_5['size']
            ),
            opacity=1,
            name=str(high_conc)
        )
    )

    # Customize the plot
    fig.update_layout(
        width = 1200,
        height = 900,
        title={
            'text': f"{name}_{file_name} embedding of {len(embedding)} cells \n{description}",
            'y':0.95,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'},
        plot_bgcolor='white',
        showlegend=True,
        xaxis_title="UMAP 1",
        yaxis_title="UMAP 2",
        xaxis_showgrid=False,
        yaxis_showgrid=False
    )
    
    # Remove axis ticks and labels
    fig.update_xaxes(showticklabels=False, zeroline=False)
    fig.update_yaxes(showticklabels=False, zeroline=False)

    # Save the figure if required
    if save_dir:
        #pio.write_image(fig, f"{save_dir}{file_name}{name}.png", scale=3)
        pio.write_image(fig, f"{file_name}{name}.pdf", scale=3)
    
    # Show the plot
    else:
        fig.show()

import seaborn as sns
import matplotlib.pyplot as plt

import datetime

def make_jointplot(embedding, colouring, cmpd, save_path=None):
    
    # Generate a color palette based on unique values in the colouring column
    unique_treatments = embedding[colouring].unique()
    palette = sns.color_palette("Set2", len(unique_treatments))
    color_map = dict(zip(unique_treatments, palette))
    
    # Adjust colors and transparency if colouring is 'Metadat_cmpdName'
    if colouring == 'Metadata_cmpdName':
        if '[DMSO]' in color_map:
            color_map['DMSO'] = 'lightgrey'
    
    embedding['color'] = embedding[colouring].map(color_map)
    point_size = 10
    embedding['size'] = point_size
    
    # Increase the DPI for displaying
    plt.rcParams['figure.dpi'] = 300
    
    # Create the base joint plot
    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    # Plot KDE plots for each category
    for treatment in unique_treatments:
        subset = embedding[embedding[colouring] == treatment]
        
        sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=color_map[treatment], legend=False)
        sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=color_map[treatment], legend=False)

    # Plot the scatter plots
    for treatment in unique_treatments:
        subset = embedding[embedding[colouring] == treatment]
        alpha_val = 0.3 if treatment == 'DMSO' and colouring == 'Metadat_cmpdName' else 0.5
        g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], label=treatment, alpha=alpha_val, edgecolor='white', linewidth=0.5)
    
    g.ax_joint.set_title(cmpd)
    legend = g.ax_joint.legend(fontsize=10)
    legend.get_frame().set_facecolor('white')

    # Display the plot
    

    
    if save_path != None:
        current_time = datetime.datetime.now()
        timestamp = current_time.strftime("%Y%m%d_%H%M%S")
        g.savefig(f"{save_path}.png", dpi=300)

    plt.show()

def join_plot_dmso(embedding, plate_column, save_path=None):
    
    # Define the colors
    DMSO_color = 'black'
    other_color = 'lightgrey'

    # Generate a color palette based on unique values in the plate column for DMSO
    unique_plates = embedding[embedding['Metadata_cmpdName'] == 'DMSO'][plate_column].unique()
    palette = sns.color_palette("husl", len(unique_plates))
    plate_color_map = dict(zip(unique_plates, palette))
    
    # Map colors based on condition
    embedding['color'] = embedding.apply(lambda x: plate_color_map[x[plate_column]] if x['Metadata_cmpdName'] == 'DMSO' else other_color, axis=1)
    
    # Define the size for each point
    point_size = 20
    embedding['size'] = embedding.apply(lambda x: 50 if x['Metadata_cmpdName'] == 'DMSO' else point_size, axis=1)
    
    # Increase the DPI for displaying
    plt.rcParams['figure.dpi'] = 300
    
    # Create the base joint plot
    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)
    
    subset_other = embedding[embedding['Metadata_cmpdName'] != 'DMSO']
    g.ax_joint.scatter(subset_other["UMAP1"], subset_other["UMAP2"], c=other_color, s=subset_other['size'], label="Others", alpha=0.5, edgecolor='white', linewidth=0.5)
    # Plot KDE plots for DMSO by plate
    for plate in unique_plates:
        subset = embedding[(embedding[plate_column] == plate) & (embedding['Metadata_cmpdName'] == 'DMSO')]
        sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=plate_color_map[plate], legend=False)
        sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=plate_color_map[plate], legend=False)

    # Plot the scatter plots
    for plate in unique_plates:
        subset = embedding[(embedding[plate_column] == plate) & (embedding['Metadata_cmpdName'] == 'DMSO')]
        g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], label=plate, alpha=0.7, edgecolor='white', linewidth=0.5)
    
    # Plot grey points for other treatments
   

    g.ax_joint.set_title('DMSO by Plate')
    g.ax_joint.legend()

    # Display the plot
    plt.show()

    # Optionally, save the plot with high DPI
    if save_path:
        g.savefig(save_path, dpi=300)

def run_umap(df, feat):
    reducer = umap.UMAP(n_neighbors=n_neighbors_value, metric = "euclidean", min_dist=min_dist_value)
    embedding = reducer.fit_transform(df[feat])
    umap_df = pd.DataFrame(data=embedding, columns=["UMAP1", "UMAP2"])
    umap_df[['Plate', 'Well', 'Site']] = df[['Plate', 'Well', 'Site']].reset_index(drop = True)
    umap_df = pd.merge(umap_df, meta, on=["Plate", "Well", "Site"])
    return umap_df



def find_file_with_string(directory, string):
    """
    Finds a file in the specified directory that contains the given string in its name.

    Args:
    directory (str): The directory to search in.
    string (str): The string to look for in the file names.

    Returns:
    str: The path to the first file found that contains the string. None if no such file is found.
    """
    # Check if the directory exists
    if not os.path.exists(directory):
        print(f"The directory {directory} does not exist.")
        return None

    # Iterate through all files in the directory
    for file in os.listdir(directory):
        if string in file:
            return os.path.join(directory, file)

    # Return None if no file is found
    return None


## Load metadata and features

In [None]:
meta = pd.read_csv(os.path.join(PROJECT_ROOT, "inputs", "metadata", "Metadata_specs3k_DeepProfiler.csv")).drop_duplicates(inplace = False)
meta = meta.sort_values(by=['Metadata_Well', 'Metadata_Site'])
meta['Metadata_cmpdName'] = meta['Metadata_cmpdName'].str.upper()
meta["Metadata_cmpdNameConc"] = meta["Metadata_cmpdName"] +   " " + meta["Metadata_cmpdConc"].astype(str)
meta_pl = pl.DataFrame(meta).drop('Unnamed: 0.1', 'Unnamed: 0', "AR", "ER", "RNA", "AGP", "DNA", "Mito")
meta_pl = meta_pl.unique()

In [None]:
validation = meta_pl.filter(pl.col('Metadata_cmpdName').apply(lambda x: '[' in x)).unique()

In [None]:
import tqdm
DMSO_plates = ["P101385", "P101386", "P101387"]
plates = validation["Metadata_Plate"].unique() - DMSO_plates
featdir = "outputs/results/parquets"
master_df = pl.DataFrame()
for p in tqdm.tqdm(plates):
    feature_df = pl.read_parquet(find_file_with_string(os.path.join(PROJECT_ROOT, featdir), p))
    feature_df = feature_df.join(
    validation,
    on=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site'],  # columns to join on
    how='inner')
    master_df = pl.concat([master_df, feature_df])

### Compounds with multiple MOAs

In [None]:
duplicates = (meta_pl.groupby(['Site', 'Plate', 'Well', 'Metadata_cmpdName'])
               .agg(pl.count())
               .filter(pl.col('count') > 1))

print(duplicates)

In [None]:
duplicate_rows = meta_pl.join(duplicates, on=['Site', 'Plate', 'Well', 'Metadata_cmpdName'])
duplicate_rows.sort(['Site', 'Plate', 'Well'])

In [None]:
features = [col for col in master_df.columns if "Feature" in col]
meta_features = [ col for col in master_df.columns if col not in features]

## Run Normalization 

In [None]:
import scipy.stats
from sklearn.preprocessing import StandardScaler, RobustScaler
import pycytominer as pm

def prep_data(df1, features, meta_features, plate):
    mask = ~df1['Metadata_cmpdName'].isin(['BLANK', 'UNTREATED', 'null'])
    filtered_df = df1[mask]
    temp1 = filtered_df.copy()
    temp1 = temp1.loc[(temp1['Metadata_Plate'] == plate)]
    data_norm = pm.normalize(profiles = temp1, features =  features, meta_features = meta_features, samples = "Metadata_cmpdName == '[DMSO]'", method = "mad_robustize")
    print("Feature selection starts, shape:", data_norm.shape)
    df_selected = pm.feature_select(data_norm, features = features, operation = ['correlation_threshold', 'drop_na_columns'], corr_threshold=0.8)
    print('Number of columns removed:', data_norm.shape[1] - df_selected.shape[1])
    removed_cols = set(data_norm.columns) - set(df_selected.columns)
    #out = df_selected.dropna().reset_index(drop = True)
    #print('Number of NA rows removed:', df_selected.shape[0] - out.shape[0])
    df_selected["Metadata_cmpdNameConc"] = df_selected["Metadata_cmpdName"] + df_selected["Metadata_cmpdConc"].astype(str)
    return data_norm, removed_cols

In [None]:
comp = list(master_df["Metadata_cmpdName"].unique())
#norm_df = prep_data(df, features, meta_features, compound = comp)

In [None]:
plate = [p for p in list(master_df["Metadata_Plate"].unique()) if p not in ["P101384", "P101385", "P101386", "P101387"]]
mad_norm_df = pl.DataFrame()
drop_cols = {}
for p in tqdm.tqdm(plate):
    filtered_df = master_df.filter(pl.col('Metadata_Plate') == p)
    temp_pandas = filtered_df.to_pandas()
    temp_processed, dropped_cols = prep_data(temp_pandas, features, meta_features, p)
    drop_cols[p] = list(dropped_cols)
    temp_polars = pl.from_pandas(temp_processed)
    mad_norm_df = pl.concat([mad_norm_df, temp_polars])

mad_norm_df = mad_norm_df.filter(pl.col("Metadata_cmpdName").is_not_null())

In [None]:
cols_to_drop = list(set().union(*drop_cols.values()))
mad_norm_df = mad_norm_df.drop(cols_to_drop)
features_fixed = [col for col in mad_norm_df.columns if col not in meta_features]

In [None]:
mad_norm_df.write_parquet('/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/sc_profiles_normalized.parquet')

In [None]:
mad_norm_df = pl.read_parquet('Results/sc_profiles_normalized_SPECS3K.parquet')
mad_norm_df = mad_norm_df.filter(pl.col('Metadata_Plate') != "P101384")
features_fixed = [f for f in mad_norm_df.columns if "Feature" in f]

## Analyse aggregated

In [None]:
aggregated_df_norm = (
    mad_norm_df
    .groupby(['Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName'])
    .agg([pl.col(feature).mean().alias(feature) for feature in features_fixed])
)

In [None]:
aggregated_df_norm

## Harmony batch normalization

In [None]:
import anndata as ad
import scanpy as sc 
from harmonypy import run_harmony
def create_anndata(df, aggregated):
    if isinstance(df, pl.DataFrame):
        df = df.to_pandas()
    data_columns = features  # Columns containing the feature data
    if aggregated:
        meta_data_columns = ['Plate','Well', 'Site', 'Metadata_cmpdName']
    else:
        meta_data_columns = meta_features

    X = df[data_columns].values
    obs = df[meta_data_columns]
    adata = ad.AnnData(X=X, obs=obs)
    return adata

def harmony(adata: ad.AnnData, batch_key: str ,
            corrected_embed: str, pca = True):
    '''Harmony correction'''
    n_latent = min(adata.shape) - 1  # required for arpack
    if pca:
        print('Computing PCA...')
        data_for_harmony = sc.tl.pca(adata, n_comps=n_latent)  # Generates X_pca
        print('Computing PCA Done.')
    else:
        data_for_harmony = adata.X
      
    harmony_out = run_harmony(data_for_harmony,
                              adata.obs,
                              batch_key,
                              max_iter_harmony=20,
                              nclust=len(comp))  # Number of compounds
    adata.obsm[corrected_embed] = harmony_out.Z_corr.T
    return adata

In [None]:
def obsm_to_df(adata: ad.AnnData, obsm_key: str,
               columns: str) -> pd.DataFrame:
    '''Convert AnnData object to DataFrame using obs and obsm properties'''
    meta = adata.obs.reset_index(drop=True)
    feats = adata.obsm[obsm_key]
    n_feats = feats.shape[1]
    if not columns:
        columns = [f'{obsm_key}_{i:04d}' for i in range(n_feats)]
    data = pd.DataFrame(feats, columns=columns)
    return pd.concat([meta, data], axis=1)

## Show distribution of columns in the data

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
aggregated_cmpd = aggregated_df_norm.groupby(['Metadata_cmpdName', 'Metadata_Plate']).count()
aggregated_cmpd = aggregated_cmpd.to_pandas()
aggregated_cmpd['Metadata_Plate'] = aggregated_cmpd['Metadata_Plate'].astype('category')

# Unique compounds
unique_compounds = aggregated_cmpd['Metadata_cmpdName'].unique()

# Set up the matplotlib figure with multiple subplots
n_compounds = len(unique_compounds)
fig, axes = plt.subplots(n_compounds, 1, figsize=(20, 30), sharex=True)

# Loop over each compound and create a bar plot
for i, cmpd in enumerate(unique_compounds):
    subset = aggregated_cmpd[aggregated_cmpd['Metadata_cmpdName'] == cmpd]
    sns.barplot(ax=axes[i], data=subset, x='Metadata_Plate', y='count', hue='Metadata_Plate', width = 2.5)
    axes[i].set_title(f'Compound: {cmpd}')
    axes[i].set_ylabel('Count')
    axes[i].get_legend().remove()

# Set common labels
plt.xlabel('Metadata Plate')
plt.xticks(rotation=45)
plt.tight_layout()

# Show the plot
#plt.show()
plt.savefig('reference_distributions.png', dpi=300)


## Site-level UMAP

In [None]:
import cuml
#import umap
def run_umap_and_merge(df, features, min_dist=0.1, n_components=2, metric='cosine'):
    # Filter the DataFrame for features and metadata
    feature_data = df.select(features).to_pandas()
    meta_features = [col for col in df.columns if col not in features]
    meta_data = df.select(meta_features)
    #n_neighbors = np.sqrt(len(feature_data))
    # Run UMAP with cuml
    print("Starting UMAP")
    #umap_model = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, metric=metric).fit(feature_data)
    #umap_embedding = umap_model.transform(feature_data)

    umap_model = cuml.UMAP(n_neighbors=15, min_dist=min_dist, n_components=n_components, metric=metric, random_state= 42)
    umap_embedding = umap_model.fit_transform(feature_data)

    #cu_score = cuml.metrics.trustworthiness( feature_data, umap_embedding )
    #print(" cuml's trustworthiness score : ", cu_score )
    
    # Convert UMAP results to DataFrame and merge with metadata
    umap_df = pl.DataFrame(umap_embedding)

    old_column_name = umap_df.columns[0]
    old_column_name2 = umap_df.columns[1]
    # Rename the column
    new_column_name = "UMAP1"
    new_column_name2 = "UMAP2"
    umap_df = umap_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})

    merged_df = pl.concat([meta_data, umap_df], how="horizontal")

    return merged_df


In [None]:
aggregated_umap = run_umap_and_merge(aggregated_df_norm, features_fixed, metric = 'cosine')

In [None]:
make_jointplot(aggregated_umap.to_pandas(), "Metadata_cmpdName", cmpd ="")

## PCA

In [None]:
pca = PCA(n_components=2)
principal_components = pca.fit_transform(grouped_features[feature_columns])

pca_df = pd.DataFrame(data=principal_components, columns=["Principal Component 1", "Principal Component 2"])
pca_df[['Plate', 'Well', 'Site']] = grouped_features[['Plate', 'Well', 'Site']]
pca_df = pd.merge(pca_df, meta, on=["Plate","Well", "Site"])

plt.figure(figsize=(10, 8))
sns.scatterplot(x="Principal Component 1", y="Principal Component 2", hue="Metadata_cmpdName", data=pca_df, palette="rainbow")
plt.title("PCA of Mean Features Colored by Well")
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

n_neighbors_value = 35
min_dist_value = 0.1
#mask = ~df['Metadata_cmpdName'].isin(['blank', 'UNTREATED'])

# Use the mask to filter the DataFrame and keep only the desired rows
#filtered_df = df
#grouped_features = filtered_df.groupby(['Plate','Well', 'Site'])[feature_columns].mean().reset_index()
harmony_features = [col for col in harmony_df.columns if "harm" in col]

def run_umap(df, feat):
    reducer = UMAP(n_neighbors=n_neighbors_value, metric = "cosine", min_dist=min_dist_value)
    embedding = reducer.fit_transform(df[feat])

    umap_df = pd.DataFrame(data=embedding, columns=["UMAP1", "UMAP2"])
    umap_df[['Plate', 'Well', 'Site']] = df[['Plate', 'Well', 'Site']]

    umap_df = pd.merge(umap_df, meta, on=["Plate", "Well", "Site"])
    return umap_df

In [None]:
make_jointplot(aggregated_umap.to_pandas(), "Metadata_cmpdName", cmpd ="")

## Run PCA on features

In [None]:
import cuml
import cudf
def run_pca(df, plot = True):
    features = [feature for feature in df.columns if "Feature" in feature]
    feature_df = df.select(features).to_pandas()
    meta =  df.select([col for col in df.columns if col not in features])

    gpu_df = cudf.DataFrame(feature_df)
    pca = cuml.PCA(n_components=50)

    principalComponents = pca.fit_transform(feature_df)
    principalDf = pd.DataFrame(data = principalComponents)
    principalDf.rename(columns={0: 'pc1', 1: 'pc2', 2: 'pc 3'}, inplace=True)
    pc_polars = pl.DataFrame(principalDf)
    principal_out = pl.concat([pc_polars, meta], how = "horizontal")
    #principalDf = pd.concat([principalDf].reset_index(), axis = 1).dropna()
    variance_ratio = pca.explained_variance_ratio_
    if plot:
        fig = plt.figure(figsize = (14,6))
        ax = fig.add_subplot(111) 
        ax.set_xlabel('PC 1: {:.2%}'.format(variance_ratio[0]), fontsize = 10)
        ax.set_ylabel('PC 2: {:.2%}'.format(variance_ratio[1]), fontsize = 10)
        ax.spines['top'].set_color('w')
        ax.spines['right'].set_color('w')
        ax.spines['left'].set_color('grey')
        ax.spines['bottom'].set_color('grey')
            #colors = [ 'grey', 'lime'] #, 'lime',  'darkorange', 'red',  'purple' #,'navy', 'sienna']

        ax =sns.scatterplot(x='pc1', y='pc2', data=principalDf, palette='bright',  s=10, alpha = 0.8,  marker = 'o');
        for i, var in enumerate(pca.components_.T):
            plt.arrow(0, 0, var[0]*max(principalDf["pc1"]), var[1]*max(principalDf["pc2"]), 
                    head_width=0.2, head_length=0.2, linewidth=2, color='red')
            plt.text(var[0]*max(principalDf["pc1"]), var[1]*max(principalDf["pc2"]), x.columns[i], color='red')
        ax.set_facecolor('w')        
        #ax.get_legend().remove()
        ax.grid(False)

        plt.subplots_adjust(top=0.85, wspace=0.1, hspace=0.1, right = 1)
        fig.suptitle('PCA, {}, {}, plate'.format("Single cell data", "RH30"), fontsize = 12)
            #fig.savefig('{}/{}_{}_PCA_DMSOandContrComps.{}'.format(OutputDir,  project, CellLine, figformat),  dpi=150, bbox_inches='tight')
            #fig.savefig('{}/{}_PCA_DMSOandContrComps.{}'.format(OutputDir,  project, 'pdf'),  dpi=150, bbox_inches='tight')
        plt.show()
        return principal_out
    else:
        return principal_out


In [None]:
def make_plot_custom_pca(embedding, colouring, save_dir=False, file_name="", name="", description=""):
    # Set the background to white
    sns.set(style="whitegrid", rc={"figure.figsize": (18, 12),'figure.dpi': 300, "axes.facecolor": "white", "grid.color": "white"})
    
    # Create a custom palette for the treatments of interest
    unique_treatments = set(embedding[colouring])
    custom_palette = sns.color_palette("hls", len(unique_treatments))
    color_dict = {treatment: color for treatment, color in zip(unique_treatments, custom_palette)}
    
    # Make the "Control" group grey
    if "[DMSO]" in color_dict:
        color_dict["[DMSO]"] = "lightgrey"
    
    # Create a size mapping
    size_dict = {treatment: 20 if treatment != "[DMSO]" else 8 for treatment in unique_treatments}
    embedding['size'] = embedding[colouring].map(size_dict)
    
    # Create the scatter plot
    sns_plot = sns.scatterplot(data=embedding, x="pc1", y="pc2", hue=colouring, size='size', palette=color_dict, sizes=(8, 25), linewidth=0.1, alpha=0.9)
    
    plt.suptitle(f"{name}_{file_name}", fontsize=16)
    sns_plot.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    sns_plot.set_title("PCA of "+str(len(embedding))+" data points" + " \n"+description, fontsize=12)
    sns.move_legend(sns_plot, "lower left", title='Treatments', prop={'size': 10}, title_fontsize=12, markerscale=0.5)
    
    # Remove grid lines
    sns.despine(bottom=True, left=True)
    
    if save_dir == True:
        # Save the figure with the specified DPI
        sns_plot.figure.savefig(f"{save_dir}{file_name}{name}.png", dpi=600)  # Changed DPI to 600
        sns_plot.figure.savefig(f"{save_dir}pdf_format/{file_name}{name}.pdf", dpi=600)  # Changed DPI to 600
    
    plt.show()

In [None]:
def make_jointplot_pca(embedding, colouring, cmpd, save_path=None):
    
    # Generate a color palette based on unique values in the colouring column
    unique_treatments = embedding[colouring].unique()
    palette = sns.color_palette("Set2", len(unique_treatments))
    color_map = dict(zip(unique_treatments, palette))
    
    # Adjust colors and transparency if colouring is 'Metadat_cmpdName'
    if colouring == 'Metadata_cmpdName':
        if '[DMSO]' in color_map:
            color_map['[DMSO]'] = 'lightgrey'
    
    embedding['color'] = embedding[colouring].map(color_map)
    point_size = 10
    embedding['size'] = point_size
    
    # Increase the DPI for displaying
    plt.rcParams['figure.dpi'] = 300
    
    # Create the base joint plot
    g = sns.JointGrid(x='pc1', y='pc2', data=embedding, height=10)

    # Plot KDE plots for each category
    for treatment in unique_treatments:
        subset = embedding[embedding[colouring] == treatment]
        
        sns.kdeplot(x=subset["pc1"], ax=g.ax_marg_x, fill=True, color=color_map[treatment], legend=False)
        sns.kdeplot(y=subset["pc2"], ax=g.ax_marg_y, fill=True, color=color_map[treatment], legend=False)

    # Plot the scatter plots
    for treatment in unique_treatments:
        subset = embedding[embedding[colouring] == treatment]
        alpha_val = 0.3 if treatment == '[DMSO]' and colouring == 'Metadat_cmpdName' else 0.5
        g.ax_joint.scatter(subset["pc1"], subset["pc2"], c=subset['color'], s=subset['size'], label=treatment, alpha=alpha_val, edgecolor='white', linewidth=0.5)
    
    g.ax_joint.set_title(cmpd)
    legend = g.ax_joint.legend(fontsize=10)
    legend.get_frame().set_facecolor('white')

    # Display the plot
    

    
    if save_path != None:
        current_time = datetime.datetime.now()
        timestamp = current_time.strftime("%Y%m%d_%H%M%S")
        g.savefig(f"{save_path}.png", dpi=300)

    plt.show()

In [None]:
make_jointplot_pca(test_pca.to_pandas(), "Metadata_cmpdName", "test")

In [None]:
compounds = list(mad_norm_df["Metadata_cmpdName"].unique())
for c in compounds:
    dat = grit_filter_df_normal.filter(pl.col("Metadata_cmpdName").is_in(["[DMSO]", c]))
    print("Now running PCA for", c )
    test_pca = run_pca(dat, plot = False)
    make_plot_custom_pca(test_pca.to_pandas(), "Metadata_cmpdName", save_dir = False, file_name = c, name ="PCA" )

## Single cell UMAP analysis

### Sample data

In [None]:
import time
def subsample_dataset_pl(df, grouping_cols, fraction=0.5):
    '''
    Subsample a dataset while preserving the distribution of plates and Metadata_cmpdName using Polars.

    Parameters:
    - df: The original Polars DataFrame.
    - plate_column: The column name representing the plates.
    - cmpd_column: The column name representing the Metadata_cmpdName.
    - fraction: The fraction of data to keep for each group. Default is 0.5 (50%).

    Returns:
    - A subsampled Polars DataFrame.
    '''

    # Start tracking time
    start_time = time.time()

    # Initialize an empty list to store subsampled data from each group
    subsampled_data = []

    # Group by plates and Metadata_cmpdName
    grouped = df.groupby(grouping_cols)

    # For each group, subsample and append to the subsampled_data list, with progress bar
    for name, group in tqdm(grouped, desc="Subsampling groups", unit="group"):
        group_size = group.height
        subsample_size = int(group_size * fraction)
        subsampled_group = group.sample(n=subsample_size, seed=42)
        subsampled_data.append(subsampled_group)

    # Concatenate all subsampled groups together
    subsampled_df = pl.concat(subsampled_data)

    # Print running time
    end_time = time.time()
    print(f"Finished in {end_time - start_time:.2f} seconds.")

    return subsampled_df

In [None]:
mad_norm_df = pl.read_parquet('Results/sc_profiles_normalized_SPECS3K.parquet')
mad_norm_df = mad_norm_df.filter(pl.col('Metadata_Plate') != "P101384")
features_fixed = [f for f in mad_norm_df.columns if "Feature" in f]

In [None]:
mad_norm_df.groupby(['Metadata_cmpdName']).count()

In [None]:
#subsample_data= subsample_dataset_pl(mad_norm_df, ["Metadata_Plate", "Metadata_cmpdNameConc", "Metadata_Well"], 0.7)
subsample_data_dmso = subsample_dmso_pl(mad_norm_df,["Metadata_Plate", "Metadata_Well"], "Metadata_cmpdName", "[DMSO]", fraction=0.1)

## Filter by grit score

In [None]:
plates = list(mad_norm_df["Metadata_Plate"].unique())

In [None]:
import tqdm
def load_grit_data(folder, plates):
    """
    Processes Parquet files in the given folder based on whether their filenames contain 
    any of the strings in identifier_list. Merges 'Feature' and 'Metric' data based on 
    a specific column and concatenates with 'Control' data.

    :param folder_path: Path to the folder containing Parquet files.
    :param identifier_list: List of strings to be searched in the file names.
    :param merge_column: Column name on which to merge 'Feature' and 'Metric' data.
    :return: Combined Polars DataFrame.
    """
    feature_dfs = pl.DataFrame()
    metric_dfs = pl.DataFrame()

    # Iterate over files in the directory
    for plate in tqdm.tqdm(plates):
        file_names = [file for file in os.listdir(folder) if plate in file]
        if len(file_names) == 0:
            print(f"Plate {plate} not found")
            continue
        neg_path = [file for file in file_names if "neg_control" in file][0]
        neg_cells = pl.read_parquet(os.path.join(folder, neg_path))["Metadata_Cell_Identity"].unique()
        for i in file_names:
            file_path = os.path.join(folder, i)
            if "sc_features" in i:
                feat = pl.read_parquet(file_path).filter(((pl.col("Metadata_Cell_Identity").is_in(neg_cells))) |  ~(pl.col("Metadata_cmpdName") == "[DMSO]"))
                feature_dfs = pl.concat([feature_dfs, feat])
            elif "sc_grit" in i:
                metrics = pl.read_parquet(file_path)
                metrics_treat = metrics.filter(pl.col("group") == pl.col("comp")).drop("comp")
                metrics_ctrl = metrics.filter(pl.col("group") != pl.col("comp")).drop("comp")
                metric_dfs = pl.concat([metric_dfs, metrics_treat, metrics_ctrl])
                #metric_df = pl.read_parquet(i).drop("comp") if metric_df is None else metric_df.vstack(pl.read_parquet(i).drop("comp"))
    
    metric_df = metric_dfs.unique(subset=["Metadata_Cell_Identity"])
    # Merge Feature and Metric DataFrames
    merged_df = feature_dfs.join(metric_df, on="Metadata_Cell_Identity", how= "inner")
    # Concatenate Control DataFrames and merge with the above
    #final_df = pl.concat([merged_df, control_dfs])
    #.unique(subset = ["Metadata_Cell_Identity"])
    merged_df.write_parquet(os.path.join(folder, "sc_grit_FULL.parquet"))
    return merged_df


In [None]:
def merge_locations(df, location_folder):

    out_df = pl.DataFrame()
    combinations = df.unique(["Metadata_Plate", "Metadata_Well", "Metadata_Site"])
    # Iterate through unique combinations of Plate, Well, and Site
    for combination in tqdm.tqdm(combinations.to_pandas().itertuples(index=False), total = len(combinations)):
        plate, well, site = combination.Metadata_Plate, combination.Metadata_Well, combination.Metadata_Site

        # Construct the file path for the CSV
        file_path = f"{location_folder}/{plate}/{well}-{site}-Nuclei.csv"

        # Check if the file exists
        if os.path.exists(file_path):
            # Read the CSV file
            csv_df = pl.read_csv(file_path)
            filter = df.filter((pl.col("Metadata_Plate") == plate) &
                                            (pl.col("Metadata_Well") == well) &
                                            (pl.col("Metadata_Site") == site))
            # Ensure that csv_df aligns with the subset of original df in terms of row count
            if len(csv_df) != len(filter):
                # Handle error or misalignment
                print(f"{combination} doesn't match")  # or log it, or raise an error
            temp = pl.concat([filter, csv_df], how = "horizontal")
            out_df = pl.concat([out_df, temp], how = "vertical")
            # Perform the column concatenation operation
            # Assuming the order of rows in csv_df corresponds exactly to the order in the subset of df
            
    return out_df

In [None]:
import polars as pl
import os
import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

def read_and_merge_single_file(df, plate, well, site, location_folder):
    file_path = f"{location_folder}/{plate}/{well}-{site}-Nuclei.csv"
    if os.path.exists(file_path):
        csv_df = pl.read_csv(file_path)
        filter_df = df.filter((pl.col("Metadata_Plate") == plate) &
                              (pl.col("Metadata_Well") == well) &
                              (pl.col("Metadata_Site") == site))
        if len(csv_df) == len(filter_df):
            return pl.concat([filter_df, csv_df], how="horizontal")
    return None

def merge_locations_parallel(df, location_folder, max_workers=10):
    combinations = df.unique(["Metadata_Plate", "Metadata_Well", "Metadata_Site"])
    dfs_to_concat = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Create and submit tasks
        future_to_combination = {
            executor.submit(read_and_merge_single_file, df, comb["Metadata_Plate"], comb["Metadata_Well"], comb["Metadata_Site"], location_folder): comb 
            for comb in combinations.to_dicts()
        }
        
        for future in tqdm.tqdm(as_completed(future_to_combination), total=len(future_to_combination)):
            result = future.result()
            if result is not None:
                dfs_to_concat.append(result)
    
    # Concatenate all DataFrames at once at the end
    out_df = pl.concat(dfs_to_concat, how="vertical")
    return out_df


In [None]:
mad_norm_locations = merge_locations_parallel(mad_norm_df,  "/home/jovyan/share/data/analyses/benjamin/Single_cell_project/DP_specs3k/inputs/locations/", max_workers=10)

In [None]:
mad_norm_df.shape

In [None]:
#grit_filt_df = load_grit_data("grit_parquet_specs3k", plates)
grit_filt_df = pl.read_parquet("/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/SPECS/grit_parquet_specs3k/sc_grit_FULL.parquet")
grit_filt_df = grit_filt_df.filter(((pl.col("grit") >= 1) & (pl.col("grit") <= 4.5)) | (pl.col("Metadata_cmpdName") == "[DMSO]"))

In [None]:
grit_filt_df.groupby(['Metadata_cmpdName']).count()

### Merge locations to grit scores

In [None]:
merge_cols = list(set(grit_filt_df.columns).intersection(set(mad_norm_locations.columns)))
merge_cols.remove("moa")

In [None]:
grit_filt_locations = grit_filt_df.drop("moa").join(mad_norm_locations.drop("moa"), on=merge_cols, how='left').unique()

In [None]:
grit_filt_locations.write_parquet("Results/sc_profiles_locations.parquet")

In [None]:
grit_filt_locations.groupby(['Metadata_cmpdName']).count()

## Fix cell duplicates

In [None]:
sc_grit = pl.read_parquet("deepprofiler/Results/sc_profiles_locations_all_grits.parquet")

In [None]:
sc_grit_location_fix =  sc_grit.filter((pl.col("Nuclei_Location_Center_X") > 250) &
                                                  (pl.col("Nuclei_Location_Center_X") < 2250) &
                                                  (pl.col("Nuclei_Location_Center_Y") > 250) &
                                                  (pl.col("Nuclei_Location_Center_Y") < 2250))

In [None]:
import polars as pl
import gc
import numpy as np
from sklearn.neighbors import RadiusNeighborsRegressor
import random

def sample_one_per_radius(X, regressor):
    sampled_indices = set()
    for i, point in enumerate(X):
        indices = regressor.radius_neighbors([point], return_distance=False)[0]
        if not any(idx in sampled_indices for idx in indices):
            sampled_index = random.choice(indices)
            sampled_indices.add(sampled_index)
    return sampled_indices

def assign_sampling_labels(df, radius=50):
    # Define the columns to group by
    group_cols = ['Metadata_Plate', 'Metadata_Well', 'Metadata_Site']
    
    # Initialize an empty DataFrame to store results
    results = []
    
    # Iterate over each group
    for group_key, group_df in tqdm(df.groupby(group_cols)):
        # Extract the nuclei locations as a NumPy array
        X = group_df.select(['Nuclei_Location_Center_X', 'Nuclei_Location_Center_Y']).to_numpy()
        
        # Initialize the regressor with the specified radius
        regressor = RadiusNeighborsRegressor(radius=radius)
        regressor.fit(X, np.zeros(X.shape[0]))
        
        # Perform sampling
        sampled_indices = sample_one_per_radius(X, regressor)
        
        # Assign labels indicating whether each point was sampled
        sampled_labels = [1 if i in sampled_indices else 0 for i in range(len(group_df))]
        
        # Add the labels back to the DataFrame
        group_df_with_labels = group_df.with_columns(pl.Series("Sampled", sampled_labels))
        
        # Append the processed group to the results DataFrame
        results.append(group_df_with_labels)
        gc.collect()
    results_df = pl.concat(results)
    return results_df

In [None]:
dublicate_cells = assign_sampling_labels(sc_grit_location_fix, radius=60)


In [None]:
dublicate_cells.write_parquet("sc_profiles_all_grit_NODUPLICATES.parquet")

In [None]:
dublicate_cells.filter(pl.col("Sampled") == 1).group_by("Metadata_cmpdName").count()

In [None]:
dublicate_cells.filter(pl.col("Sampled") == 0).group_by("Metadata_cmpdName").count()

In [None]:
def plot_sampled_points(df):
    # Ensure DataFrame is filtered for visualization if necessary, or adjust as needed
    
    # Convert to a Pandas DataFrame for easier plotting (optional but often simpler for plotting with matplotlib)
    df_pandas = df.to_pandas()

    # Plot non-sampled points in grey
    plt.scatter(
        df_pandas[df_pandas["Sampled"] == 0]["Nuclei_Location_Center_X"],
        df_pandas[df_pandas["Sampled"] == 0]["Nuclei_Location_Center_Y"],
        color='grey', alpha=0.5, label='Not Sampled', s = 5
    )
    
    # Plot sampled points in orange
    plt.scatter(
        df_pandas[df_pandas["Sampled"] == 1]["Nuclei_Location_Center_X"],
        df_pandas[df_pandas["Sampled"] == 1]["Nuclei_Location_Center_Y"],
        color='tab:orange', label='Sampled', s = 5
    )
    
    plt.xlabel('Nuclei_Location_Center_X')
    plt.ylabel('Nuclei_Location_Center_Y')
    plt.title('Sampled vs Non-Sampled Nuclei Locations')
    plt.gca().invert_yaxis()
    #plt.legend()
    plt.show()

plot_sampled_points(test_small_with_labels)

In [None]:
X = test.select(['Nuclei_Location_Center_X', 'Nuclei_Location_Center_Y']).to_numpy()
plt.style.use('default')
for i, point in enumerate(X):
        plt.scatter(point[0], point[1], color='tab:orange', s = 5)  # Sampled points in orange

plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.title('Nuclei Locations with Sampled Points Highlighted')
plt.gca().invert_yaxis()
plt.show()

## Summary statistics 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Assuming 'aggregated_cmpd' is your DataFrame and is already defined.
aggregated_cmpd = grit_filt_df.groupby(['Metadata_cmpdName', 'Metadata_Plate']).count()
aggregated_cmpd = aggregated_cmpd.to_pandas()
aggregated_cmpd['Metadata_Plate'] = aggregated_cmpd['Metadata_Plate'].astype('category')
# Unique compounds
unique_compounds = aggregated_cmpd['Metadata_cmpdName'].unique()
unique_compounds.sort()

# Seaborn color palette
palette = sns.color_palette("viridis", n_colors=aggregated_cmpd['Metadata_Plate'].nunique())

# Set up the matplotlib figure with multiple subplots
n_compounds = len(unique_compounds)
fig, axes = plt.subplots(n_compounds, 1, figsize=(20, 30), sharex='all') # 'all' to ensure alignment

# Sort the unique plates for consistent x-axis across subplots
sorted_plates = sorted(aggregated_cmpd['Metadata_Plate'].unique())
n_unique_plates = len(sorted_plates)

# Bar width
bar_width = 0.5

# Loop over each compound and create a bar plot
for i, cmpd in enumerate(unique_compounds):
    subset = aggregated_cmpd[aggregated_cmpd['Metadata_cmpdName'] == cmpd]
    
    # Calculate counts, ensuring that the plates are sorted
    counts = subset.groupby('Metadata_Plate')['count'].mean().reindex(sorted_plates).fillna(0)

    # Draw the bars
    bars = axes[i].bar(np.arange(n_unique_plates), counts, width=bar_width, color=palette)

    # Title and labels
    axes[i].set_title(f'Compound: {cmpd}')
    axes[i].set_ylabel('Count')

# Set the x-axis ticks to be the sorted names of the plates once for all subplots
axes[-1].set_xticks(np.arange(n_unique_plates))
axes[-1].set_xticklabels(sorted_plates, rotation=45)

# Common X label
fig.text(0.5, 0.04, '', ha='center')

plt.tight_layout()

# Save and show the plot
plt.savefig('Figures_SPECS3K/reference_distributions_grit.png', dpi=300)
plt.show()


In [None]:
# Create the boxplot
fig, ax = plt.subplots(figsize=(24,15), dpi = 300)
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
sns.violinplot(data=grit_filt_df.to_pandas(), x='Metadata_cmpdName', y='grit', ax=ax,
               palette="GnBu",
               inner = "box",
               density_norm='area')
# Optional: Customize the plot
plt.title('Violin of grit score based on compound', fontsize = 30)
plt.xlabel('Compound', fontsize = 25)
plt.ylabel('Grit', fontsize = 25)

# Show the plot
plt.legend(title='')
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Sample DataFrame
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
pal = sns.cubehelix_palette(10, rot=-.25, light=.7)
# Initialize the FacetGrid object
g = sns.FacetGrid(data=grit_filt_df.to_pandas(), row="Metadata_cmpdName", hue = "Metadata_cmpdName", aspect=6, height=1.8, palette=pal)

# Draw the densities
g.map_dataframe(sns.kdeplot, "grit",
                bw_adjust=.5, clip_on=False, fill=True, alpha=1, linewidth=1.5)
g.map(sns.kdeplot, "grit", clip_on=False, color="w", lw=2, bw_adjust=.5)

g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)

# Define and use a simple function to label the plot in axes coordinates
def label(x, color, label):
    ax = plt.gca()
    ax.text(0, .2, label, fontweight="bold", color=color,
            ha="left", va="center", transform=ax.transAxes)


g.map(label, "grit")
# Set the subplots to overlap
# Set the subplots to overlap
g.figure.subplots_adjust(hspace=-.25)

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[], ylabel="")
g.despine(bottom=True, left=True)

## Refined sampling grit + all cells

In [None]:
import tqdm
def sample_groups(df, grouping_cols, ratio):
    subsampled_data = []
    # Group by the specified columns only in the filtered DataFrame
    grouped = df.groupby(grouping_cols)
    # For each group, subsample and append to the subsampled_data list, with progress bar
    for name, group in tqdm.tqdm(grouped, desc="Subsampling groups", unit="group"):
        #if not (any(df["Metadata_cmpdName"].unique()) == "[DMSO]") & int(len(group)*ratio) < 5:
        #    subsampled_group = group
        #else:
        subsampled_group = group.sample(fraction=ratio, seed=42)
        subsampled_data.append(subsampled_group)
    # Concatenate the subsampled groups together
    subsampled_df = pl.concat(subsampled_data)
    return subsampled_df

In [None]:
def sample_compounds(df1, df2, sampling_rate, mode ="normal"):
    #Filter out sorbitol because neg control comp
    df1 = df1.filter(~(pl.col("Metadata_cmpdName") == "[SORB]"))
    df2 = df2.filter(~(pl.col("Metadata_cmpdName") == "[SORB]"))
    # Step 1: Extract "DMSO" values from first DataFrame

    # Sample DMSO group based on this ratio
    # Initialize dictionary to hold dataframes for each compound including DMSO
    if mode == "dmso_only":
        # Sample only DMSO group, ensuring random distribution from each well
        dmso_sampled  = sample_groups(df1.filter(pl.col("Metadata_cmpdName") == "[DMSO]"), ["Metadata_Plate", "Metadata_Well"], sampling_rate)
        dmso_sampled = dmso_sampled.with_columns(pl.lit("cell_na").alias("Metadata_Cell_Identity"))
        dmso_sampled = dmso_sampled.with_columns(pl.lit(0).cast(pl.Float64).alias("grit"))
        dmso_sampled = dmso_sampled.with_columns(dmso_sampled['moa'].cast(pl.Utf8))
        df2 = df2.with_columns(df2['moa'].cast(pl.Utf8))
        sampled_df = pl.concat([dmso_sampled, df2.filter(~((pl.col("Metadata_cmpdName") == "[DMSO]"))).drop("group")])
        return(sampled_df)
    # Sample other compounds
    elif mode == "normal":
        dmso_values = list(df2.filter(pl.col("Metadata_cmpdName") == "[DMSO]")["Metadata_Cell_Identity"])
        avg_compound_counts = df1.groupby(["Metadata_cmpdName", "Metadata_Plate", "Metadata_Well"]).agg(pl.count().alias('count')).group_by("Metadata_cmpdName").agg(pl.mean("count").alias("avg_count"))
        # Find the compound with the highest average count, excluding "DMSO"
        average_comp = avg_compound_counts.filter(~(pl.col("Metadata_cmpdName") == "[DMSO]")).select(pl.max("avg_count"))["avg_count"][0]
        # Calculate the average count for DMSO
        dmso_avg_count = avg_compound_counts.filter(pl.col("Metadata_cmpdName") == "[DMSO]")["avg_count"][0]
        # Calculate ratio for DMSO based on the compound with the highest average count
        dmso_ratio = (dmso_avg_count / average_comp)
        max_rows = 0
        cmpd_sampled = pl.DataFrame()
        for cmpd in df1["Metadata_cmpdName"].unique():
            if cmpd != "[DMSO]":
                cmpd_sample = sample_groups(df2.filter(pl.col("Metadata_cmpdName") == cmpd), ["Metadata_Plate", "Metadata_Well"], sampling_rate).drop("group")
                cmpd_sample = cmpd_sample.with_columns(cmpd_sample['moa'].cast(pl.Utf8))
                cmpd_sampled = pl.concat([cmpd_sampled, cmpd_sample])
                num_rows = cmpd_sample.shape[0]
                if num_rows > max_rows:
                    max_rows = num_rows

        if max_rows*dmso_ratio < len(dmso_values):
            # if dmso samples would be smaller than number of grit reference dmso cells, only sample using those!
            dmso_df = df2.filter(pl.col("Metadata_cmpdName") == "[DMSO]").drop("group")
        else:
            dmso_df = df1.filter(pl.col("Metadata_cmpdName") == "[DMSO]")

        dmso_sample = max_rows*dmso_ratio / dmso_df.shape[0]
        dmso_sampled  = sample_groups(dmso_df, ["Metadata_Plate", "Metadata_Well"], dmso_sample)
        dmso_sampled = dmso_sampled.with_columns(pl.lit("cell_na").alias("Metadata_Cell_Identity"))
        dmso_sampled = dmso_sampled.with_columns(pl.lit(0).cast(pl.Float64).alias("grit"))
        dmso_sampled = dmso_sampled.with_columns(dmso_sampled['moa'].cast(pl.Utf8))
        return pl.concat([dmso_sampled, cmpd_sampled])

    elif mode == "equal_sampling":
        cmpd_sampled = pl.DataFrame()
        for cmpd in df1["Metadata_cmpdName"].unique():
            if cmpd == "[DMSO]":
                cmpd_sample = sample_groups(df1.filter(pl.col("Metadata_cmpdName") == cmpd), ["Metadata_Plate", "Metadata_Well"], sampling_rate).drop("group")
                cmpd_sample = cmpd_sample.with_columns(pl.lit("cell_na").alias("Metadata_Cell_Identity"))
                cmpd_sample = cmpd_sample.with_columns(pl.lit(0).cast(pl.Float64).alias("grit"))
            else:
                cmpd_sample = sample_groups(df2.filter(pl.col("Metadata_cmpdName") == cmpd), ["Metadata_Plate", "Metadata_Well"], sampling_rate).drop("group")
            cmpd_sample = cmpd_sample.with_columns(cmpd_sample['moa'].cast(pl.Utf8))
            cmpd_sampled = pl.concat([cmpd_sampled, cmpd_sample])
        return cmpd_sampled    
    else:
        print(f"{mode} not valid as a sampling mode!")

## UMAP sampling strategies (external script)

In [None]:
def make_jointplot_seaborn_benchmark(embedding, colouring, cmpd, samp_meth, samp_rate, overlay=False, overlay_df=None):
    
    def get_color(val):
        if "[DMSO]" in val:
            return "lightgrey"
        else:
            return "#e96565"
    
    def get_size(val):
        return 20 if val != "[DMSO]" else 10
    
    embedding['color'] = embedding[colouring].apply(get_color)
    embedding['size'] = embedding[colouring].apply(get_size)

    all_treatments = list(embedding[colouring].unique())
    sorted_treatments = all_treatments.copy()
    specific_value = '[DMSO]'
    if specific_value in sorted_treatments:
        sorted_treatments.remove(specific_value)
    sorted_treatments.insert(0, specific_value)

    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    for treatment in sorted_treatments:
        subset = embedding[embedding[colouring] == treatment]
        
        sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=get_color(treatment), legend=False)
        sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=get_color(treatment), legend=False)

    for treatment in sorted_treatments:
        subset = embedding[embedding[colouring] == treatment]
        g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], label=f"{treatment} - {len(subset)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)

    # Overlay additional points if the option is active
    if overlay and overlay_df is not None:
        overlay_df['color'] = overlay_df[colouring].apply(get_color)
        # Increase the size for the overlay points
        overlay_df['size'] = overlay_df[colouring].apply(lambda val: get_size(val) * 2)  
        
        for treatment in sorted_treatments:
            subset = overlay_df[overlay_df[colouring] == treatment]
            g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], alpha=0.9, edgecolor='grey', linewidth=0.5)

    g.ax_joint.set_title(f"{cmpd} {samp_meth} rate: {samp_rate}")
    g.ax_joint.legend()
    filename = f"grit_umap_specs3k_{cmpd}_{samp_meth}_{samp_rate}.png"
    plt.savefig(os.path.join("Figures_SPECS3K", "sampling_benchmark",filename), dpi=300, bbox_inches='tight')
    #plt.show()

In [None]:
import time
import torch
sampling_rates = [0.1, 0.3, 0.5, 1]
sampling_types = ["normal", "dmso_only", "equal_sampling"]
compounds = ["[CA-0]", "[TETR"]
for k in compounds:
    for r in sampling_rates:
        for t in sampling_types:
            print(f"Now running compound {k} in mode {t} with rate {r}")
            start_time = time.time()
            dat = sample_compounds(mad_norm_df, grit_filt_df, r, mode = t)
            dat = dat.filter(pl.col("Metadata_cmpdName").is_in(["[DMSO]", k]))
            print(f"Data shape: {dat.shape}")
            if len(dat) > 200000:
                res = run_umap_and_merge(dat, features_fixed, option = 'standard')
            else:
                res = run_umap_and_merge(dat, features_fixed, option = 'cuml')

            make_jointplot_seaborn_benchmark(res.to_pandas(), "Metadata_cmpdName",k, samp_meth = t, samp_rate = r)
            torch.cuda.empty_cache()

            end_time = time.time()  # Record end time of the iteration
            iteration_time = end_time - start_time 
            print(f"Analysis for {t}, {r} took {iteration_time:.2f} seconds")
        

## UMAP single cell dataset

In [None]:
import cuml
import math
import umap
def run_umap_and_merge(df, features, option = 'cuml', n_neigh = None, min_dist=0.1, n_components=2, metric='cosine', aggregate=False):
    # Filter the DataFrame for features and metadata
    feature_data = df.select(features).to_pandas()
    meta_features = [col for col in df.columns if col not in features]
    meta_data = df.select(meta_features)
    #n_neighbors = 100
    if n_neigh is None:
        n_neighbors = math.ceil(np.sqrt(len(feature_data)))
    # Run UMAP with cuml
    print(f"Starting UMAP with {n_neigh} neighbors")
    if option == "cuml":
        umap_model = cuml.UMAP(n_neighbors=n_neighbors, spread= spread,  min_dist=min_dist, n_components=n_components, metric=metric).fit(feature_data)
        umap_embedding = umap_model.transform(feature_data)
    elif option == "standard":
        umap_model = umap.UMAP(n_neighbors=15, spread = spread, min_dist=min_dist, n_components=n_components, metric=metric, n_jobs = -1)
        umap_embedding = umap_model.fit_transform(feature_data)
    else:
        print(f"Option not available. Please choose 'cuml' or 'standard'")

    #cu_score = cuml.metrics.trustworthiness( feature_data, umap_embedding )
    #print(" cuml's trustworthiness score : ", cu_score )
    
    # Convert UMAP results to DataFrame and merge with metadata
    umap_df = pl.DataFrame(umap_embedding)

    old_column_name = umap_df.columns[0]
    old_column_name2 = umap_df.columns[1]
    # Rename the column
    new_column_name = "UMAP1"
    new_column_name2 = "UMAP2"
    umap_df = umap_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})

    merged_df = pl.concat([meta_data, umap_df], how="horizontal")


    if aggregate:
        print("Aggregating data")
        aggregated_data = (df.groupby(['Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName']).agg([pl.col(feature).mean().alias(feature) for feature in features]))
        aggregated_data = aggregated_data.to_pandas()
        print(aggregated_data)
        aggregated_umap_embedding = umap_model.transform(aggregated_data[features])
        umap_agg = pl.DataFrame(aggregated_umap_embedding)
        umap_agg = umap_agg.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})

        aggregated_meta_data = pl.DataFrame(aggregated_data[['Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName']])
        merged_agg = pl.concat([aggregated_meta_data, umap_agg], how="horizontal")
        return merged_df, merged_agg

    else:
        return merged_df

In [None]:
def umap_projection(df, features, option = 'cuml', n_neigh = None, min_dist=0.1, n_components=2, metric='cosine', aggregate=False):
    # Filter the DataFrame for features and metadata
    feature_data = df.select(features).to_pandas()
    meta_features = [col for col in df.columns if col not in features]
    meta_data = df.select(meta_features)
    dmso_data = feature_data
    #n_neighbors = 100
    if n_neigh is None:
        n_neighbors = math.ceil(np.sqrt(len(feature_data)))
    # Run UMAP with cuml
    print(f"Starting UMAP with {n_neigh} neighbors")
    if option == "cuml":
        umap_model = cuml.UMAP(n_neighbors=n_neighbors, spread= spread,  min_dist=min_dist, n_components=n_components, metric=metric).fit(feature_data)
        umap_embedding = umap_model.transform(feature_data)
    elif option == "standard":
        umap_model = umap.UMAP(n_neighbors=15, spread = spread, min_dist=min_dist, n_components=n_components, metric=metric, n_jobs = -1)
        umap_embedding = umap_model.fit_transform(feature_data)
    else:
        print(f"Option not available. Please choose 'cuml' or 'standard'")

    #cu_score = cuml.metrics.trustworthiness( feature_data, umap_embedding )
    #print(" cuml's trustworthiness score : ", cu_score )
    
    # Convert UMAP results to DataFrame and merge with metadata
    umap_df = pl.DataFrame(umap_embedding)

    old_column_name = umap_df.columns[0]
    old_column_name2 = umap_df.columns[1]
    # Rename the column
    new_column_name = "UMAP1"
    new_column_name2 = "UMAP2"
    umap_df = umap_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})

    merged_df = pl.concat([meta_data, umap_df], how="horizontal")


    if aggregate:
        print("Aggregating data")
        aggregated_data = (df.groupby(['Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName']).agg([pl.col(feature).mean().alias(feature) for feature in features]))
        aggregated_data = aggregated_data.to_pandas()
        print(aggregated_data)
        aggregated_umap_embedding = umap_model.transform(aggregated_data[features])
        umap_agg = pl.DataFrame(aggregated_umap_embedding)
        umap_agg = umap_agg.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})

        aggregated_meta_data = pl.DataFrame(aggregated_data[['Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName']])
        merged_agg = pl.concat([aggregated_meta_data, umap_agg], how="horizontal")
        return merged_df, merged_agg

    else:
        return merged_df

In [None]:
import cudf
from sklearn.preprocessing import LabelEncoder

def supervised_umap(df, features, spread = 4, min_dist=0.1, n_components=2, metric='cosine', aggregate=False):
    # Filter the DataFrame for features and metadata
    feature_data = df.select(features).to_pandas()
    meta_features = [col for col in df.columns if col not in features]
    meta_data = df.select(meta_features)
    n_neighbors = math.ceil(np.sqrt(len(feature_data)))
    #n_neighbors = 150
    features_df = cudf.DataFrame(feature_data)
    le = LabelEncoder()
    labels = le.fit_transform(df["Metadata_cmpdName"])
    labels_series = cudf.Series(labels)
    # Run UMAP with cuml
    print("Starting UMAP")
    umap_model = cuml.UMAP(n_neighbors=n_neighbors, spread = spread, min_dist=min_dist, n_components=n_components, metric=metric)

    umap_embedding = umap_model.fit_transform(features_df, y=labels_series)

    #umap_model = umap.UMAP(n_neighbors=n_neighbors, spread = spread, min_dist=min_dist, n_components=n_components, metric=metric, n_jobs = -1)
    #umap_embedding = umap_model.fit_transform(feature_data)

    #cu_score = cuml.metrics.trustworthiness( feature_data, umap_embedding )
    #print(" cuml's trustworthiness score : ", cu_score )
    
    # Convert UMAP results to DataFrame and merge with metadata
    umap_df = pl.DataFrame(umap_embedding.to_pandas())

    old_column_name = umap_df.columns[0]
    old_column_name2 = umap_df.columns[1]
    # Rename the column
    new_column_name = "UMAP1"
    new_column_name2 = "UMAP2"
    umap_df = umap_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})

    merged_df = pl.concat([meta_data, umap_df], how="horizontal")

    return merged_df

In [None]:
import pickle
with open('SPECS3K_ref_umaps.pkl', 'rb') as pickle_file:
   umap_dict = pickle.load(pickle_file)

In [None]:
grit_filter_df_normal = sample_compounds(mad_norm_df, grit_filt_df, sampling_rate= 1, mode = "normal")

In [None]:
#grit_filter_df_full.group_by(pl.col("Metadata_cmpdName")).count()

In [None]:
big_umap = run_umap_and_merge(grit_filter_df_full, features_fixed, option = "cuml", min_dist = 0.5)

In [None]:
make_jointplot(big_umap.to_pandas(), "Metadata_cmpdName", cmpd ="", save_path = False)

## Compound embeddings

### Simple UMAPs

In [None]:
import time
method = "normal"
compounds = list(grit_filt_df["Metadata_cmpdName"].unique())
compounds.sort()
compounds.remove("[DMSO]")
compounds.remove("[SORB]")
features_fixed = [feat for feat in grit_filt_df.columns if "Feature" in feat]
print(compounds)
#min_dist_value = 0.01
#n_neighbors_value = 130

for k in compounds:
    start_time = time.time()
    dat = subsample_data_dmso.filter(pl.col('Metadata_cmpdName').is_in(['[DMSO]', k]))
    print(dat.shape)
    if method == "harmony":
        ann_pca = create_anndata(dat, aggregated = True)
        pca_df = harmony(ann_pca, "Plate", "harm", pca = False)
        harmony_df = obsm_to_df(pca_df, obsm_key= "harm", columns = None)
        harmony_features = [col for col in harmony_df.columns if "harm" in col]
        res = run_umap(harmony_df, harmony_features) 
    elif method == "median":
        res, res_agg = run_umap_and_merge(dat, features_fixed, aggregate= True)
        make_jointplot_seaborn_specs(res.to_pandas(), "Metadata_cmpdName", overlay = True, overlay_df = res_agg.to_pandas(), cmpd = k)
    elif method == "normal":
        res = run_umap_and_merge(dat, features_fixed)
        make_jointplot_seaborn_specs(res.to_pandas(), "Metadata_cmpdName",k, overlay_df=None)
        torch.cuda.empty_cache()
    else:
        print("Select UMAP method")

    end_time = time.time()  # Record end time of the iteration
    iteration_time = end_time - start_time 
    print(f"Analysis for {k} took {iteration_time:.2f} seconds")

### UMAP Color by grit

In [None]:
import time
import torch
method = "normal"
compounds = list(grit_filt_df["Metadata_cmpdName"].unique())
compounds.sort()
compounds.remove("[DMSO]")
compounds.remove("[SORB]")
features_fixed = [feat for feat in grit_filt_df.columns if "Feature" in feat]
print(compounds)

umap_dfs = {}
for k in compounds:
    start_time = time.time()
    dat = grit_filter_df_full.filter(pl.col("Metadata_cmpdName").is_in(["[DMSO]", k]))
    print(dat.shape)
    if method == "harmony":
        ann_pca = create_anndata(dat, aggregated = True)
        pca_df = harmony(ann_pca, "Plate", "harm", pca = False)
        harmony_df = obsm_to_df(pca_df, obsm_key= "harm", columns = None)
        harmony_features = [col for col in harmony_df.columns if "harm" in col]
        res = run_umap(harmony_df, harmony_features) 
    elif method == "median":
        res, res_agg = run_umap_and_merge(dat, features_fixed, aggregate= True)
        make_jointplot_seaborn_specs(res.to_pandas(), "Metadata_cmpdName", overlay = True, overlay_df = res_agg.to_pandas(), cmpd = k)
    elif method == "normal":
        res = run_umap_and_merge(dat, features_fixed)
        umap_dfs[k] = res
        make_jointplot_seaborn_grit(res.to_pandas(), "Metadata_cmpdName",k, name = "grit_filter", overlay_df=None, color_by_other_column=True, other_column="grit")
        torch.cuda.empty_cache()
    else:
        print("Select UMAP method")

    end_time = time.time()  # Record end time of the iteration
    iteration_time = end_time - start_time 
    print(f"Analysis for {k} took {iteration_time:.2f} seconds")

### DMSO projections (of reference compounds)

In [None]:
import math
import polars as pl
def umap_projection(df, features, option = 'cuml', n_neigh = None, spread = 4, min_dist=0.1, n_components=2, metric='cosine', aggregate=False):
    # Filter the DataFrame for features and metadata
    dmso_data = df.filter(pl.col("Metadata_cmpdName") == "[DMSO]")
    feature_dmso = dmso_data.select(features).to_pandas()
    meta_features = [col for col in df.columns if col not in features]
    meta_dmso = dmso_data.select(meta_features)
    cmpd = list(df["Metadata_cmpdName"].unique())
    cmpd.sort()
    embedding_dict = {}
    if n_neigh is None:
        n_neighbors = math.ceil(np.sqrt(len(df)))
    # Run UMAP with cuml
    print(f"Starting UMAP with {n_neigh} neighbors")
    if option == "cuml":
        umap_model = cuml.UMAP(n_neighbors=n_neighbors, spread= spread,  min_dist=min_dist, n_components=n_components, metric=metric)
        dmso_embedding = umap_model.fit_transform(feature_dmso)
        dmso_df = pl.DataFrame(dmso_embedding)
        old_column_name = dmso_df.columns[0]
        old_column_name2 = dmso_df.columns[1]
        new_column_name = "UMAP1"
        new_column_name2 = "UMAP2"
        dmso_df = dmso_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})
        dmso_merged = pl.concat([dmso_df, meta_dmso], how="horizontal")
        for c in cmpd:
            if c  == "[DMSO]":
                continue
            print("Now running UMAP for:", c)
            subset = df.filter(pl.col("Metadata_cmpdName") == c)
            feature_df = subset.select(features).to_pandas()
            meta = df.select(meta_features)
            umap_embedding = umap_model.transform(feature_df)
            #Create full df
            umap_df = pl.DataFrame(umap_embedding)
            umap_df = umap_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})
            temp_df = pl.concat([umap_df, meta], how="horizontal")
            merged_df = pl.concat([temp_df, dmso_merged], how="vertical")
            embedding_dict[c] = merged_df
    elif option == "standard":
        umap_model = umap.UMAP(n_neighbors=15, spread = spread, min_dist=min_dist, n_components=n_components, metric=metric, n_jobs = -1)
        umap_embedding = umap_model.fit_transform(feature_data)
    else:
        print(f"Option not available. Please choose 'cuml' or 'standard'")

    return embedding_dict

In [None]:
def umap_projection(df, features, option='cuml', n_neigh=None, spread=4, min_dist=0.1, n_components=2, metric='cosine', aggregate=False):
    # Filter the DataFrame for features and metadata
    dmso_data = df.filter(pl.col("Metadata_cmpdName") == "[DMSO]")
    feature_dmso = dmso_data.select(features).to_pandas()
    meta_features = [col for col in df.columns if col not in features]
    meta_dmso = dmso_data.select(meta_features)
    cmpd = list(df["Metadata_cmpdName"].unique())
    cmpd.sort()
    embedding_dict = {}

    if n_neigh is None:
        n_neighbors = math.ceil(np.sqrt(len(df)))
    else:
        n_neighbors = n_neigh

    print(f"Starting UMAP with {n_neighbors} neighbors")

    if option == "cuml":
        umap_model = cuml.UMAP(n_neighbors=n_neighbors, spread=spread, min_dist=min_dist, n_components=n_components, metric=metric)
        dmso_embedding = umap_model.fit_transform(feature_dmso)
        dmso_df = pl.DataFrame(dmso_embedding)
        old_column_name = dmso_df.columns[0]
        old_column_name2 = dmso_df.columns[1]
        new_column_name = "UMAP1"
        new_column_name2 = "UMAP2"
        dmso_df = dmso_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})
        dmso_merged = pl.concat([dmso_df, meta_dmso], how="horizontal")

        for c in cmpd:
            if c == "[DMSO]":
                continue

            print("Now running UMAP for:", c)
            subset = df.filter(pl.col("Metadata_cmpdName") == c)
            feature_df = subset.select(features).to_pandas()
            meta = subset.select(meta_features)
            umap_embedding = umap_model.transform(feature_df)
            umap_df = pl.DataFrame(umap_embedding)
            umap_df = umap_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})
            temp_df = pl.concat([umap_df, meta], how="horizontal")
            embedding_dict[c] = temp_df

    return embedding_dict, dmso_merged

In [None]:
emb_dict, dmos_merged = umap_projection(dat, features_fixed)

In [None]:
import time
import torch
start_time = time.time()
for key, value in emb_dict.items():
    dat = pl.concat([value, dmos_merged], how ="vertical")
    #make_jointplot_seaborn_grit(dat.to_pandas(), "Metadata_cmpdName",key, save = False, name = "", overlay_df=None)
    make_jointplot_seaborn_grit(dat.to_pandas(), "Metadata_cmpdName",key, save  = False, name = "grit_filter", overlay_df=None, color_by_other_column=True, other_column="grit")
end_time = time.time()  # Record end time of the iteration
iteration_time = end_time - start_time 
print(f"Analysis for {k} took {iteration_time:.2f} seconds")

## Joinplot

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt


import seaborn as sns
import matplotlib.pyplot as plt

def make_jointplot_seaborn_RH30(embedding, colouring, cmpd):
    
    # Define your custom color mapping based on conditions
    def get_color(val):
        if "_1" in val:
            return "#28B6D2"
        elif "_5" in val:
            return "#e96565"
        elif val == "DMSO_0":
            return "lightgrey"
        else:
            return "grey"
    
    # Size mapping
    def get_size(val):
        return 20 if val != "DMSO_0" else 10
    
    embedding['color'] = embedding[colouring].apply(get_color)
    embedding['size'] = embedding[colouring].apply(get_size)
    
    # Create the base joint plot
    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    # Plot KDE plots for each category
    for treatment in embedding[colouring].unique():
        subset = embedding[embedding[colouring] == treatment]
        
        sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=get_color(treatment), legend=False)
        sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=get_color(treatment), legend=False)

    # Ensure plotting order: first DMSO_0.1%, then _1, and finally _5
    for treatment in ["DMSO_0", "_1", "_5"]:
        subset = embedding[embedding[colouring].str.contains(treatment)]
        g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], label=f"{treatment} - {len(embedding)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)
    
    g.ax_joint.set_title(cmpd)
    g.ax_joint.legend()

    plt.show()

import seaborn as sns
import matplotlib.pyplot as plt

def make_jointplot_seaborn_specs(embedding, colouring, cmpd, overlay=False, overlay_df=None):
    
    def get_color(val):
        if "[DMSO]" in val:
            return "lightgrey"
        else:
            return "#e96565"
    
    def get_size(val):
        return 20 if val != "[DMSO]" else 10
    
    embedding['color'] = embedding[colouring].apply(get_color)
    embedding['size'] = embedding[colouring].apply(get_size)

    all_treatments = list(embedding[colouring].unique())
    sorted_treatments = all_treatments.copy()
    specific_value = '[DMSO]'
    if specific_value in sorted_treatments:
        sorted_treatments.remove(specific_value)
    sorted_treatments.insert(0, specific_value)

    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    for treatment in sorted_treatments:
        subset = embedding[embedding[colouring] == treatment]
        
        sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=get_color(treatment), legend=False)
        sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=get_color(treatment), legend=False)

    for treatment in sorted_treatments:
        subset = embedding[embedding[colouring] == treatment]
        g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], label=f"{treatment} - {len(subset)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)

    # Overlay additional points if the option is active
    if overlay and overlay_df is not None:
        overlay_df['color'] = overlay_df[colouring].apply(get_color)
        # Increase the size for the overlay points
        overlay_df['size'] = overlay_df[colouring].apply(lambda val: get_size(val) * 2)  
        
        for treatment in sorted_treatments:
            subset = overlay_df[overlay_df[colouring] == treatment]
            g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], alpha=0.9, edgecolor='grey', linewidth=0.5)

    g.ax_joint.set_title(cmpd)
    g.ax_joint.legend()
    plt.show()



import matplotlib.cm as cm
import matplotlib.colors as mcolors

def make_jointplot_seaborn_grit(embedding, colouring, cmpd, name, save = True, overlay=False, overlay_df=None, color_by_other_column=False, other_column=None):
    
    def get_continuous_color_map(data, cmap_name="viridis"):
        # Create a color map based on the range of the data
        cmap = plt.get_cmap(cmap_name)
        min_val, max_val = data.min(), data.max()
        norm = mcolors.Normalize(vmin=min_val, vmax=max_val)
        color_map = cm.ScalarMappable(norm=norm, cmap=cmap)
        return color_map

    # Create a color map if coloring by another column
    if color_by_other_column and other_column is not None:
        color_map = get_continuous_color_map(embedding[other_column])
        embedding['color'] = embedding.apply(lambda row: color_map.to_rgba(row[other_column]) if row[colouring] != "[DMSO]" else "lightgrey", axis=1)
    else:
        color_map, color_dict = None, None
        embedding['color'] = embedding[colouring].apply(lambda val: "lightgrey" if "[DMSO]" in val else "#e96565")


    #embedding['color'] = embedding.apply(lambda row: "lightgrey" if "[DMSO]" in row[colouring] else color_map[row[other_column]], axis=1) if color_by_other_column and other_column is not None else embedding[colouring].apply(lambda val: "lightgrey" if "[DMSO]" in val else "#e96565")
    embedding['size'] = embedding[colouring].apply(lambda val: 20 if val != "[DMSO]" else 10)

    all_treatments = list(embedding[colouring].unique())
    sorted_treatments = all_treatments.copy()
    specific_value = '[DMSO]'
    if specific_value in sorted_treatments:
        sorted_treatments.remove(specific_value)
    sorted_treatments.insert(0, specific_value)

    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    for treatment in sorted_treatments:
        subset = embedding[embedding[colouring] == treatment]
        color = subset['color'].iloc[0] if color_by_other_column and other_column is not None else "lightgrey" if "[DMSO]" in treatment else "#e96565"
        sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=color, legend=False)
        sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=color, legend=False)

    for treatment in sorted_treatments:
        subset = embedding[embedding[colouring] == treatment]
        g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], label=f"{treatment} - {len(subset)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)

    # Overlay additional points if the option is active
    if overlay and overlay_df is not None:
        overlay_df['color'] = overlay_df.apply(lambda row: "lightgrey" if "[DMSO]" in row[colouring] else color_map[row[other_column]], axis=1) if color_by_other_column and other_column is not None else overlay_df[colouring].apply(lambda val: "lightgrey" if "[DMSO]" in val else "#e96565")
        overlay_df['size'] = overlay_df[colouring].apply(lambda val: 20 if val != "[DMSO]" else 10)

        for treatment in sorted_treatments:
            subset = overlay_df[overlay_df[colouring] == treatment]
            g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], alpha=0.9, edgecolor='grey', linewidth=0.5)

    g.ax_joint.set_title(cmpd)
    g.ax_joint.tick_params(axis='both', labelsize=10)
    g.ax_marg_x.tick_params(axis='x', labelsize=10)
    g.ax_marg_y.tick_params(axis='y', labelsize=10)
    legend = g.ax_joint.legend(fontsize=10)
    legend.get_frame().set_facecolor('white')
    #g.ax_joint.legend()

    if color_by_other_column and color_map is not None:
        cbar_ax = g.fig.add_axes([1, 0.25, 0.02, 0.5])
        plt.colorbar(color_map, cax=cbar_ax,orientation='vertical', label=other_column)
    filename = f"grit_umap_specs3k_{cmpd}_{name}.png"
    if save == True:
        plt.savefig(os.path.join("Figures_SPECS3K", filename), dpi=300, bbox_inches='tight')
    plt.show()

## Datashader analysis

In [None]:
import datashader as ds
import datashader.transfer_functions as tf
from datashader.colors import Elevation

In [None]:
def make_jointplot_datashader(embedding, colouring, cmpd, name, overlay=False, overlay_df=None):
    # Step 1: Prepare the Data
    non_dmso = embedding[embedding[colouring] != "[DMSO]"]
    dmso = embedding[embedding[colouring] == "[DMSO]"]

    # Step 2: Create a Datashader Image for non-[DMSO] points
    cvs = ds.Canvas(plot_width=400, plot_height=400)
    agg = cvs.points(non_dmso, 'UMAP1', 'UMAP2', ds.count())
    img = tf.shade(agg, how='eq_hist')

    # Convert Datashader image to a format that can be displayed with Matplotlib
    img = tf.set_background(img, 'white')
    img = tf.dynspread(img, threshold=0.5, max_px=4)
    img = img.to_pil()

    # Step 3: Overlay [DMSO] Points
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(img, extent=[non_dmso['UMAP1'].min(), non_dmso['UMAP1'].max(), non_dmso['UMAP2'].min(), non_dmso['UMAP2'].max()])
    ax.scatter(dmso['UMAP1'], dmso['UMAP2'], color='grey', s=10, edgecolor='white')

    # Step 4: Manually Add Marginal Plots
    sns.kdeplot(non_dmso['UMAP1'], ax=ax, bw_adjust=0.5, fill=True, vertical=False, color='blue', alpha=0.3)
    sns.kdeplot(non_dmso['UMAP2'], ax=ax, bw_adjust=0.5, fill=True, vertical=True, color='blue', alpha=0.3)

    # Overlay and plot settings
    if overlay and overlay_df is not None:
        # Additional overlay points handling
        pass

    ax.set_title(f"{cmpd} - {name}")
    plt.show()

In [None]:
def using_datashader(ax, x, y):

    df = umap_dict["all"].to_pandas()
    dsartist = dsshow(
        df,
        ds.Point("UMAP1", "UMAP2"),
        ds.count(),
        norm="linear",
        aspect="auto",
        ax=ax,
    )

    plt.colorbar(dsartist)
fig, ax = plt.subplots()
using_datashader(ax, "UMAP1","UMAP2")
plt.show()

In [None]:
import mpl_scatter_density
from matplotlib.colors import LinearSegmentedColormap
white_viridis = LinearSegmentedColormap.from_list('white_viridis', [
    (0, '#ffffff'),
    (1e-20, '#440053'),
    (0.2, '#404388'),
    (0.4, '#2a788e'),
    (0.6, '#21a784'),
    (0.8, '#78d151'),
    (1, '#fde624'),
], N=10000)

def using_mpl_scatter_density(fig,df):
    ax = fig.add_subplot(1, 1, 1, projection='scatter_density')
    density = ax.scatter_density(df["UMAP1"], df["UMAP2"], cmap=white_viridis)
    fig.colorbar(density, label='Number of points per pixel')

fig = plt.figure()
using_mpl_scatter_density(fig,umap_dict["[BERB]"].to_pandas())
plt.show()

## Density analysis

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
import numpy as np

def make_jointplot_seaborn_density(embedding, colouring, cmpd, overlay=False, overlay_df=None):
    
    def get_color(val):
        if "[DMSO]" in val:
            return "lightgrey"
        else:
            return "#e96565"  # This color will be overridden for non-[DMSO] treatments
    
    def get_size(val):
        return 20 if val != "[DMSO]" else 10
    
    embedding['color'] = embedding[colouring].apply(get_color)
    embedding['size'] = embedding[colouring].apply(get_size)

    all_treatments = list(embedding[colouring].unique())
    sorted_treatments = all_treatments.copy()
    specific_value = '[DMSO]'
    if specific_value in sorted_treatments:
        sorted_treatments.remove(specific_value)
    sorted_treatments.insert(0, specific_value)

    g = sns.JointGrid(x='UMAP1', y='UMAP2', data=embedding, height=10)

    #cmap = plt.cm.viridis
    cmap = plt.cm.jet
    norm = plt.Normalize(vmin=0, vmax=1)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    for treatment in sorted_treatments:
        subset = embedding[embedding[colouring] == treatment]
        
        if treatment != '[DMSO]':
            # Calculate density for non-[DMSO] treatments
            values = np.vstack([subset["UMAP1"], subset["UMAP2"]])
            kernel = stats.gaussian_kde(values)(values)
            colors = cmap(kernel)

            # Plot KDE for x and y axes
            sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color=colors.mean(axis=0), legend=False)
            sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color=colors.mean(axis=0), legend=False)

            # Scatter plot with density color
            g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=kernel, s=subset['size'], cmap=cmap, label=f"{treatment} - {len(subset)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)
        else:
            # Plot for [DMSO] treatment
            sns.kdeplot(x=subset["UMAP1"], ax=g.ax_marg_x, fill=True, color='lightgrey', legend=False)
            sns.kdeplot(y=subset["UMAP2"], ax=g.ax_marg_y, fill=True, color='lightgrey', legend=False)
            g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c='lightgrey', s=subset['size'], label=f"{treatment} - {len(subset)} cells", alpha=0.7, edgecolor='white', linewidth=0.5)
    # Overlay additional points if the option is active
    if overlay and overlay_df is not None:
        overlay_df['color'] = overlay_df[colouring].apply(get_color)
        overlay_df['size'] = overlay_df[colouring].apply(lambda val: get_size(val) * 2)  
        
        for treatment in sorted_treatments:
            subset = overlay_df[overlay_df[colouring] == treatment]
            g.ax_joint.scatter(subset["UMAP1"], subset["UMAP2"], c=subset['color'], s=subset['size'], alpha=0.9, edgecolor='grey', linewidth=0.5)

    #plt.colorbar(sm, ax=g.ax_joint, pad=0.05, aspect=10)

    fig = g.fig  # Get the figure of the JointGrid
    cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.7])  # Add axes for the colorbar

    # Add colorbar to the figure, not the joint plot axes
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Relative density', rotation=270, labelpad=15) 

    # Adjust the figure to make space for the colorbar
    fig.subplots_adjust(right=0.9)

    g.ax_joint.set_title(cmpd)
    g.ax_joint.legend()
    plt.show()


In [None]:
for key, value in emb_dict.items():
    print("Plotting compound", key)
    dat = pl.concat([value, dmos_merged], how ="vertical")
    make_jointplot_seaborn_density(dat.to_pandas(), "Metadata_cmpdName", cmpd = key)