In [3]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import altair_saver

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [4]:
import os
os.chdir('../../../')

In [5]:
# define samples in each age cohort
sample_dict = {
    "vietnam_ped_189K": [
        "age 2.1 (Vietnam)", 
        "age 2.2 (Vietnam)",
        # "age 2.4 (Vietnam)",
        "age 2.5 (Vietnam)",
        "age 2.5b (Vietnam)",
        # "age 3.3 (Vietnam)", 
        # "age 3.3b (Vietnam)",
        # "age 3.4 (Vietnam)", 
        "age 3.5 (Vietnam)",
    ], 
    "vietnam_ped_189N": [
        # "age 2.1 (Vietnam)", 
        # "age 2.2 (Vietnam)",
        "age 2.4 (Vietnam)",
        # "age 2.5 (Vietnam)",
        # "age 2.5b (Vietnam)",
        "age 3.3 (Vietnam)", 
        "age 3.3b (Vietnam)",
        "age 3.4 (Vietnam)", 
        # "age 3.5 (Vietnam)",
    ],    
    "vietnam_adult": [
        "age 30.5 (Vietnam)",
        "age 31.5 (Vietnam)",
        "age 33.5 (Vietnam)",
    ],
    "misc_adult": [
        "age 21 (Seattle)",
        "age 53 (Seattle)",
        "age 64 (Seattle)",
        "age 65 (Seattle)",
    ],
    "ferret": [
        "ferret 1 (Pitt)",
        "ferret 2 (Pitt)",
        "ferret 3 (Pitt)",
        "ferret (WHO)",
    ]
}

# get full dataset
escape_df = pd.read_csv(f'results/perth2009/merged_escape.csv')[['name', 'site', 'wildtype', 'mutant', 'escape']]
escape_df = escape_df.rename(columns={'name': 'serum'})

# Function to convert '(HA2)X' to numeric
def convert_site_to_numeric(site):
    if '(HA2)' in site:
        try:
            number = int(site.replace('(HA2)', '').strip())
            return number + 329
        except ValueError:
            return site  # If there's an issue with conversion, return the original value
    else:
        return site

# Apply the function to the 'site' column
escape_df['site'] = escape_df['site'].apply(convert_site_to_numeric)

# get summed escape at each site
escape_df = escape_df.groupby(['serum', 'site', 'wildtype'], as_index=False).aggregate({'escape': 'sum'})

# floor at 0
escape_df['escape'] = escape_df['escape'].clip(lower=0)

# add cohort label
def find_sample_type(sample_name):
    for sample_type, sample_list in sample_dict.items():
        if sample_name in sample_list:
            return sample_type
    return None

escape_df['cohort'] = escape_df['serum'].apply(find_sample_type)

# add 'mean_cohort_escape' column of mean escape values per site within a cohort
escape_df['mean_cohort_escape'] = (
    escape_df.groupby(['site', 'cohort'])['escape']
    .transform('mean')
)

escape_df.head()

Unnamed: 0,serum,site,wildtype,escape,cohort,mean_cohort_escape
0,age 2.1 (Vietnam),330,G,0.0,vietnam_ped_189K,0.0
1,age 2.1 (Vietnam),331,I,0.0,vietnam_ped_189K,0.0
2,age 2.1 (Vietnam),332,F,0.602227,vietnam_ped_189K,0.13672
3,age 2.1 (Vietnam),333,G,0.0,vietnam_ped_189K,0.0
4,age 2.1 (Vietnam),334,A,0.0,vietnam_ped_189K,0.0


In [6]:
escape_df = escape_df.loc[escape_df['cohort'] != 'misc_adult']

In [7]:
# initialize list of key sites
site_list = [48, 50, 81, 82, 121, 122, 124, 131, 135, 137, 145, 156, 157,
              159, 160, 186, 188, 189, 192, 193, 275, 276]
             # 103, 244]

# make strings to match site dtype in escape_df
site_list_str = list(map(str, site_list))

# filter df to just these key sites
escape_df_filtered = escape_df[escape_df['site'].isin(site_list_str)]

# filter to just vietnam cohorts
# escape_df_filtered = escape_df_filtered.loc[(escape_df_filtered['cohort'] == 'vietnam_ped') | (escape_df_filtered['cohort'] == 'vietnam_adult')]

# add antigenic region label
site_A = list(range(121, 147))
site_B = list(range(155, 161)) + list(range(186, 199))
site_C = list(range(44, 55)) + list(range(273, 281))
site_D = list(range(201,220))
site_E = list(range(62, 66)) + list(range(78, 95)) + list(range(260, 266))

antigenic_regions = {
    'A': site_A, 
    'B': site_B, 
    'C': site_C, 
    'D': site_D,
    'E': site_E
}

# Function to map sites to antigenic regions
def map_site_to_antigenic_region(site):
    for region, sites in antigenic_regions.items():
        if site in sites:
            return region
    return None  # Handle the case where the site doesn't belong to any region

