In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../')

In [30]:
site_list = [50, 62, 82, 94, 103, 121, 122, 124, 131, 135, 137, 138, 145, 156, 157, 
              159, 160, 188, 189, 193, 220, 224, 276]

ped_sera_list = [2367, 3944, 2462, 2389, 2323, 2388, 2463, 3973, 4299, 4584]
teen_sera_list = [2343, 2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862, 3895]
adult_sera_list = ['33C', '34C', '197C', '199C', '215C']

In [31]:
def get_summed_escapes(sera_list, age_group, site_list):
    summed_escape_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 3"
        )
        
        prob_escape_sum = prob_escape.groupby('site', as_index=False).aggregate({'escape_mean': 'sum'})
        
        prob_escape_filt = prob_escape_sum[prob_escape_sum['site'].isin(site_list)]
        
        prob_escape_filt['serum'] = serum
        prob_escape_filt['age_group'] = age_group
        # prob_escape_filt['site'] = pd.Categorical(prob_escape_filt['site'], ordered=True)
        # prob_escape_filt['site'] = prob_escape_filt['site'].astype(str)
        
        summed_escape_list.append(prob_escape_filt)
        
    summed_escape = pd.concat(summed_escape_list)
    return summed_escape

In [32]:
ped = get_summed_escapes(ped_sera_list, '0-5', site_list)

In [33]:
teen = get_summed_escapes(teen_sera_list, '15-18', site_list)

In [34]:
adult = get_summed_escapes(adult_sera_list, '40-45', site_list)

In [35]:
full_escape = pd.concat([ped, teen, adult])
full_escape['site'] = full_escape['site'].astype(str)

site_dict = {'50': '050', 
             '62': '062', 
             '82': '082', 
             '94': '094'}

full_escape['site'] = full_escape['site'].apply(lambda x: site_dict[x] if x in site_dict else x)
full_escape

Unnamed: 0,site,escape_mean,serum,age_group
48,050,0.5339,2367,0-5
58,062,-1.1642,2367,0-5
73,082,-0.0175,2367,0-5
84,094,0.2265,2367,0-5
89,103,-0.5267,2367,0-5
...,...,...,...,...
154,189,4.5592,215C,40-45
158,193,0.0311,215C,40-45
181,220,-3.6198,215C,40-45
185,224,-2.1645,215C,40-45


In [43]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=400, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=full_escape
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=1
).configure_axis(
    grid=False
).resolve_axis(
    x='independent'
)

# chart.save('immune_escape_by_age_test.png')


In [24]:
sites_list = [103, 121, 122, 124, 135, 137, 138, 145, 159, 160, 186, 189, 193, 220, 224]

summed_escape_list = []

i=0
for prob_escape in prob_escape_list:
    summed_escape = get_summed_escape(prob_escape, sites_list)
    summed_escape['sera'] = adult_sera_list[i]
    
    summed_escape_list.append(summed_escape)
    
    i+=1

In [36]:
summed_escape_full = pd.concat(summed_escape_list)
summed_escape_full['site'] = summed_escape_full['site'].astype(str)
summed_escape_full.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 75 entries, 87 to 185
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   site    75 non-null     object 
 1   escape  75 non-null     float64
 2   sera    75 non-null     object 
dtypes: float64(1), object(2)
memory usage: 2.3+ KB


In [56]:
summed_escape_base = (
    alt.Chart(summed_escape_full)
    .encode(
        x=alt.X("site", 
                title="site",
                # scale=alt.Scale(type="log"),
               ),
        y=alt.Y(
            "escape",
            title="summed_escape",
            # scale=alt.Scale(type="log", constant=0.02, domainMax=1),
        ),
        # column='antibody:N',
        color=alt.Color('sera:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
)

site_lineplot = (
    (
        (
            summed_escape_base.mark_line(size=1, opacity=0.7)
            # .transform_calculate(_stat_show_line="true")
            # .transform_filter(line_selection)
        )
        + summed_escape_base.mark_circle(opacity=0.7, size=20)
        + alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')
    )
)

# line = alt.Chart(pd.DataFrame({'y': [1]})).mark_rule().encode(y='y')

site_lineplot.configure_axis(grid=False)

In [10]:
def get_summed_escapes(sera_list, age_group):
    summed_escape_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 3"
        )
        
        prob_escape_sum = prob_escape.groupby('site', as_index=False).aggregate({'escape_mean': 'sum'})
        
        # prob_escape_filt = prob_escape_sum[prob_escape_sum['site'].isin(site_list)]
        
        prob_escape_sum['serum'] = serum
        prob_escape_sum['age_group'] = age_group
        # prob_escape_filt['site'] = pd.Categorical(prob_escape_filt['site'], ordered=True)
        # prob_escape_filt['site'] = prob_escape_filt['site'].astype(str)
        
        summed_escape_list.append(prob_escape_sum)
        
    summed_escape = pd.concat(summed_escape_list)
    return summed_escape

In [11]:
ped = get_summed_escapes(ped_sera_list, '0-5')
teen = get_summed_escapes(teen_sera_list, '15-18')
adult = get_summed_escapes(adult_sera_list, '40-45')

In [12]:
full_escape = pd.concat([ped, teen, adult])
full_escape

Unnamed: 0,site,escape_mean,serum,age_group
0,-2,0.0681,2367,0-5
1,1,-0.1331,2367,0-5
2,2,0.0374,2367,0-5
3,3,-0.4908,2367,0-5
4,4,-0.2832,2367,0-5
...,...,...,...,...
426,533,0.0147,215C,40-45
427,536,-0.0543,215C,40-45
428,537,-0.0926,215C,40-45
429,538,-0.0992,215C,40-45


In [15]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=800, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=full_escape
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=1
).configure_axis(grid=False)

