In [None]:
import sys
sys.path.append("..")
import anndata as ad
import pandas as pd
from viewer import *

In [None]:
import polars as pl
import numpy as np
from sklearn.cluster import KMeans
import random
from sklearn_extra.cluster import KMedoids
import tqdm
def find_representative_cells(df, group_column, feature_columns, method='random', n=1):
    # Ensure feature_columns is a list
    if isinstance(feature_columns, str):
        feature_columns = [feature_columns]

    # Group by the specified column
    groups = df.group_by(group_column)
    # Initialize a list to hold the selected rows
    selected_rows = []

    for name, group in tqdm.tqdm(groups, total = len(df[group_column].unique())):
        # Apply the selection method
        if method == 'random':
            # Randomly select n rows from the group
            selected_rows.extend(group.sample(n=n, with_replacement=False))

        elif method == 'geomean':
            # Calculate the geometric mean of the feature columns for each group
            geomean = group.select([pl.col(col).prod()**(1/len(group)) for col in feature_columns])
            # Find the row closest to the geometric mean
            closest = group[feature_columns].apply(lambda row: np.linalg.norm(row - geomean)).arg_min()
            selected_rows.append(group.row(closest))

        elif method == 'kmeans':
            if group.shape[0]>60:
                n_cells_in_each_cluster_unif=30
            else:
                n_cells_in_each_cluster_unif=int(group.shape[0]/5) 
        
            n_clusts=int(group.shape[0]/n_cells_in_each_cluster_unif) 
            # Apply k-means clustering on the feature columns to find the most representative row
            kmeans = KMeans(n_clusters=1, random_state=0, n_init = 10).fit(group[feature_columns].to_numpy())
            centroid = kmeans.cluster_centers_[0]
            closest = group[feature_columns].apply(lambda row: np.linalg.norm(row - centroid)).arg_min()
            selected_rows.append(group.row(closest))
        
        elif method == 'kmedoid':
            # Check if group is smaller than n
            if len(group) < n:
                raise ValueError(f"Group {name} has fewer rows than the number of requested representatives.")

            # Initialize and fit the KMedoids
            kmedoids = KMedoids(n_clusters=n, random_state=0).fit(group[feature_columns])

            # Get the indices of the medoids
            medoids_indices = kmedoids.medoid_indices_

            # Select rows corresponding to medoids
            for index in medoids_indices:
                selected_rows.append(group.row(index))

        else:
            raise ValueError("Unknown method: choose 'random', 'geomean', or 'kmeans'")

    # Concatenate all selected rows into a new DataFrame
    result_df = pl.concat(selected_rows)

    return result_df

def anndata_to_pandas(ad):
    # Convert the main data matrix .X to a DataFrame
    if isinstance(ad.X, (np.ndarray, np.generic)):  # If .X is already a dense matrix
        df = pd.DataFrame(ad.X, columns=ad.var_names)
    else:  # If .X is a sparse matrix
        df = pd.DataFrame(ad.X.toarray(), columns=ad.var_names)

    # Add observation metadata from .obs
    df = pd.concat([ad.obs.reset_index(), df], axis=1)  # Reset index to align the data
    
    # Handling .obsm data
    for key, matrix in ad.obsm.items():
        if matrix.ndim == 2:  # Ensure the matrix is two-dimensional
            obsm_df = pd.DataFrame(matrix, columns=[f"{key}_{i}" for i in range(matrix.shape[1])])
            df = pd.concat([df, obsm_df.reset_index(drop=True)], axis=1)  # Concatenate to the main DataFrame
        else:
            print(f"Skipping {key} as it is not 2-dimensional")
    
    return df

def pd_to_polars(df):
    """
    Convert a Pandas DataFrame to Polars DataFrame and handle columns
    with int and float categorical dtypes.
    """
    df = df.copy()
    for col in df.columns:
        if isinstance(df[col].dtype, pd.CategoricalDtype):
            if pd.api.types.is_integer_dtype(df[col].cat.categories.dtype):
                df[col] = df[col].astype(int)
                print(f"Column [{col}] cast to int")
            elif pd.api.types.is_float_dtype(df[col].cat.categories.dtype):
                df[col] = df[col].astype(float)
                print(f"Column [{col}] cast to float")

    return pl.from_pandas(df)

In [None]:
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from sklearn_extra.cluster import KMedoids
import tqdm