# Apply the mapping function to create the 'antigenic region' column
escape_df_filtered['site'] = escape_df_filtered['site'].astype(int)
escape_df_filtered['antigenic_region'] = escape_df_filtered['site'].map(map_site_to_antigenic_region)

escape_df_filtered['site'] = escape_df_filtered['site'].astype(str)

escape_df_filtered.head()

Unnamed: 0,serum,site,wildtype,escape,cohort,mean_cohort_escape,antigenic_region
262,age 2.1 (Vietnam),121,N,0.0,vietnam_ped_189K,0.0,A
263,age 2.1 (Vietnam),122,N,0.0,vietnam_ped_189K,0.0,A
265,age 2.1 (Vietnam),124,S,0.0,vietnam_ped_189K,0.0,A
273,age 2.1 (Vietnam),131,T,0.0,vietnam_ped_189K,0.0,A
277,age 2.1 (Vietnam),135,T,0.0,vietnam_ped_189K,0.0,A


In [6]:
# Group the DataFrame by the 'serum' column
grouped = escape_df_filtered.groupby('serum')

# Define a function to normalize the 'escape_mean' column within each group
def normalize(group):
    group['escape'] = group['escape'] / group['escape'].max()
    return group

# Apply the normalization function to each group
normalized_df = grouped.apply(normalize)

# Reset the index of the resulting DataFrame
normalized_df.reset_index(drop=True, inplace=True)

normalized_df['escape'] = normalized_df['escape'].clip(lower=0)

# replace mean_escape_mean values
normalized_df['mean_cohort_escape'] = (
    normalized_df.groupby(['site', 'cohort'])['escape']
    .transform('mean')
)

In [12]:
site_order = ['121', '122', '124', '131', '135', '137', '142', '144', '145', 
              '156', '157', '159', '160', '186', '188', '189', '192', '193', 
              '48', '50', '275', '276', 
              '81', '82']  

In [9]:
# filtered sites, line and scatter overlay
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order
               ),
        y=alt.Y(
            "escape",
            scale=alt.Scale(
                domain=[-1, 45],
                # clamp=True
                           ),
            axis=alt.Axis(values=[0, 10, 20, 30, 40, 50]),
            title="escape score",
        ),
        color=alt.Color('cohort:N', 
                        legend=None,
                        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
                       ).scale(scheme='dark2'),
        detail=alt.Detail(['serum', 'antigenic_region']),
        tooltip=['serum', 'site', 'escape']
    )
    .mark_line(size=2.7, opacity=0.3, clip=True)
    .properties(width=500, height=130)
)

mean_points = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order
               ),
        y=alt.Y("mean_cohort_escape"),
        color=alt.Color('cohort:N', 
                        legend=None,
                        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
                       ).scale(scheme='dark2'),
        tooltip=['site', 'escape']
    )
    .mark_circle(size=55, opacity=0.75)
    # .properties(width=400, height=150)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(y='y')

faceted_line_scatter_overlay = alt.layer(
    summed_escape_lineplot, mean_points, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'cohort:N',
        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
        title='summed escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    spacing=3,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=15,
)

# faceted_line_scatter_overlay.save(
#     'scratch_notebooks/figure_drafts/perth09_analysis/sitewise_escape_plots/231004_perth09_escape.png',
#     scale_factor=2.0
# )

faceted_line_scatter_overlay

In [10]:
# filtered sites, line and scatter overlay
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order
               ),
        y=alt.Y(
            "escape",
            scale=alt.Scale(domain=[-0.1, 1.1], clamp=True),
            axis=alt.Axis(values=[0, 0.5, 1]),
            # axis=None,
            title="escape score",
        ),
        color=alt.Color('cohort:N', 
                        legend=None,
                        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
                       ).scale(scheme='dark2'),
        detail=alt.Detail(['serum', 'antigenic_region']),
        tooltip=['serum', 'site', 'escape']
    )
    .mark_line(size=2.7, opacity=0.3, clip=True)
    .properties(width=500, height=130)
)

mean_points = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order
               ),
        y=alt.Y("mean_cohort_escape"),
        color=alt.Color('cohort:N', 
                        legend=None,
                        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
                       ).scale(scheme='dark2'),
        tooltip=['site', 'mean_cohort_escape']
    )
    .mark_circle(size=55, opacity=0.75)
    # .properties(width=400, height=150)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(y='y')

