In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import sklearn.metrics

import scipy.linalg
import scipy.spatial.distance
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc('ytick', labelsize=7)
import seaborn as sns
import os
import psutil
import polars as pl
import cuml
import pandas as pd
from typing import List
from sklearn.preprocessing import StandardScaler
import numpy as np 
from collections import OrderedDict

In [None]:
import sys
sys.path.append("..")
from analysis_functions.plotting import *
from analysis_functions.sampling import *
from analysis_functions.utils import *


In [None]:
featdir = "outputs/results/"
PROJECT_ROOT = "/share/data/analyses/benjamin/Single_cell_project/DP_BEACTICA/"
RAPIDS_ROOT = "/share/data/analyses/benjamin/Single_cell_project_rapids/Beactica/"
REG_PARAM = 1e-2

In [None]:
def find_file_with_string(directory, string):
    """
    Finds a file in the specified directory that contains the given string in its name.

    Args:
    directory (str): The directory to search in.
    string (str): The string to look for in the file names.

    Returns:
    str: The path to the first file found that contains the string. None if no such file is found.
    """
    # Check if the directory exists
    if not os.path.exists(directory):
        print(f"The directory {directory} does not exist.")
        return None

    # Iterate through all files in the directory
    for file in os.listdir(directory):
        if string in file:
            return os.path.join(directory, file)

    # Return None if no file is found
    return print(f"No file found with {string}")


import colorcet as cc
def make_plot_custom(embedding, colouring, save_dir=False, file_name="file_name", name="Emb type", description="details"):
    # Set the background to white
    sns.set(style="whitegrid", rc={"figure.figsize": (18, 12),'figure.dpi': 300, "axes.facecolor": "white", "grid.color": "white"})
    
    # Create a custom palette for the treatments of interest
    unique_treatments = set(embedding[colouring])
    custom_palette = sns.color_palette(cc.glasbey, len(unique_treatments))
    color_dict = {treatment: color for treatment, color in zip(unique_treatments, custom_palette)}
    
    # Make the "Control" group grey
    if "DIMETHYL SULFOXIDE" in color_dict:
        color_dict["DIMETHYL SULFOXIDE"] = "black"
    
    # Create a size mapping
    size_dict = {treatment: 20 if treatment != "DIMETHYL SULFOXIDE" else 15 for treatment in unique_treatments}
    embedding['size'] = embedding[colouring].map(size_dict)
    
    # Create the scatter plot
    sns_plot = sns.scatterplot(data=embedding, x="UMAP1", y="UMAP2", hue=colouring, size='size', palette=color_dict, sizes=(8, 25), linewidth=0.1, alpha=0.6)
    sns_plot.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    sns_plot.set_title("Beactica UMAP embedding of "+str(len(embedding))+" data points" + " \n"+description, fontsize=12)
    sns.move_legend(sns_plot, "lower left", title='Treatments', prop={'size': 10}, title_fontsize=12, markerscale=0.5)
    
    # Remove grid lines
    sns.despine(bottom=True, left=True)
    plt.legend(loc='upper left', bbox_to_anchor=(1,1))
    if save_dir == True:
        # Save the figure with the specified DPI
        sns_plot.figure.savefig(f"{save_dir}{file_name}{name}.png", dpi=600)  # Changed DPI to 600
        #sns_plot.figure.savefig(f"{save_dir}pdf_format/{file_name}{name}.pdf", dpi=600)  # Changed DPI to 600
    
    plt.show()


## Load metadata and features

In [None]:
meta = pd.read_csv(os.path.join(PROJECT_ROOT, "inputs", "metadata", "metadata_deepprofiler_beactica.csv")).drop_duplicates(inplace = False)
meta = meta.sort_values(by=['Metadata_Well', 'Metadata_Site'])
meta['Metadata_cmpdName'] = meta['Metadata_cmpdName'].str.upper()
meta["Metadata_cmpdNameConc"] = meta["Metadata_cmpdName"] +   " " + meta["Metadata_cmpdConc"].astype(str)
meta_pl = pl.DataFrame(meta).drop('Unnamed: 0.1', 'Unnamed: 0', "AR", "ER", "RNA", "AGP", "DNA", "Mito", "libtxt",  'solvent',
 'stock_conc',
 'stock_conc_unit',
 'cmpd_vol',
 'cmpd_vol_unit',
 'well_vol',
 'well_vol_unit',
 'treatment_h',
 "cat",
 "cmpd_conc_unit",
 "pertType")
