In [None]:
import polars as pl
import pandas as pd
import os
import sys
import matplotlib.pyplot as plt
import tqdm
import gc

## Import classification compounds

In [None]:
class_comp = pl.read_csv("MoaLive_compoundProp_v5.csv")

In [None]:
class_comp.group_by("moa_broad").count()

In [None]:
specs3k_meta = pl.read_csv("/share/data/analyses/benjamin/Single_cell_project/DP_specs3k/inputs/metadata/Metadata_specs3k_DeepProfiler.csv")

In [None]:
specs3k_meta_big = pl.read_csv("/share/data/analyses/benjamin/Single_cell_project/specs3k/specs3k_metadata.csv")

In [None]:
specs3k_filter = class_comp.join(specs3k_meta_big.select(["cbkid", "batchid", "barcode", "well", "compound_name"]), right_on= "batchid", left_on= "BatchID", how = "left").unique()
specs3k_class_cbkid = list(specs3k_filter["cbkid"].unique())
specs3k_class_comp = specs3k_meta.filter(pl.col("Metadata_cmpdName").is_in(specs3k_class_cbkid + ["[dmso]"]))
specs3k_class_comp = specs3k_class_comp.drop(["moa"]).join(specs3k_filter, left_on= ["Metadata_cmpdName", "Metadata_Well", "Metadata_Plate"], right_on= ["cbkid", "well", "barcode"], how = "left")

In [None]:
specs2k_comp = pl.read_csv("/share/data/analyses/benjamin/Single_cell_project/specs2k_cmpd.csv")
specs2k_meta = pl.read_csv("/share/data/analyses/benjamin/Single_cell_project/DP_specs2k/inputs/metadata/metadata_deepprofilerspecs2k.csv")

In [None]:
specs2k_filter = class_comp.join(specs2k_comp.select(["cbkid", "batchid", "barcode", "well", "compound_name"]), right_on= "batchid", left_on= "BatchID", how = "left").unique()
specs2k_class_cbkid = list(specs2k_filter["cbkid"].unique())
specs2k_class_comp = specs2k_meta.filter(pl.col("Metadata_cmpdName").is_in(specs2k_class_cbkid + ["[dmso]"]))
specs2k_class_comp = specs2k_class_comp.join(specs2k_filter,left_on= ["Metadata_cmpdName", "Metadata_Well", "Metadata_Plate"], right_on= ["cbkid", "well", "barcode"], how = "left")

In [None]:
specs5k_classication_list = pl.concat([specs2k_class_comp.drop(["DNA", "ER", "AGP", "Mito", "RNA"]), specs3k_class_comp.drop(["Unnamed: 0", "DNA", "ER", "AGP", "Mito", "RNA"])])
specs5k_classication_list = specs5k_classication_list.with_columns(specs5k_classication_list['Metadata_cmpdName'].str.to_uppercase())

In [None]:
specs5k_classication_list.write_parquet("specs5k_compound_list.parquet")

In [None]:
specs5k_classication_list = pl.read_parquet("/home/jovyan/share/data/analyses/benjamin/Single_cell_supervised/SPECS_MOA/DeepProfiler/datasets/specs5k_compound_list.parquet")

## Generate feature data

In [None]:
specs3k_feature_path = "/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/SPECS3K/cellprofiler/feature_parquets/"

In [None]:
specs2k_feature_path = "/home/jovyan/share/data/analyses/benjamin/Single_cell_project_rapids/SPECS2K/cellprofiler/feature_parquets"

In [None]:
meta_features = ['Metadata_Plate',
                    'Metadata_cmpdName',
                    'Metadata_Well',
                    'Metadata_Site',
                    'Location_Center_X_nuclei',
                    'Location_Center_Y_nuclei',
                    'ImageNumber_nuclei',
                    'ObjectNumber_nuclei',
                    'Metadata_AcqID_nuclei',
                    'FileName_CONC_nuclei',
                    'FileName_HOECHST_nuclei',
                    'FileName_ICF_CONC_nuclei',
                    'FileName_ICF_HOECHST_nuclei',
                    'FileName_ICF_MITO_nuclei',
                    'FileName_ICF_PHAandWGA_nuclei',
                    'FileName_ICF_SYTO_nuclei',
                    'FileName_MITO_nuclei',
                    'FileName_PHAandWGA_nuclei',
                    'FileName_SYTO_nuclei',
                    'PathName_CONC_nuclei',
                    'PathName_HOECHST_nuclei',
                    'PathName_ICF_CONC_nuclei',
                    'PathName_ICF_HOECHST_nuclei',
                    'PathName_ICF_MITO_nuclei',
                    'PathName_ICF_PHAandWGA_nuclei',
                    'PathName_ICF_SYTO_nuclei',
                    'PathName_MITO_nuclei',
                    'PathName_PHAandWGA_nuclei',
                    'PathName_SYTO_nuclei']
    

