In [None]:
from IPython.display import HTML
from bokeh_scatter import bokeh_scatter, bokeh_scatter_with_selection
import pandas as pd
import numpy as np
from markupsafe import Markup
import viewer

In [None]:
import anndata as ad
import os

def save_html(html_content, file_name):
    with open(file_name, "w") as file:
        file.write(html_content)

def prepare_bokeh_sc(path, file, box_size, embedding_type = "umap"):
    df = ad.read_h5ad(os.path.join(path, file))

    embedding_df = pd.DataFrame()

    # Search for matching keys in .obsm
    for key in df.obsm.keys():
        if embedding_type.lower() in key.lower():
            # If a match is found, convert the embedding to a DataFrame
            embedding_array = df.obsm[key]
            cols = [f"{embedding_type}_{i+1}" for i in range(embedding_array.shape[1])]
            embedding_df = pd.DataFrame(embedding_array, columns=cols, index=df.obs.index)
            break  # Stop searching after the first match

    # If an embedding was found, concatenate it with .obs
    if not embedding_df.empty:
        result_df = pd.concat([df.obs, embedding_df], axis=1)
    else:
        # If no embedding was found, return the original .obs DataFrame
        print(f"No embedding found for '{embedding_type}'. Returning original .obs DataFrame.")
        result_df = df.obs

    result_df["site"] = result_df["Metadata_Site"].str[-1].astype(int)
    result_df["well"] = result_df["Metadata_Well"]
    if "barcode" not in result_df.columns:
        result_df["barcode"] = result_df["Metadata_Plate"]
    result_df["clip"] = result_df.apply(lambda row: viewer.ClipSquare(row['Nuclei_Location_Center_X'], 
                                                                  row['Nuclei_Location_Center_Y'], 
                                                                  box_size).to_str(), axis=1)
    result_df["UMAP 1"] = result_df[f"{embedding_type}_{1}"]
    result_df["UMAP 2"] = result_df[f"{embedding_type}_{2}"]
    result_df = result_df.reset_index(drop = True)
    if "name_0" in result_df.columns:
        del result_df['name_0']
    num_cols = ["Metadata_cmpdConc", "cells_per_well", "grit", "group"]

    for col in result_df.columns:
        # Ensure you reference result_df here, not df
        if col in num_cols:
            # Convert specified columns to numerical
            result_df[col] = pd.to_numeric(result_df[col], errors='coerce')
        elif pd.api.types.is_categorical_dtype(result_df[col]):
            # Convert categorical columns to string, except those specified for numerical conversion
            result_df[col] = result_df[col].astype(str)

    #result_df = result_df[:1000]
    return result_df

In [None]:
test = prepare_bokeh_sc("/home/jovyan/share/data/analyses/benjamin/cellxgene/", "dmso_only.h5ad", 250, embedding_type = "X_dmso_only")

In [None]:
test["cluster"] = "cluster_" + test["leiden_v5_r0.7"]
test["cluster"] = test["cluster"].astype(str)

In [None]:

# Example usage of your bokeh_scatter function
# Assuming `df` is your DataFrame and you've set the `x`, `y`, and other parameters correctly
# Replace `df`, `x`, `y`, `hues`, etc. with your actual variables and parameters
plot_html = bokeh_scatter_with_selection(
    test,
    x = 'UMAP 1',
    y = 'UMAP 2',
    hues = 'Metadata_cmpdName grit cluster'.split(),
    title='Beactica DMSO ',
    hover_columns='Metadata_Plate Metadata_cmpdName Metadata_Well clip'.split(),
    # marker='square', size=0.95,
    size=5,
    filter_column= "cluster"
)
# Now, save the generated HTML content to a file
save_html(plot_html, "sc_scatter_dp_dmso.html")

