In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../')

In [4]:
def get_prob_escape(sera_list, lib, date, replicate):
    prob_escape_list = []
    for serum in sera_list:        
        prob_escape = pd.read_csv(
            f'results/prob_escape/{lib}_{date}_1_{serum}_{replicate}_prob_escape.csv', 
            keep_default_na=False,
            na_values="nan"
        ).query(
            "`no-antibody_count` >= no_antibody_count_threshold"
        )
        
        prob_escape_list.append(prob_escape)
        
    return prob_escape_list

In [6]:
adult_sera_1 = ['33C', '34C', '197C', '199C', '215C']
adult_sera_2 = ['210C', '74C', '68C', '150C', '18C']
adult_sera_names = adult_sera_1 + adult_sera_2

ped_sera_1 = [2367, 3944, 2462, 2389, 2323]
ped_sera_2 = [2388, 2463, 3973, 4299, 4584]
ped_sera_names = ped_sera_1 + ped_sera_2

teen_sera_1 = [2343, 2350, 2365, 2382, 3866]
teen_sera_2 = [2380, 3856, 3857, 3862, 3895]
teen_sera_names = teen_sera_1 + teen_sera_2

In [7]:
adult_1 = get_prob_escape(adult_sera_1, 'libA', '230419', '1')
adult_2 = get_prob_escape(adult_sera_2, 'libA', '230425', '1')
adult_prob_escape = adult_1 + adult_2

ped_1 = get_prob_escape(ped_sera_1, 'libA', '230221', '1')
ped_2 = get_prob_escape(ped_sera_2, 'libA', '230323', '1')
ped_prob_escape = ped_1 + ped_2

teen_1 = get_prob_escape(teen_sera_1, 'libA', '230317', '1')
teen_2 = get_prob_escape(teen_sera_2, 'libA', '230403', '1')
teen_prob_escape = teen_1 + teen_2

In [11]:
def get_filtered_models(prob_escape_list):
    
    pat = r'[A-Z](160)[A-Z]'
    filt_model_list = [] 
    
    for prob_escape in prob_escape_list:
        
        prob_escape_filt = prob_escape[prob_escape['aa_substitutions_reference'].str.contains(pat)]
        
        model = polyclonal.Polyclonal(
            n_epitopes=1,
            data_to_fit=prob_escape_filt.rename(
                columns={
                    "antibody_concentration": "concentration",
                    "aa_substitutions_reference": "aa_substitutions",
                }
            ),
            alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
        )

        # fit model, suppressing output text to avoid clutter in notebook
        with io.capture_output() as captured:
            opt_res = model.fit(
                logfreq=200,
                reg_escape_weight=0.1,
            )
        
        filt_model_list.append(model)

    return filt_model_list

In [32]:
def get_agg_escapes(serum_list, model_list, site_list, age_group, agg='mean'):
    mean_escape_list = []    
    i=0
    
    for model in model_list:
        prob_escape_df = model.mut_escape_df
        
        prob_escape_mean = (prob_escape_df
                            .groupby('site', as_index=False)
                            .aggregate({'escape': agg})
                           )
        
        prob_escape_filt = prob_escape_mean[prob_escape_mean['site'].isin(site_list)]
        
        prob_escape_filt['serum'] = serum_list[i]
        prob_escape_filt['age_group'] = age_group
        
        mean_escape_list.append(prob_escape_filt)
        
        i+=1
        
    mean_escape = pd.concat(mean_escape_list)
    mean_escape['site'] = mean_escape['site'].astype(str)
    
    return mean_escape

In [17]:
b_sites = [155, 156, 157, 158, 159, 186, 187, 188, 189, 190, 191, 192, 193, 
           194, 195, 196, 197, 198]

In [14]:
a_sites = [122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 
           136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146]

In [72]:
ped_1 = get_prob_escape(ped_sera_1, 'libB', '230407', '1')
ped_2 = get_prob_escape(ped_sera_2, 'libB', '230407', '2')
ped_prob_escape_b = ped_1 + ped_2

teen_1 = get_prob_escape(teen_sera_1, 'libB', '230412', '1')
teen_2 = get_prob_escape(teen_sera_2, 'libB', '230412', '2')
teen_prob_escape_b = teen_1 + teen_2