meta_pl = meta_pl.unique()

validation = meta_pl

In [None]:
meta_pl

## Aggregated analysis

In [None]:
import tqdm
import polars as pl
plates = validation["Metadata_Plate"].unique()
feat_out = "Beactica/Results/parquets"
feature_dfs = []
site_df = []
for p in tqdm.tqdm(plates):
    # Construct the file path using a function that finds the correct file
    file_path = os.path.join(RAPIDS_ROOT, "Results", f"sc_profiles_normalized_Beactica_{p}.parquet")
    if file_path is not None:
        feature_df = pl.read_parquet(file_path)
        features_fixed = [f for f in feature_df.columns if "Feature" in f]
        aggregated_df_norm = feature_df.groupby(['Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName']).agg([pl.col(feature).mean().alias(feature) for feature in features_fixed])
        site_level = feature_df.groupby(['Metadata_Plate', 'Metadata_Site', 'Metadata_Well', 'Metadata_cmpdName']).agg([pl.col(feature).mean().alias(feature) for feature in features_fixed])
        feature_dfs.append(aggregated_df_norm)  # Append the result DataFrame to the list
        site_df.append(site_level)
# Concatenate all DataFrames in the list outside the loop
master_df_aggregrated = pl.concat(feature_dfs)
site_df_aggregated = pl.concat(site_df)

### Multiple MoA analysis

In [None]:
features = [col for col in master_df_aggregrated.columns if "Feature" in col]
meta_features = [ col for col in master_df_aggregrated.columns if col not in features]

In [None]:
master_df_aggregrated.estimated_size(unit = "mb") 

In [None]:
master_df_aggregrated = master_df_aggregrated.with_columns(
    [
        pl.col(column).cast(pl.Float32)
        for column in master_df_aggregrated.columns
        if "Feature" in column and master_df_aggregrated[column].dtype == pl.Float64
    ]
)

site_df_aggregated = site_df_aggregated.with_columns(
    [
        pl.col(column).cast(pl.Float32)
        for column in site_df_aggregated.columns
        if "Feature" in column and site_df_aggregated[column].dtype == pl.Float64
    ]
)

## Aggregated UMAP analysis

In [None]:
import cuml
import math
import umap
def run_umap_and_merge(df, features, option = 'cuml', spread = 4, min_dist=0.1, n_components=2, metric='cosine', aggregate=False):
    # Filter the DataFrame for features and metadata
    feature_data = df.select(features).to_pandas()
    meta_features = [col for col in df.columns if col not in features]
    meta_data = df.select(meta_features)
    #n_neighbors = 100
    n_neighbors = math.ceil(np.sqrt(len(feature_data)))
    # Run UMAP with cuml
    print("Starting UMAP")
    if option == "cuml":
        umap_model = cuml.UMAP(n_neighbors=n_neighbors, spread= spread,  min_dist=min_dist, n_components=n_components, metric=metric).fit(feature_data)
        umap_embedding = umap_model.transform(feature_data)
    elif option == "standard":
        umap_model = umap.UMAP(n_neighbors=15, spread = spread, min_dist=min_dist, n_components=n_components, metric=metric, n_jobs = -1)
        umap_embedding = umap_model.fit_transform(feature_data)
    else:
        print(f"Option not available. Please choose 'cuml' or 'standard'")

    #cu_score = cuml.metrics.trustworthiness( feature_data, umap_embedding )
    #print(" cuml's trustworthiness score : ", cu_score )
    
    # Convert UMAP results to DataFrame and merge with metadata
    umap_df = pl.DataFrame(umap_embedding)

    old_column_name = umap_df.columns[0]
    old_column_name2 = umap_df.columns[1]
    # Rename the column
    new_column_name = "UMAP1"
    new_column_name2 = "UMAP2"
    umap_df = umap_df.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})

    merged_df = pl.concat([meta_data, umap_df], how="horizontal")


    if aggregate:
        print("Aggregating data")
        aggregated_data = (df.groupby(['Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName']).agg([pl.col(feature).mean().alias(feature) for feature in features]))
        aggregated_data = aggregated_data.to_pandas()
        print(aggregated_data)
        aggregated_umap_embedding = umap_model.transform(aggregated_data[features])
        umap_agg = pl.DataFrame(aggregated_umap_embedding)
        umap_agg = umap_agg.rename({old_column_name: new_column_name, old_column_name2: new_column_name2})

        aggregated_meta_data = pl.DataFrame(aggregated_data[['Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName']])
        merged_agg = pl.concat([aggregated_meta_data, umap_agg], how="horizontal")
        return merged_df, merged_agg

    else:
        return merged_df