cols_to_drop = ['Children_cytoplasm_Count_nuclei',
                    'Location_Center_Z_nuclei',
                    'Neighbors_FirstClosestObjectNumber_Adjacent_nuclei',
                    'Neighbors_SecondClosestObjectNumber_Adjacent_nuclei',
                    'Number_Object_Number_nuclei',
                    'Parent_cells_nuclei',
                    'ImageNumber_cells',
                    'Metadata_AcqID_cells',
                    'FileName_CONC_cells',
                    'FileName_HOECHST_cells',
                    'FileName_ICF_CONC_cells',
                    'FileName_ICF_HOECHST_cells',
                    'FileName_ICF_MITO_cells',
                    'FileName_ICF_PHAandWGA_cells',
                    'FileName_ICF_SYTO_cells',
                    'FileName_MITO_cells',
                    'FileName_PHAandWGA_cells',
                    'FileName_SYTO_cells',
                    'PathName_CONC_cells',
                    'PathName_HOECHST_cells',
                    'PathName_ICF_CONC_cells',
                    'PathName_ICF_HOECHST_cells',
                    'PathName_ICF_MITO_cells',
                    'PathName_ICF_PHAandWGA_cells',
                    'PathName_ICF_SYTO_cells',
                    'PathName_MITO_cells',
                    'PathName_PHAandWGA_cells',
                    'PathName_SYTO_cells',
                    'Children_cytoplasm_Count_cells',
                    'Children_nuclei_Count_cells',
                    'Location_Center_Z_cells',
                    'Neighbors_FirstClosestObjectNumber_Adjacent_cells',
                    'Neighbors_SecondClosestObjectNumber_Adjacent_cells',
                    'Number_Object_Number_cells',
                    'Parent_precells_cells',
                    'ImageNumber_cytoplasm',
                    'Metadata_AcqID_cytoplasm',
                    'FileName_CONC_cytoplasm',
                    'FileName_HOECHST_cytoplasm',
                    'FileName_ICF_CONC_cytoplasm',
                    'FileName_ICF_HOECHST_cytoplasm',
                    'FileName_ICF_MITO_cytoplasm',
                    'FileName_ICF_PHAandWGA_cytoplasm',
                    'FileName_ICF_SYTO_cytoplasm',
                    'FileName_MITO_cytoplasm',
                    'FileName_PHAandWGA_cytoplasm',
                    'FileName_SYTO_cytoplasm',
                    'PathName_CONC_cytoplasm',
                    'PathName_HOECHST_cytoplasm',
                    'PathName_ICF_CONC_cytoplasm',
                    'PathName_ICF_HOECHST_cytoplasm',
                    'PathName_ICF_MITO_cytoplasm',
                    'PathName_ICF_PHAandWGA_cytoplasm',
                    'PathName_ICF_SYTO_cytoplasm',
                    'PathName_MITO_cytoplasm',
                    'PathName_PHAandWGA_cytoplasm',
                    'PathName_SYTO_cytoplasm',
                    'Number_Object_Number_cytoplasm',
                    'Parent_cells_cytoplasm',
                    'Parent_nuclei_cytoplasm']

import re
import pycytominer as pm

def is_meta_column(c):
    for ex in '''
        Metadata
        ^Count
        ImageNumber
        Object
        Parent
        Children
        Plate
        Well
        location
        Location
        _[XYZ]_
        _[XYZ]$
        Phase
        Scale
        Scaling
        Width
        Height
        Group
        FileName
        PathName
        BoundingBox
        URL
        Execution
        ModuleError
        LargeBrightArtefact
    '''.split():
        if re.search(ex, c):
            return True
    return False

def drop_skew(df, columns_to_check, quantile: float=0.8):
    """
    Drop columns based on skewness threshold from a list of specified columns and
    print the number of columns dropped. Validates that columns exist before processing.

    Parameters:
    - df: The input DataFrame.
    - columns_to_check: A list of column names to check for skewness.
    - quantile: The quantile of skewness to use as a threshold (default is 0.8).

    Returns:
    - A DataFrame with specified skewed columns dropped.
    """
    df = df.to_pandas()
    existing_columns = [col for col in columns_to_check if col in df.columns]
    missing_columns = set(columns_to_check) - set(existing_columns)
    
    if missing_columns:
        print(f"Warning: The following columns do not exist in DataFrame and will be skipped: {missing_columns}")

    initial_col_count = len(df.columns)
    skew = df[existing_columns].skew().abs()
    threshold = skew.quantile(quantile)
    skewed = list(skew[skew > threshold].index)
    final_df = df.drop(columns=skewed)
    final_col_count = len(final_df.columns)

    print(f"Skewness-based method dropped {initial_col_count - final_col_count} columns.")
    out_polars = pl.DataFrame(final_df)
    return out_polars

def drop_low_variance(df, columns_to_check, threshold: float=0.001):
    """
    Drop columns based on variance threshold from a list of specified columns and
    print the number of columns dropped. Validates that columns exist before processing.

    Parameters:
    - df: The input DataFrame.
    - columns_to_check: A list of column names to check for low variance.
    - threshold: The variance threshold below which columns are dropped (default is 0.001).

    Returns:
    - A DataFrame with specified low variance columns dropped.
    """
    df = df.to_pandas()
    existing_columns = [col for col in columns_to_check if col in df.columns]
    missing_columns = set(columns_to_check) - set(existing_columns)
    
    if missing_columns:
        print(f"Warning: The following columns do not exist in DataFrame and will be skipped: {missing_columns}")

    initial_col_count = len(df.columns)
    var = df[existing_columns].var().abs()
    low_variance_cols = list(var[var < threshold].index)
    final_df = df.drop(columns=low_variance_cols)
    final_col_count = len(final_df.columns)

    print(f"Low variance-based method dropped {initial_col_count - final_col_count} columns.")
    out_polars = pl.DataFrame(final_df)
    return out_polars

