In [1]:
from pathlib import Path

import ipywidgets as widgets
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.core.display import display

import rp2
from rp2 import data, hagai_2018, notebooks, paths, regression, ui

_ = notebooks.initialise_environment(
    "Mean_Variance_Fits",
    dependencies=["Burst_Model_Fitting"],
)

In [2]:
def map_between_columns(df, index_column, value_column, indices):
    if isinstance(indices, list):
        indices = pd.Series(indices)
    df = df[[index_column, value_column]].drop_duplicates().dropna()
    df = df.loc[df[index_column].isin(indices)]
    index_map = df.set_index(index_column).squeeze()
    return indices.map(index_map).values


def map_gene_ids_to_mouse(species, ids):
    species_column = f"{species}_gene"
    return map_between_columns(all_orthologues, species_column, "mouse_gene", ids)


def make_condition_info_df(species):
    stats_columns = ["min", "max", "mean", "variance", "skew"]

    txburst_df = data.load_txburst_results(species, condition_columns, count_type)[index_columns + ["k_on", "k_off", "k_syn", "keep"]]
    txburst_df.rename(columns={"keep": "k_keep"}, inplace=True)

    info_df = condition_info_df.loc[condition_info_df.species == species, index_columns + stats_columns]
    info_df.rename(columns={col: f"rna_{col}" for col in stats_columns}, inplace=True)

    stats_df = info_df.merge(
        txburst_df,
        on=index_columns,
        how="left",
    )
    stats_df = stats_df.merge(
        condition_info_df[index_columns + ["bs_point", "bf_point"]],
        on=index_columns,
        how="left",
    )
    stats_df.rename(columns={"bs_point": "bs", "bf_point": "bf"}, inplace=True)
    stats_df.insert(1, "symbol", gene_info_df.loc[stats_df.gene].symbol.values)

    if species != "mouse":
        stats_df["mouse_id"] = map_gene_ids_to_mouse(species, stats_df.gene)
        stats_df["mouse_symbol"] = gene_info_df.loc[stats_df.mouse_id].symbol.values

    return stats_df


def make_mv_trend_info_df(species):
    species_gene_ids = analysis_orthologues[species].tolist() + additional_gene_ids[species]
    trend_info = mv_lr_df.loc[mv_lr_df.index.isin(species_gene_ids)].copy()
    trend_info.insert(0, "symbol", gene_info_df.loc[trend_info.index].symbol.values)

    if species != "mouse":
        trend_info["mouse_id"] = map_gene_ids_to_mouse(species, trend_info.index)
        trend_info["mouse_symbol"] = gene_info_df.loc[trend_info.mouse_id].symbol.values

    return trend_info


def save_all_species_data(output_folder):
    rp2.create_folder(output_folder)
    for species in species_to_compare:
        make_condition_info_df(species).to_csv(output_folder.joinpath(f"{species}_condition_info.csv"), index=False)
        trend_info_df = make_mv_trend_info_df(species)
        trend_info_df.to_csv(output_folder.joinpath(f"{species}_trend_info.csv"))

        if species == "mouse":
            bp_curve_fits_df = pd.read_csv(paths.get_rp2_path("bp_curve_fits.csv"), index_col=(0, 1))
            for bp in ["bs", "bf"]:
                curve_fits_df = trend_info_df[["symbol"]].join(bp_curve_fits_df.loc[f"{bp}_point"]).sort_values(by="symbol")
                curve_fits_df.to_csv(output_folder.joinpath(f"{species}_{bp}_curves_info.csv"))

In [3]:
species_to_compare = ["mouse", "pig", "rabbit", "rat"]

use_rp2_gene_subset = True #high covergae gene set

results_output_path = "Output/Mean_Variance_Fits"

additional_mouse_genes = ["Tnf", "Il1b"]
include_additional_orthologues = True

In [4]:
lps_responsive_mouse_genes = hagai_2018.load_lps_responsive_genes()
all_orthologues = rp2.load_mouse_orthologues()
analysis_orthologues = rp2.load_one_to_one_mouse_orthologues().loc[lps_responsive_mouse_genes].reset_index()
analysis_orthologues.columns = analysis_orthologues.columns.str.slice(0, -5)

print(f"{len(analysis_orthologues):,} one-to-one arthologues")

if use_rp2_gene_subset:
    rp2_analysis_genes = data.load_rp2_analysis_genes()
    analysis_orthologues = analysis_orthologues.loc[rp2_analysis_genes.loc[analysis_orthologues.mouse].in_subset.to_numpy()]
    print(f"  Using subset of {len(analysis_orthologues):,} from RP2 project")

2,336 one-to-one arthologues
  Using subset of 97 from RP2 project


In [5]:
gene_info_df = pd.concat([rp2.load_biomart_gene_symbols_df(species) for species in species_to_compare])

In [6]:
additional_gene_ids = {species: [] for species in species_to_compare}

if len(additional_mouse_genes) > 0:
    additional_gene_ids["mouse"] = map_between_columns(gene_info_df.loc[gene_info_df.index.str.startswith("ENSMUSG")].reset_index(), "symbol", "id", additional_mouse_genes).tolist()
    additional_orthologues = all_orthologues.loc[all_orthologues.mouse_gene.isin(additional_gene_ids["mouse"])]
    if include_additional_orthologues:
        for species in species_to_compare:
            additional_gene_ids[species] = additional_orthologues[f"{species}_gene"].dropna().unique().tolist()
            additional_gene_ids[species] = list(set(additional_gene_ids[species]).difference(analysis_orthologues[species]))
            if len(additional_gene_ids[species]) > 0:
                print(f"Including orthologues for {species}: ", ", ".join(sorted(gene_info_df.loc[additional_gene_ids[species]].symbol)))