In [None]:
well_umap = run_umap_and_merge(master_df_aggregrated, features, min_dist = 0.4)

In [None]:
make_plot_custom(well_umap.to_pandas(), "Metadata_cmpdName", description = "Well-level")

## Single cells

### Prepare grit

In [None]:
import time
def subsample_dataset_pl(df, grouping_cols, fraction=0.5):
    '''
    Subsample a dataset while preserving the distribution of plates and Metadata_cmpdName using Polars.

    Parameters:
    - df: The original Polars DataFrame.
    - plate_column: The column name representing the plates.
    - cmpd_column: The column name representing the Metadata_cmpdName.
    - fraction: The fraction of data to keep for each group. Default is 0.5 (50%).

    Returns:
    - A subsampled Polars DataFrame.
    '''

    # Start tracking time
    start_time = time.time()

    # Initialize an empty list to store subsampled data from each group
    subsampled_data = []

    # Group by plates and Metadata_cmpdName
    grouped = df.groupby(grouping_cols)

    # For each group, subsample and append to the subsampled_data list, with progress bar
    for name, group in tqdm.tqdm(grouped, desc="Subsampling groups", unit="group"):
        group_size = group.height
        subsample_size = int(group_size * fraction)
        subsampled_group = group.sample(n=subsample_size, seed=42)
        subsampled_data.append(subsampled_group)

    # Concatenate all subsampled groups together
    subsampled_df = pl.concat(subsampled_data)

    # Print running time
    end_time = time.time()
    print(f"Finished in {end_time - start_time:.2f} seconds.")

    return subsampled_df

import polars as pl
import os
import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

def read_and_merge_single_file(df, plate, well, site, location_folder):
    file_path = f"{location_folder}/{plate}/{well}-{site}-Nuclei.csv"
    if os.path.exists(file_path):
        csv_df = pl.read_csv(file_path)
        filter_df = df.filter((pl.col("Metadata_Plate") == plate) &
                              (pl.col("Metadata_Well") == well) &
                              (pl.col("Metadata_Site") == site))
        if len(csv_df) == len(filter_df):
            return pl.concat([filter_df, csv_df], how="horizontal")
    return None

def merge_locations_parallel(df, location_folder, max_workers=10):
    combinations = df.unique(["Metadata_Plate", "Metadata_Well", "Metadata_Site"])
    dfs_to_concat = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Create and submit tasks
        future_to_combination = {
            executor.submit(read_and_merge_single_file, df, comb["Metadata_Plate"], comb["Metadata_Well"], comb["Metadata_Site"], location_folder): comb 
            for comb in combinations.to_dicts()
        }
        
        for future in tqdm.tqdm(as_completed(future_to_combination), total=len(future_to_combination)):
            result = future.result()
            if result is not None:
                dfs_to_concat.append(result)
    
    # Concatenate all DataFrames at once at the end
    out_df = pl.concat(dfs_to_concat, how="vertical")
    return out_df

In [None]:
import tqdm
import polars as pl
plates = validation["Metadata_Plate"].unique()
feat_out = "Beactica/Results/parquets"
feature_dfs = []
for p in tqdm.tqdm(plates):
    # Construct the file path using a function that finds the correct file
    file_path = os.path.join(RAPIDS_ROOT, "Results", f"sc_profiles_normalized_Beactica_{p}.parquet")
    if file_path is not None:
        feature_df = pl.read_parquet(file_path)
        print("Starting location merge")
        feature_df = merge_locations_parallel(feature_df,  "/home/jovyan/share/data/analyses/benjamin/Single_cell_project/DP_BEACTICA/inputs/locations/", max_workers= 20)
        sampled_features = subsample_dataset_pl(feature_df, ["Metadata_Plate", "Metadata_Well", "Metadata_Site"], fraction = 0.05)
        feature_dfs.append(sampled_features)  # Append the result DataFrame to the list