faceted_line_scatter_overlay = alt.layer(
    summed_escape_lineplot, mean_points, x_axis, data=normalized_df
).facet(
    facet=alt.Facet(
        'cohort:N',
        sort=['vietnam_ped', 'vietnam_adult'],
        title='summed escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    spacing=3,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=15
)

# faceted_line_scatter_overlay.save(
#     'scratch_notebooks/figure_drafts/perth09_analysis/sitewise_escape_plots/231004_perth09_normalized.png',
#     scale_factor=2.0
# )

faceted_line_scatter_overlay

In [41]:
escape_df['site'] = escape_df['site'].astype(int)

# all sites, lineplot
summed_escape_lineplot_full = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                scale=alt.Scale(domain=[0, 540],
                                clamp=True
                               )
               ),
        y=alt.Y(
            "escape",
            title=None,
            scale=alt.Scale(domain=[-5, 55],
                            clamp=True
                           )
        ),
        color=alt.Color('cohort:N', 
                        legend=None,
                        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
                       ).scale(scheme='dark2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape']
    )
    .mark_line(size=.75, opacity=0.7)
    .properties(width=600, height=70)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(y='y')

faceted_lineplot_full = alt.layer(
    summed_escape_lineplot_full, x_axis, data=escape_df
).facet(
    facet=alt.Facet(
        'cohort:N',
        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
        title='escape at all residues',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=1,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic',
        )
    ),
    spacing=2,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

faceted_lineplot_full.save(
    'scratch_notebooks/figure_drafts/perth09_analysis/sitewise_escape_plots/231004_summed_escape_full.png',
    scale_factor=2.0
)

faceted_lineplot_full

In [104]:
escape_df['site'] = escape_df['site'].astype(int)

escape_df_floored = escape_df.copy()
escape_df_floored['escape'] = escape_df_floored['escape'].clip(lower=0)

# all sites, lineplot
summed_escape_lineplot_full = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                scale=alt.Scale(domain=[0, 540],
                                clamp=True
                               )
               ),
        y=alt.Y(
            "escape",
            title=None,
            scale=alt.Scale(domain=[-5, 55],
                            clamp=True
                           )
        ),
        color=alt.Color('cohort:N', 
                        legend=None,
                        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
                       ).scale(scheme='dark2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape']
    )
    .mark_line(size=.75, opacity=0.7)
    .properties(width=550, height=130)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(y='y')

faceted_lineplot_full = alt.layer(
    summed_escape_lineplot_full, x_axis, data=escape_df_floored
).facet(
    facet=alt.Facet(
        'cohort:N',
        sort=['vietnam_ped', 'vietnam_adult', 'misc_adult'],
        title='escape at all residues',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic',
        )
    ),
    spacing=2,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

faceted_lineplot_full.save(
    'scratch_notebooks/figure_drafts/perth09_analysis/sitewise_escape_plots/230810_summed_escape_full_floored.png',
    scale_factor=2.0
)

faceted_lineplot_full

In [9]:
escape_df_189 = escape_df.loc[(escape_df['site'] == '189') & (
    (escape_df['cohort'] == 'vietnam_ped_189K') | (escape_df['cohort'] == 'vietnam_ped_189N'))]

In [10]:
escape_df_189

Unnamed: 0,serum,site,wildtype,escape,cohort,mean_cohort_escape
336,age 2.1 (Vietnam),189,K,23.981084,vietnam_ped_189K,36.049357
902,age 2.2 (Vietnam),189,K,31.83535,vietnam_ped_189K,36.049357
1468,age 2.4 (Vietnam),189,K,0.0,vietnam_ped_189N,0.022521
2034,age 2.5 (Vietnam),189,K,55.2002,vietnam_ped_189K,36.049357
2600,age 2.5b (Vietnam),189,K,42.86996,vietnam_ped_189K,36.049357
3732,age 3.3 (Vietnam),189,K,0.0,vietnam_ped_189N,0.022521
4298,age 3.3b (Vietnam),189,K,0.090084,vietnam_ped_189N,0.022521
4864,age 3.4 (Vietnam),189,K,0.0,vietnam_ped_189N,0.022521
5430,age 3.5 (Vietnam),189,K,26.36019,vietnam_ped_189K,36.049357


In [22]:
colors=['#58595B', '#BCBEC0']

summed_escape_lineplot = (
    alt.Chart(escape_df_189)
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order,
               ),
        y=alt.Y(
            "escape",
            title="escape score",
            scale=alt.Scale(domain=[-1, 35], clamp=True),
            axis=alt.Axis(values=[0, 10, 20, 30]),
        ),
        xOffset='jitter:Q',
        color=alt.Color('cohort:N', 
                        legend=None,
                       ),
        detail=alt.Detail(['serum']),
        tooltip=['serum', 'site', 'escape']
    )
    .transform_calculate(jitter="sqrt(-2*log(random()))*cos(2*PI*random())")
    .mark_circle(size=150)
    .configure_range(category=alt.RangeScheme(colors))
    .configure_axis(grid=False,
                    labelFontSize=18,
                    titleFontSize=15
                   )
    .properties(width=60, height=150)
)

# 189N - BDBEC0
# 189K - 6E6F72
#58595B
# D1D3D4

summed_escape_lineplot.save(
    'scratch_notebooks/figure_drafts/perth09_analysis/sitewise_escape_plots/ped-189.png',
    scale_factor=2.0
)

summed_escape_lineplot