In [28]:
data = [ped_prob_escape, teen_prob_escape, adult_prob_escape]

filtered_models = {}

filtered_models['0-5'] = get_filtered_models(ped_prob_escape)
filtered_models['15-18'] = get_filtered_models(teen_prob_escape)
filtered_models['40-45'] = get_filtered_models(adult_prob_escape)

{'0-5': [<polyclonal.polyclonal.Polyclonal at 0x7f598e7cf6d0>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58be8e2ad0>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58c11674d0>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58bea6bad0>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58bed27110>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58be5bb7d0>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58be72e910>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58be7583d0>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58beaee910>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58bea2b0d0>],
 '15-18': [<polyclonal.polyclonal.Polyclonal at 0x7f58bee96910>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58bed92150>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58cf978e90>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58be634b10>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58c23510d0>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58bf008d10>,
  <polyclonal.polyclonal.Polyclonal at 0x7f58c33d5410>,
  <polyclonal.polyclonal.Polycl

In [73]:
data = [ped_prob_escape_b, teen_prob_escape_b]

filtered_models_b = {}

filtered_models_b['0-5'] = get_filtered_models(ped_prob_escape_b)
filtered_models_b['15-18'] = get_filtered_models(teen_prob_escape_b)

In [44]:
def plot_agg_escape(model_dict, sera_names, sites, agg_type):
    i=0
    agg_escape_list = []
    
    for group in model_dict:
        agg_escapes = get_agg_escapes(
            sera_names[i],
            model_dict[group], 
            sites,
            group,
            agg=agg_type
        )
        
        agg_escape_list.append(agg_escapes)
        i+= 1
    
    agg_escape_full = pd.concat(agg_escape_list)
    
    # generate altair plot
    agg_escape_lineplot = (
        alt.Chart()
        .encode(
            x=alt.X("site", 
                    title="site",
                   ),
            y=alt.Y(
                "escape",
                title=f"{agg_type}_escape",
            ),
            color=alt.Color('serum:N', 
                            legend=alt.Legend(orient="right", title='sera'))
        )
        .mark_line(size=1, opacity=0.7)
        .properties(width=400, height=200)
    )

    x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

    chart = alt.layer(
        agg_escape_lineplot, x_axis, data=agg_escape_full
    ).facet(
        facet=alt.Facet(
            'age_group:N',
            title='age cohort',
        ),
        columns=2
    ).configure_axis(
        grid=False
    ).resolve_axis(
        x='independent'
    )
    
    return chart

In [45]:
sera_names = [ped_sera_names, teen_sera_names, adult_sera_names]

plot_agg_escape(filtered_models, sera_names, b_sites, 'sum')

In [74]:
sera_names_b = [ped_sera_names, teen_sera_names]

plot_agg_escape(filtered_models_b, sera_names_b, b_sites, 'sum')

In [46]:
plot_agg_escape(filtered_models, sera_names, b_sites, 'mean')

In [76]:
sera_names_b = [ped_sera_names, teen_sera_names]

plot_agg_escape(filtered_models_b, sera_names_b, b_sites, 'mean')

In [47]:
plot_agg_escape(filtered_models, sera_names, a_sites, 'sum')

In [77]:
sera_names_b = [ped_sera_names, teen_sera_names]

plot_agg_escape(filtered_models_b, sera_names_b, a_sites, 'sum')

In [48]:
plot_agg_escape(filtered_models, sera_names, a_sites, 'mean')

In [78]:
sera_names_b = [ped_sera_names, teen_sera_names]

plot_agg_escape(filtered_models_b, sera_names_b, a_sites, 'mean')

In [58]:
plot_agg_escape(filtered_models, sera_names, a_sites + b_sites, 'mean')

In [52]:
plot_agg_escape(filtered_models, sera_names, a_sites + b_sites, 'sum')

In [56]:
site_list = [50, 62, 82, 94, 103, 121, 122, 124, 131, 135, 137, 138, 145, 156, 157, 
              159, 188, 189, 193, 220, 224, 276]

In [57]:
plot_agg_escape(filtered_models, sera_names, site_list, 'sum')

## scratch code

In [9]:
def generate_model(
    prob_escape_df,
    n_epitopes=1
):
    
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_df.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    )

    # fit model, suppressing output text to avoid clutter in notebook
    with io.capture_output() as captured:
        opt_res = model.fit(
            logfreq=200,
            reg_escape_weight=0.1,
        )

    mut_escape_plot = model.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, 
                                            init_floor_at_zero=False 
                                           )

    return mut_escape_plot

