In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../')

In [4]:
ped_sera_230221 = [2367, 3944, 2462, 2389, 2323]
ped_sera_230323 = [2388, 2463, 3973, 4299, 4584]

teen_sera_230317 = [2343, 2350, 2365, 2382, 3866]
teen_sera_230403 = [2380, 3856, 3857, 3862, 3895]

In [55]:
def get_prob_escape(sera_list, lib, date, replicate):
    prob_escape_list = []
    for serum in sera_list:        
        prob_escape = pd.read_csv(
            f'results/prob_escape/{lib}_{date}_1_{serum}_{replicate}_prob_escape.csv', 
            keep_default_na=False,
            na_values="nan"
        ).query(
            "`no-antibody_count` >= no_antibody_count_threshold"
        )
        
        prob_escape_list.append(prob_escape)
        
    return prob_escape_list

In [51]:
ped_sera_run1 = get_prob_escape(ped_sera_230221, 'libA', '230221', '1')
ped_sera_run2 = get_prob_escape(ped_sera_230323, 'libA', '230323', '1')
ped_sera = ped_sera_potent + ped_sera_run2

In [41]:
teen_sera_run1 = get_prob_escape(teen_sera_230317, 'libA', '230317', '1')
teen_sera_run2 = get_prob_escape(teen_sera_230403, 'libA', '230403', '1')
teen_sera = teen_sera_run1 + teen_sera_run2

In [73]:
def plot_wt_prob_escape(serum_list):
    
    mean_prob_escape_list = []

    for prob_escape in serum_list:
        serum = prob_escape['antibody'][0]
        max_aa_subs = 4

        mean_prob_escape = (
            prob_escape.assign(
                n_subs=lambda x: (
                    x["aa_substitutions_reference"]
                    .str.split()
                    .map(len)
                    .clip(upper=max_aa_subs)
                    .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
                )
            )
            .groupby(["antibody_concentration", "n_subs"], as_index=False)
            .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
            .rename(
                columns={
                    "prob_escape": "censored to [0, 1]",
                    "prob_escape_uncensored": "not censored",
                }
            )
            .melt(
                id_vars=["antibody_concentration", "n_subs"],
                var_name="censored",
                value_name="probability escape",
            )
            .assign(serum=serum)
        )

        mean_prob_escape_wt = mean_prob_escape.loc[(mean_prob_escape['n_subs'] == '0') &
                                                   (mean_prob_escape['censored'] != 
                                                    'not censored')
                                                  ]
        mean_prob_escape_wt['concentration'] = range(1, 1+len(mean_prob_escape_wt))
        mean_prob_escape_wt['concentration'] = mean_prob_escape_wt['concentration'].astype(str)
        mean_prob_escape_wt['serum'] = mean_prob_escape_wt['serum'].astype(str)

        mean_prob_escape_list.append(mean_prob_escape_wt)

    mean_prob_escape_full = pd.concat(mean_prob_escape_list, axis=0, ignore_index=True)
    
    mean_prob_escape_chart = (
        alt.Chart(mean_prob_escape_full)
        .encode(
            x=alt.X("concentration"),
            y=alt.Y(
                "probability escape",
                scale=alt.Scale(type="symlog", constant=0.05),
            ),
            color=alt.Color("serum", title="serum"),
            tooltip=[
                alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
                for c in mean_prob_escape.columns
            ],
        )
        .mark_line(point=True, size=0.5)
        .properties(width=200, height=125, title='avg prob escape of WT variants')
        .configure_axis(grid=False)
        .configure_title(
            dy=-5,
            fontWeight=500
        )
    )

    return mean_prob_escape_chart

In [74]:
plot_wt_prob_escape(ped_sera)

In [75]:
plot_wt_prob_escape(teen_sera)

In [76]:
plot_wt_prob_escape(teen_sera_run1)

In [77]:
plot_wt_prob_escape(teen_sera_run2)

In [78]:
plot_wt_prob_escape(ped_sera_run1)

In [79]:
plot_wt_prob_escape(ped_sera_run2)

In [56]:
ped_sera_libB_h6 = [2367, 3944, 2462, 2389, 2323]
ped_sera_libB_spikein = [2388, 2463, 3973, 4299, 4584]

ped_sera_libB_potent = get_prob_escape(ped_sera_libB_h6, 'libB', '230407', '1')
ped_sera_libB_weak = get_prob_escape(ped_sera_libB_spikein, 'libB', '230407', '2')

In [57]:
ped_sera_libB = ped_sera_libB_potent + ped_sera_libB_weak

In [80]:
plot_wt_prob_escape(ped_sera_libB)

In [81]:
# set up function for mean prob escape chart to avoid clutter from large block of code

def plot_avg_escape(prob_escape):
    max_aa_subs = 4  # group if >= this many substitutions
    
    mean_prob_escape = (
        prob_escape.assign(
            n_subs=lambda x: (
                x["aa_substitutions_reference"]
                .str.split()
                .map(len)
                .clip(upper=max_aa_subs)
                .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
            )
        )
        .groupby(["antibody_concentration", "n_subs"], as_index=False)
        .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
        .rename(
            columns={
                "prob_escape": "censored to [0, 1]",
                "prob_escape_uncensored": "not censored",
            }
        )
        .melt(
            id_vars=["antibody_concentration", "n_subs"],
            var_name="censored",
            value_name="probability escape",
        )
    )

    mean_prob_escape_chart = (
        alt.Chart(mean_prob_escape)
        .encode(
            x=alt.X("antibody_concentration"),
            y=alt.Y(
                "probability escape",
                scale=alt.Scale(type="symlog", constant=0.05),
            ),
            column=alt.Column("censored", title=None),
            color=alt.Color("n_subs", title="n substitutions"),
            tooltip=[
                alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
                for c in mean_prob_escape.columns
            ],
        )
        .mark_line(point=True, size=0.5)
        .properties(width=200, height=125)
        .configure_axis(grid=False)
    )

    return mean_prob_escape_chart

