In [1]:
import random
from pathlib import Path

import numpy as np
import pandas as pd
from bokeh import palettes
from bokeh.plotting import figure, show, ColumnDataSource, output_file
from bokeh.models.tools import HoverTool
from bokeh.transform import dodge
from bokeh.io import reset_output, output_notebook
from plotly import offline
import plotly.express as px

In [2]:
class PlotArgBin():
    """
        Store variables for processing of RNA seq file into DataFrame.
    """
    
    def __init__(
                    self, df_in_path_, out_dir_path_=Path(Path.cwd().parent / "Plots"), 
                    sample_=None, terms_=None, search_locs_="Gene", universal_=False,
                ):
        
        self.df_in_path = df_in_path_ #  (str) Path to input file from which the main DataFrame will be constructed
        self.out_dir_path = out_dir_path_ #  (str) Path to directory where plots should be stored
        self.sample = sample_ #  (int) size of random sample to be taken from dataset
        self.terms = [] if terms_ is None else terms_ #  (list of str) List of terms to search for in dataset
        self.search_locs = [] if search_locs_ is None else search_locs_ #  (list of str) List of columns to search for search terms within
        self.universal = universal_ #  (bool) If True, all columns will be searched

In [3]:
# In-browser
# res = "1080p"
# plot_width = 1600
# plot_height = 800

# In-notebook
# res = "1080p"
# plot_width = 1200
# plot_height = 600

# In-browser
# res = "1440p"
# plot_width = 2200
# plot_height = 1100

# In-notebook
res = "1440p"
plot_width = 1600
plot_height = 800

In [4]:
def style(plot, res="1440p"):
    """
        Adds basic styling to plots.
    """
    if res == "1080p":
        plot.title.align = 'center'
        plot.title.text_font_size = '18pt'
        plot.xaxis.axis_label_text_font_size = '12pt'
        plot.xaxis.major_label_text_font_size = '12pt'
        plot.yaxis.axis_label_text_font_size = '12pt'
        plot.yaxis.major_label_text_font_size = '12pt'
        plot.legend.label_text_font_size = '20pt'
        
    if res == "1440p":
        plot.title.align = 'center'
        plot.title.text_font_size = '24pt'
        plot.xaxis.axis_label_text_font_size = '20pt'
        plot.xaxis.major_label_text_font_size = '20pt'
        plot.yaxis.axis_label_text_font_size = '20pt'
        plot.yaxis.major_label_text_font_size = '20pt'
        plot.legend.label_text_font_size = '20pt'

    return plot

## RNA seq plotting

In [5]:
def plot_rna_bar(df, args, out_file_name, save=False):
    """
        Generates bar plot showing mean FPKM values across the three conditions.
    """
    
    reset_output()
    if save:
        output_file(args.out_dir_path / out_file_name)
    
    cols = ["YPD_rna_mean", "NS_rna_mean", "AS_rna_mean"]
    legend_labels = ["YPD Mean FPKM", "NS Mean FPKM", "AS Mean FPKM"]
    pal = palettes.Plasma[len(cols)]
    source = ColumnDataSource(df.sort_values(by=[cols[0]], ascending=False))
    x_range_ = list(df["gene"].astype(str))
    
    plot = figure(    
                    x_range=x_range_, plot_width=plot_width, plot_height=plot_height,
                    x_axis_label="Gene Name", 
                    y_axis_label="Mean FPKM",
                  )

    hover = HoverTool()
    # Hover tips are hardcoded and thus not receptive to differnt col inputs
    hover.tooltips=[    
                        ("Name", "@gene"), ("Accession", "@accession"), ("YPD Mean FPKM", "@YPD_rna_mean"), 
                        ("AS Mean FPKM", "@AS_rna_mean"), ("NS Mean FPKM", "@NS_rna_mean")
                   ]
    
    plot.add_tools(hover)
    
    for i, col in enumerate(cols):
        plot.vbar(
                    x=dodge("gene", (i * 0.2), range=plot.x_range), top=col, width=.2, 
                    color=pal[i], source=source, legend_label=legend_labels[i]
                 )
        
    plot = style(plot)
    output_notebook()
    show(plot)

