In [1]:
import panel as pn
import scanpy as sc
import spacec as sp
import matplotlib.pyplot as plt
import pandas as pd
import warnings
from pyFlowSOM import map_data_to_nodes, som
import os

2024-11-17 21:37:06.807506: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO:root: * TissUUmaps version: 3.1.1.6


In [2]:
def launch_interactive_clustering():
    warnings.filterwarnings('ignore')
    pn.extension('deckgl', design='bootstrap', theme='default', template='bootstrap')
    pn.state.template.config.raw_css.append("""
    #main {
    padding: 0;
    }""")

    # Define the app
    def create_clustering_app():
        
        # Define the clustering function
        def clustering(
            adata,
            clustering="leiden",
            marker_list=None,
            resolution=1,
            n_neighbors=10,
            reclustering=False,
            key_added=None,
            key_filter=None,
            subset_cluster=None,
            seed=42,
            fs_xdim=10,
            fs_ydim=10,
            fs_rlen=10,  # FlowSOM parameters
            **cluster_kwargs,
        ):
            """
            Perform clustering on the given annotated data matrix.

            Parameters
            ----------
            adata : AnnData
                The annotated data matrix of shape n_obs x n_vars.
            clustering : str, optional
                The clustering algorithm to use. Defaults to "leiden".
            marker_list : list, optional
                A list of markers for clustering. Defaults to None.
            resolution : int, optional
                The resolution for the clustering algorithm. Defaults to 1.
            n_neighbors : int, optional
                The number of neighbors to use for the neighbors graph. Defaults to 10.
            reclustering : bool, optional
                Whether to recluster the data. Defaults to False.
            key_added : str, optional
                The key name to add to the adata object. Defaults to None.
            seed : int, optional
                Seed for random state. Default is 42.
            fs_xdim : int, optional
                X dimension for FlowSOM. Default is 10.
            fs_ydim : int, optional
                Y dimension for FlowSOM. Default is 10.
            fs_rlen : int, optional
                Rlen for FlowSOM. Default is 10.

            Returns
            -------
            AnnData
                The annotated data matrix with the clustering results added.
            """
            if clustering not in ["leiden", "louvain", "leiden_gpu", "flowSOM"]:
                print(
                    "Invalid clustering options. Please select from leiden, louvain, leiden_gpu, or flowSOM!"
                )
                sys.exit()

            if key_added is None:
                key_added = clustering + "_" + str(resolution)

            if marker_list is not None:
                if len(list(set(marker_list) - set(adata.var_names))) > 0:
                    print("Marker list not all in adata var_names! Using intersection instead!")
                    marker_list = list(set(marker_list) & set(adata.var_names))
                    print("New marker_list: " + " ".join(marker_list))
                adata = adata[:, marker_list]

            if not reclustering and clustering != "flowSOM":
                sc.pp.neighbors(adata, n_neighbors=n_neighbors)
                sc.tl.umap(adata)

            if clustering == "leiden":
                sc.tl.leiden(
                    adata,
                    resolution=resolution,
                    key_added=key_added,
                    random_state=seed,
                    **cluster_kwargs,
                )
            elif clustering == "louvain":
                sc.tl.louvain(
                    adata,
                    resolution=resolution,
                    key_added=key_added,
                    random_state=seed,
                    **cluster_kwargs,
                )
            elif clustering == "flowSOM":
                # Implement FlowSOM clustering
                adata_df = pd.DataFrame(
                    adata.X, index=adata.obs.index, columns=adata.var.index
                )
                som_input_arr = adata_df.to_numpy()
                node_output = som(
                    som_input_arr,
                    xdim=fs_xdim,
                    ydim=fs_ydim,
                    rlen=fs_rlen,
                    seed=seed,
                )
                clusters, dists = map_data_to_nodes(node_output, som_input_arr)
                clusters = pd.Categorical(clusters)
                adata.obs[key_added] = clusters
            else:
                print("Clustering method not implemented in this example.")

            return adata
        
        # Callback to load data
        def load_data(event):
            if not input_path.value or not os.path.isfile(input_path.value):
                output_area.object = "**Please enter a valid AnnData file path.**"
                return
            adata = sc.read_h5ad(input_path.value)
            adata_container['adata'] = adata
            marker_list_input.options = list(adata.var_names)
            output_area.object = "**AnnData file loaded successfully.**"

        # Callback to run clustering
        def run_clustering(event):
            adata = adata_container.get('adata', None)
            if adata is None:
                output_area.object = "**Please load an AnnData file first.**"
                return
            marker_list = list(marker_list_input.value) if marker_list_input.value else None
            key_added = key_added_input.value if key_added_input.value else clustering_method.value + '_' + str(resolution.value)
            # Start loading indicator
            loading_indicator.active = True
            output_area.object = "**Clustering in progress...**"
            # Run clustering
            try:
                if clustering_method.value == 'flowSOM':
                    adata = clustering(
                        adata,
                        clustering=clustering_method.value,
                        marker_list=marker_list,
                        reclustering=reclustering.value,
                        seed=seed.value,
                        key_added=key_added,
                        fs_xdim=fs_xdim.value,
                        fs_ydim=fs_ydim.value,
                        fs_rlen=fs_rlen.value
                    )
                else:
                    adata = clustering(
                        adata,
                        clustering=clustering_method.value,
                        marker_list=marker_list,
                        resolution=resolution.value,
                        n_neighbors=n_neighbors.value,
                        reclustering=reclustering.value,
                        seed=seed.value,
                        key_added=key_added
                    )
                adata_container['adata'] = adata
                output_area.object = "**Clustering completed.**"
                # Automatically generate visualization
                key_to_visualize = key_added
                tabs = []
                sc.pl.umap(adata, color=[key_to_visualize], show=False)
                umap_fig = plt.gcf()
                plt.close()
                tabs.append(('UMAP', pn.pane.Matplotlib(umap_fig, dpi=100)))
                if marker_list:
                    sc.pl.dotplot(adata, marker_list, groupby=key_to_visualize, dendrogram=True, show=False)
                    dotplot_fig = plt.gcf()
                    plt.close()
                    tabs.append(('Dotplot', pn.pane.Matplotlib(dotplot_fig, dpi=100)))
                # Generate histogram plot
                cluster_counts = adata.obs[key_to_visualize].value_counts()
                cluster_counts.sort_index(inplace=True)
                cluster_counts.plot(kind='bar')
                plt.xlabel('Cluster')
                plt.ylabel('Number of Cells')
                plt.title(f'Cluster Counts for {key_to_visualize}')
                hist_fig = plt.gcf()
                plt.close()
                tabs.append(('Histogram', pn.pane.Matplotlib(hist_fig, dpi=100)))
                # Add new tabs to visualization area
                for name, pane in tabs:
                    visualization_area.append((name, pane))
                # Update cluster annotations
                clusters = adata.obs[key_to_visualize].unique().astype(str)
                annotations_df = pd.DataFrame({'Cluster': clusters, 'Annotation': ['']*len(clusters)})
                cluster_annotation.value = annotations_df
            except Exception as e:
                output_area.object = f"**Error during clustering: {e}**"
            finally:
                # Stop loading indicator
                loading_indicator.active = False

        # Callback to run subclustering
        def run_subclustering(event):
            adata = adata_container.get('adata', None)
            if adata is None:
                output_area.object = "**Please run clustering first.**"
                return
            if not subcluster_key.value or not subcluster_values.value:
                output_area.object = "**Please provide subcluster key and values.**"
                return
            clusters = [c.strip() for c in subcluster_values.value.split(',')]
            key_added = subcluster_key.value + '_subcluster'
            # Start loading indicator for subclustering
            loading_indicator_subcluster.active = True
            output_area.object = "**Subclustering in progress...**"
            try:
                sc.tl.leiden(
                    adata,
                    seed=seed.value,
                    restrict_to=(subcluster_key.value, clusters),
                    resolution=subcluster_resolution.value,
                    key_added=key_added
                )
                adata_container['adata'] = adata
                output_area.object = "**Subclustering completed.**"
                # Update visualization
                tabs = []
                sc.pl.umap(adata, color=[key_added], show=False)
                umap_fig = plt.gcf()
                plt.close()
                tabs.append(('UMAP_Sub', pn.pane.Matplotlib(umap_fig, dpi=100)))
                marker_list = list(marker_list_input.value) if marker_list_input.value else None
                if marker_list:
                    sc.pl.dotplot(adata, marker_list, groupby=key_added, dendrogram=True, show=False)
                    dotplot_fig = plt.gcf()
                    plt.close()
                    tabs.append(('Dotplot_Sub', pn.pane.Matplotlib(dotplot_fig, dpi=100)))
                # Generate histogram plot
                cluster_counts = adata.obs[key_added].value_counts()
                cluster_counts.sort_index(inplace=True)
                cluster_counts.plot(kind='bar')
                plt.xlabel('Subcluster')
                plt.ylabel('Number of Cells')
                plt.title(f'Subcluster Counts for {key_added}')
                hist_fig = plt.gcf()
                plt.close()
                tabs.append(('Histogram_Sub', pn.pane.Matplotlib(hist_fig, dpi=100)))
                # Add new tabs to visualization area
                for name, pane in tabs:
                    visualization_area.append((name, pane))
                # Update cluster annotations
                clusters = adata.obs[key_added].unique().astype(str)
                annotations_df = pd.DataFrame({'Cluster': clusters, 'Annotation': ['']*len(clusters)})
                cluster_annotation.value = annotations_df
            except Exception as e:
                output_area.object = f"**Error during subclustering: {e}**"
            finally:
                # Stop loading indicator for subclustering
                loading_indicator_subcluster.active = False

        # Callback to save annotations
        def save_annotations(event):
            adata = adata_container.get('adata', None)
            if adata is None:
                output_area.object = "**No AnnData object to annotate.**"
                return
            annotation_dict = dict(zip(cluster_annotation.value['Cluster'], cluster_annotation.value['Annotation']))
            key_to_annotate = key_added_input.value if key_added_input.value else clustering_method.value + '_' + str(resolution.value)
            adata.obs['cell_type'] = adata.obs[key_to_annotate].astype(str).map(annotation_dict).astype('category')
            output_area.object = "**Annotations saved to AnnData object.**"

        def save_adata(event):
            adata = adata_container.get('adata', None)
            if adata is None:
                output_area.object = "**No AnnData object to save.**"
                return
            if not output_dir.value:
                output_area.object = "**Please specify an output directory.**"
                return
            os.makedirs(output_dir.value, exist_ok=True)
            output_filepath = os.path.join(output_dir.value, 'adata_annotated.h5ad')
            adata.write(output_filepath)
            output_area.object = f"**AnnData saved to {output_filepath}.**"

        # Callback to run spatial visualization
        def run_spatial_visualization(event):
            adata = adata_container.get('adata', None)
            if adata is None:
                output_area.object = "**Please load an AnnData file first.**"
                return
            try:
                sp.pl.catplot(
                    adata, 
                    color=spatial_color.value, 
                    unique_region=spatial_unique_region.value, 
                    X=spatial_x.value, 
                    Y=spatial_y.value, 
                    n_columns=spatial_n_columns.value, 
                    palette=spatial_palette.value, 
                    savefig=spatial_savefig.value, 
                    output_fname=spatial_output_fname.value, 
                    output_dir=output_dir.value, 
                    figsize=spatial_figsize.value, 
                    size=spatial_size.value
                )
                spatial_fig = plt.gcf()
                plt.close()
                # Add new tab to visualization area
                visualization_area.append(('Spatial Visualization', pn.pane.Matplotlib(spatial_fig, dpi=100)))
                output_area.object = "**Spatial visualization completed.**"
            except Exception as e:
                output_area.object = f"**Error during spatial visualization: {e}**"

        # File paths
        input_path = pn.widgets.TextInput(name='AnnData File Path', placeholder='Enter path to .h5ad file')
        output_dir = pn.widgets.TextInput(name='Output Directory', placeholder='Enter output directory path')
        load_data_button = pn.widgets.Button(name='Load Data', button_type='primary')

        # Clustering parameters
        clustering_method = pn.widgets.Select(name='Clustering Method', options=["leiden", "louvain", "flowSOM"])
        resolution = pn.widgets.FloatInput(name='Resolution', value=1.0)
        n_neighbors = pn.widgets.IntInput(name='Number of Neighbors', value=10)
        reclustering = pn.widgets.Checkbox(name='Reclustering', value=False)
        seed = pn.widgets.IntInput(name='Random Seed', value=42)
        key_added_input = pn.widgets.TextInput(name='Key Added', placeholder='Enter key to add to AnnData.obs', value='')
        marker_list_input = pn.widgets.MultiChoice(name='Marker List', options=[], width=950)

        # Subclustering parameters
        subcluster_key = pn.widgets.TextInput(name='Subcluster Key', placeholder='Enter key to filter on (e.g., "leiden_1")')
        subcluster_values = pn.widgets.TextInput(name='Subcluster Values', placeholder='Enter clusters to subset (comma-separated)')
        subcluster_resolution = pn.widgets.FloatInput(name='Subcluster Resolution', value=0.3)
        subcluster_button = pn.widgets.Button(name='Run Subclustering', button_type='primary')

        # Cluster annotation
        cluster_annotation = pn.widgets.DataFrame(pd.DataFrame(columns=['Cluster', 'Annotation']), name='Cluster Annotations', autosize_mode='fit_columns')
        save_annotations_button = pn.widgets.Button(name='Save Annotations', button_type='success')

        fs_xdim = pn.widgets.IntInput(name='FlowSOM xdim', value=10)
        fs_ydim = pn.widgets.IntInput(name='FlowSOM ydim', value=10)
        fs_rlen = pn.widgets.IntInput(name='FlowSOM rlen', value=10)

        # Buttons
        run_clustering_button = pn.widgets.Button(name='Run Clustering', button_type='primary')
        save_adata_button = pn.widgets.Button(name='Save AnnData', button_type='success')

        # Loading indicators
        loading_indicator = pn.widgets.Progress(name='Clustering Progress', active=False, bar_color='primary')
        loading_indicator_subcluster = pn.widgets.Progress(name='Subclustering Progress', active=False, bar_color='primary')

        # Output areas
        output_area = pn.pane.Markdown()
        visualization_area = pn.Tabs()  # Changed to pn.Tabs to hold multiple plots

        # Global variable to hold the AnnData object
        adata_container = {}
        
        # Spatial visualization parameters
        spatial_color = pn.widgets.TextInput(name='Color By Column', placeholder='Enter group column name (e.g., cell_type_coarse)')
        spatial_unique_region = pn.widgets.TextInput(name='Unique Region Column', value='unique_region')
        spatial_x = pn.widgets.TextInput(name='X Coordinate Column', value='x')
        spatial_y = pn.widgets.TextInput(name='Y Coordinate Column', value='y')
        spatial_n_columns = pn.widgets.IntInput(name='Number of Columns', value=2)
        spatial_palette = pn.widgets.TextInput(name='Color Palette', value='tab20')
        spatial_figsize = pn.widgets.FloatInput(name='Figure Size', value=17)
        spatial_size = pn.widgets.FloatInput(name='Point Size', value=20)
        spatial_savefig = pn.widgets.Checkbox(name='Save Figure', value=False)
        spatial_output_fname = pn.widgets.TextInput(name='Output Filename', placeholder='Enter output filename')
        run_spatial_visualization_button = pn.widgets.Button(name='Run Spatial Visualization', button_type='primary')

        # Link callbacks
        load_data_button.on_click(load_data)
        run_clustering_button.on_click(run_clustering)
        subcluster_button.on_click(run_subclustering)
        save_annotations_button.on_click(save_annotations)
        save_adata_button.on_click(save_adata)
        run_spatial_visualization_button.on_click(run_spatial_visualization)

        # Clustering Tab Layout
        clustering_tab = pn.Column(
            pn.pane.Markdown("### Load Data"),
            pn.Row(input_path, output_dir, load_data_button),
            pn.layout.Divider(),
            pn.pane.Markdown("### Clustering Parameters"),
            pn.Row(clustering_method, resolution, n_neighbors),
            pn.Row(seed, reclustering),
            pn.Row(fs_xdim, fs_ydim, fs_rlen),
            key_added_input,
            marker_list_input,
            pn.layout.Divider(),
            pn.Row(run_clustering_button, loading_indicator),
            output_area
        )

        # Subclustering Tab Layout
        subclustering_tab = pn.Column(
            pn.pane.Markdown("### Subclustering Parameters"),
            pn.Row(subcluster_key, subcluster_values, subcluster_resolution),
            pn.layout.Divider(),
            pn.Row(subcluster_button, loading_indicator_subcluster),
            output_area
        )

        # Annotation Tab Layout
        annotation_tab = pn.Column(
            pn.pane.Markdown("### Cluster Annotation"),
            cluster_annotation,
            pn.layout.Divider(),
            save_annotations_button,
            output_area
        )

        # Save Tab Layout
        save_tab = pn.Column(
            pn.pane.Markdown("### Save Data"),
            save_adata_button,
            output_area
        )

        # Spatial Visualization Tab Layout
        spatial_visualization_tab = pn.Column(
            pn.pane.Markdown("### Spatial Visualization Parameters"),
            pn.Row(spatial_color, spatial_palette),
            pn.Row(spatial_unique_region, spatial_n_columns),
            pn.Row(spatial_x, spatial_y),
            pn.Row(spatial_figsize, spatial_size),
            pn.layout.Divider(),
            pn.Row(spatial_savefig, spatial_output_fname),
            pn.layout.Divider(),
            pn.Row(run_spatial_visualization_button),
            output_area
        )

        # Assemble Tabs
        tabs = pn.Tabs(
            ("Clustering", clustering_tab),
            ("Subclustering", subclustering_tab),
            ("Annotation", annotation_tab),
            ("Spatial Visualization", spatial_visualization_tab),
            ("Save", save_tab)
        )

        # Main Layout with Visualization Area
        main_layout = pn.Row(
            tabs,
            visualization_area,
            sizing_mode='stretch_both'
        )

        return main_layout

    # Run the app
    main_layout = create_clustering_app()

    main_layout.servable(title='SPACEc Clustering App')
    
    return main_layout

In [3]:
launch_interactive_clustering()

BokehModel(combine_events=True, render_bundle={'docs_json': {'3e472175-93e1-4b23-ba37-d2289e829f0e': {'version…