In [None]:
plot_html = bokeh_scatter_with_selection(
    cp_specs3k_ref,
    x = 'UMAP 1',
    y = 'UMAP 2',
    hues = 'Metadata_cmpdName'.split(),
    title='CellProfiler SPECS3K',
    hover_columns='Metadata_Plate Metadata_cmpdName Metadata_Well clip'.split(),
    # marker='square', size=0.95,
    filter_column= "Metadata_cmpdName",
    size=5)
save_html(plot_html, "sc_scatter_specs3k_cp.html")

In [None]:
from __future__ import annotations
from typing import *

import numpy as np
import pandas as pd
import uuid
from bokeh.models import *
from bokeh.plotting import *
from bokeh.layouts import *
from bokeh.palettes import *
from bokeh.transform import factor_cmap, linear_cmap
import bokeh.core.properties as properties

import viewer

from bokeh.models import Select, CustomJS, ColumnDataSource, Button, ColorBar, HoverTool, LinearColorMapper, CategoricalColorMapper, DataRange1d
from bokeh.io import output_notebook
from typing import Sequence, Literal
from bokeh.models import CheckboxGroup
from bokeh.models import CheckboxGroup, CustomJS
from bokeh.layouts import column


def bokeh_to_html(root, id: str):
    import bokeh.embed
    import bokeh.document
    import bokeh.themes
    import bokeh.embed
    import bokeh.document
    import bokeh.themes

    script, div = bokeh.embed.components(root, theme=bokeh.themes.built_in_themes['dark_minimal'])

    osd = viewer.Viewer([], height='100%').to_html()

    html = '''
        <style>
            .bokeh_scatter {
                box-sizing: border-box;
                display: grid;
                width: 100%;
                grid-template-columns: auto 1fr;
                margin: 0;
                padding: 0;
            }
        </style>
        <script type="text/javascript" src="https://cdn.bokeh.org/bokeh/release/bokeh-3.2.2.min.js"></script>
        <script type="text/javascript" src="https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.2.2.min.js"></script>
        <script type="text/javascript" src="https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.2.2.min.js"></script>
        <script type="text/javascript">
            Bokeh.set_log_level("info");
        </script>
        <div class="bokeh_scatter" id=@id>
            @div
            @osd
        </div>
        @script
    '''
    html = html.replace('@id', id)
    html = html.replace('@div', div)
    html = html.replace('@osd', osd)
    html = html.replace('@script', script).replace('<script ', '<script eval ')

    return html

def check_colormap_type(cmap_name):
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import matplotlib.colors as colors

    # Check if the colormap name is valid
    if cmap_name not in plt.colormaps():
        raise ValueError(f"{cmap_name} is not a valid Matplotlib colormap.")

    # Retrieve the colormap object
    cmap = cm.get_cmap(cmap_name)

    # Check if the colormap is continuous or discrete
    if isinstance(cmap, colors.LinearSegmentedColormap):
        return 'continuous'
    elif isinstance(cmap, colors.ListedColormap):
        return 'discrete'
    else:
        raise ValueError(f"{cmap_name} type is unknown.")
    
    