# Concatenate all DataFrames in the list outside the loop
master_df = pl.concat(feature_dfs)

In [None]:
mask = master_df["compound_id"].str.contains("\[|\]")

# Apply the mask and set the entire column value to uppercase if any row contains a bracket
master_df = master_df.with_columns(
    pl.when(mask).then(master_df["compound_id"].str.to_uppercase()).otherwise(master_df["compound_id"]).alias("compound_id")
)

In [None]:
master_df.write_parquet("/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/Beactica/Results/sc_profiles_all_sampled_5%_BEACTICA.parquet")

In [None]:
mad_norm_df = pl.read_parquet(os.path.join(RAPIDS_ROOT,"Results/sc_profiles_all_sampled_5%_BEACTICA.parquet"))

In [None]:
mad_norm_df.group_by("compound_id").count().sort("compound_id")

## Load grit data

In [None]:
plates = validation["Metadata_Plate"].unique()
def load_grit_data(folder, plates):
    """
    Processes Parquet files in the given folder based on whether their filenames contain 
    any of the strings in identifier_list. Merges 'Feature' and 'Metric' data based on 
    a specific column and concatenates with 'Control' data.

    :param folder_path: Path to the folder containing Parquet files.
    :param identifier_list: List of strings to be searched in the file names.
    :param merge_column: Column name on which to merge 'Feature' and 'Metric' data.
    :return: Combined Polars DataFrame.
    """
    feature_dfs = pl.DataFrame()
    metric_dfs = pl.DataFrame()

    # Iterate over files in the directory
    for plate in tqdm.tqdm(plates):
        file_names = [file for file in os.listdir(folder) if plate in file]
        if len(file_names) == 0:
            print(f"Plate {plate} not found")
            continue
        neg_path = [file for file in file_names if "neg_control" in file][0]
        neg_cells = pl.read_parquet(os.path.join(folder, neg_path))["Metadata_Cell_Identity"].unique()
        for i in file_names:
            file_path = os.path.join(folder, i)
            if "sc_features" in i:
                feat = pl.read_parquet(file_path).filter(((pl.col("Metadata_Cell_Identity").is_in(neg_cells))) | ~(pl.col("compound_id") == "[DMSO]"))
                feature_dfs = pl.concat([feature_dfs, feat])
            elif "sc_grit" in i:
                metrics = pl.read_parquet(file_path)
                metrics_treat = metrics.filter(pl.col("group") == pl.col("comp")).drop("comp")
                metrics_ctrl = metrics.filter(pl.col("group") != pl.col("comp")).drop("comp")
                metric_dfs = pl.concat([metric_dfs, metrics_treat, metrics_ctrl])
                #metric_df = pl.read_parquet(i).drop("comp") if metric_df is None else metric_df.vstack(pl.read_parquet(i).drop("comp"))
    
    metric_df = metric_dfs.unique(subset=["Metadata_Cell_Identity"])
    # Merge Feature and Metric DataFrames
    merged_df = feature_dfs.join(metric_df, on="Metadata_Cell_Identity", how= "inner")
    # Concatenate Control DataFrames and merge with the above
    #final_df = pl.concat([merged_df, control_dfs])
    #.unique(subset = ["Metadata_Cell_Identity"])
    #merged_df.write_parquet(os.path.join(folder, "sc_grit_FULL.parquet"))
    return merged_df

In [None]:
grit_cells = load_grit_data(os.path.join(RAPIDS_ROOT,"Results/grit/"), plates)

In [None]:
neg_ctrl_cells = mad_norm_df.filter(pl.col("compound_id") == "[DMSO]")

In [None]:
columns_grit = set(grit_cells.columns)
columns_sample = set(neg_ctrl_cells.columns)

# Find common columns
common_columns = columns_grit.intersection(columns_sample)

# Convert to list if you need it as a list
common_columns_list = list(common_columns)

In [None]:
neg_ctrl_grit = neg_ctrl_cells.join(grit_cells, on = common_columns_list, how = "left")

In [None]:
grit_cells_fixed = pl.concat([grit_cells.filter(pl.col("compound_id") != "[DMSO]"), neg_ctrl_grit])

In [None]:
import polars as pl

# Assume 'df' is your Polars DataFrame and 'your_column' is the column you want to modify

