In [1]:
import ipywidgets as widgets
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
from IPython.display import display, Latex
from scipy import stats
from sklearn import metrics

import rp2.data
import rp2.environment
from rp2 import hagai_2018, create_gene_symbol_map
from rp2.environment import check_environment
from rp2.paths import get_output_path
from rp2.regression import power_function, calculate_curve_fit

check_environment()

### Settings controlling downstream analysis

In [2]:
analysis_species = "mouse"
analysis_counts = "median"
analysis_treatments = ["unst", "lps", "pic"]
analysis_time_points = ["0", "2", "4", "6"]

default_huber_epsilon = 1.345
mv_rlm_factor = default_huber_epsilon
min_mv_r2 = 0.60
min_conditions = 10

bp_curve_loss = "huber"
bp_curve_f_scale = 1.00

condition_columns = ["replicate", "treatment", "time_point"]
index_columns = ["gene"] + condition_columns

In [3]:
gene_symbol_map = create_gene_symbol_map(analysis_species)

### Determine which genes have a sufficient number of conditions with valid burst parameters

In [4]:
condition_info_df = rp2.data.load_and_recalculate_txburst_results(analysis_species, condition_columns, count_type=analysis_counts)
condition_info_df["k_burstiness"] = condition_info_df.k_off / condition_info_df.k_on
condition_info_df["log_burstiness"] = np.log(condition_info_df.k_burstiness)

condition_info_df = condition_info_df.loc[condition_info_df.treatment.isin(analysis_treatments)]
condition_info_df = condition_info_df.loc[condition_info_df.time_point.isin(analysis_time_points)]
print(f"{len(condition_info_df):,} conditions for {condition_info_df.gene.nunique():,} genes have been processed by txburst")

condition_info_df["valid_bp"] = condition_info_df.bs_point.notna() & condition_info_df.bf_point.notna()

print(f"{np.count_nonzero(condition_info_df.valid_bp):,} conditions have valid burst parameters")

valid_counts = condition_info_df.groupby("gene").valid_bp.agg(np.count_nonzero)
valid_gene_ids = valid_counts.index[valid_counts >= min_conditions]
print(f"{len(valid_gene_ids):,} genes have {min_conditions} or more conditions with valid burst parameters")

condition_info_df = condition_info_df.loc[condition_info_df.gene.isin(valid_gene_ids)]

11,604 conditions for 2,303 genes have been processed by txburst
11,604 conditions have valid burst parameters
123 genes have 10 or more conditions with valid burst parameters


### Calculate statistics of RNA counts

In [5]:
def calculate_count_stats(condition_subset):
    counts_adata = hagai_2018.load_counts(analysis_species, scaling=analysis_counts)
    print(f"Counts available for {counts_adata.n_obs:,} cells and {counts_adata.n_vars:,} genes")

    counts_adata = counts_adata[:, counts_adata.var_names.isin(condition_subset.gene)]
    for column in condition_columns:
        counts_adata = counts_adata[counts_adata.obs[column].isin(condition_subset[column])]

    counts_adata = counts_adata.copy()
    print(f"Calculating count statistics for {counts_adata.n_obs:,} cells and {counts_adata.n_vars:,} genes")

    stats_df = hagai_2018.calculate_counts_condition_stats(counts_adata, group_columns=condition_columns)

    return stats_df


condition_info_df = condition_info_df.set_index(index_columns).join(
    calculate_count_stats(condition_info_df).set_index(index_columns),
    how="left",
).reset_index()

Counts available for 55,898 cells and 16,798 genes
Calculating count statistics for 53,086 cells and 123 genes


### Display an interactive mean-variance plot for genes with sufficient conditions with valid burst parameters

Although the list of genes is restricted to those with a minimum number of conditions with valid burst parameters, all conditions are plotted and used to fit the regression line. The solid line shows the fit to all points and dotted lines show fits for individual treatments (including unstimulated points in each case).

Plotted points are scaled according to the weight they are assigned by the robust linear model. The sensitivity of the model to outliers can be adjusted rlm_factor slider. Changes to this value are for illustration only and will not change downstream analysis (for this, change the value of mv_rlm_factor above and re-run all cells).

In [6]:
def apply_huber_regressor(df, x_var, y_var, epsilon, include_weights=False):
    x, y = df.loc[:, [x_var, y_var]].to_numpy().T
    x = sm.add_constant(x)

    lm_res = sm.RLM(y, x, sm.robust.norms.HuberT(t=epsilon)).fit()

    results = {
        "slope": lm_res.params[1],
        "intercept": lm_res.params[0],
    }

    r2_unweighted = metrics.r2_score(y, lm_res.predict(x))
    if include_weights:
        results["r2_unweighted"] = r2_unweighted
        results["r2_weighted"] = metrics.r2_score(y, lm_res.predict(x), sample_weight=lm_res.weights)
        results["weights"] = lm_res.weights
    else:
        results["r2"] = r2_unweighted

    return results