In [11]:
import re

In [30]:
pat = r'[A-Z](160)[A-Z]'
test_filt = test[test['aa_substitutions_reference'].str.contains(pat)]

In [29]:
generate_model(test)

In [28]:
generate_model(test_filt)

In [32]:
ped_sera_list = [2367, 3944, 2462, 2389, 2323]
ped_1 = get_prob_escape(ped_sera_list, 'libA', '230221', '1')

In [36]:
ped_test = ped_1[2]

ped_test_filt = ped_test[ped_test['aa_substitutions_reference'].str.contains(pat)]

In [37]:
generate_model(ped_test_filt)

In [38]:
ped_test = ped_1[3]

ped_test_filt = ped_test[ped_test['aa_substitutions_reference'].str.contains(pat)]

generate_model(ped_test_filt)

In [44]:
adults_list = get_filtered_models(adults)

In [69]:
adult_sera = adult_sera_1 + adult_sera_2

mean_escapes = get_mean_escapes(adult_sera, adults_list, b_sites, '40-45')

mean_escapes['site'] = mean_escapes['site'].astype(str)

mean_escapes = mean_escapes.loc[mean_escapes['site'] != '160']

In [76]:
mean_escapes_a = get_mean_escapes(adult_sera, adults_list, a_sites, '40-45')
mean_escapes_a['site'] = mean_escapes_a['site'].astype(str)

In [70]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="mean_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=400, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=mean_escapes
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=1
).configure_axis(
    grid=False
).resolve_axis(
    x='independent'
)

# chart.save('immune_escape_by_age_test.png')


In [77]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="mean_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=400, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=mean_escapes_a
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=1
).configure_axis(
    grid=False
).resolve_axis(
    x='independent'
)

# chart.save('immune_escape_by_age_test.png')


In [18]:
data = [ped_prob_escape, teen_prob_escape, adult_prob_escape]
sera = [ped_sera_names, teen_sera_names, adult_sera_names]
age_groups = ['0-5', '15-18', '40-45']

i=0
mean_escape_list = []

for group in data:
    filt_model_list = get_filtered_models(group)
    
    mean_escapes = get_agg_escapes(sera[i], filt_model_list, b_sites, age_groups[i])
    
    mean_escape_list.append(mean_escapes)
    i+=1
    
mean_escape_160_filt = pd.concat(mean_escape_list)

In [19]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="mean_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=400, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=mean_escape_160_filt
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=1
).configure_axis(
    grid=False
).resolve_axis(
    x='independent'
)

# chart.save('immune_escape_by_age_test.png')


In [24]:
data = [ped_prob_escape, teen_prob_escape, adult_prob_escape]
sera = [ped_sera_names, teen_sera_names, adult_sera_names]
age_groups = ['0-5', '15-18', '40-45']

i=0
mean_escape_list_a = []

for group in data:
    filt_model_list = get_filtered_models(group)
    
    mean_escapes = get_agg_escapes(sera[i], filt_model_list, a_sites, age_groups[i])
    
    mean_escape_list_a.append(mean_escapes)
    i+=1
    
mean_escape_160_filt_a = pd.concat(mean_escape_list_a)

In [25]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="mean_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=400, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=mean_escape_160_filt_a
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=1
).configure_axis(
    grid=False
).resolve_axis(
    x='independent'
)

# chart.save('immune_escape_by_age_test.png')


In [None]:
data = [ped_prob_escape, teen_prob_escape, adult_prob_escape]
sera = [ped_sera_names, teen_sera_names, adult_sera_names]
age_groups = ['0-5', '15-18', '40-45']

i=0
mean_escape_list_a = []

for group in data:
    filt_model_list = get_filtered_models(group)
    
    summed_escapes = get_agg_escapes(agg='sum', sera[i], filt_model_list, b_sites, age_groups[i])
    
    summed_escape_list.append(summed_escapes)
    i+=1
    
mean_escape_160_filt_sum = pd.concat(mean_escape_list)