def bokeh_scatter_with_selection(
    df,
    x,
    y,
    hues: Sequence[str],
    hover_columns: Sequence[str] | None = None,
    title='Scatter plot',
    invert: str = '',
    marker: Literal['circle', 'square'] = 'circle',
    size: float | int = 10,
    aspect_ratio: float | None = None,
    filter_column: str = None
):
    if not hues:
        raise ValueError('At least one hue is needed')

    v = 4

    cols = [x, y, *(hover_columns or []), *hues]
    for c in 'barcode site well overlay overlay_style'.split():
        if c in df.columns:
            cols += [c]

    cols = list(set(cols))
    df = df[cols]

    source = ColumnDataSource(data=df)
    original_source = ColumnDataSource(data=df.copy())

    mappers = {}
    color_bars = {}

    for hue in hues:
        linear = (
            np.issubdtype(df[hue], np.floating)
            or np.issubdtype(df[hue], np.integer)
            and df[hue].nunique() > 10
        )
        if linear:
            mappers[hue] = LinearColorMapper(
                palette="Viridis256",
                low=df[hue].min(),
                high=df[hue].max(),
            )
        else:
            factors = df[hue].unique()
            mappers[hue] = CategoricalColorMapper(
                palette=Category10[10][: len(factors)],
                factors=factors,
            )
        color_bars[hue] = ColorBar(
            color_mapper=mappers[hue],
            height=300 if linear else 24 * len(factors),
            width=15 if linear else 24,
            background_fill_color='rgba(0,0,0,0)',
            title=hue,
        )

    if aspect_ratio is not None:
        aspect = dict(
            match_aspect=True,
            aspect_scale=aspect_ratio,
        )
    else:
        aspect = {}

    p = figure(
        width=1000,
        height=1000,
        x_axis_label=x,
        y_axis_label=y,
        x_range=DataRange1d(
            flipped='x' in invert, max_interval=1.2 * max(np.ptp(df[x]), np.ptp(df[y]))
        ),
        y_range=DataRange1d(
            flipped='y' in invert, max_interval=1.2 * max(np.ptp(df[x]), np.ptp(df[y]))
        ),
        title=title,
        tools='lasso_select box_select wheel_zoom box_zoom reset fullscreen pan'.split(),
        # toolbar_location='below',
        toolbar_location='right',
        active_scroll='wheel_zoom',
        active_drag='box_select',
        border_fill_color='#2d2d2d',
        output_backend='webgl',
        **aspect,
    )

    if hover_columns:
        hover = HoverTool(
            tooltips=[(c, f'@{c}') for c in hover_columns],
            mode='mouse',
            point_policy='snap_to_data',
        )
        p.add_tools(hover)

    p2 = figure(
        width=100,
        height=p.height,
        tools='',
        toolbar_location=None,
        outline_line_alpha=0,
        background_fill_color='rgba(0,0,0,0)',
        border_fill_color='rgba(0,0,0,0)',
    )
    if marker == 'square':
        scatter = p.rect(
            x=x,
            y=y,
            width=size, height=size,
            source=source,
            fill_color={
                'field': hues[0],
                'transform': mappers[hues[0]],
            },
            line_color='#111',
            alpha=0.9,
            line_alpha=1.0,
            line_width=1.0,
        )
        # Add invisible circles for lasso selection
        p.circle(x=x, y=y, size=10, source=source, color=None, alpha=0)
    elif marker == 'circle':
        scatter = p.circle(
            x=x,
            y=y,
            size=size,
            source=source,
            fill_color={
                'field': hues[0],
                'transform': mappers[hues[0]],
            },
            line_color='#111',
            alpha=0.9,
            line_alpha=1.0,
            line_width=1.0,
        )
        # Add invisible circles for lasso selection
        # p.circle(x=x, y=y, size=10, source=source, color=None, alpha=0)
    else:
        raise ValueError(f'Unsupported {marker=}!')

    scatter.selection_glyph = type(scatter.glyph)(**scatter.glyph.properties_with_values())
    scatter.nonselection_glyph = type(scatter.glyph)(**scatter.glyph.properties_with_values())
    scatter.selection_glyph.line_color = '#eee'
    scatter.nonselection_glyph.fill_alpha = 0.7
    scatter.nonselection_glyph.line_alpha = 0.9
    for hue in hues:
        p2.add_layout(color_bars[hue], 'right')
        color_bars[hue].visible = False
    color_bars[hues[0]].visible = True

    # Callback
    js_code = """
        scatter.glyph.fill_color = {field: hue, transform: mappers[hue]}
        scatter.selection_glyph.fill_color = {field: hue, transform: mappers[hue]};
        scatter.nonselection_glyph.fill_color = {field: hue, transform: mappers[hue]};
        for (let hue_i of Object.keys(color_bars)) {
            color_bars[hue_i].visible = hue_i === hue
        }
        source.change.emit()
    """

    # Buttons
    buttons = []
    for hue in hues:
        button = Button(label=hue, button_type='success')
        callback = CustomJS(
            args=dict(
                source=source,
                scatter=scatter,
                mappers=mappers,
                color_bars=color_bars,
                hue=hue,
            ),
            code=js_code,
        )
        button.js_on_click(callback)
        buttons += [button]

    if 0:
        # experiment to change x y dynamically
        for coords in 'xy xz yz'.split():
            button = Button(label=coords, button_type='success')
            callback = CustomJS(
                args=dict(
                    p=p,
                    source=source,
                    scatter=scatter,
                    coords=coords,
                    x_axis=p.xaxis[0],
                    y_axis=p.yaxis[0],
                ),
                code='''
                    scatter.glyph.x.field = coords[0]
                    scatter.glyph.y.field = coords[1]
                    x_axis.axis_label = coords[0]
                    y_axis.axis_label = coords[1]
                    source.change.emit()
                    p.change.emit()
                ''',
            )
            button.js_on_click(callback)
            buttons += [button]

    p.js_on_event(
        'doubletap',
        CustomJS(
            args=dict(source=source),
            code="""
                source.selected.indices = []
                source.change.emit()
            """,
        ),
    )

    id = f'scatter-{uuid.uuid4()}'

    source.selected.js_on_change('indices', CustomJS(
        args=dict(source=source, id=id),
        code="""
            const data = source.data
            const rows = []
            for (const i of source.selected.indices) {
              const row = {}
              for (const col of 'barcode well site clip'.split(' ')) {
                if (data.hasOwnProperty(col)) 
                  row[col] = data[col][i]
              }
              rows.push(row)
            }
            console.log(rows)
            const osd_iframe = document.querySelector(`#${id} iframe`)

            function call_osd(method, ...args) {
              osd_iframe.contentWindow.postMessage({method, arguments: args}, '*')
            }
            call_osd('update_tile_source', rows)
        """,
    ))

    if filter_column is not None:
        checkbox_widgets = {}
        for col in filter_column:
            unique_values = sorted(df[col].unique().tolist())
            checkbox_widget = CheckboxGroup(labels=unique_values, active=list(range(len(unique_values))), name=col)
            checkbox_widgets[col] = checkbox_widget
            #unique_values = sorted(df[filter_column].unique().tolist())
            #filter_widget = CheckboxGroup(labels=unique_values, active=list(range(len(unique_values))))  # All items selected by default
            #unique_values = ['All'] + sorted(df[filter_column].unique().tolist())
            #filter_widget = Select(title=f"Filter by {filter_column}", value="All", options=unique_values)
            
            # CustomJS callback to filter data based on the selected value    

        filter_callback_code = """
                const data = source.data;
                const original_data = original_source.data;
                const selections = {};
                
                // Gather active selections for each filter
                for (const [key, widget] of Object.entries(checkbox_widgets)) {
                    const selected_indices = widget.active;
                    const selected_labels = selected_indices.map(index => widget.labels[index]);
                    selections[widget.name] = selected_labels;
                }
                
                // Reset data
                for (const key in data) {
                    data[key] = [];
                }
                
                // Filter data based on active selections
                main_loop:
                for (let i = 0; i < original_data['index'].length; ++i) {
                    for (const [column, selected_labels] of Object.entries(selections)) {
                        if (!selected_labels.includes(original_data[column][i])) {
                            continue main_loop;
                        }
                    }
                    for (let key in data) {
                        data[key].push(original_data[key][i]);
                    }
                }
                
                source.change.emit();
            """
        for checkbox_widget in checkbox_widgets.values():
                checkbox_widget.js_on_change('active', CustomJS(args=dict(source=source, original_source=original_source, checkbox_widgets=checkbox_widgets), code=filter_callback_code))
            
            
    else:
        filter_widget = None
        #checkbox_widget = None
        
    buttons_layout = row(buttons)  # Assuming 'buttons' is a list of Button widgets
    plot_layout = row(p2, p)  # Your original plot arrangement

    # Updated layout with filter widget
    if filter_column is not None:
        all_widgets_layout = [widget for widget in checkbox_widgets.values()]
        print(all_widgets_layout)
        l = column(*all_widgets_layout, buttons_layout, plot_layout)
    else:
        l = column(buttons_layout, plot_layout)

    return bokeh_to_html(l, id=id)

