In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../')

## Plotting immune escape across full protein and selected sites

In [3]:
def get_summed_escapes(sera_list, age_group, site_list=None):
    summed_escape_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 5"
        )
        
        prob_escape_sum = prob_escape.groupby(['site', 'wildtype'], as_index=False).aggregate({'escape_mean': 'sum'})

        if site_list:
            prob_escape_final = prob_escape_sum[prob_escape_sum['site'].isin(site_list)]
            prob_escape_final['site'] = pd.Categorical(prob_escape_final['site'], ordered=True)
            prob_escape_final['site'] = prob_escape_final['site'].astype(str)

        else:
            prob_escape_final = prob_escape_sum.copy()
            
        prob_escape_final['serum'] = serum
        prob_escape_final['age_group'] = age_group
        
        summed_escape_list.append(prob_escape_final)
        
    summed_escape = pd.concat(summed_escape_list)
    return summed_escape

In [4]:
# initialize list of key sites, plus samples in each age cohort
site_list = [50, 82, 103, 121, 122, 124, 131, 135, 137, 138, 145, 156, 157, 
              159, 160, 186, 188, 189, 193, 220, 224, 244, 276]

peds = [2367, 3944, 2462, 2389, 2323, 2388, 2463, 3973, 4299, 4584]
teens = [2343, 2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862, 3895]
adults = ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C']
ferrets = ['ferret_1', 'ferret_2', 'ferret_3']

sample_lists = [peds, teens, adults, ferrets]
cohorts = ['0-5', '15-18', '40-45', 'ferrets']
summed_escapes_filtered = []

i=0 # for looping through age cohort definitions

# start by getting full escape df filtered to key sites
for list in sample_lists:
    summed_escape_selected_sites = get_summed_escapes(list, cohorts[i], site_list)
    summed_escapes_filtered.append(summed_escape_selected_sites)

    i+=1

escape_df_filtered = pd.concat(summed_escapes_filtered)

site_dict = {'50': '050', 
             '62': '062', 
             '82': '082', 
             '94': '094'}

escape_df_filtered['site'] = escape_df_filtered['site'].apply(
    lambda x: site_dict[x] if x in site_dict else x
)

escape_df_filtered['serum'] = escape_df_filtered['serum'].astype(str)

In [5]:
# also generate escape df with all sites included
summed_escapes = []
i=0 # for looping through age cohort definitions

for list in sample_lists:
    summed_escape = get_summed_escapes(list, cohorts[i])
    summed_escapes.append(summed_escape)

    i+=1

escape_df_full = pd.concat(summed_escapes)

escape_df_full['serum'] = escape_df_full['serum'].astype(str)

## Set up different chart options

In [6]:
# filtered sites, scatterplot
summed_escape_scatterplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_circle(size=50, opacity=0.7)
    .properties(width=500, height=200)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_scatter_filtered = alt.layer(
    summed_escape_scatterplot, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=23,
            # titleFontWeight='normal',
            labelFontSize=17,
        )
    ),
    columns=2
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
).resolve_axis(
    x='independent'
)

faceted_scatter_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230713_summed_escape_scatterplot.html')

In [7]:
# filtered sites, lineplot
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=1.2, opacity=0.7)
    .properties(width=500, height=200)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_lineplot_filtered = alt.layer(
    summed_escape_lineplot, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=23,
            # titleFontWeight='normal',
            labelFontSize=17,
        )
    ),
    columns=2
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
).resolve_axis(
    x='independent'
)

faceted_lineplot_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230713_summed_escape_lineplot.html')

In [8]:
# filtered sites, line and scatter overlay
summed_escape_scatterplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_circle(size=40, opacity=0.7)
    .properties(width=500, height=200)
    # .mark_circle(opacity=0.7, size=20))
)

summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=0.8, opacity=0.5)
    .properties(width=500, height=200)
)

faceted_line_scatter_overlay = alt.layer(
    summed_escape_lineplot, summed_escape_scatterplot, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=23,
            # titleFontWeight='normal',
            labelFontSize=17,
        )
    ),
    columns=2
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
).resolve_axis(
    x='independent'
)

faceted_line_scatter_overlay.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230713_summed_escape_line_scatter.html')

In [9]:
# all sites, lineplot
summed_escape_lineplot_full = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=alt.Legend(orient="right", title='age group')
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=1, opacity=0.7)
    .properties(width=800, height=150)
)

faceted_lineplot_full = alt.layer(
    summed_escape_lineplot_full, x_axis, data=escape_df_full
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at all residues',
        header=alt.Header(
            titleFontSize=20,
            # titleFontWeight='normal',
            labelFontSize=0
        )
    ),
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

faceted_lineplot_full.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230713_summed_escape_full.html')

# Chart comparison

In [10]:
faceted_scatter_filtered

In [11]:
faceted_lineplot_filtered

In [12]:
faceted_line_scatter_overlay

In [13]:
faceted_lineplot_full