def apply_mv_regressor(df, x_var, y_var, epsilon=mv_rlm_factor):
    return apply_huber_regressor(df, x_var, y_var, epsilon, include_weights=True)


def apply_standard_regressor(df, x_var, y_var):
    return apply_huber_regressor(df, x_var, y_var, epsilon=default_huber_epsilon)


def make_gene_selector(gene_ids):
    gene_symbols = gene_symbol_map.lookup(gene_ids).sort_values()
    return widgets.Select(
        options=list(zip(gene_symbols.values, gene_symbols.index)),
        rows=3,
    )


def format_plus_c(c):
    return f"+{c:.2f}" if c > 0 else f"-{abs(c):.2f}"


def plot_mean_var(gene_id, scale, plot_treatment_lines, rlm_factor):
    treatment_colour_map = {"unst": "black", "lps": "red", "pic": "green"}

    condition_info_subset = condition_info_df.loc[condition_info_df.gene == gene_id]

    lr_results = apply_mv_regressor(condition_info_subset, "mean", "variance", rlm_factor)
    lr_weights = lr_results["weights"]

    treatment_lr_results_map = {treatment: apply_mv_regressor(condition_info_subset.loc[condition_info_subset.treatment.isin(["unst", treatment])], "mean", "variance")
                                for treatment in filter(lambda t: t != "unst", analysis_treatments)}

    plot_output = widgets.Output()
    info_output = widgets.Output()

    with plot_output:
        colours = condition_info_subset.treatment.map(treatment_colour_map)
        legend_handles = []
        for treatment_name, colour in treatment_colour_map.items():
            legend_handles.append(matplotlib.lines.Line2D([], [], marker="o", color=colour, label=treatment_name, linestyle="None", markersize=8))

        is_log = scale == "log"
        log_shift = 1 if is_log else 0
        space_function = np.geomspace if is_log else np.linspace

        x, y = condition_info_subset.loc[:, ["mean", "variance"]].to_numpy().T
        s = np.interp(lr_weights, (0, 1), (10, 50))
        plt.scatter(x + log_shift, y + log_shift, c=colours, s=s)

        lr_x = space_function(log_shift, x.max() + log_shift)
        lr_y = ((lr_x - log_shift) * lr_results["slope"]) + lr_results["intercept"]
        plt.plot(lr_x, lr_y + log_shift, "-")

        if plot_treatment_lines:
            for treatment, lr_res in treatment_lr_results_map.items():
                lr_y2 = ((lr_x - log_shift) * lr_res["slope"]) + lr_res["intercept"]
                plt.plot(lr_x, lr_y2 + log_shift, ":", c=treatment_colour_map[treatment])

        plt.xscale(scale)
        plt.xlim(left=log_shift)
        plt.xlabel("Mean count ($\mu$)")
        plt.yscale(scale)
        plt.ylim(bottom=log_shift)
        plt.ylabel("Variance ($\sigma^2$)")
        plt.legend(
            handles=legend_handles,
            loc="upper left",
            bbox_to_anchor=(1, 1)
        )
        plt.show()

    with info_output:
        print(f"No. of conditions with burst parameters: {np.count_nonzero(condition_info_subset.valid_bp)} / {len(condition_info_subset)}")
        print(f"No. of weights < 1: {np.count_nonzero(lr_weights < 1)}")
        for treatment in analysis_treatments:
            print(f"  {np.count_nonzero((lr_weights < 1) & (condition_info_subset.treatment == treatment))} {treatment}")
        display(Latex(f"$\sigma^2={lr_results['slope']:.2f}\mu{format_plus_c(lr_results['intercept'])}$"))
        display(Latex(f"Weighted $R^2$: {lr_results['r2_weighted']:.3f}"))
        display(Latex(f"Unweighted $R^2$: {lr_results['r2_unweighted']:.3f}"))

        for name, treatment_lr_results in treatment_lr_results_map.items():
            display(Latex(f"Unweighted $R_{{{name}}}^2$: {treatment_lr_results['r2_unweighted']:.3f}"))


    display(widgets.HBox((plot_output, info_output)))


widgets.interactive(
    plot_mean_var,
    gene_id=make_gene_selector(valid_gene_ids),
    scale=widgets.RadioButtons(options=[["Linear", "linear"], ["Log-log (+1)", "log"]]),
    plot_treatment_lines=False,
    rlm_factor=widgets.FloatSlider(mv_rlm_factor, min=1.001, max=5, step=0.001, readout_format=".3f"),
)

interactive(children=(Select(description='gene_id', options=(('Abca1', 'ENSMUSG00000015243'), ('Abracl', 'ENSM…