In [None]:
HTML(plot_html)

## Filter widget

In [None]:
callback = CustomJS(args=dict(source=source, original_source=original_source, filter_widget=filter_widget), code="""
            const data = source.data;
            const original_data = original_source.data;
            const selected = filter_widget.value;
            for (let key in data) {
                data[key] = [];
            }
            for (let i = 0; i < original_data['index'].length; ++i) {
                if (selected === 'All' || original_data['"""+filter_column+"""'][i] === selected) {
                    for (let key in data) {
                        data[key].push(original_data[key][i]);
                    }
                }
            }
            source.change.emit();
        """)

## Simple checkbox widget

In [None]:
    if filter_column is not None:
        unique_values = sorted(df[filter_column].unique().tolist())
        filter_widget = CheckboxGroup(labels=unique_values, active=list(range(len(unique_values))))  # All items selected by default
        filter_callback = CustomJS(args=dict(source=source, original_source=original_source, checkbox_group=filter_widget, labels=unique_values), code="""
            const selected_indices = checkbox_group.active;
            const selected_labels = selected_indices.map(index => labels[index]);
            const data = source.data;
            const original_data = original_source.data;
            
            // Reset data
            for (const key in data) {
                data[key] = [];
            }
            
            // Filter data based on selected labels
            for (let i = 0; i < original_data['index'].length; ++i) {
                if (selected_labels.includes(original_data['"""+filter_column+"""'][i])) {
                    for (let key in data) {
                        data[key].push(original_data[key][i]);
                    }
                }
            }
            
            source.change.emit();
        """)
        filter_widget.js_on_change('active', filter_callback)
        #filter_widget.js_on_change('active', filter_callback)
        
        #filter_widget.js_on_change('value', callback)



