In [5]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io
import glob

In [2]:
import os
os.chdir('../../')

In [3]:
# set up function for mean prob escape chart to avoid clutter from large block of code

def plot_avg_escape(prob_escape):
    max_aa_subs = 4  # group if >= this many substitutions
    
    mean_prob_escape = (
        prob_escape.assign(
            n_subs=lambda x: (
                x["aa_substitutions_reference"]
                .str.split()
                .map(len)
                .clip(upper=max_aa_subs)
                .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
            )
        )
        .groupby(["antibody_concentration", "n_subs"], as_index=False)
        .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
        .rename(
            columns={
                "prob_escape": "censored to [0, 1]",
                "prob_escape_uncensored": "not censored",
            }
        )
        .melt(
            id_vars=["antibody_concentration", "n_subs"],
            var_name="censored",
            value_name="probability escape",
        )
    )

    mean_prob_escape_chart = (
        alt.Chart(mean_prob_escape)
        .encode(
            x=alt.X("antibody_concentration"),
            y=alt.Y(
                "probability escape",
                scale=alt.Scale(type="symlog", constant=0.05),
            ),
            column=alt.Column("censored", title=None),
            color=alt.Color("n_subs", title="n substitutions"),
            tooltip=[
                alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
                for c in mean_prob_escape.columns
            ],
        )
        .mark_line(point=True, size=0.5)
        .properties(width=200, height=125)
        .configure_axis(grid=False)
    )

    return mean_prob_escape_chart

In [10]:
def filt_model(
    prob_escape_df,
    min_conc,
    n_epitopes=1
):

    filt_df = prob_escape_df.loc[prob_escape_df['antibody_concentration'] > min_conc]
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=filt_df.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    )

    # fit model, suppressing output text to avoid clutter in notebook
    with io.capture_output() as captured:
        opt_res = model.fit(
            logfreq=200,
            reg_escape_weight=0.1,
        )

    mut_escape_plot = model.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False)
#     mut_escape_plot = model.mut_escape_plot()
    return mut_escape_plot

In [7]:
libB_sera = [4584, 2463, 3973, 2388, 4299, 2462, 3944]
prob_escape_list = []

for serum in libB_sera:

    file_pattern = f'results/prob_escape/libB_*_{serum}_*_prob_escape.csv'
    file_list = glob.glob(file_pattern)
    for file_path in file_list:
        df=pd.read_csv(file_path, keep_default_na=False, na_values="nan"
                      ).query("`no-antibody_count` >= no_antibody_count_threshold")
        prob_escape_list.append(df)

In [8]:
plot_avg_escape(prob_escape_list[0])

In [11]:
filt_model(prob_escape_list[0], 0.02)

In [12]:
filt_model(prob_escape_list[1], 0.02)

In [13]:
filt_model(prob_escape_list[2], 0.02)

In [15]:
filt_model(prob_escape_list[3], 0.015)

In [16]:
filt_model(prob_escape_list[5], 0.01)

In [17]:
filt_model(prob_escape_list[6], 0.002)