In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../')

In [17]:
# set up function for mean prob escape chart to avoid clutter from large block of code

def plot_avg_escape(prob_escape):
    max_aa_subs = 4  # group if >= this many substitutions
    
    mean_prob_escape = (
        prob_escape.assign(
            n_subs=lambda x: (
                x["aa_substitutions_reference"]
                .str.split()
                .map(len)
                .clip(upper=max_aa_subs)
                .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
            )
        )
        .groupby(["antibody_concentration", "n_subs"], as_index=False)
        .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
        .rename(
            columns={
                "prob_escape": "censored to [0, 1]",
                "prob_escape_uncensored": "not censored",
            }
        )
        .melt(
            id_vars=["antibody_concentration", "n_subs"],
            var_name="censored",
            value_name="probability escape",
        )
    )

    mean_prob_escape_chart = (
        alt.Chart(mean_prob_escape)
        .encode(
            x=alt.X("antibody_concentration"),
            y=alt.Y(
                "probability escape",
                scale=alt.Scale(type="symlog", constant=0.05),
            ),
            column=alt.Column("censored", title=None),
            color=alt.Color("n_subs", title="n substitutions"),
            tooltip=[
                alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
                for c in mean_prob_escape.columns
            ],
        )
        .mark_line(point=True, size=0.5)
        .properties(width=200, height=125)
        .configure_axis(grid=False)
    )

    return mean_prob_escape_chart

In [23]:
def generate_model(
    prob_escape_df,
    n_epitopes=1
):
    
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_df.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    )

    # fit model, suppressing output text to avoid clutter in notebook
    with io.capture_output() as captured:
        opt_res = model.fit(
            logfreq=200,
            reg_escape_weight=0.1,
        )

    mut_escape_plot = model.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, 
                                            init_floor_at_zero=False, 
                                            show_heatmap=False)

    return mut_escape_plot

In [7]:
prob_escape_1 = pd.read_csv(
    f'results/prob_escape/libA_230403_1_2380_1_prob_escape.csv', 
    keep_default_na=False,
    na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)

In [8]:
prob_escape_2 = pd.read_csv(
    f'results/prob_escape/libA_230419_1_2380_2_prob_escape.csv', 
    keep_default_na=False,
    na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)

In [9]:
display(
    prob_escape_1.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
0.0096,28153
0.0191,28153
0.0383,28153
0.0765,28153


In [10]:
display(
    prob_escape_2.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
0.0191,28140


In [15]:
prob_escape_1_filt = prob_escape_1.loc[prob_escape_1['antibody_concentration'] != 0.0191]

prob_escape_combined = pd.concat([prob_escape_1_filt, prob_escape_2])

display(
    prob_escape_combined.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
0.0096,28153
0.0191,28140
0.0383,28153
0.0765,28153


In [19]:
plot_avg_escape(prob_escape_1)

In [20]:
plot_avg_escape(prob_escape_combined)

In [28]:
generate_model(prob_escape_1)

In [29]:
generate_model(prob_escape_combined)

In [26]:
prob_escape_1_filt = prob_escape_1.loc[prob_escape_1['antibody_concentration'] > 0.01]
prob_escape_comb_filt = prob_escape_combined.loc[prob_escape_combined['antibody_concentration'] > 0.01]

In [27]:
generate_model(prob_escape_1_filt)

In [30]:
generate_model(prob_escape_comb_filt)

In [31]:
plot_avg_escape(prob_escape_1_filt)

In [33]:
plot_avg_escape(prob_escape_comb_filt)

## 3856

In [34]:
prob_escape_1 = pd.read_csv(
    f'results/prob_escape/libA_230403_1_3856_1_prob_escape.csv', 
    keep_default_na=False,
    na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)

In [35]:
prob_escape_2 = pd.read_csv(
    f'results/prob_escape/libA_230419_1_3856_2_prob_escape.csv', 
    keep_default_na=False,
    na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)

In [36]:
display(
    prob_escape_1.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
0.0097,28153
0.0194,28153
0.0388,28153
0.0776,28153


In [37]:
display(
    prob_escape_2.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
0.0097,28140
0.0194,28140


In [48]:
run1_filt = prob_escape_1.loc[prob_escape_1['antibody_concentration'] > 0.02]

In [47]:
plot_avg_escape(run1_filt)

In [40]:
plot_avg_escape(prob_escape_2)

In [43]:
run2_c2 = prob_escape_2.loc[prob_escape_2['antibody_concentration'] == 0.0194]

In [44]:
run2_c1 = prob_escape_2.loc[prob_escape_2['antibody_concentration'] == 0.0097]

In [51]:
plot_avg_escape(prob_escape_1)

In [49]:
run1_comb_1 = pd.concat([run1_filt, run2_c2])

plot_avg_escape(run1_comb_1)

In [52]:
run1_comb_2 = pd.concat([run1_comb_1, run2_c1])

plot_avg_escape(run1_comb_2)

In [55]:
run1_filt1 = prob_escape_1.loc[prob_escape_1['antibody_concentration'] > 0.01]

In [53]:
generate_model(prob_escape_1)

In [54]:
generate_model(run1_comb_2)

In [56]:
generate_model(run1_filt1)

In [57]:
generate_model(run1_comb_1)

In [58]:
plot_avg_escape(run1_filt1)

In [59]:
plot_avg_escape(run1_comb_1)

In [60]:
single_comb = pd.concat([run1_filt1, run2_c1])
plot_avg_escape(single_comb)

In [61]:
display(
    single_comb.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
0.0097,28140
0.0194,28153
0.0388,28153
0.0776,28153


In [62]:
generate_model(prob_escape_1)

In [63]:
generate_model(single_comb)