def drop_low_variance_pl(df, columns_to_check, threshold: float=0.001):
    """
    Drop columns based on variance threshold from a list of specified columns in a Polars DataFrame.

    Parameters:
    - df: The input Polars DataFrame.
    - columns_to_check: A list of column names to check for low variance.
    - threshold: The variance threshold below which columns are dropped.

    Returns:
    - A DataFrame with specified low variance columns dropped.
    """
    # Ensure columns_to_check only contains columns that exist in df
    valid_columns = [col for col in columns_to_check if col in df.columns]
    
    # Initialize a list to keep track of columns to drop
    columns_to_drop = []

    # Iterate over each column to check variance
    for col in valid_columns:
        # Calculate the variance of the column
        variance = df.select(pl.var(pl.col(col)).alias("variance")).to_pandas().iloc[0, 0]

        # If variance is below the threshold, mark the column for dropping
        if variance < threshold:
            columns_to_drop.append(col)

    # Drop the columns with low variance
    df = df.drop(columns_to_drop)

    # Print information about dropped columns
    if columns_to_drop:
        print(f"Dropped {len(columns_to_drop)} columns for low variance: {columns_to_drop}")
    else:
        print("No columns dropped due to low variance.")

    return df


def clip_to_percentiles(df, cols, lower_percentile=1, upper_percentile=99):
    """
    Clip values in the specified columns of the DataFrame to the given percentiles,
    while keeping all columns in the returned DataFrame.

    Parameters:
    - df: The input DataFrame.
    - cols: A list of column names to be processed.
    - lower_percentile: The lower percentile to clip values at (default is 1).
    - upper_percentile: The upper percentile to clip values at (default is 99).

    Returns:
    - A DataFrame with values in the specified columns clipped to the percentiles,
      including all original columns.
    """
    for col in tqdm.tqdm(cols):
        if col not in df.columns:
            print(f"Column {col} does not exist in DataFrame.")
            continue  # Skip non-existent column
        
        # Calculate the percentile values for the column
        lower_value = df.select(pl.col(col).quantile(lower_percentile / 100.0)).to_numpy()[0,0]
        upper_value = df.select(pl.col(col).quantile(upper_percentile / 100.0)).to_numpy()[0,0]
        
        # Create a new column for the clipped values
        clipped_col = (pl.when(pl.col(col) < lower_value).then(lower_value)
                         .when(pl.col(col) > upper_value).then(upper_value)
                         .otherwise(pl.col(col)).alias(col))
        
        # Add the clipped column to the DataFrame
        df = df.with_columns(clipped_col)
    return df
        

def drop_outliers(df, percentile=99):
    conditions = []
    for col in df.columns:
        if col not in meta_features + extra_features:  # Skip meta and extra features
            p99 = df[col].quantile(percentile / 100.0)
            conditions.append(df[col] <= p99)
    # Combine conditions: row must satisfy all conditions to be retained
    combined_condition = conditions[0]
    for condition in conditions[1:]:
        combined_condition = combined_condition & condition
    return df.filter(combined_condition)

def feature_selection_cellprofiler(normalized_profiles, meta_dat, operation = "clip"):
    meta_df_features = meta_dat.columns
    meta_features = [col for col in normalized_profiles.columns if is_meta_column(col)]
    #normalized_profiles = normalized_profiles.filter(pl.col("Children_cytoplasm_Count_nuclei") > 0).filter(pl.col("Children_cytoplasm_Count_cells") > 0).filter(pl.col('Children_nuclei_Count_cells') > 0).filter(~pl.any_horizontal(pl.all().is_null()))
    normalized_profiles = normalized_profiles.filter(~pl.any_horizontal(pl.all().is_null()))
    normalized_profiles_merge = normalized_profiles.drop(["Metadata_cmpdConc", "moa", "compound_name"]).join(specs5k_classication_list.drop("Metadata_cmpdConc"), left_on = ["Metadata_Plate", "Metadata_Well","Metadata_cmpdName", "Metadata_Site"], right_on = ["Metadata_Plate", "Metadata_Well","Metadata_cmpdName", "Metadata_Site"], how ="left")
    blocklist_features = [col for col in normalized_profiles.columns if "Correlation_Manders" in col and "_nuclei" in col] +[col for col in normalized_profiles.columns if "Correlation_RWC" in col and "_nuclei" in col] +[col for col in normalized_profiles.columns if "Granularity_14" in col and "_nuclei" in col] + [col for col in normalized_profiles.columns if "Granularity_15" in col and "_nuclei" in col] +[col for col in normalized_profiles.columns if "Granularity_16" in col and "_nuclei" in col]
    features = [feat for feat in normalized_profiles_merge.columns if feat not in meta_features and feat not in blocklist_features and feat not in meta_df_features]
    extra_features = [feat for feat in normalized_profiles_merge.columns if feat in meta_df_features]
    features_comp = list(set(features + extra_features))
    final_feat_df = normalized_profiles_merge.select(features_comp)
    final_feat_df = final_feat_df.drop('')
    #final_feat_df = drop_skew(final_feat_df, features)
    final_feat_df = drop_low_variance(final_feat_df, features)

    print(final_feat_df.shape)
    if operation == 'clip':
        final_features = clip_to_percentiles(final_feat_df, features)
    elif operation == 'drop':
        final_features = drop_outliers(final_feat_df)
    else:
        raise ValueError("Unsupported operation. Choose 'clip' or 'drop'.")

    return final_features