# Define your conditions and the values to set for each condition
condition1 = grit_cells_fixed['Metadata_cmpdConc'] == 0.09  # Example condition 1
value_when_condition1_true = 0.1  # Value to set when condition 1 is true

condition2 = grit_cells_fixed['Metadata_cmpdConc'] == 9.99  # Example condition 2
value_when_condition2_true = 10

condition3 = grit_cells_fixed['Metadata_cmpdConc'] == 0.99  # Example condition 2
value_when_condition3_true = 1
  # Value to set when condition 2 is true

# Define the value when all conditions are false (you can keep the original value or set a new one)
value_when_all_false = grit_cells_fixed['Metadata_cmpdConc']  # Keeping original value

# Apply the conditions and modify the column
grit_fixed_cells = grit_cells_fixed.with_columns(
    pl.when(condition1)
    .then(value_when_condition1_true)
    .when(condition2)
    .then(value_when_condition2_true)
    .when(condition3)
    .then(value_when_condition3_true)
    .otherwise(value_when_all_false)
    .alias('Metadata_cmpdConc')
)

grit_fixed_cells = grit_fixed_cells.filter(
    (pl.col("Nuclei_Location_Center_X") > 150) &
    (pl.col("Nuclei_Location_Center_X") < 2850) &
    (pl.col("Nuclei_Location_Center_Y") > 150) &
    (pl.col("Nuclei_Location_Center_Y") < 2850)
)


In [None]:
grit_fixed_cells.write_parquet(os.path.join(RAPIDS_ROOT,"Results/grit/sc_grit_full_FILTERED.parquet"))

In [None]:
grit_full = pl.read_parquet(os.path.join(RAPIDS_ROOT,"deepprofiler/Results/grit/sc_grit_full_FILTERED.parquet"))

## Merge cellprofiler features

In [None]:
import re
import gc
features_interesting = ['Metadata_Plate',
 'Metadata_Site',
 'Metadata_Well',
 'Location_Center_X_nuclei',
 'Location_Center_Y_nuclei',
 'Metadata_cmpdName',
 'compound_id',
 'Metadata_cmpdConc',
 'Intensity_MeanIntensity_illumCONC_nuclei',
 'Intensity_MeanIntensity_illumHOECHST_nuclei',
 'Intensity_IntegratedIntensity_illumHOECHST_nuclei',
 'Intensity_MeanIntensity_illumMITO_nuclei',
 'Intensity_MeanIntensity_illumPHAandWGA_nuclei',
 'Intensity_MeanIntensity_illumSYTO_nuclei',
  'Intensity_MeanIntensity_illumCONC_cells',
 'Intensity_MeanIntensity_illumHOECHST_cells',
 'Intensity_MeanIntensity_illumMITO_cells',
 'Intensity_MeanIntensity_illumPHAandWGA_cells',
 'Intensity_MeanIntensity_illumSYTO_cells',
 'AreaShape_Area_cells',
 'AreaShape_Area_nuclei',
 ]
 
def is_meta_column(c):
    for ex in '''
        Metadata
        ^Count
        ImageNumber
        Object
        Parent
        Children
        Plate
        compound
        Well
        location
        Location
        _[XYZ]_
        _[XYZ]$
        Phase
        Scale
        Scaling
        Width
        Height
        Group
        FileName
        PathName
        BoundingBox
        URL
        Execution
        ModuleError
        LargeBrightArtefact
    '''.split():
        if re.search(ex, c):
            return True
    return False

def merge_cellprofiler_deepprofiler(cell_locations, feats):
    plates = ['PB000051',
 'PB000047',
 'PB000049',
 'PB000053',
 'PB000046',
 'PB000048',
 'PB000050',
 'PB000052']
    
    cells = []

    for p in tqdm.tqdm(plates):
        loc_filt = cell_locations.filter(pl.col("Metadata_Plate") == p)
        print("Reading in plate", p)
        df = pl.read_parquet(os.path.join(RAPIDS_ROOT, "cellprofiler/feature_parquets", f"sc_profiles_cellprofiler_{p}.parquet"))
        df = df[features_interesting]
        print("Joining with deepprofiler")
        temp = loc_filt.join(df, left_on = ["Metadata_Plate", "Metadata_Site", "Metadata_Well", "Nuclei_Location_Center_X", "Nuclei_Location_Center_Y"], right_on=["Metadata_Plate", "Metadata_Site", "Metadata_Well", "Location_Center_X_nuclei", "Location_Center_Y_nuclei"], how = "inner")
        #temp.write_parquet(os.path.join(PROJECT_PATH, "cellprofiler/feature_parquets", f"sc_profiles_joined_cellprofiler_{p}.parquet"))
        cells.append(temp)
        gc.collect()
    return cells
    #out_matched = load_and_stack_dataframes(cells).unique()
    #return out_matched

