In [1]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display

from rp2 import hagai_2018, txburst, notebooks
from rp2.analysis import create_default_mouse_analysis
from rp2.data import load_and_recalculate_txburst_results

nb_env, _ = notebooks.initialise_environment(
    "Explore_Gene_Conditions",
    dependencies=["Burst_Model_Fitting"],
)

In [2]:
class GenesFilter:
    def __init__(self, min_conditions, count_type, treatments):
        self._min_conditions = min_conditions
        self._count_type = count_type
        self._treatments = treatments

    def apply(self, gene_list, condition_info_map):
        condition_info_subset = condition_info_map[self._count_type]
        condition_info_subset = condition_info_subset[condition_info_subset.gene.isin(gene_list.index)]
        condition_info_subset = condition_info_subset.loc[condition_info_subset.bs_point.notna() & condition_info_subset.bf_point.notna()]
        condition_info_subset = condition_info_subset.loc[condition_info_subset.treatment.isin(self._treatments)]
        condition_counts = condition_info_subset.groupby("gene").bs_point.count()
        gene_ids = condition_counts.index[condition_counts >= self._min_conditions]
        return gene_list[gene_list.index.isin(gene_ids)]


def get_available_count_types(species):
    if species == "mouse":
        return ["umi", "median"]
    return ["median"]


mouse_analysis = create_default_mouse_analysis()
analyses = {species: mouse_analysis.create_orthologue_analysis(species) for species in ["mouse", "pig", "rabbit", "rat"]}

In [3]:
species = "mouse"

gene_filter = GenesFilter(
    min_conditions=10,
    count_type="median",
    treatments=["unst", "lps", "pic"],
)

In [4]:
analysis = analyses[species]
available_count_types = get_available_count_types(analysis.species)

analysis_genes = hagai_2018.load_biomart_gene_symbols_df(species)["symbol"]
analysis_replicates = set()
analysis_treatments = set()
analysis_time_points = set()

condition_info_map = {}
for count_type in available_count_types:
    txburst_params_df = load_and_recalculate_txburst_results(species, analysis.condition_columns, count_type=count_type)
    txburst_params_df["k_burstiness"] = txburst_params_df.k_off / txburst_params_df.k_on

    gene_ids = txburst_params_df.gene.unique()
    print(f'{len(gene_ids):,} genes have been processed for "{count_type}" by txburst')
    analysis_genes = analysis_genes.loc[analysis_genes.index.isin(gene_ids)]
    analysis_replicates.update(txburst_params_df.replicate.values)
    analysis_treatments.update(txburst_params_df.treatment.values)
    analysis_time_points.update(txburst_params_df.time_point.values)

    condition_info_map[count_type] = txburst_params_df
    del txburst_params_df

analysis_replicates = sorted(analysis_replicates)
analysis_treatments = sorted(analysis_treatments)
analysis_time_points = sorted(analysis_time_points)

analysis_genes = analysis_genes.sort_values()
print(f"{len(analysis_genes):,} genes are common to all results")

filtered_genes_list = gene_filter.apply(analysis_genes, condition_info_map)
print(f"{len(filtered_genes_list):,} genes remain after filtering")

696 genes have been processed for "umi" by txburst
2,303 genes have been processed for "median" by txburst
693 genes are common to all results
133 genes remain after filtering


In [5]:
counts_adata_map = dict(zip(available_count_types, hagai_2018.load_counts(species, scaling=available_count_types)))
for count_type in available_count_types:
    counts_adata = counts_adata_map[count_type]
    print(f'Total of {counts_adata.n_vars:,} genes for "{count_type}"')
    counts_adata = counts_adata[:, analysis_genes.index].copy()

    condition_info_map[count_type] = condition_info_map[count_type].set_index(analysis.index_columns).join(
        hagai_2018.calculate_counts_condition_stats(counts_adata).set_index(analysis.index_columns),
        how="left",
    ).reset_index()

    counts_adata_map[count_type] = counts_adata
    del counts_adata

Total of 16,798 genes for "umi"
Total of 16,798 genes for "median"


In [6]:
def create_df_row_mask(df, column_values):
    mask = pd.Series(index=df.index, data=True)
    for column_name, column_value in column_values.items():
        if isinstance(column_value, (list, tuple)):
            mask &= df[column_name].isin(column_value)
        else:
            mask &= df[column_name] == column_value
    return mask


def filter_df_rows(df, column_values):
    mask = create_df_row_mask(df, column_values)
    return df.loc[mask]


def plot_expression_distribution(expression_adata, burst_param_subset):
    counts = np.squeeze(expression_adata.X.A)
    max_count = np.max(counts)
    mean_count = np.mean(counts)

    k_on, k_off, k_syn = burst_param_subset.squeeze()[["k_on", "k_off", "k_syn"]]

    upper_bin = int(np.ceil(max_count))
    bin_values, bin_edges, _ = plt.hist(counts, bins=min(50, upper_bin), color="orange")
    hist_area = np.sum(np.diff(bin_edges) * bin_values)

    pmf_in = np.arange(max_count)
    pmf_out = txburst.poisson_beta_pmf(pmf_in, k_on, k_off, k_syn)
    plt.plot(pmf_in + np.mean(bin_edges[:2]), hist_area * pmf_out, "g--", linewidth=2)

    plt.axvline(x=mean_count, linestyle=":")
    plt.xlabel("Count")
    plt.ylabel("Cells")
    plt.show()