In [None]:
import polars as pl
import os
import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

def merge_locations(df, location_folder):

    out_df = pl.DataFrame()
    combinations = df.unique(["Metadata_Plate", "Metadata_Well", "Metadata_Site"])
    # Iterate through unique combinations of Plate, Well, and Site
    for combination in tqdm.tqdm(combinations.to_pandas().itertuples(index=False), total = len(combinations)):
        plate, well, site = combination.Metadata_Plate, combination.Metadata_Well, combination.Metadata_Site

        # Construct the file path for the CSV
        file_path = f"{location_folder}/{plate}/{well}-{site}-Nuclei.csv"

        # Check if the file exists
        if os.path.exists(file_path):
            # Read the CSV file
            csv_df = pl.read_csv(file_path)
            filter = df.filter((pl.col("Metadata_Plate") == plate) &
                                            (pl.col("Metadata_Well") == well) &
                                            (pl.col("Metadata_Site") == site))
            # Ensure that csv_df aligns with the subset of original df in terms of row count
            if len(csv_df) != len(filter):
                # Handle error or misalignment
                print(f"{combination} doesn't match")  # or log it, or raise an error
            temp = pl.concat([filter, csv_df], how = "horizontal")
            out_df = pl.concat([out_df, temp], how = "vertical")
            # Perform the column concatenation operation
            # Assuming the order of rows in csv_df corresponds exactly to the order in the subset of df
            
    return out_df


def read_and_merge_single_file(df, plate, well, site, location_folder):
    file_path = f"{location_folder}/{plate}/{well}-{site}-Nuclei.csv"
    if os.path.exists(file_path):
        csv_df = pl.read_csv(file_path)
        filter_df = df.filter((pl.col("Metadata_Plate") == plate) &
                              (pl.col("Metadata_Well") == well) &
                              (pl.col("Metadata_Site") == site))
        if len(csv_df) == len(filter_df):
            return pl.concat([filter_df, csv_df], how="horizontal")
    return None

def merge_locations_parallel(df, location_folder, max_workers=10):
    combinations = df.unique(["Metadata_Plate", "Metadata_Well", "Metadata_Site"])
    dfs_to_concat = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Create and submit tasks
        future_to_combination = {
            executor.submit(read_and_merge_single_file, df, comb["Metadata_Plate"], comb["Metadata_Well"], comb["Metadata_Site"], location_folder): comb 
            for comb in combinations.to_dicts()
        }
        
        for future in tqdm.tqdm(as_completed(future_to_combination), total=len(future_to_combination)):
            result = future.result()
            if result is not None:
                dfs_to_concat.append(result)
    
    # Concatenate all DataFrames at once at the end
    out_df = pl.concat(dfs_to_concat, how="vertical")
    return out_df

In [None]:
import os
import gc
import tqdm


def generate_supervised_data(cmpd_df, feature_path):
    dmso_plates = ["P103620", "P103621", "P103619", "P101387", "P101386", "P101385", "P101384"]
    plates = list(cmpd_df["Metadata_Plate"].unique())
    plates_fix = [s for s in plates if s not in dmso_plates]
    plates_fix.sort()
    sc_features = []
    for p in tqdm.tqdm(plates_fix):
        file_path = f"{feature_path}/sc_profiles_normalized_cellprofiler_{p}.parquet"
        if os.path.exists(file_path):
            #print("Analysising plate:", p)
            temp_cmpd_df = cmpd_df.filter(pl.col("Metadata_Plate") == p)
            cmps = list(temp_cmpd_df["Metadata_cmpdName"].unique())
            features = pl.read_parquet(file_path)
            for col in features.columns:
                if features[col].dtype == pl.Float64:
                    features = features.with_columns(features[col].cast(pl.Float32))
            features_filt = features.filter(pl.col("Metadata_cmpdName").is_in(cmps))
            sc_features.append(features_filt)
            gc.collect()
    sc_df = pl.concat(sc_features)
    return sc_df

In [None]:
import os
import gc
import tqdm
def check_data_size(cmpd_df, feature_path):
    plates = list(cmpd_df["Metadata_Plate"].unique())
    plates.sort()
    for p in tqdm.tqdm(plates):
        file_path = f"{feature_path}/sc_profiles_normalized_cellprofiler_{p}.parquet"
        if os.path.exists(file_path):
            #print("Analysising plate:", p)
            temp_cmpd_df = cmpd_df.filter(pl.col("Metadata_Plate") == p)
            cmps = list(temp_cmpd_df["Metadata_cmpdName"].unique())
            features = pl.read_parquet(file_path)
            for col in features.columns:
                if features[col].dtype == pl.Float64:
                    features = features.with_columns(features[col].cast(pl.Float32))
            features_filt = features.filter(pl.col("Metadata_cmpdName").is_in(cmps))
            size = features_filt.estimated_size("mb")
            print(f"Plate {p} dataFrame size: {size} MB with dimensions {features_filt.shape}")
            gc.collect()