In [100]:
def generate_model(
    prob_escape_df,
    n_epitopes=1
):
    
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_df.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    )

    # fit model, suppressing output text to avoid clutter in notebook
    with io.capture_output() as captured:
        opt_res = model.fit(
            logfreq=200,
            reg_escape_weight=0.1,
        )

    mut_escape_plot = model.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False)

    return model

In [83]:
prob_escape_3944 = pd.read_csv(
    "results/prob_escape/libA_230221_1_3944_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts

prob_escape_3944.head()

Unnamed: 0,library,antibody_sample,no-antibody_sample,aa_substitutions_sequential,n_aa_substitutions,barcode,prob_escape,prob_escape_uncensored,antibody_count,no-antibody_count,antibody_neut_standard_count,no-antibody_neut_standard_count,no_antibody_count_threshold,antibody_count_threshold,aa_substitutions_reference,retain,antibody,antibody_concentration
0,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,K297I,1,ATAACACAAAAAAGTA,0.0025,0.0025,64222,369330,5953821,86955,23,,K278I,True,3944,0.0043
1,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,D123H K208E,2,AAGCCACAAGGTACTA,0.1891,0.1891,62174,4803,5953821,86955,23,,D104H K189E,True,3944,0.0043
2,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,T49Y K140N K154T F156S S163K Y178S T179K K208H...,11,ACAATCACGGTACTCC,0.1503,0.1503,44440,4317,5953821,86955,23,,T30Y K121N K135T F137S S144K Y159S T160K K189H...,True,3944,0.0043
3,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,P122V S212D Q216K G405D,4,AGACCGGGACTCCTCA,0.1968,0.1968,42644,3164,5953821,86955,23,,P103V S193D Q197K G386D,True,3944,0.0043
4,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,S143D K208E I261T R280P,4,CTTTTGTTAATTGATA,0.1436,0.1436,42460,4318,5953821,86955,23,,S124D K189E I242T R261P,True,3944,0.0043


In [84]:
plot_avg_escape(prob_escape_3944)

In [110]:
filt_3944 = prob_escape_3944.loc[prob_escape_3944['antibody_concentration'] > 0.002]

In [111]:
plot_avg_escape(filt_3944)

In [96]:
generate_model(filt_3944)

In [103]:
libB_3944 = pd.read_csv(
    "results/prob_escape/libB_230407_1_3944_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts

libB_3944.head()

Unnamed: 0,library,antibody_sample,no-antibody_sample,aa_substitutions_sequential,n_aa_substitutions,barcode,prob_escape,prob_escape_uncensored,antibody_count,no-antibody_count,antibody_neut_standard_count,no-antibody_neut_standard_count,no_antibody_count_threshold,antibody_count_threshold,aa_substitutions_reference,retain,antibody,antibody_concentration
0,libB,230407_1_antibody_3944_0.006408_1,230407_1_no-antibody_control_1,K208V L263Y,2,CCAATCGCTGACAATA,0.0732,0.0732,142449,48782,20882384,523496,38,,K189V L244Y,True,3944,0.0064
1,libB,230407_1_antibody_3944_0.006408_1,230407_1_no-antibody_control_1,K208E L263S R280Y Q473M W556L,5,CCCATCCTTTAGTCAG,0.0934,0.0934,25764,6912,20882384,523496,38,,K189E L244S R261Y Q454M W537L,True,3944,0.0064
2,libB,230407_1_antibody_3944_0.006408_1,230407_1_no-antibody_control_1,K208E R227G I233L,3,GCTAAATAGGCTAATC,0.0927,0.0927,22072,5972,20882384,523496,38,,K189E R208G I214L,True,3944,0.0064
3,libB,230407_1_antibody_3944_0.006408_1,230407_1_no-antibody_control_1,K208V L263Q,2,GGCACTCGCTATAAAC,0.0626,0.0626,19292,7722,20882384,523496,38,,K189V L244Q,True,3944,0.0064
4,libB,230407_1_antibody_3944_0.006408_1,230407_1_no-antibody_control_1,N57Y Y113T K208I I233E S388H,5,AATCATACGATTGGAC,0.0752,0.0752,17626,5879,20882384,523496,38,,N38Y Y94T K189I I214E S369H,True,3944,0.0064


In [93]:
plot_avg_escape(libB_3944)

In [112]:
libA_3944_model = generate_model(filt_3944)

In [113]:
libB_3944_model = generate_model(libB_3944)

In [114]:
libs = ['libA', 'libB']
replicates = ['1', '1']
models = [libA_3944_model, libB_3944_model]
models_df = pd.DataFrame({
    'library': libs,
    'replicate': replicates,
    'model': models
})

models_df

Unnamed: 0,library,replicate,model
0,libA,1,<polyclonal.polyclonal.Polyclonal object at 0x...
1,libB,1,<polyclonal.polyclonal.Polyclonal object at 0x...


In [115]:
avg_model = polyclonal.PolyclonalAverage(models_df)
avg_model.mut_escape_corr_heatmap()