# Visualization of copy number distribution for human and mouse proteins

__Author:__ \
Emanuel Lange \
Mehrdimensionale OMICS-Datenanalyse \
ISAS e.V. \
Bunsen-Kirchhoff-Straße 11 \
44139 Dortmund, Germany \
emanuel.lange@isas.de

__Last Revision:__ \
March 7, 2024

__License:__
MIT

__Objective:__ \
This script generates interactive visualizations of protein copy number distributions.

__How it works:__ \
For interactive visualization we utilized the bokeh library (http://bokeh.org/). \
Visualizations and data are stored as html documents and can be viewed using any modern web-browser.

__How to execute:__ \
Jupyter-notebooks are comprised of cells (text and code) that can be executed by CNTRL + ENTER (executes selected cell) or SHIFT + ENTER (executes selected cell and jumps to next one).

We recommend using [ANACONDA navigator](https://docs.anaconda.com/free/navigator/index.html) for setup. Before running this script, make sure you installed and activated the provided CONDA environment (protein_profiles_env.yaml). \
You can also run this notebook in [Google Colab](https://colab.google/) without installing anything locally. Check out this [Guide](https://saturncloud.io/blog/how-can-i-run-notebooks-of-a-github-project-in-google-colab/) on how to open a jupyter notebook from GitHub.

The protein data spreadsheet should be located in the root directory of this script.

## 1) Importing libraries

In [2]:
import pandas as pd #data handling
from bokeh.layouts import column, row
from bokeh.plotting import figure, show # visualization
from bokeh.models import (ColumnDataSource, BoxAnnotation, DataTable, StringFormatter, NumberFormatter, TableColumn, Legend, LegendItem, CategoricalColorMapper, Circle, HoverTool, TextInput, CDSView, CustomJS, IndexFilter, Div) # utilities for visualization
from bokeh.palettes import BrBG4 # color palette for annotation boxes
from bokeh.io import output_notebook, output_file, export_svg  # output visualization in notebook

## 2) Importing data

In [3]:
def read_data(path, sheet_names):
    """
    Reads data from an Excel file and merges the data from different sheets into a single DataFrame.

    Parameters:
    path (str): The path to the Excel file.
    sheet_names (list): A list of sheet names to read from the Excel file.

    Returns:
    DataFrame: A DataFrame containing the merged data from all the specified sheets. Each row in the DataFrame includes the name of the sheet it came from in a 'category' column.
    """
    data = {}

    for sheet_name in sheet_names:
        df = pd.read_excel(path, sheet_name=sheet_name)
        df['category'] = sheet_name
        data[sheet_name] = df

    merged_data = pd.concat(
        [sheet for sheet in data.values()],
        ignore_index=True,
        keys=data.keys())
    
    return merged_data

## 3) Functions to generate dashboard

In [6]:
def make_data_source(data):
    """
    Creates a ColumnDataSource from the input DataFrame.

    Parameters:
    data (DataFrame): The input DataFrame.

    Returns:
    ColumnDataSource: A ColumnDataSource containing the input data.
    """
    
    data_dict = {}

    if not 'Rank' in data.columns or not 'Copy number' in data.columns:
        raise ValueError("Data must contain columns 'Rank' and 'Copy number'")

    for column in data.columns:

        data_dict[column] = data[column]

    source = ColumnDataSource(data_dict)

    return source

def make_data_table(data_source, view):
    """
    Creates a DataTable from the input data source.

    Parameters:
    data_source (ColumnDataSource): The data source for the DataTable.
    view (CDSView): The view for filtering the DataTable.

    Returns:
    DataTable: A DataTable containing the input data.
    """

    columns = []

    for column in data_source.column_names:
        formatter = StringFormatter(font_style="bold")
        if (data_source.data[column].dtype == 'float64' or data_source.data[column].dtype == 'float32'):
            formatter = NumberFormatter(format='0',text_align='right')

        columns.append(TableColumn(field=column, title=column, formatter=formatter))

    return DataTable(
        source=data_source,
        view=view,
        columns=columns,
        editable=False,
        scroll_to_selection=True,
        margin=(20, 20, 20, 20), 
        sizing_mode="stretch_both",
        )

def initialize_figure(title, title_size):
    """
    Setup of the plot figure.

    Parameters:
    title (str): The title of the plot.
    title_size (int): The size of the title font.

    Returns:
    Figure: A plot figure with the specified title and size.
    """

    p = figure(
    title=title,
    x_axis_label="protein rank",
    y_axis_label="Log10(protein copy number)",
    y_axis_type="log",
    tools="pan,wheel_zoom,ybox_select,tap,reset,save",
    active_drag="ybox_select",
    margin=(20, 20, 20, 20),
    # lod_factor = 1000,
    # frame_width = 200,
    output_backend="webgl",
    sizing_mode="stretch_both",
    )
    p.title.align = 'center'
    p.title.text_font_size = str(title_size)+'pt'

    return p

def set_axes_and_grid(p, data_source, axis_label_size, tick_label_size):
    """
    Sets the axes and grid properties of the plot.

    Parameters:
    p (Figure): The plot figure.
    data_source (ColumnDataSource): The data source for the plot.
    axis_label_size (int): The size of the axis label font.
    tick_label_size (int): The size of the tick label font.
    """

    p.xaxis.bounds = (data_source.data['Rank'].min(), data_source.data['Rank'].max())
    p.yaxis.bounds = (data_source.data['Copy number'].min(), data_source.data['Copy number'].max())

    p.xgrid.grid_line_color = None
    p.ygrid.grid_line_color = None

    p.xaxis.axis_label_text_font_size = str(axis_label_size)+'pt'
    p.yaxis.axis_label_text_font_size = str(axis_label_size)+'pt'

    p.xaxis.major_label_text_font_size  = str(tick_label_size)+'pt'
    p.yaxis.major_label_text_font_size  = str(tick_label_size)+'pt'

def add_annotation_boxes(p, legend_font_size, categories, palette):
    """
    Adds annotation boxes and legend to the plot.

    Parameters:
    p (Figure): The plot figure.
    legend_font_size (int): The size of the legend font.
    categories (list): The categories in the data.
    palette (list): The color palette for the categories.
    """

    ## annotation boxes
    high_abundances_box = BoxAnnotation(bottom=500000, fill_alpha=0.3, fill_color=BrBG4[3], level='underlay')
    low_abundances_box = BoxAnnotation(top=10000, fill_alpha=0.1, fill_color=BrBG4[0], level='underlay')
    p.add_layout(high_abundances_box)
    p.add_layout(low_abundances_box)

    ## legend
    legend_items = [
        LegendItem(label='high abundance', renderers=[p.square(fill_alpha=0.2, color=BrBG4[3])]),
        LegendItem(label='low abundance', renderers=[p.square(fill_alpha=0.2, color=BrBG4[0])])
    ]

    for category in categories:
        legend_items.append(
            LegendItem(label=category,
                        renderers=[
                            p.circle(fill_alpha=0.4, color=palette[categories.index(category)])
                            ]))
    
    p.add_layout(Legend(items=legend_items))

    p.legend.label_text_font_size = str(legend_font_size)+'pt'

def add_data_points(p, source, categories, palette):
    """
    Adds data points to the plot.

    Parameters:
    p (Figure): The plot figure.
    source (ColumnDataSource): The data source for the plot.
    categories (list): The categories in the data.
    palette (list): The color palette for the categories.
    """

    # add mapper
    mapper = CategoricalColorMapper(palette=palette, factors=categories)

    # add circle renderer
    r_circle = p.circle(
        source=source,
        x='Rank',
        y='Copy number',
        size=10,
        alpha=0.4,
        hover_alpha=1,
        color={'field': 'category', 'transform': mapper},
        level='glyph'
        )

    # define circle selection and nonselection properties
    selected_circle = Circle(fill_alpha=1, fill_color={'field': 'category', 'transform': mapper}, line_color=None)
    nonselected_circle = Circle(fill_alpha=0.5, fill_color="grey", line_color=None)

    r_circle.selection_glyph = selected_circle
    r_circle.nonselection_glyph = nonselected_circle

def add_hover(p, columns):
    """
    Adds a hover tool to the plot.

    Parameters:
    p (Figure): The plot figure.
    columns (list): The columns in the data.
    """

    ## hover tooltip for data points
    # custom tooltip layout

    hover_content = """<div @tooltip{custom}>"""

    for column in columns:
        hover_string = """    <b>{}</b>: @{{{}}} <br>""".format(column, column)
        hover_content += hover_string

    hover_content += """
    </div>
    <style>
        div.bk-tooltip-content > div > div:not(:first-child) {
            display:none !important;
            }
    </style>"""

    # initiate and add hover tool to display tooltips
    hover = HoverTool()
    hover.tooltips = hover_content
    p.add_tools(hover)

def make_plot(data_source, categories, title, title_size, axis_label_size, tick_label_size, legend_font_size, palette):
    p = initialize_figure(title, title_size)
    set_axes_and_grid(p, data_source, axis_label_size, tick_label_size)
    add_data_points(p, data_source, categories, palette)
    add_hover(p, data_source.column_names)
    add_annotation_boxes(p, legend_font_size, categories, palette)
    return p

## create documentation
def make_docu():
    """
    Creates the documentation notice element.

    Returns:
    Div: A Div containing the documentation.
    """

    text = ("""
            <div style="padding: 0.5rem; border-style: solid; border-radius: 0.5rem; border-width: 2px; border-color: #28a745;">
                Visit our <a href="https://github.com/voidsailor/protein_abundance_visualization"> GitHub repository </a> for a manual and the source code of this dashboard.
            </div>    
            """)
    return Div(text=text, margin=(0,0,0,0))

## compose visualization
def make_widgets_and_docu(data, source, filter):
    """
    Creates the widgets and documentation notice for the visualization.

    Parameters:
    data (DataFrame): The input DataFrame.
    source (ColumnDataSource): The data source for the plot.
    filter (IndexFilter): The filter for the data view.

    Returns:
    Column: A Column layout containing the widgets and documentation.
    """

    columns = list(data.columns)
    
    search_input = TextInput(value="", title="Search", placeholder="Type and hit Enter to search", width=300, margin=(20, 0, 10, 0))

    search_input.js_on_change(
        'value',
        CustomJS(
            args=dict(
                search_input=search_input,
                source=source,
                filter=filter,
                columns=columns),
            code="""
            function filterAll(obj, columns, val) {
                let stripped_val = val.replace(/[\s.;,\/_#]/g, '')
                stripped_val = stripped_val.replace('-', '')
                
                if (stripped_val.length === 0) {
                    return obj[columns[0]].map((entry, i) => i)
                }

                let indices = []

                for (let column of columns) {
                    indices = indices.concat(filterByColumn(obj, column, stripped_val))
                }

                return [...new Set(indices)]
            }

            function filterByColumn(obj, column, val) {

                const indices = []

                obj[column].forEach((entry, i) => {
                    let stripped_entry = String(entry).replace(/[\s.;,\/_#]/g, '')
                    stripped_entry = stripped_entry.replace('-', '')

                    if (String(stripped_entry).toLowerCase().includes(String(val).toLowerCase())) {
                        indices.push(i)
                    }
                })

                return indices
            }
            
            filter.indices = filterAll(source.data, columns, search_input.value)
            source.change.emit();
            """))
    
    template_total_proteins = ("""
                <div><b>Total number of proteins:</b> {total_proteins}</div>
                """)
    content_total_proteins = template_total_proteins.format(total_proteins=len(data))
    total_proteins_div = Div(text=content_total_proteins)

    selected_proteins_div = Div(text="<b>Selected proteins:</b> 0")
    
    source.selected.js_on_change(
        'indices',
        CustomJS(
            args=dict(div_selected=selected_proteins_div),
            code ="""
                // console.log(this.indices)
                div_selected.text = "<b>Selected proteins:</b> " + this.indices.length
            """
            )
            )
    
    template_filtered_proteins = ("""
                <div><b>Number of filtered proteins:</b> {filtered_proteins}</div>
                """)
    content_filtered_proteins = template_filtered_proteins.format(filtered_proteins=len(data))
    filtered_proteins_div = Div(text=content_filtered_proteins)

    source.js_on_change(
        "change",
        CustomJS(
            args=dict(filter=filter, filtered_proteins_div=filtered_proteins_div),
            code="""
                filtered_proteins_div.text = "<b>Number of filtered proteins:</b> " + filter.indices.length
            """
        )
    )

    docu = make_docu()

    return column(docu, search_input, row(total_proteins_div, filtered_proteins_div, selected_proteins_div, margin=(0, 0, 10, 0)), sizing_mode='stretch_width', margin=(20, 20, 0, 20))

def get_catgories(data):
    """
    Returns the unique categories in the input DataFrame.

    Parameters:
    data (DataFrame): The input DataFrame.

    Returns:
    list: A list of unique categories in the input DataFrame.
    """
    codes, uniques = pd.factorize(data['category'])
    return list(uniques)

def make_interactive_plot(
        data,
        palette,
        title,
        output_file_name=None,
        title_size=20,
        axis_label_size=14,
        tick_label_size=14,
        legend_font_size=14):
    
    """
    Creates the interactive visualization of protein abundance data.

    Parameters:
    data (DataFrame): The input DataFrame.
    palette (list): A list of colors to use for the different categories in the data.
    title (str): The title of the visualization.
    output_file_name (str): The name of the output file. If None, the visualization will be only displayed in the notebook.
    title_size (int): The size of the title font.
    axis_label_size (int): The size of the axis label font.
    tick_label_size (int): The size of the tick label font.
    legend_font_size (int): The size of the legend font.

    Returns:
    None
    """

    output_notebook()

    if output_file_name:
        output_file(filename=output_file_name+".html", title=title, mode="inline")

    categories = get_catgories(data)

    # create the data source that binds the data to the plot
    data_source = make_data_source(data)

    # create the filter for searching
    filter = IndexFilter(list(range(len(data))))
    view = CDSView(filter=filter)

    data_table = make_data_table(data_source, view)

    plot = make_plot(data_source, categories, title, title_size, axis_label_size, tick_label_size,legend_font_size, palette)
    # plot.output_backend = "svg"

    widgets = make_widgets_and_docu(data, data_source, filter)

    layout = row(
            column(widgets, data_table, sizing_mode="stretch_both"),
            column(plot, sizing_mode='stretch_both'), sizing_mode='stretch_both'
            )
    
    show(layout)


## 4) Read data and create the Visualization

In [7]:
data = read_data('copy_number_distribution_human_mouse_105.xlsx', ['Human', 'Mouse'])

palette = ["tomato", "dodgerblue"]

title = "Human and Mouse_copy number per cell (10^5)"

make_interactive_plot(
    data,
    palette,
    title,
    output_file_name="human_mouse_105_copy_number_plot",
    title_size=16,
    axis_label_size=14,
    tick_label_size=14,
    legend_font_size=14
    )