def find_representative_cells(df, group_column, feature_columns, method='random', n=1):
    # Ensure feature_columns is a list
    if isinstance(feature_columns, str):
        feature_columns = [feature_columns]

    # Initialize a list to hold the selected rows
    selected_rows = []

    # Group by the specified column
    groups = df.groupby(group_column)

    for name, group in tqdm.tqdm(groups):
        if name == "unassigned":
            print(name, "not a valid cluster")
            continue
        # Apply the selection method
        if method == 'random':
            # Randomly select n rows from the group
            selected_rows.append(group.sample(n=n))

        elif method == 'geomean':
            # Calculate the geometric mean of the feature columns for each group
            geomean = group[feature_columns].apply(lambda x: np.prod(x)**(1/len(x)), axis=0)
            # Find the row closest to the geometric mean
            closest = (group[feature_columns] - geomean).apply(np.linalg.norm, axis=1).idxmin()
            selected_rows.append(group.loc[[closest]])

        elif method == 'kmeans':
            if group.shape[0] > 60:
                n_cells_in_each_cluster_unif = 30
            else:
                n_cells_in_each_cluster_unif = int(group.shape[0] / 5)

            n_clusts = int(group.shape[0] / n_cells_in_each_cluster_unif)
            # Apply k-means clustering on the feature columns to find the most representative row
            kmeans = KMeans(n_clusters=n_clusts, random_state=0, n_init=10).fit(group[feature_columns])
            centroid = kmeans.cluster_centers_[0]
            closest = (group[feature_columns] - centroid).apply(np.linalg.norm, axis=1).idxmin()
            selected_rows.append(group.loc[[closest]])

        elif method == 'kmedoid':
            # Check if group is smaller than n
            if len(group) < n:
                raise ValueError(f"Group {name} has fewer rows than the number of requested representatives.")

            # Initialize and fit the KMedoids
            kmedoids = KMedoids(n_clusters=n, random_state=0).fit(group[feature_columns].values)
            
            # Get the indices of the medoids
            medoids_indices = kmedoids.medoid_indices_

            # Select rows corresponding to medoids
            for index in medoids_indices:
                selected_rows.append(group.iloc[[index]])

        else:
            raise ValueError("Unknown method: choose 'random', 'geomean', or 'kmeans'")

    # Concatenate all selected rows into a new DataFrame
    result_df = pd.concat(selected_rows, axis=0).reset_index(drop=True)
    sorted_df = result_df.sort_values(by=['Metadata_cmpdName'], ascending=[True])
    return sorted_df


In [None]:
def show_representatives_v2(df, box_size, grouping, n_cells):
    df['Metadata_Site'] = df['Metadata_Site'].astype(str)
    df['Metadata_Site'] = df['Metadata_Site'].str.extract('(\d+)').astype(int)

    views = []
    # Define top row as labels
    # Assuming you want to group by 'Metadata_Plate' and 'Metadata_Well'
    site_row = []
    for i in range(0, n_cells + 2):
        if i == 0:
            site_row.append(View(hover='(padding at top left)'))
        elif i == (n_cells + 1):
            site_row.append(View(hover='(padding at top right)'))
        else:
            site_row.append(View(overlay=f'Cell {i}', overlay_style=style7 + ":white-space: pre", overlay_dir='S'))
    views.append(site_row)
    groups = df[grouping].unique()
    y = 1
    for g in groups:
        row_views = []
        print(g)
        group = df[df[grouping] == g]
        x = 1
        for index, row in group.iterrows():
                plate = row["Metadata_Plate"]
                well = row["Metadata_Well"]
                site = row["Metadata_Site"]
                center_x = row["Nuclei_Location_Center_X"]
                center_y = row["Nuclei_Location_Center_Y"]
                row_views.append(View(
                    barcode=plate, well=well, site=site,
                    #clip=ClipBox(center_x - box_size // 2, center_y - box_size // 2, box_size, box_size),
                    clip = ClipSquare(center_x, center_y, box_size),
                    #x = x,
                    #y = y
                ))
                x += 1
        y += 1
        row_views.insert(0, View(overlay=g, overlay_style=style7 + ";text-align:right", overlay_dir='E'))
        views.append(row_views)
    views.append([View(hover='(padding at bottom left)')])
    return views

def table(grid):
    res = []
    for y, row in enumerate(grid):
        for x, cell in enumerate(row):
            assert isinstance(cell, View), f'Execpted View, got: {cell} ({y=}, {x=})'
            res += [replace(cell, x=x, y=y)]
    #return res
    return Viewer(res)