In [None]:
        checkbox_widgets = {}
        for column in filter_columns:
            print(df[column])
            unique_values = sorted(df[column].unique().tolist())
            checkbox_widget = CheckboxGroup(labels=unique_values, active=list(range(len(unique_values))), name=column)
            checkbox_widgets[column] = checkbox_widget
        #unique_values = sorted(df[filter_column].unique().tolist())
        #filter_widget = CheckboxGroup(labels=unique_values, active=list(range(len(unique_values))))  # All items selected by default
        #unique_values = ['All'] + sorted(df[filter_column].unique().tolist())
        #filter_widget = Select(title=f"Filter by {filter_column}", value="All", options=unique_values)
        
        # CustomJS callback to filter data based on the selected value
    

        

        filter_callback_code = """
            const data = source.data;
            const original_data = original_source.data;
            const selections = {};
            
            // Gather active selections for each filter
            for (const [key, widget] of Object.entries(checkbox_widgets)) {
                const selected_indices = widget.active;
                const selected_labels = selected_indices.map(index => widget.labels[index]);
                selections[widget.name] = selected_labels;
            }
            
            // Reset data
            for (const key in data) {
                data[key] = [];
            }
            
            // Filter data based on active selections
            main_loop:
            for (let i = 0; i < original_data['index'].length; ++i) {
                for (const [column, selected_labels] of Object.entries(selections)) {
                    if (!selected_labels.includes(original_data[column][i])) {
                        continue main_loop;
                    }
                }
                for (let key in data) {
                    data[key].push(original_data[key][i]);
                }
            }
            
            source.change.emit();
        """
        for checkbox_widget in checkbox_widgets.values():
            checkbox_widget.js_on_change('active', CustomJS(args=dict(source=source, original_source=original_source, checkbox_widgets=checkbox_widgets), code=filter_callback_code))
        