Including orthologues for mouse:  Il1b, Tnf
Including orthologues for pig:  IL1B1, IL1B2, TNF
Including orthologues for rabbit:  IL1B, TNF
Including orthologues for rat:  Il1b, LOC103694380, Tnf


In [11]:
condition_columns = ["replicate", "treatment", "time_point"]
index_columns = ["gene"] + condition_columns
count_type = "median"

def create_condition_info(species):
    counts_adata = hagai_2018.load_counts(species, scaling=count_type)
    species_gene_ids = analysis_orthologues[species].tolist() + additional_gene_ids[species]
    counts_adata = counts_adata[counts_adata.obs.time_point.isin(["0", "2", "4", "6"]), species_gene_ids].copy()
    info_df = hagai_2018.calculate_counts_condition_stats(counts_adata)
    info_df.insert(1, "species", species)
    info_df = info_df.merge(
        data.load_and_recalculate_txburst_results(species, condition_columns=condition_columns, count_type=count_type),
        on=index_columns,
        how="left",
    )
    return info_df


condition_info_df = pd.concat([create_condition_info(species) for species in species_to_compare])

In [8]:
mv_lr_df = condition_info_df.groupby("gene").apply(regression.fit_robust_linear_trend, x_var="mean", y_var="variance")



In [9]:
if results_output_path is not None:
    save_all_species_data(Path(results_output_path))

In [10]:
mv_plot_output = widgets.Output()
mv_info_output = widgets.Output()
#bs_plot_output = widgets.Output()
#bf_plot_output = widgets.Output()
burst_info_output = widgets.Output()

tab_widget = widgets.Tab()
tab_widget.children = [widgets.HBox([mv_plot_output, mv_info_output])]
#tab_widget.children = [widgets.HBox([mv_plot_output, mv_info_output]), widgets.HBox([bs_plot_output, bf_plot_output, burst_info_output])]

tab_widget.set_title(0, "Mean-Variance")
#tab_widget.set_title(1, "Bursting")


@widgets.interact(mouse_gene_id=ui.make_gene_selector(gene_info_df.loc[analysis_orthologues.mouse].symbol, rows=5))
def plot_mv(mouse_gene_id):
    gene_ids = analysis_orthologues.loc[analysis_orthologues.mouse == mouse_gene_id].squeeze().to_list()
    stats_subset = condition_info_df.loc[condition_info_df.gene.isin(gene_ids)]
    lr_subset = mv_lr_df.loc[gene_ids].copy()
    lr_subset.insert(0, "symbol", gene_info_df.symbol[lr_subset.index])

    plot_df = stats_subset[["gene", "species", "replicate", "treatment", "time_point", "mean", "variance"]].copy()
    #plot_df = stats_subset[["gene", "species", "replicate", "treatment", "time_point", "mean", "variance", "bs_point", "bf_point"]].copy()

    plot_df = plot_df.merge(lr_subset, on="gene").sort_values(by=["gene", "mean"])
    plot_df["lr_y"] = (plot_df["mean"] * plot_df["slope"]) + plot_df["intercept"]

    colours = sns.color_palette()[:len(gene_ids)]

    mv_plot_output.clear_output()
    with mv_plot_output:
        sns.scatterplot(
            plot_df["mean"],
            plot_df["variance"],
            style=plot_df["treatment"],
            style_order=["unst", "lps", "pic"],
            hue=plot_df["species"],
            hue_order=species_to_compare,
            palette=colours,
        )

        ax = sns.lineplot(
            plot_df["mean"],
            plot_df["lr_y"],
            hue=plot_df["species"],
            hue_order=species_to_compare,
            palette=colours,
            legend=None,
        )
        for line in ax.lines:
            line.set_linestyle("--")

        ui.zero_axes_origin()
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.xlabel("Mean")
        plt.ylabel("Variance")
        plt.show()

    mv_info_output.clear_output()
    with mv_info_output:
        display(lr_subset)
'''
    bs_plot_output.clear_output()
    with bs_plot_output:
        sns.scatterplot(
            plot_df["mean"],
            plot_df["bs_point"],
            style=plot_df["treatment"],
            style_order=["unst", "lps", "pic"],
            hue=plot_df["species"],
            hue_order=species_to_compare,
            palette=colours,
        )
        ui.zero_axes_origin()
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.show()

    bf_plot_output.clear_output()
    with bf_plot_output:
        sns.scatterplot(
            plot_df["mean"],
            plot_df["bf_point"],
            style=plot_df["treatment"],
            style_order=["unst", "lps", "pic"],
            hue=plot_df["species"],
            hue_order=species_to_compare,
            palette=colours,
        )
        ui.zero_axes_origin()
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.show()

    burst_info_output.clear_output()
    with burst_info_output:
        print(f"{plot_df.bs_point.count()} result(s)")

'''
display(tab_widget)

interactive(children=(Select(description='mouse_gene_id', options=(('Abca1', 'ENSMUSG00000015243'), ('Abracl',…

Tab(children=(HBox(children=(Output(), Output())),), _titles={'0': 'Mean-Variance'})