In [None]:
def find_representative_views(file_path, grouping_col, n_rep, method = "kmedoid", image_size = 250, compounds_filter = None):
    print("Importing data")
    location_df = ad.read_h5ad(file_path)
    if compounds_filter is not None:
        location_df = location_df[location_df.obs['Metadata_cmpdName'].isin(compounds_filter)]
    
    location_pd = anndata_to_pandas(location_df)

    #if compounds_filter is not None:
    #    location_pd = location_pd[location_pd['Metadata_cmpdName'].isin(compounds_filter)]

    # Reset the index to avoid issues with previous groupings
    location_pd = location_pd.reset_index(drop=True)
    feature_cols = [feat for feat in location_pd.columns if "Feature" in feat]
    print("Finding representatives", location_pd.shape)
    representatives = find_representative_cells(location_pd, grouping_col, feature_cols, method=method, n=n_rep)
    representative_sort = representatives.sort_values(by=[grouping_col])
    print("Generating views")
    views = show_representatives_v2(representative_sort, image_size, grouping = grouping_col, n_cells = n_rep)
    return views

## Find representative cells

In [None]:
views = find_representative_views("/home/jovyan/share/data/analyses/benjamin/cellxgene/sc_embedding_scanpy_Beactica_deep+cell_cellcycle2.h5ad", "Metadata_cmpdNameConc", 10, image_size= 150, compounds_filter= ["BORTEZOMIB"])

In [None]:
table(views)

In [None]:
def show_single_cell(plate, well ,site, loc_x, loc_y, box_size):
    view = View(
                    barcode=plate, well=well, site=site,
                    #clip=ClipBox(center_x - box_size // 2, center_y - box_size // 2, box_size, box_size),
                    clip = ClipSquare(loc_x, loc_y, box_size),
                    #x = x,
                    #y = y
                )
    
    return view

In [None]:
Viewer(show_single_cell("PB000046", "M17", 6, 2633, 253, 150))

## Show cells in wells

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def plot_treatment_scatter(df, x_coord_column, y_coord_column, treatment_column, x1, x2, y1, y2):
    """
    Create a grid of scatter plots of points based on their spatial coordinates, 
    with each subplot corresponding to a different treatment group.

    Parameters:
    df (DataFrame): The DataFrame containing the coordinate and treatment data.
    x_coord_column (str): The name of the column containing x coordinates.
    y_coord_column (str): The name of the column containing y coordinates.
    treatment_column (str): The name of the column containing treatment labels.
    """
    # Determine the number of treatment groups and setup the subplot grid
    treatments = df[treatment_column].unique()
    n_treatments = len(treatments)
    n_cols = 3  # number of columns
    n_rows = int(np.ceil(n_treatments / n_cols))  # calculate required number of rows

    # Create a figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False, dpi = 300)
    axes = axes.flatten()  # Flatten to 1D array for easy iteration

    # Loop through each treatment and create a subplot
    for i, treatment in enumerate(treatments):
        ax = axes[i]
        subset = df[df[treatment_column] == treatment]
        
        sns.scatterplot(data=subset, x=x_coord_column, y=y_coord_column, ax=ax,
                        hue="Metadata_cmpdName", palette="Set2", legend=False, s = 5)  # Remove individual legends
        

        ax.axvline(x=x1, color='red', linestyle='--')  # Vertical line at x1
        ax.axvline(x=x2, color='red', linestyle='--')  # Vertical line at x2
        ax.axhline(y=y1, color='red', linestyle='--')  # Horizontal line at y1
        ax.axhline(y=y2, color='red', linestyle='--')  # Horizontal line at y2
        ax.set_title(f'Treatment: {treatment}')

        ax.set_xlim(0, 2500)  # Adjust according to your data
        ax.set_ylim(0, 2500)  # Adjust according to your data

    # Hide any unused subplots
    for j in range(n_treatments, len(axes)):
        axes[j].axis('off')

    # Create a single legend outside the rightmost subplot
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(1.1, 1), ncol=1)
    
    plt.tight_layout()
    plt.show()

# Example usage:
# plot_treatment_scatter(your_dataframe, 'X_Coord', 'Y_Coord', 'Treatment')


In [None]:
location_pd['grit'] = location_pd['grit'].astype(float)


In [None]:
plot_treatment_scatter(location_pd, "Nuclei_Location_Center_X", "Nuclei_Location_Center_Y", "Metadata_cmpdName", 250, 2250, 250, 2250)

In [None]:
plot_treatment_scatter(representative__old, "Nuclei_Location_Center_X", "Nuclei_Location_Center_Y", "dmso_populations")