In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../')

In [20]:
# define samples in each age cohort
sample_dict = {
    "vietnam_ped": [
        "age 2.1 (Vietnam)", 
        "age 2.2 (Vietnam)",
        "age 2.4 (Vietnam)",
        "age 2.5 (Vietnam)",
        "age 2.5b (Vietnam)",
        "age 3.3 (Vietnam)", 
        "age 3.3b (Vietnam)",
        "age 3.4 (Vietnam)", 
        "age 3.5 (Vietnam)",
    ], 
    "vietnam_adult": [
        "age 30.5 (Vietnam)",
        "age 31.5 (Vietnam)",
        "age 33.5 (Vietnam)",
    ],
    "misc_adult": [
        "age 21 (Seattle)",
        "age 53 (Seattle)",
        "age 64 (Seattle)",
        "age 65 (Seattle)",
    ],
    "ferret": [
        "ferret 1 (Pitt)",
        "ferret 2 (Pitt)",
        "ferret 3 (Pitt)",
        "ferret (WHO)",
    ]
}

# get full dataset
escape_df = pd.read_csv(f'results/perth2009/merged_escape.csv')[['name', 'site', 'wildtype', 'mutant', 'escape']]
escape_df = escape_df.rename(columns={'name': 'serum'})

# get summed escape at each site
escape_df = escape_df.groupby(['serum', 'site', 'wildtype'], as_index=False).aggregate({'escape': 'sum'})

# add cohort label
def find_sample_type(sample_name):
    for sample_type, sample_list in sample_dict.items():
        if sample_name in sample_list:
            return sample_type
    return None

escape_df['cohort'] = escape_df['serum'].apply(find_sample_type)

# add 'mean_cohort_escape' column of mean escape values per site within a cohort
escape_df['mean_cohort_escape'] = (
    escape_df.groupby(['site', 'cohort'])['escape']
    .transform('mean')
)

escape_df.head()

Unnamed: 0,serum,site,wildtype,escape,cohort,mean_cohort_escape
0,age 2.1 (Vietnam),(HA2)1,G,-1.153665,vietnam_ped,-1.263039
1,age 2.1 (Vietnam),(HA2)10,I,0.580139,vietnam_ped,-1.211677
2,age 2.1 (Vietnam),(HA2)100,V,-2.019882,vietnam_ped,-2.426894
3,age 2.1 (Vietnam),(HA2)101,A,-5.917736,vietnam_ped,-4.327383
4,age 2.1 (Vietnam),(HA2)102,L,-4.925232,vietnam_ped,-3.581447


In [25]:
# initialize list of key sites
site_list = [50, 82, 103, 121, 122, 124, 131, 135, 137, 138, 145, 156, 157, 
              159, 160, 186, 188, 189, 192, 193, 220, 224, 244, 276]

# make strings to match site dtype in escape_df
site_list_str = list(map(str, site_list))

# filter df to just these key sites
escape_df_filtered = escape_df[escape_df['site'].isin(site_list_str)]

# add leading 0 to 2-digit sites for correct ordering on plots
site_dict = {'50': '050', 
             '82': '082', 
             '94': '094'}

escape_df_filtered['site'] = escape_df_filtered['site'].apply(
    lambda x: site_dict[x] if x in site_dict else x
)

escape_df_filtered.head()

Unnamed: 0,serum,site,wildtype,escape,cohort,mean_cohort_escape
242,age 2.1 (Vietnam),103,P,-4.446905,vietnam_ped,-7.034206
262,age 2.1 (Vietnam),121,N,-0.086454,vietnam_ped,-3.135215
263,age 2.1 (Vietnam),122,N,-4.088973,vietnam_ped,-3.877047
265,age 2.1 (Vietnam),124,S,-3.376425,vietnam_ped,-2.446253
273,age 2.1 (Vietnam),131,T,-2.24867,vietnam_ped,-1.713087


In [26]:
# filtered sites, line and scatter overlay
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="escape score",
        ),
        color=alt.Color('cohort:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape']
    )
    .mark_line(size=1.5, opacity=0.4)
    .properties(width=525, height=120)
)

mean_points = (
    alt.Chart()
    .encode(
        x=alt.X("site", title="site"),
        y=alt.Y("mean_cohort_escape"),
        color=alt.Color('cohort:N', legend=None).scale(scheme='set2'),
        tooltip=['site', 'escape']
    )
    .mark_circle(size=45, opacity=0.75)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_line_scatter_overlay = alt.layer(
    summed_escape_lineplot, mean_points, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'cohort:N',
        title='summed escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    spacing=5,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=15
)

faceted_line_scatter_overlay.save('scratch_notebooks/figure_drafts/perth09_analysis/sitewise_escape_plots/230724_summed_escape_filt_sites.html')

faceted_line_scatter_overlay

In [24]:
# all sites, lineplot
summed_escape_lineplot_full = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="summed escape",
        ),
        color=alt.Color('cohort:N', 
                        legend=None,
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape']
    )
    .mark_line(size=0.7, opacity=0.7)
    .properties(width=800, height=120)
)

faceted_lineplot_full = alt.layer(
    summed_escape_lineplot_full, x_axis, data=escape_df
).facet(
    facet=alt.Facet(
        'cohort:N',
        title='escape at all residues',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    spacing=5,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

# faceted_lineplot_full.save('scratch_notebooks/figure_drafts/figures_for_talks/230713_summed_escape_full.html')

faceted_lineplot_full

** Sites are currently strings, with stalk residues labeled as (HA2)1 etc. So note that they do not order correctly on this plot. Need to convert to numeric similar to HK/19 data.