In [6]:
def plot_rna_scat(df, args, out_file_name, save=False):
    """
        Generates 2D scatter plot showing mean FPKM values across the three conditions.
    """
    
    reset_output()
    if save:
        output_file(args.out_dir_path / out_file_name)

    plot = figure(
                    plot_width=plot_width, plot_height=plot_height,
                    x_axis_label="Gene Index", 
                    y_axis_label="Mean FPKM",
                 )
    
    source = ColumnDataSource(df)
    hover = HoverTool()
    hover.tooltips= [    
                        ("Name", "@gene"), ("Accession", "@accession"), ("YPD Mean FPKM", "@YPD_rna_mean"), 
                        ("AS Mean FPKM", "@AS_rna_mean"), ("NS Mean FPKM", "@NS_rna_mean")
                    ]
    
    plot.add_tools(hover)

    plot.circle(x="index", y="YPD_rna_mean", source=source, color="black", size=20, legend_label="YPD Mean FPKM")
    plot.circle(x="index", y="NS_rna_mean", source=source, color="blue", size=20, legend_label="NS Mean FPKM")
    plot.circle(x="index", y="AS_rna_mean", source=source, color="red", size=20, legend_label="AS Mean FPKM")
    
    plot = style(plot)
    output_notebook()
    show(plot)

In [7]:
def plot_rna_3d_scat(df, args, out_file_name):
    """
        Generates 3D scatter plot showing mean FPKM values in log10 scale across the three conditions.
    """
    
    width_ = 1800 if res == "1080p" else 2400
    height_ = 900 if res == "1080p" else 1250
    
    scale_func = np.log10
    df_clone = df.copy()
    df_clone["AS_NS_L2FC"] = df["AN_rna_L2FC"].apply(abs)
    df_clone["YPD_log10"] = df["YPD_rna_mean"].apply(lambda x: scale_func(x + 1))
    df_clone["NS_log10"] = df["NS_rna_mean"].apply(lambda y: scale_func(y + 1))
    df_clone["AS_log10"] = df["AS_rna_mean"].apply(lambda z: scale_func(z + 1))

    fig = px.scatter_3d(
                            df_clone, x="YPD_log10", y="NS_log10", z="AS_log10", color="AS_NS_L2FC", hover_data=["gene", "accession"], 
                            symbol="search_hit", width=width_, height=height_
                       )

    fig.update_layout(
                        scene = {
                                    "xaxis":{"nticks":9, "range":[0,4.5]},
                                    "yaxis":{"nticks":9, "range":[0,4.5]},
                                    "zaxis":{"nticks":9, "range":[0,4.5]},
                                }
                     )
    
    fig.update_layout(font={"size":18})
    fig.update_layout(scene_aspectmode="cube")
    fig.update_layout(legend_x = 0)
    
    out_path = str(args.out_dir_path / out_file_name)
    offline.plot(fig, filename=out_path)

## RNA seq-SILAC plotting

In [8]:
def plot_prot_scat(df, args, out_file_name, save):
    """
        Generates 2D scatter plot showing mean FPKM values across the three conditions.
    """
        
    reset_output()
    if save:
        output_file(args.out_dir_path / out_file_name)

    # Initialize plot
    plot = figure(    
                plot_width=plot_width, plot_height=plot_height,
                x_axis_label="RNA AS to NS Log 2 Fold Change", 
                y_axis_label="Protein AS to NS Log 2 Fold Change"
              )
    
    source = ColumnDataSource(df)
    hover = HoverTool(tooltips=[("Name", "@gene"), ("Accession", "@accession")])
    plot.add_tools(hover)

    # Plot entire DataFrame
    plot.circle(x="AN_rna_L2FC", y="SD-AA/SD-N_6h_prot", source=source, color="black", legend_label="Bulk Data", size=20)
    
    # Plot search hits 
    hit_filt = df["search_hit"] == True
    hit_source = ColumnDataSource(df[hit_filt])
    plot.circle(x="AN_rna_L2FC", y="SD-AA/SD-N_6h_prot", source=hit_source, color="red", legend_label="Search Hit", size=20)
    
    # Output stats
    print(f"Initial DataFrame shape: \t{df.shape}")
    print(f"Filtered DataFrame shape: \t{df[hit_filt].shape}")
    print(f"Search terms: \t\t\t{args.terms}")
    print(f"Search locations: \t\t{args.search_locs}")
    
    # Output plot
    plot.legend.click_policy="hide"
    plot = style(plot)
    output_notebook()
    show(plot)