In [None]:
pl.read_parquet(os.path.join(RAPIDS_ROOT, "cellprofiler/feature_parquets", f"sc_profiles_cellprofiler_PB000047.parquet"))

In [None]:
import gc
gc.collect()

In [None]:
cells_list = merge_cellprofiler_deepprofiler(grit_full, features_interesting)

In [None]:
cells_merged = pl.concat(cells_list)

In [None]:
cells_merged.write_parquet("sc_grit_merged_cellprofiler.parquet")

## Fix segmentation error

In [None]:
sc_profiles = pl.read_parquet("deepprofiler/Results/sc_grit_merged_cellprofiler.parquet")

In [None]:
sc_profiles

In [None]:
test_small = sc_profiles.filter((pl.col("Metadata_Plate") == "PB000051") & (pl.col("Metadata_Well") == "K18"))

In [None]:
test_small

In [None]:
import polars as pl
import gc
import numpy as np
from sklearn.neighbors import RadiusNeighborsRegressor
import random

def sample_one_per_radius(X, regressor):
    sampled_indices = set()
    for i, point in enumerate(X):
        indices = regressor.radius_neighbors([point], return_distance=False)[0]
        if not any(idx in sampled_indices for idx in indices):
            sampled_index = random.choice(indices)
            sampled_indices.add(sampled_index)
    return sampled_indices

def assign_sampling_labels(df, radius=50):
    # Define the columns to group by
    group_cols = ['Metadata_Plate', 'Metadata_Well', 'Metadata_Site']
    
    # Initialize an empty DataFrame to store results
    results = []
    
    # Iterate over each group
    for group_key, group_df in tqdm.tqdm(df.groupby(group_cols)):
        # Extract the nuclei locations as a NumPy array
        X = group_df.select(['Nuclei_Location_Center_X', 'Nuclei_Location_Center_Y']).to_numpy()
        
        # Initialize the regressor with the specified radius
        regressor = RadiusNeighborsRegressor(radius=radius)
        regressor.fit(X, np.zeros(X.shape[0]))
        
        # Perform sampling
        sampled_indices = sample_one_per_radius(X, regressor)
        
        # Assign labels indicating whether each point was sampled
        sampled_labels = [1 if i in sampled_indices else 0 for i in range(len(group_df))]
        
        # Add the labels back to the DataFrame
        group_df_with_labels = group_df.with_columns(pl.Series("Sampled", sampled_labels))
        
        # Append the processed group to the results DataFrame
        results.append(group_df_with_labels)
        gc.collect()
    results_df = pl.concat(results)
    return results_df

In [None]:
check = assign_sampling_labels(test_small, radius= 50)

In [None]:
check

In [None]:
def plot_sampled_points(df):
    # Ensure DataFrame is filtered for visualization if necessary, or adjust as needed
    
    # Convert to a Pandas DataFrame for easier plotting (optional but often simpler for plotting with matplotlib)
    df_pandas = df.to_pandas()

    # Plot non-sampled points in grey
    plt.scatter(
        df_pandas[df_pandas["Sampled"] == 0]["Nuclei_Location_Center_X"],
        df_pandas[df_pandas["Sampled"] == 0]["Nuclei_Location_Center_Y"],
        color='grey', alpha=0.5, label='Not Sampled', s = 10
    )
    
    # Plot sampled points in orange
    plt.scatter(
        df_pandas[df_pandas["Sampled"] == 1]["Nuclei_Location_Center_X"],
        df_pandas[df_pandas["Sampled"] == 1]["Nuclei_Location_Center_Y"],
        color='tab:orange', label='Sampled', s = 5
    )
    
    plt.xlabel('Nuclei_Location_Center_X')
    plt.ylabel('Nuclei_Location_Center_Y')
    plt.title('Sampled vs Non-Sampled Nuclei Locations')
    plt.gca().invert_yaxis()
    #plt.legend()
    plt.show()