def plot_expression_distributions(expression_adata, counts_obs_subset, burst_param_subset):
    if counts_obs_subset.empty:
        print("No data available")
        return

    condition_groups = counts_obs_subset.groupby(analysis.condition_columns)
    if len(condition_groups) == 1:
        index_values, group_df = list(condition_groups)[0]
        plot_expression_distribution(expression_adata[group_df.index], burst_param_subset)
        return

    _, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    counts_list = []
    labels_list = []
    for (replicate, treatment, time_point), obs_group in condition_groups:
        counts_list.append(np.squeeze(expression_adata[obs_group.index].X.A))
        labels_list.append(f"{treatment} R{replicate} T{time_point}")
        sns.kdeplot(
            counts_list[-1],
            label=labels_list[-1],
            bw=1,
            ax=ax2,
        )

    ax1.hist(
        counts_list,
        label=labels_list,
    )

    for ax in (ax1, ax2):
        ax.set_xlabel("Count")
        ax.set_xlim(left=0)
        ax.legend(loc="upper left", bbox_to_anchor=(1, 1))

    ax1.set_ylabel("Cells")
    ax2.set_ylabel("Density (kde)")
    ax2.set_xlim(right=ax1.get_xlim()[1])

    plt.tight_layout()
    plt.show()


gene_selector = widgets.Select(
    description="Gene:",
    options=list(zip(filtered_genes_list.values, filtered_genes_list.index)),
    rows=4,
)
replicates_selector = widgets.SelectMultiple(
    description="Replicate:",
    options=analysis_replicates,
    value=analysis_replicates,
    rows=3,
)
treatments_selector = widgets.SelectMultiple(
    description="Treatments:",
    options=analysis_treatments,
    value=["unst", "lps", "pic"],
    rows=3,
)
time_points_selector = widgets.SelectMultiple(
    description="Time points:",
    options=analysis_time_points,
    value=analysis_time_points,
    rows=len(analysis_time_points),
)
expression_type_selector = widgets.Select(
    description="Count type:",
    options=available_count_types,
    value="median",
    rows=len(available_count_types),
)

mv_plot_output = widgets.Output()
dist_plot_output = widgets.Output()
bp_output = widgets.Output()
bp_plot_output = widgets.Output()


def update_ui():
    selected_gene_id = gene_selector.value
    selected_conditions = {
        "replicate": replicates_selector.value,
        "treatment": treatments_selector.value,
        "time_point": time_points_selector.value,
    }
    selected_expression_type = expression_type_selector.value

    expression_data = dict(
        counts=counts_adata_map[selected_expression_type],
        stats=condition_info_map[selected_expression_type],
    )

    all_condition_stats_df = expression_data["stats"].copy()
    all_condition_stats_df["m_burstiness"] = all_condition_stats_df["variance"] / all_condition_stats_df["mean"]

    mv_plot_output.clear_output()
    dist_plot_output.clear_output()
    bp_output.clear_output()
    bp_plot_output.clear_output()

    condition_stats_subset = filter_df_rows(all_condition_stats_df, {**selected_conditions, "gene": selected_gene_id})
    if not condition_stats_subset.empty:
        with mv_plot_output:
            ax = sns.scatterplot(
                x="mean",
                y="variance",
                hue="replicate",
                style="treatment",
                size="time_point",
                data=condition_stats_subset,
            )
            plt.xlim(left=0)
            plt.ylim(bottom=0)
            sns.regplot(
                x="mean",
                y="variance",
                truncate=False,
                scatter=False,
                data=condition_stats_subset,
                ax=ax,
            )
            plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
            plt.show()

        bp_plot_params = ["bs_point", "bf_point", "k_off", "k_syn", "k_burstiness"]
        with bp_plot_output:
            fig, axes = plt.subplots(1, len(bp_plot_params), figsize=(len(bp_plot_params) * 4, 4))
            for plot_param, ax in zip(bp_plot_params, axes.flat):
                sns.scatterplot(
                    x="mean",
                    y=plot_param,
                    hue="time_point",
                    style="replicate",
                    data=condition_stats_subset,
                    ax=ax,
                )
                ax.set_xlim(left=0)
                ax.set_ylim(bottom=0)
                ax_legend = ax.get_legend()
                if ax_legend is not None:
                    ax_legend.remove()
            plt.legend(*ax.get_legend_handles_labels(), loc="upper left", bbox_to_anchor=(1, 1))
            plt.tight_layout()
            plt.show()

    gene_expression_adata = expression_data["counts"][:, selected_gene_id]
    counts_obs_subset = filter_df_rows(gene_expression_adata.obs, selected_conditions)
    txburst_params_subset = condition_stats_subset.drop(columns="gene").set_index(analysis.condition_columns).loc[:, "k_on":]

    with dist_plot_output:
        plot_expression_distributions(gene_expression_adata, counts_obs_subset, txburst_params_subset)

    with bp_output:
        display(txburst_params_subset)


def event_handler(event):
    if event["type"] != "change": return
    if event["name"] != "value": return
    update_ui()


gene_selector.observe(event_handler)
replicates_selector.observe(event_handler)
treatments_selector.observe(event_handler)
time_points_selector.observe(event_handler)
expression_type_selector.observe(event_handler)

tab_widget = widgets.Tab()
tab_widget.children = [mv_plot_output, dist_plot_output, bp_output, bp_plot_output]
tab_widget.set_title(0, "Mean-variance plot")
tab_widget.set_title(1, "Count distribution")
tab_widget.set_title(2, "Burst parameters")
tab_widget.set_title(3, "Burst parameter plots")

ui_container = widgets.VBox([
    widgets.HBox([
        widgets.VBox([
            gene_selector,
            replicates_selector,
        ]),
        widgets.VBox([
            treatments_selector,
            time_points_selector,
        ]),
        expression_type_selector,
    ]),
    tab_widget,
])

display(ui_container)

update_ui()

VBox(children=(HBox(children=(VBox(children=(Select(description='Gene:', options=(('Abca1', 'ENSMUSG0000001524…