In [None]:
specs3k_sc_features = generate_supervised_data(specs5k_classication_list, specs3k_feature_path).unique()

In [None]:
specs2k_sc_features = generate_supervised_data(specs5k_classication_list, specs2k_feature_path).unique()

In [None]:
gc.collect()

In [None]:
specs2k_sc_locations =  specs2k_sc_features.filter((pl.col("Location_Center_X_nuclei") > 250) &
                                                  (pl.col("Location_Center_X_nuclei") < 2250) &
                                                  (pl.col("Location_Center_Y_nuclei") > 250) &
                                                  (pl.col("Location_Center_Y_nuclei") < 2250))

In [None]:
specs3k_sc_locations =  specs3k_sc_features.filter((pl.col("Location_Center_X_nuclei") > 250) &
                                                  (pl.col("Location_Center_X_nuclei") < 2250) &
                                                  (pl.col("Location_Center_Y_nuclei") > 250) &
                                                  (pl.col("Location_Center_Y_nuclei") < 2250))

In [None]:
specs2k_sc_locations.write_parquet("sc_profiles_classification_specs2k_CellProfiler_standardized.parquet")

In [None]:
specs3k_sc_locations.write_parquet("sc_profiles_classification_specs3k_CellProfiler_standardized.parquet")

In [None]:
specs2k_sc_locations = pl.read_parquet("datasets/sc_profiles_classification_specs2k_CellProfiler.parquet")

## Load and merge features

In [None]:
specs3k_sc_features_total = pl.read_parquet("datasets/standardized/specs3k_sc_featfix_CP.parquet")

In [None]:
specs2k_sc_features_total = pl.read_parquet("datasets/standardized/specs2k_sc_featfix_CP.parquet")

In [None]:
gc.collect()

In [None]:
columns_df1 = set(specs3k_sc_features_total.columns)
columns_df2 = set(specs2k_sc_features_total.columns)

# Find common columns
common_columns = columns_df1.intersection(columns_df2)


In [None]:
specs3k_sc_features_total = specs3k_sc_features_total.select(common_columns)
specs2k_sc_features_total = specs2k_sc_features_total.select(common_columns)

In [None]:
for column in specs3k_sc_features_total.columns:
    if specs3k_sc_features_total[column].dtype == pl.Float64:
        specs3k_sc_features_total = specs3k_sc_features_total.with_columns(pl.col(column).cast(pl.Float32))

In [None]:
for column in specs2k_sc_features_total.columns:
    if specs2k_sc_features_total[column].dtype == pl.Float64:
        specs2k_sc_features_total = specs2k_sc_features_total.with_columns(pl.col(column).cast(pl.Float32))

In [None]:
gc.collect()

In [None]:
specs5k_sc_features_total = pl.concat([specs3k_sc_features_total.drop("_right"), specs2k_sc_features_total.drop("_right")]).unique()

In [None]:
specs5k_sc_features_total = specs5k_sc_features_total.with_columns(
    pl.col('moa_broad').fill_null('DMSO')
)

In [None]:
specs5k_sc_features_total = specs5k_sc_features_total.rename({"Location_Center_X_nuclei": "Nuclei_Location_Center_X", "Location_Center_Y_nuclei": "Nuclei_Location_Center_Y"})

In [None]:
specs5k_sc_features_total.write_parquet("datasets/standardized/sc_profiles_classification_specs5k_total.parquet")

In [None]:
specs5k_sc_features_total.groupby("moa_broad").count()

## Show summary stats

In [None]:
def show_group_dist(feature_df, group_col):

    # Assuming 'df' is your Polars DataFrame and 'group_column' is the name of the column you want to group by
    grouped_df = feature_df.groupby(group_col).agg(
        pl.count().alias('count')
    )

    # Now plot the data using Matplotlib
    plt.bar(grouped_df[group_col].to_list(), grouped_df['count'].to_list())

    plt.xlabel('Group')
    plt.ylabel('Count')
    plt.title('Number of Data Points per Group')
    plt.xticks(rotation=45)  # Rotate labels if they overlap
    plt.show()

In [None]:
show_group_dist(specs5k_sc_features_total, "moa_broad")

## Label encoding

In [None]:
from sklearn.preprocessing import LabelEncoder
def encode_labels(df):
    le = LabelEncoder()
    le.fit(df["moa_broad"])
    df_labels = list(le.transform(df["moa_broad"])) 
    df = df.with_columns(pl.Series(name="label", values=df_labels))  
    return df 

In [None]:
specs5k_sc_features_total = encode_labels(specs5k_sc_features_total)

In [None]:
specs5k_sc_features_total.group_by("label").count()

## Undersampling

In [None]:
from imblearn.under_sampling import NearMiss
import numpy as np