plot_sampled_points(check)

In [None]:
plt.style.use('default')
for i, point in enumerate(X):
        plt.scatter(point[0], point[1], color='tab:orange')  # Sampled points in orange

plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.title('Nuclei Locations with Sampled Points Highlighted')
plt.show()

## Summary statistics

In [None]:
plates = validation["Metadata_Plate"].unique()
count_df_plates = []
for p in tqdm.tqdm(plates):
    # Construct the file path using a function that finds the correct file
    file_path = os.path.join(RAPIDS_ROOT, "Results", f"sc_profiles_normalized_Beactica_{p}.parquet")
    if file_path is not None:
        feature_df = pl.read_parquet(file_path)
        aggregated_df = feature_df.groupby(["Metadata_Plate", 'compound_id']).count()
        count_df_plates.append(aggregated_df)
# Concatenate all DataFrames in the list outside the loop
count_df = pl.concat(count_df_plates)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Assuming 'aggregated_cmpd' is your DataFrame and is already defined.
aggregated_cmpd = grit_full.groupby(['compound_id', 'Metadata_Plate']).count()
aggregated_cmpd = aggregated_cmpd.to_pandas()
aggregated_cmpd['Metadata_Plate'] = aggregated_cmpd['Metadata_Plate'].astype('category')
#aggregated_cmpd = count_df.to_pandas()
# Unique compounds
unique_compounds = aggregated_cmpd['compound_id'].unique()
unique_compounds.sort()

unique_plates = sorted(aggregated_cmpd['Metadata_Plate'].unique())

# Seaborn color palette
palette = sns.color_palette("viridis", n_colors=unique_compounds.size)

# Set up the matplotlib figure with multiple subplots
n_plates = len(unique_plates)
fig, axes = plt.subplots(n_plates, 1, figsize=(20, 5 * n_plates), sharex='all') # Adjust height as needed

# Bar width
bar_width = 0.5

# Loop over each plate and create a bar plot
for i, plate in enumerate(unique_plates):
    subset = aggregated_cmpd[aggregated_cmpd['Metadata_Plate'] == plate]
    
    # Calculate counts, ensuring that the compounds are sorted
    counts = subset.groupby('compound_id')['count'].mean().reindex(unique_compounds).fillna(0)

    # Draw the bars
    bars = axes[i].bar(np.arange(unique_compounds.size), counts, width=bar_width, color=palette)

    # Title and labels
    axes[i].set_title(f'Plate: {plate}')
    axes[i].set_ylabel('Cell Count')

# Set the x-axis ticks to be the sorted names of the compounds once for all subplots
axes[-1].set_xticks(np.arange(unique_compounds.size))
axes[-1].set_xticklabels(unique_compounds, rotation=90)

# Common X label
fig.text(0.5, 0.04, '', ha='center')

plt.tight_layout()

# Save and show the plot
plt.savefig('Figures/cellcount_dist_sampled.png', dpi=300)
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import math

# Assuming grit_cells is your Polars DataFrame, convert it to pandas
df = grit_full.to_pandas()

# Ensure "DIMETHYL SULFOXIDE" is always included
special_compound = "[DMSO]"

# Remove the special compound from the unique compounds list
unique_compounds = df[df['compound_id'] != special_compound]['compound_id'].unique()

# Calculate the number of plots (considering one slot is always taken by the special compound)
num_plots = math.ceil((len(unique_compounds) + 1) / 8)