In [9]:
def take_sample(df, args):
    """
        Take random subset of input data if requested.
    """
    
    if args.sample:
        df.loc[:, "search_hit"] = False
        i_list = []
        while len(i_list) < args.sample:
            rand = random.randint(0, df.shape[0]-1)
            if rand not in i_list:
                i_list.append(rand)
                
        df.iloc[i_list, df.columns.get_loc("search_hit")] = True

def search_terms(df, args):
    """
        Search for the presense of any of the search terms provided by the user. If found in
        a given row, the "search_hit" value is set to True.
    """
    
    # Excludes hits to search terms prefixed with "non-"
    _terms = [r"[^(non\-)]?" + term for term in args.terms]

    if _terms is None or len(_terms) == 0:
        return df
    
    terms = "|".join(_terms)
    locs = df.columns if args.universal else args.search_locs
    
    # Check if all provided locations are valid
    for col in locs:
        assert col in df.columns, ("Invalid column provided: " + col)

    # Hits will be a Series of Boolean values indicating one or more search hit
    hits = pd.Series(np.zeros(df.shape[0], dtype=bool))
    for col in locs:
        sub_hits = df[col].astype(str).str.contains(terms, case=False)
        hits = (hits == True) | (sub_hits == True)
    
    # Maintain restriction imposed by previous call to take_sample
    filt = (df["search_hit"].astype(bool) == True) & (hits == True)
    df.loc[~filt, "search_hit"] = False

def apply_thresh(df, col, lower=None, upper=None, quantile=False):
    """
        Imposes a threshold, either absolute or by quantile for a given column which
        must be met for the row to be considered a search_hit.
    """
    
    assert col in df.columns, ("Invalid column provided: " + col)
    
    # Only look at rows passing previous restrictions
    search_filt = (df["search_hit"] == True)
    
    if quantile:
        lower = df.loc[search_filt, [col]].quantile(lower).iloc[0] if lower is not None else float("-inf")
        upper = df.loc[search_filt, [col]].quantile(upper).iloc[0] if upper is not None else float("inf")
    else:
        lower = lower if lower is not None else float("-inf")
        upper = upper if upper is not None else float("inf")

    val_filt = (df[col].astype(float) >= lower) & (df[col] <= upper)
    df[~val_filt | ~search_filt] = False

def set_search_col(df, args):
    # Initialize all rows to true
    df.loc[:, "search_hit"] = True
    take_sample(df, args)
    search_terms(df, args)
    #apply_thresh(df, "AN_rna_L2FC", lower=0.2, upper=0.70, quantile=False)

## Initialize DataFrame, define search parameters and perform search

In [10]:
plot_args = PlotArgBin(
                        Path(Path.cwd().parent / "DataFrames" / "comb_df.xlsx"),
                        sample_=None,
                        terms_=["atg"],
                        search_locs_=["gene"],
                        universal_=False,
                      )

In [11]:
df = pd.read_excel(plot_args.df_in_path)
set_search_col(df, plot_args)
df["search_hit"].value_counts()

False    4411
True       31
Name: search_hit, dtype: int64

## Generate plots

In [12]:
plot_rna_bar(df, plot_args, out_file_name="rna_bar.html", save=False)

In [None]:
plot_rna_scat(df, plot_args, out_file_name="rna_scat.html", save=False)

In [229]:
plot_rna_3d_scat(df, plot_args, "rna_3d_scat.html")

In [230]:
plot_prot_scat(df, plot_args, "rna_prot_scat.html", save=False)

Initial DataFrame shape: 	(4442, 31)
Filtered DataFrame shape: 	(31, 31)
Search terms: 			['atg']
Search locations: 		['gene', 'search_hit']