# chart.save('immune_escape_by_age.pdf')

In [200]:
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        # column='age_group:N',
        color=alt.Color('serum:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=400, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

alt.layer(
    summed_escape_lineplot, x_axis, data=full_escape
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='age cohort',
    ),
    columns=1
).configure_axis(grid=False)

# chart.save('immune_escape_by_age.pdf')

### scratch code - 

In [19]:
def get_prob_escape(sera_list, lib, date, replicate):
    prob_escape_list = []
    for serum in sera_list:        
        prob_escape = pd.read_csv(
            f'results/prob_escape/{lib}_{date}_1_{serum}_{replicate}_prob_escape.csv', 
            keep_default_na=False,
            na_values="nan"
        ).query(
            "`no-antibody_count` >= no_antibody_count_threshold"
        )
        
        prob_escape_list.append(prob_escape)
        
    return prob_escape_list

In [20]:
adult_sera_list = ['33C', '34C', '197C', '199C', '215C']

prob_escape_list = get_prob_escape(adult_sera_list, 'libA', '230419', '2')

In [6]:
def generate_model(
    prob_escape_df,
    n_epitopes=1
):
    
    model = polyclonal.Polyclonal(
        n_epitopes=n_epitopes,
        data_to_fit=prob_escape_df.rename(
            columns={
                "antibody_concentration": "concentration",
                "aa_substitutions_reference": "aa_substitutions",
            }
        ),
        alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    )

    # fit model, suppressing output text to avoid clutter in notebook
    with io.capture_output() as captured:
        opt_res = model.fit(
            logfreq=200,
            reg_escape_weight=0.1,
        )

    mut_escape_plot = model.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False)

    return model

In [23]:
def get_summed_escape(prob_escape, sites_list):
    model = generate_model(prob_escape)
    df = model.mut_escape_df
    df = df.loc[df['times_seen'] >= 3]
    
    summed_escapes = df.groupby('site', as_index=False).aggregate({'escape': 'sum'})
    
    filtered_escape = summed_escapes[summed_escapes['site'].isin(sites_list)]
    
    return filtered_escape

In [24]:
sites_list = [103, 121, 122, 124, 135, 137, 138, 145, 159, 160, 186, 189, 193, 220, 224]

summed_escape_list = []

i=0
for prob_escape in prob_escape_list:
    summed_escape = get_summed_escape(prob_escape, sites_list)
    summed_escape['sera'] = adult_sera_list[i]
    
    summed_escape_list.append(summed_escape)
    
    i+=1

In [36]:
summed_escape_full = pd.concat(summed_escape_list)
summed_escape_full['site'] = summed_escape_full['site'].astype(str)
summed_escape_full.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 75 entries, 87 to 185
Data columns (total 3 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   site    75 non-null     object 
 1   escape  75 non-null     float64
 2   sera    75 non-null     object 
dtypes: float64(1), object(2)
memory usage: 2.3+ KB


In [56]:
summed_escape_base = (
    alt.Chart(summed_escape_full)
    .encode(
        x=alt.X("site", 
                title="site",
                # scale=alt.Scale(type="log"),
               ),
        y=alt.Y(
            "escape",
            title="summed_escape",
            # scale=alt.Scale(type="log", constant=0.02, domainMax=1),
        ),
        # column='antibody:N',
        color=alt.Color('sera:N', 
                        legend=alt.Legend(orient="right", title='sera'))
    )
)

site_lineplot = (
    (
        (
            summed_escape_base.mark_line(size=1, opacity=0.7)
            # .transform_calculate(_stat_show_line="true")
            # .transform_filter(line_selection)
        )
        + summed_escape_base.mark_circle(opacity=0.7, size=20)
        + alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')
    )
)

# line = alt.Chart(pd.DataFrame({'y': [1]})).mark_rule().encode(y='y')

site_lineplot.configure_axis(grid=False)