# Loop over the number of plots needed
for i in range(num_plots):
    # Select the subset of data for this plot
    start_idx = max(i*8 - 1, 0)
    end_idx = start_idx + 7
    subset_compounds = unique_compounds[start_idx:end_idx]

    # Always include the special compound in the first position
    subset_compounds = np.insert(subset_compounds, 0, special_compound)
    subset_df = df[df['compound_id'].isin(subset_compounds)]
    subset_df = subset_df[subset_df["grit"].notnull()]
    # Create a new figure for each plot
    plt.figure(figsize=(24, 15), dpi=300)

    # Create the violin plot for this subset
    sns.violinplot(data=subset_df, x='compound_id', y='grit',
                   palette="GnBu", inner="box", density_norm='area')

    # Set the title and labels for this plot
    plt.title(f'Violin Plot for compound group {i+1}', fontsize=20)
    plt.xlabel('Compound', fontsize=15)
    plt.ylabel('Grit', fontsize=15)
    plt.xticks(rotation=45, fontsize=12)
    plt.yticks(fontsize=12)

    # Save the plot with a naming pattern
    plt.savefig(f'Figures/violin_compound_group_{i+1}.png')
    #plt.show()
    # Close the plot to free memory
    plt.close()


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def generate_dot_plots(df, grouping_col, cmpd_col):
    # Aggregate the data
    summary_df = df.groupby([grouping_col, cmpd_col]).agg(
        num_cells=pd.NamedAgg(column='grit', aggfunc='size'),
        median_grit=pd.NamedAgg(column='grit', aggfunc='median')
    ).reset_index()

    
    # Convert numeric grouping column to string
    if grouping_col and pd.api.types.is_numeric_dtype(summary_df[grouping_col]):
        summary_df[grouping_col] = summary_df[grouping_col].astype(str)
    
    # Get unique compounds
    unique_compounds = summary_df[cmpd_col].unique()
    
    # Number of plots (16 compounds per plot)
    num_plots = len(unique_compounds) // 14 + (1 if len(unique_compounds) % 14 > 0 else 0)
    
    global_min_grit = summary_df['median_grit'].min()
    global_max_grit = summary_df['median_grit'].max()
    norm = plt.Normalize(global_min_grit, global_max_grit)
    sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
    sm.set_array([])  
    
    # Determine representative sizes for the legend
    min_cells = summary_df['num_cells'].min()
    max_cells = summary_df['num_cells'].max()
    representative_sizes = np.linspace(min_cells, max_cells, 5, dtype=int)  # 5 representative sizes
    
    # Create labels for the legend based on representative sizes
    labels = [f'{size} Cells' for size in representative_sizes]
    
    for i in range(num_plots):
        # Get subset of compounds for the current plot
        compounds_subset = unique_compounds[i*14 : (i+1)*14]
        plot_data = summary_df[summary_df[cmpd_col].isin(compounds_subset)]
        
        # Create the plot
        fig, ax = plt.subplots(figsize=(20, 15), dpi=300)
        
        # Create a scatter plot using Seaborn
        scatter = sns.scatterplot(
            data=plot_data, 
            x=grouping_col, y=cmpd_col, 
            size='num_cells', sizes=(100, 500),  # Adjust dot sizes as needed
            hue='median_grit', palette='viridis', 
            hue_norm=(global_min_grit, global_max_grit),
            alpha=0.6, edgecolor='w', ax=ax
        )
        scatter.legend_.remove()  # Remove automatic legend

        # Adjust layout for the plot area
        plt.tight_layout(pad=4)

        # Add color bar using the global normalization
        cax = fig.add_axes([ax.get_position().x1+0.05, ax.get_position().y0, 0.03, ax.get_position().height / 2])
        cbar = fig.colorbar(sm, cax=cax) 
        cbar.set_label('Median Grit')
        
        # Add a legend for dot size above the color bar
        size_legend_ax = fig.add_axes([ax.get_position().x1+0.05, ax.get_position().y0 + ax.get_position().height / 2 + 0.02, 0.03, 0.1])
        for size, label in zip(representative_sizes, labels):
            size_legend_ax.scatter([], [], s=(size-min_cells+1)/max_cells*500, label=label, color='black', alpha=0.6)
        size_legend_ax.legend(title='Number of Cells', loc='center', frameon=False, fontsize='large')
        size_legend_ax.axis('off')
        
        # Adjust subplot parameters
        plt.subplots_adjust(right=0.85)

        # Show the plot
        #plt.show()
        # Optionally save the plot
        plt.savefig(f'Figures/grit_cell_count_group{i+1}.png')


In [None]:
generate_dot_plots(grit_full.to_pandas(), "Metadata_cmpdConc", "compound_id")

In [None]:
grit_fixed_cells.filter(pl.col("Metadata_cmpdName") == "DIMETHYL SULFOXIDE").to_pandas().groupby(["Metadata_cmpdConc", 'Metadata_cmpdName']).agg(
        num_cells=pd.NamedAgg(column='grit', aggfunc='size'),
        median_grit=pd.NamedAgg(column='grit', aggfunc='median')
    )