In [None]:
def stratified_sampling_pl(df, class_col, stratify_cols, fraction):
    """
    Perform stratified downsampling using Polars, focusing on a correct approach.
    
    Parameters:
    - df: Polars DataFrame, the dataset to sample from.
    - class_col: str, the column name for class labels.
    - stratify_cols: list of str, columns for further stratification within each class.
    - fraction: float, target fraction for downsampling.
    
    Returns:
    - Polars DataFrame after downsampling.
    """
    # Calculate the target downsampling size based on the smallest class size
    smallest_class_size = df[class_col].value_counts().min()["counts"][0]
    target_size = int(smallest_class_size * fraction)

    # Prepare to collect downsampled data frames
    downsampled_frames = []

    # Iterate over each class to perform downsampling
    for class_label in df.select(class_col).unique().to_numpy().flatten():
        class_df = df.filter(pl.col(class_col) == class_label)
        
        # Calculate downsampling fraction for the current class
        current_size = class_df.height
        downsample_fraction = min(1.0, (target_size / current_size) * fraction)
        grouped = class_df.groupby(stratify_cols)
        # Perform stratified sampling if needed
        if 0.1 < downsample_fraction < 1.0:
            # Randomly sample rows to achieve approximately the target size
            sampled_df = grouped.apply(lambda x: x.sample(fraction=downsample_fraction))
        elif downsample_fraction < 0.1:
            sampled_df = class_df.sample(fraction = downsample_fraction)
        else:
            sampled_df = class_df
        
        downsampled_frames.append(sampled_df)

    # Concatenate the downsampled frames into a single DataFrame
    downsampled_df = pl.concat(downsampled_frames)
    
    return downsampled_df

def sample_n_rows_per_group(df, group_cols, fraction, seed=None):
    # Define a custom sampling function that operates on DataFrames
    def sample_group(group_df):
                
        if len(group_df) <= n_samples:
            return group_df
        return group_df.sample(fraction=fraction, with_replacement=False, seed=seed)

    # Group the DataFrame and apply the custom sampling function to each group
    sampled_groups = (df
                      .group_by(group_cols)
                      .apply(sample_group))

    return sampled_groups

In [None]:
def undersampling(df, strategy):
    df_pd = df.to_pandas()
    if strategy == "nearmmiss":
        feature_cols = [col for col in df.columns if "Feature" in col]
        metadata_cols = [col for col in df.columns if col not in feature_cols]
        metadata_cols.remove("label")
        nm = NearMiss(version=1, n_jobs= -1)

        # Split features and target
        #X = specs3k_sc_features_pandas[[col for col in specs3k_sc_features_total.columns if not "label"]]
        X = df_pd[feature_cols]
        y = df_pd['label']

        # Apply NearMiss
        X_res, y_res = nm.fit_resample(X, y)

        df_resampled = pl.DataFrame(X_res)
        df_resampled = df_resampled.with_columns(pl.Series('label', y_res))

        resampled_df = df_resampled.join(df, on = feature_cols, how='left')
        resampled_df = resampled_df.drop("")
    elif strategy == "random":
        resampled_df = stratified_sampling_pl(df, "label", ["Metadata_Plate", "Metadata_Well", "Metadata_Site", "Metadata_cmpdName"], 1)
    
    elif strategy == "control_group_sampling":
        # Identify the most abundant class and its size
        
        # Assuming 'control_label' is the label of your control group
        control_label = 2
        
        # Filter the DataFrame for the control group and other groups
        control_group = df.filter(pl.col('label') == control_label)
        other_groups = df.filter(pl.col('label') != control_label)

        value_counts = other_groups.select(pl.col('label')).groupby('label').agg(pl.count().alias('count'))
        most_abundant_class_size = value_counts.select(pl.max('count')).to_numpy()[0][0]

        sample_rate = most_abundant_class_size/(control_group.shape[0])
        print(sample_rate)
        
        if 0.1 < sample_rate < 1.0:
            # Randomly sample rows to achieve approximately the target size
            control_grouped = (control_group.group_by(["Metadata_Plate", "Metadata_Well", "Metadata_Site", "Metadata_cmpdName"]))
            sampled = control_grouped.apply(lambda x: x.sample(fraction=sample_rate, seed = 42))
        elif sample_rate < 0.1:
            control_grouped = (control_group.group_by(["Metadata_Plate", "Metadata_Well", "Metadata_cmpdName"]))
            sampled = control_grouped.apply(lambda x: x.sample(fraction=sample_rate, seed = 42))
        
        # Concatenate the sampled control group back with the other data
        resampled_df = pl.concat([other_groups, sampled])
    
    return resampled_df


In [None]:
import polars as pl
from polars import col, lit
def undersampling_lazy(df, strategy):
    if strategy == "nearmiss":
        # Polars does not directly support NearMiss. You would need to implement a custom logic or use the eager version for this part.
        pass
    elif strategy == "random":
        # Implement stratified_sampling_pl using lazy evaluation
        pass  # Placeholder for lazy implementation
    elif strategy == "control_group_sampling":
        # Convert the control group sampling logic to lazy evaluation
        control_label = "DMSO"
        control_group = df.filter(col('moa_broad') == lit(control_label))
        other_groups = df.filter(col('moa_broad') != lit(control_label))

        value_counts = other_groups.groupby('moa_broad').agg(pl.count())
        most_abundant_class_size = value_counts.select(pl.max('count')).collect().to_numpy()[0][0]

        sample_rate = most_abundant_class_size / control_group.count().collect()[0]

        # Use LazyFrame's sample method
        if 0.1 < sample_rate < 1.0:
            control_grouped = control_group.groupby(["Metadata_Plate", "Metadata_Well", "Metadata_Site", "Metadata_cmpdName"])
            sampled = control_grouped.apply(lambda x: x.sample(fraction=sample_rate, seed=42))
        elif sample_rate <= 0.1:
            control_grouped = control_group.groupby(["Metadata_Plate", "Metadata_Well", "Metadata_cmpdName"])
            sampled = control_grouped.apply(lambda x: x.sample(fraction=sample_rate, seed=42))

        resampled_df = other_groups.concat(sampled).collect()
    
    return resampled_df

In [None]:
specs5k_sc_features_total = pl.read_parquet("datasets/sc_profiles_classification_specs5k_total.parquet")

In [None]:
import gc
gc.collect()

In [None]:
resampled_specs5k_big = undersampling(specs5k_sc_features_total, "control_group_sampling")

In [None]:
show_group_dist(resampled_specs5k_big, "moa_broad")

In [None]:
def prepare_class_data(df, plate2k, plate3k):
    df = df.drop('')
    df = df.with_columns(
    pl.when(pl.col('Metadata_Plate').is_in(plate2k)).then(pl.lit("specs2k"))
    .when(pl.col('Metadata_Plate').is_in(plate3k)).then(pl.lit("specs3k"))
    .otherwise(pl.lit("other"))
    .alias('project')
    )
    return df

In [None]:
resampled_specs5k_big = prepare_class_data(resampled_specs5k_big, specs2k_plates, specs3k_plates)

In [None]:
specs2k_plates = ['P103617',
 'P103602',
 'P103595',
 'P103597',
 'P103613',
 'P103591',
 'P103615',
 'P103607',
 'P103619',
 'P103606',
 'P103616',
 'P103601',
 'P103603',
 'P103620',
 'P103614',
 'P103621',
 'P103593',
 'P103592',
 'P103612',
 'P103608',
 'P103600',
 'P103609',
 'P103618',
 'P103589',
 'P103605',
 'P103590',
 'P103599',
 'P103610',
 'P103604',
 'P103611',
 'P103598',
 'P103596',
 'P103594']
specs3k_plates = ['P101382',
 'P101339',
 'P101338',
 'P101337',
 'P101354',
 'P101350',
 'P101360',
 'P101375',
 'P101363',
 'P101335',
 'P101373',
 'P101372',
 'P101352',
 'P101334',
 'P101369',
 'P101336',
 'P101345',
 'P101377',
 'P101346',
 'P101366',
 'P101359',
 'P101361',
 'P101364',
 'P101365',
 'P101362',
 'P101374',
 'P101380',
 'P101367',
 'P101358',
 'P101342',
 'P101371',
 'P101341',
 'P101368',
 'P101348',
 'P101370',
 'P101379',
 'P101386',
 'P101353',
 'P101381',
 'P101351',
 'P101357',
 'P101384',
 'P101347',
 'P101343',
 'P101387',
 'P101385',
 'P101355',
 'P101340',
 'P101378',
 'P101344',
 'P101349',
 'P101376',
 'P101356']

In [None]:
resampled_specs5k_big.write_parquet("specs5k_undersampled_big_moa.parquet")

## Prepare splits

In [None]:
resampled_specs5k = pl.read_parquet("datasets/standardized/specs5k_undersampled_moa_CP.parquet")

In [None]:
radial_feats = [feat for feat in resampled_specs5k.columns if "RadialDistribution_Frac" in feat]

In [None]:
resampled_specs5k = resampled_specs5k.drop(radial_feats)

In [None]:
non_sign_cmp = ["CBK041160", "CBK041211" ,"CBK277970", "CBK289918H", "CBK290118", "CBK308723"]

In [None]:
resampled_specs5k_sign = resampled_specs5k.filter(~(pl.col("Metadata_cmpdName").is_in(non_sign_cmp)))

In [None]:
show_group_dist(resampled_specs5k_sign, "moa_broad")

In [None]:
resampled_specs5k_sign = pl.DataFrame(resampled_specs5k_sign.to_pandas().dropna(subset = "AreaShape_FormFactor_nuclei"))

In [None]:
na_counts = []

# Iterate through each column, checking if it's numeric and counting NaN values if so
for col_name in resampled_specs5k_sign.columns:
    if resampled_specs5k_sign[col_name].dtype in [pl.Float32, pl.Float64]:
        na_count = resampled_specs5k_sign[col_name].is_nan().sum()
        na_counts.append((col_name, na_count))

# Convert the list of tuples to a DataFrame
na_summary_df = pl.DataFrame(na_counts)
na_summary_df = na_summary_df.sort("column_1", descending=True)

print(na_summary_df)

In [None]:
resampled_specs5k_sign.write_parquet("datasets/standardized/specs5k_undersampled_significant_CP.parquet")

## Aggregated

In [None]:
features_fixed = [feat for feat in resampled_specs5k_sign.columns if "Feature" in feat]
resampled_specs5k_aggregated = (
    resampled_specs5k_sign
    .groupby(["moa_broad", "project", 'Metadata_Plate', 'Metadata_Well', 'Metadata_cmpdName'])
    .agg([pl.col(feature).median().alias(feature) for feature in features_fixed])
)

In [None]:
resampled_specs5k_aggregated

In [None]:
resampled_specs5k_aggregated.write_parquet("specs5k_undersampled_moa_aggregated.parquet")

## Split for training csv

In [None]:
resampled_specs5k_sign = pl.read_parquet("datasets/standardized/specs5k_undersampled_significant_CP.parquet")

In [None]:
import polars as pl
import tqdm
def stratified_split(df, group_columns, n_splits=3):
    # Create a unique group identifier based on the combination of group columns
    unique_group_column = "unique_group"
    df = df.with_columns(pl.struct([pl.col(c) for c in group_columns]).cast(str).alias(unique_group_column))
    # Calculate the size of each split for each unique group
    group_sizes = df.groupby(unique_group_column).agg(pl.count().alias('size'))
    split_info = group_sizes.with_columns(
        (pl.col('size') / n_splits).floor().alias('split_size'),
        (pl.col('size') % n_splits).alias('remainder')
    )

    # Prepare a list to hold each split
    splits = [pl.DataFrame() for _ in range(n_splits)]

    # Iterate over each unique group and split accordingly
    for group in tqdm.tqdm(split_info[unique_group_column]):
        group_df = df.filter(pl.col(unique_group_column) == group)
        size_info = split_info.filter(pl.col(unique_group_column) == group)

        split_size = size_info['split_size'][0]
        remainder = size_info['remainder'][0]

        start_idx = 0
        for i in range(n_splits):
            additional_size = 1 if i < remainder else 0
            # Ensure the slice length is an integer
            slice_length = int(split_size + additional_size)
            end_idx = start_idx + slice_length
            group_split = group_df.slice(start_idx, slice_length)
            splits[i] = pl.concat([splits[i], group_split])
            start_idx = end_idx

    # Optionally, drop the unique group identifier from the split DataFrames
    splits = [split.drop(unique_group_column) for split in splits]

    return splits

In [None]:
split = stratified_split(resampled_specs5k_sign, ["moa_broad", "Metadata_cmpdName", "Metadata_Plate", "Metadata_Well"])

In [None]:
meta_cols = specs5k_classication_list.columns + ["Nuclei_Location_Center_X", "Nuclei_Location_Center_Y","project"]

In [None]:
features_fixed = [feat for feat in split[0].columns if feat not in meta_cols]

In [None]:
for i, df in enumerate(split):
    df = df.select(features_fixed)
    file_name = f"training_split_CP/specs5k_moa_split_{i}_CP_standardized.csv"
    df.write_csv(file_name)

In [None]:
all = pl.read_parquet("specs5k_undersampled_significant.parquet")

In [None]:
df = resampled_specs5k_sign.select(features_fixed )
file_name = f"training_split_CP/specs5k_moa_split_ALL_CP_standardized.csv"
df.write_csv(file_name)

## Check discrepency

In [None]:
resampled_specs5k_sign = pl.read_parquet("datasets/standardized/specs5k_undersampled_significant_CP.parquet")

In [None]:
resampled_specs5k_sign_DP = pl.read_parquet("/home/jovyan/share/data/analyses/benjamin/Single_cell_supervised/SPECS_MOA/DeepProfiler/datasets/specs5k_undersampled_significant.parquet")

In [None]:
count_dp = resampled_specs5k_sign_DP.group_by("Metadata_cmpdName").count().sort("count", descending = True)

In [None]:
count_cp = resampled_specs5k_sign.group_by("Metadata_cmpdName").count().sort("count", descending = True)

In [None]:
count_cp = count_cp.with_columns(count_cp["count"].cast(pl.Int64).alias("cell_count"))
count_dp = count_dp.with_columns(count_dp["count"].cast(pl.Int64).alias("cell_count"))

# Step 2: Perform a left join
df_joined = count_cp.join(count_dp, on="Metadata_cmpdName", how="left", suffix="_df2")

# Add a column to check for existence in df2
df_joined = df_joined.with_columns(
    pl.col("cell_count_df2").is_null().alias("exclusive_to_df1")
)

# Calculate percentage difference where applicable
df_joined = df_joined.with_columns(
    (
        (abs(df_joined["cell_count"] - df_joined["cell_count_df2"].fill_null(0)) / df_joined["cell_count"]) * 100
    ).fill_null(0).alias("percentage_diff")
)

# Filter based on criteria:
# - Percentage difference greater than 20%
# - Or exclusive to df1
df_filtered = df_joined.filter(
    (pl.col("percentage_diff") > 20) | 
    (pl.col("exclusive_to_df1") == True)
)

print(df_filtered)


In [None]:
df_filtered

In [None]:
abs((4466 - 4480))