In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import altair_saver

In [3]:
import os
os.chdir('../../../')

## Plotting immune escape across full protein and selected sites

In [8]:
def get_summed_escapes(sera_list, cohort, site_list=None):
    summed_escape_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 5"
        )
        
        prob_escape_sum = (
            prob_escape.groupby(['site', 'wildtype'], as_index=False)
            .aggregate({'escape_mean': 'sum'})
            .rename(columns={'escape_mean': 'escape'})
        )

        # if site_list:
        #     prob_escape_final = prob_escape_sum[prob_escape_sum['site'].isin(site_list)]
        #     prob_escape_final['site'] = pd.Categorical(prob_escape_final['site'], ordered=True)
            # prob_escape_final['site'] = prob_escape_final['site'].astype(str)

        # else:
        #     prob_escape_final = prob_escape_sum.copy()
            
        prob_escape_sum['serum'] = serum
        prob_escape_sum['cohort'] = cohort
        
        summed_escape_list.append(prob_escape_sum)
        
    summed_escape = pd.concat(summed_escape_list)
    return summed_escape

Get HK19 data:

In [22]:
# define samples in each age cohort
peds = [3944, 2389, 2323, 2388, 3973, 4299, 4584, 2367]
teens = [2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862]
adults = ['33C', '34C', '197C', '199C', '215C', 
          '210C', '74C', '68C', '150C', '18C',]
elderly = ['AUSAB-13']
infant = [2462]

sample_lists = [peds, teens, adults, elderly, infant]
cohorts = ['02-05 years', '15-20 years', '40-45 years',  '68 years', 'infant']
cohort_dfs = []

i=0 # for looping through age cohort definitions

# start by getting full escape df for HK19
for entry in sample_lists:
    cohort_escape_df = get_summed_escapes(entry, cohorts[i], site_list)
    cohort_dfs.append(cohort_escape_df)

    i+=1

hk19_df = pd.concat(cohort_dfs)

hk19_df['serum'] = hk19_df['serum'].astype(str)

# add 'escape_mean' column of mean escape values per site within an age group
hk19_df['escape_mean'] = (
    hk19_df.groupby(['site', 'cohort'])['escape']
    .transform('mean')
)

hk19_df['ha_strain'] = 'hk19'

Get Perth09 data:

In [23]:
# define samples in each age cohort
sample_dict = {
    "2-4_years": [
        "age 2.1 (Vietnam)", 
        "age 2.2 (Vietnam)",
        "age 2.4 (Vietnam)",
        "age 2.5 (Vietnam)",
        "age 2.5b (Vietnam)",
        "age 3.3 (Vietnam)", 
        "age 3.3b (Vietnam)",
        "age 3.4 (Vietnam)", 
        "age 3.5 (Vietnam)",
    ],   
    "30-34_years": [
        "age 30.5 (Vietnam)",
        "age 31.5 (Vietnam)",
        "age 33.5 (Vietnam)",
    ],
    "misc_adult": [
        "age 21 (Seattle)",
        "age 53 (Seattle)",
        "age 64 (Seattle)",
        "age 65 (Seattle)",
    ],
    "ferret": [
        "ferret 1 (Pitt)",
        "ferret 2 (Pitt)",
        "ferret 3 (Pitt)",
        "ferret (WHO)",
    ]
}

# get full dataset
perth09_df = pd.read_csv(f'results/perth2009/merged_escape.csv')[['name', 'site', 'wildtype', 'mutant', 'escape']]
perth09_df = perth09_df.rename(columns={'name': 'serum'})

# Function to convert '(HA2)X' to numeric
def convert_site_to_numeric(site):
    if '(HA2)' in site:
        try:
            number = int(site.replace('(HA2)', '').strip())
            return number + 329
        except ValueError:
            return site  # If there's an issue with conversion, return the original value
    else:
        return site

# Apply the function to the 'site' column
perth09_df['site'] = perth09_df['site'].apply(convert_site_to_numeric)

# get summed escape at each site
perth09_df = perth09_df.groupby(['serum', 'site', 'wildtype'], as_index=False).aggregate({'escape': 'sum'})

# floor at 0
perth09_df['escape'] = perth09_df['escape'].clip(lower=0)

# add cohort label
def find_sample_type(sample_name):
    for sample_type, sample_list in sample_dict.items():
        if sample_name in sample_list:
            return sample_type
    return None

perth09_df['cohort'] = perth09_df['serum'].apply(find_sample_type)

perth09_df = perth09_df.loc[(perth09_df['cohort'] != 'misc_adult')]
perth09_df['site'] = perth09_df['site'].astype(int)

# add 'escape_mean' column of mean escape values per site within a cohort
perth09_df['escape_mean'] = (
    perth09_df.groupby(['site', 'cohort'])['escape']
    .transform('mean')
)

perth09_df['ha_strain'] = 'perth09'

Combine to a single dataframe, and label antigenic regions:

In [29]:
# make single full dataframe
escape_df = pd.concat([perth09_df, hk19_df])

# add antigenic region labels
site_A = list(range(121, 147))
site_B = list(range(155, 161)) + list(range(186, 199))
site_C = list(range(44, 55)) + list(range(273, 281))
site_D = list(range(201,220))
site_E = list(range(62, 66)) + list(range(78, 95)) + list(range(260, 266))

antigenic_regions = {
    'A': site_A, 
    'B': site_B, 
    'C': site_C, 
    'D': site_D,
    'E': site_E
}

# Function to map sites to antigenic regions
def map_site_to_antigenic_region(site):
    for region, sites in antigenic_regions.items():
        if site in sites:
            return region
    return None  # Handle the case where the site doesn't belong to any region

# Apply the mapping function to create the 'antigenic region' column
escape_df['antigenic_region'] = escape_df['site'].map(map_site_to_antigenic_region)

### Filter to selected sites and normalize

In [30]:
# define list of key sites
site_list = [48, 50, 81, 82, 121, 122, 124, 131, 135, 137, 145, 156, 157, 
              159, 160, 186, 188, 189, 192, 193, 275, 276]

# filter dataframe
escape_df_filtered = escape_df[escape_df['site'].isin(site_list)]

escape_df_filtered['site'] = escape_df_filtered['site'].astype(str)

Generate normalized version of the dataframe:

In [49]:
# Group the DataFrame by the 'serum' column
grouped = escape_df_filtered.groupby('serum')

# Define a function to normalize the 'escape_mean' column within each group
def normalize(group):
    group['escape'] = group['escape'] / group['escape'].max()
    return group

# Apply the normalization function to each group
normalized_df = grouped.apply(normalize)

# Reset the index of the resulting DataFrame
normalized_df.reset_index(drop=True, inplace=True)

# set lower bound to 0
normalized_df['escape'] = normalized_df['escape'].clip(lower=0)

# replace mean_escape_mean values
normalized_df['escape_mean'] = (
    normalized_df.groupby(['site', 'cohort'])['escape']
    .transform('mean')
)

## Set up different chart options

In [34]:
site_order = ['121', '122', '124', '131', '135', '137', '145', 
              '156', '157', '159', '160', '186', '188', '189', '192', '193', 
              '48', '50', '275', '276', 
              '81', '82']  

libraries = ['hk19', 'perth09']

In [63]:
def filtered_site_plot(filtered_df, ha_strain, domain, axis_values=None):
    
    filtered_df = filtered_df.loc[filtered_df['ha_strain'] == ha_strain]

    if axis_values:
        filtered_sites_lineplot = (
            alt.Chart()
            .encode(
                x=alt.X("site", 
                        title="site",
                        sort=site_order                 
                       ),
                y=alt.Y(
                    "escape",
                    scale=alt.Scale(domain=domain),
                    title="escape score",
                    axis=alt.Axis(values=axis_values),
                ),
                color=alt.Color('cohort:N', 
                                legend=None
                               ).scale(scheme='dark2'),
                # detail='serum',
                detail=alt.Detail(['serum', 'antigenic_region']),
                tooltip=['serum', 'site', 'escape']
            )
            .mark_line(size=2.7, opacity=0.3, clip=True)
            .properties(width=500, height=130)
        )
        
    else:
        filtered_sites_lineplot = (
            alt.Chart()
            .encode(
                x=alt.X("site", 
                        title="site",
                        sort=site_order                 
                       ),
                y=alt.Y(
                    "escape",
                    scale=alt.Scale(domain=domain),
                    title="escape score",
                ),
                color=alt.Color('cohort:N', 
                                legend=None
                               ).scale(scheme='dark2'),
                # detail='serum',
                detail=alt.Detail(['serum', 'antigenic_region']),
                tooltip=['serum', 'site', 'escape']
            )
            .mark_line(size=2.7, opacity=0.3, clip=True)
            .properties(width=500, height=130)     
        )
    
    mean_points = (
        alt.Chart()
        .encode(
            x=alt.X("site", 
                    title="site",
                    sort=site_order   
                   ),
            y=alt.Y("escape_mean"),
            color=alt.Color('cohort:N', legend=None).scale(scheme='dark2'),
            tooltip=['site', 'escape_mean']
        )
        .mark_circle(size=55, opacity=0.75)
    )
    
    x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
        size=1, 
        opacity=0.5, color='gray').encode(y='y')
    
    faceted_overlay = alt.layer(
        filtered_sites_lineplot, mean_points, x_axis, data=filtered_df
    ).facet(
        facet=alt.Facet(
            'cohort:N',
            title='summed escape',
            header=alt.Header(
                titleFontSize=21,
                titleFontWeight='normal',
                titlePadding=5,
                labelFontSize=17,
                labelOrient='right',
                # labelAngle=0,
                labelFontStyle='italic'
            )
        ),
        spacing=3,
        columns=1
    ).configure_axis(
        grid=False,
        labelFontSize=14,
        titleFontSize=15
    )   

    return faceted_overlay

In [64]:
perth09_site_plot = filtered_site_plot(escape_df_filtered, 'perth09', domain=[-1, 45])
perth09_site_plot

In [65]:
hk19_site_plot = filtered_site_plot(escape_df_filtered, 'hk19', domain=[-12.5, 12.5])
hk19_site_plot

In [66]:
filtered_site_plot(normalized_df, 'perth09', domain=[-0.1, 1.1], axis_values=[0, 0.5, 1])

In [67]:
filtered_site_plot(normalized_df, 'hk19', domain=[-0.1, 1.1], axis_values=[0, 0.5, 1])

In [12]:
# filtered sites, line and scatter overlay
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order                 
               ),
        y=alt.Y(
            "escape_mean",
            scale=alt.Scale(domain=[-12.5, 12.5]),
            title="escape score",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='dark2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=1, opacity=0.4, clip=True)
    .properties(width=500, height=130)
)

mean_points = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order   
               ),
        y=alt.Y("mean_escape_mean"),
        color=alt.Color('age_group:N', legend=None).scale(scheme='set2'),
        tooltip=['site', 'escape_mean']
    )
    .mark_circle(size=30, opacity=0.75)
    # .properties(width=400, height=150)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(y='y')

faceted_line_scatter_overlay = alt.layer(
    summed_escape_lineplot, data=escape_df_filtered
# ).facet(
#     facet=alt.Facet(
#         'age_group:N',
#         title='summed escape',
#         header=alt.Header(
#             titleFontSize=21,
#             titleFontWeight='normal',
#             titlePadding=5,
#             labelFontSize=17,
#             labelOrient='right',
#             # labelAngle=0,
#             labelFontStyle='italic'
#         )
#     ),
#     spacing=3,
#     columns=1
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=15
)

faceted_line_scatter_overlay.save(
    'scratch_notebooks/figure_drafts/sitewise_escape/23090912_filt_escape_overlay.png',
    scale_factor=2.0
)

faceted_line_scatter_overlay

In [14]:
# filtered sites, line and scatter overlay
norm_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order 
               ),
        y=alt.Y(
            "escape_mean",
            # scale=alt.Scale(domain=[-12.5, 12.5]),
            scale=alt.Scale(domain=[-0.1, 1.1], clamp=True),
            title="escape score",
            axis=alt.Axis(values=[0, 0.5, 1])
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='dark2'),
        detail=['serum', 'antigenic_region'],
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=2.7, opacity=0.3, clip=True)
    .properties(width=500, height=130)
)

mean_points_norm = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order 
               ),
        y=alt.Y("mean_escape_mean"),
        color=alt.Color('age_group:N', legend=None).scale(scheme='dark2'),
        tooltip=['site', 'mean_escape_mean']
    )
    .mark_circle(size=55, opacity=0.75)
    # .properties(width=400, height=150)
)

faceted_norm_overlay = alt.layer(
    norm_escape_lineplot, mean_points_norm, data=normalized_df
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='normalized positive summed escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    spacing=3,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=15
)

# faceted_norm_overlay.save(
#     'scratch_notebooks/figure_drafts/sitewise_escape/230923_normalized_escape.png',
#     scale_factor=2.0
# )

faceted_norm_overlay

In [22]:
# filtered sites, line and scatter overlay
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                sort=site_order
               ),
        y=alt.Y(
            "escape_mean",
            # scale=alt.Scale(domain=[-12.5, 12.5]),
            scale=alt.Scale(domain=[-0.1, 1.1], clamp=True),
            title="escape score",
            axis=alt.Axis(values=[0, 0.5, 1])
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='dark2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=1.5, opacity=0.4, clip=True)
    .properties(width=500, height=130)
)

mean_points = (
    alt.Chart()
    .encode(
        x=alt.X("site", title="site"),
        y=alt.Y("mean_escape_mean"),
        color=alt.Color('age_group:N', legend=None).scale(scheme='dark2'),
        tooltip=['site', 'escape_mean']
    )
    .mark_circle(size=30, opacity=0.75)
    # .properties(width=400, height=150)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=0.5, 
    opacity=0.5, color='gray').encode(y='y')

line_scatter_overlay = alt.layer(
    summed_escape_lineplot, data=normalized_df
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=15
)

line_scatter_overlay.save(
    'scratch_notebooks/figure_drafts/sitewise_escape/230912_normalized_escape_overlay.png',
    scale_factor=2.0
)

line_scatter_overlay

In [13]:
# site_brush = alt.selection_interval(
#     encodings=["x"],
#     mark=alt.BrushConfig(stroke="black", strokeWidth=2),
# )

# site_zoom_bar = (
#     alt.Chart(escape_df_full)
#     .mark_rect()
#     .encode(
#         x=alt.X(
#             "site:O",
#             sort=alt.EncodingSortField(field="_stat_site_order", order="ascending"),
#         ),
#         color=(
#             alt.Color(
#                 'antigenic_region:N',
#                 # scale=alt.Scale(scheme=site_zoom_bar_color_scheme),
#                 legend=alt.Legend(orient="left"),
#                 # sort=(
#                 #     site_zoom_bar_df.set_index("site")
#                 #     .loc[sites]['antigenic_region']
#                 #     .unique()
#                 # ),
#             )
#             # if site_zoom_bar_color_col
#             # else alt.value("gray")
#         ),
#         # tooltip=[c for c in site_zoom_bar_df.columns if not c.startswith("_stat")],
#     )
#     .mark_rect()
#     .add_params(site_brush)
#     # .properties(width=500, height=10, title="site zoom bar")
# )



# all sites, lineplot
summed_escape_lineplot_full = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
                scale=alt.Scale(domain=[0, 540])
               ),
        y=alt.Y(
            "escape_mean",
            title=None,
            scale=alt.Scale(domain=[-12.5, 12.5],
                            clamp=True
                           )
        ),
        color=alt.Color('age_group:N', 
                        legend=None,
                       ).scale(scheme='dark2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=.75, opacity=0.7)
    .properties(width=600, height=70)
)

faceted_lineplot_full = alt.layer(
    summed_escape_lineplot_full, data=escape_df_full
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='escape at all residues',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=5,
            # labelExpr="''", title=None
            labelFontSize=1,
            labelOrient='right',
            # labelAngle=0,
            # labelFontStyle='italic',
        )
    ),
    spacing=2,
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

faceted_lineplot_full.save(
    'scratch_notebooks/figure_drafts/sitewise_escape/231003_summed_escape_full.png',
    scale_factor=2.0
)

faceted_lineplot_full

In [18]:
site_brush = alt.selection_interval(
    encodings=["x"],
    mark=alt.BrushConfig(stroke="black", strokeWidth=2),
)

site_zoom_bar = (
    alt.Chart(escape_df_full)
    .mark_rect()
    .encode(
        x=alt.X(
            "site:O",
            sort=alt.EncodingSortField(field="_stat_site_order", order="ascending"),
        ),
        color=(
            alt.Color(
                'antigenic_region:N',
                # scale=alt.Scale(scheme=site_zoom_bar_color_scheme),
                legend=alt.Legend(orient="left"),
                # sort=(
                #     site_zoom_bar_df.set_index("site")
                #     .loc[sites]['antigenic_region']
                #     .unique()
                # ),
            )
            # if site_zoom_bar_color_col
            # else alt.value("gray")
        ),
        # tooltip=[c for c in site_zoom_bar_df.columns if not c.startswith("_stat")],
    )
    .mark_rect()
    # .add_params(site_brush)
    .properties(width=500, height=10, title="site zoom bar")
)

site_zoom_bar

In [None]:
# create site zoom bar
    site_brush = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(stroke="black", strokeWidth=2),
    )
    if site_zoom_bar_color_col:
        assert site_zoom_bar_color_col not in {"_n", "_drop"}, site_zoom_bar_color_col
        site_zoom_bar_df = (
            data_df[["site", "_stat_site_order", site_zoom_bar_color_col]]
            .drop_duplicates()
            .assign(
                _n=lambda x: (
                    x.groupby("site")[site_zoom_bar_color_col].transform("size")
                ),
                _drop=lambda x: (
                    x[site_zoom_bar_color_col]
                    .isnull()
                    .where(
                        x["_n"] > 1,
                        False,
                    )
                ),
            )
            .query("not _drop")
            .drop(columns=["_n", "_drop"])
        )
        site_zoom_bar_df[site_zoom_bar_color_col] = site_zoom_bar_df[
            site_zoom_bar_color_col
        ].fillna("null")
        if any(site_zoom_bar_df.groupby("site").size() > 1):
            raise ValueError(
                f"multiple {site_zoom_bar_color_col=} values for sites:"
                + str(
                    site_zoom_bar_df.assign(
                        n=lambda x: (
                            x.groupby("site")[site_zoom_bar_color_col].transform("size")
                        ),
                    )
                    .sort_values("n", ascending=False)
                    .reset_index(drop=True)
                )
            )
    else:
        site_zoom_bar_df = data_df[["site", "_stat_site_order"]].drop_duplicates()
    site_zoom_bar = (
        alt.Chart(site_zoom_bar_df)
        .mark_rect()
        .encode(
            x=alt.X(
                "site:O",
                sort=alt.EncodingSortField(field="_stat_site_order", order="ascending"),
            ),
            color=(
                alt.Color(
                    site_zoom_bar_color_col,
                    type="nominal",
                    scale=alt.Scale(scheme=site_zoom_bar_color_scheme),
                    legend=alt.Legend(orient="left"),
                    sort=(
                        site_zoom_bar_df.set_index("site")
                        .loc[sites][site_zoom_bar_color_col]
                        .unique()
                    ),
                )
                if site_zoom_bar_color_col
                else alt.value("gray")
            ),
            tooltip=[c for c in site_zoom_bar_df.columns if not c.startswith("_stat")],
        )
        .mark_rect()
        .add_params(site_brush)
        .properties(width=site_zoom_bar_width, height=cell_size, title="site zoom bar")
    )



# Chart comparison

In [15]:
faceted_line_scatter_overlay

In [16]:
faceted_lineplot_full

## scratch code

In [75]:
def get_escapes(sera_list, age_group, agg_type, site_list=None):
    escape_df_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 5"
        )

        # prob_escape['summed'] = prob_escape.groupby(['site', 'wildtype'])['escape_mean'].transform('sum')
        # prob_escape['mean'] = prob_escape.groupby(['site', 'wildtype'])['escape_mean'].transform('mean')
        # prob_escape = prob_escape[['site', 'wildtype', 'summed', 'mean']].drop_duplicates()

        # prob_escape = prob_escape.melt(
        #     id_vars=['site', 'wildtype'], 
        #     value_vars=['summed', 'mean'],
        #     var_name='escape_type',
        #     value_name='escape'
        # )

        prob_escape['escape'] = prob_escape.groupby(['site', 'wildtype'])['escape_mean'].transform(agg_type)
        prob_escape = prob_escape[['site', 'wildtype', 'escape']].drop_duplicates()

        if site_list:
            prob_escape = prob_escape[prob_escape['site'].isin(site_list)]
            prob_escape['site'] = pd.Categorical(prob_escape['site'], ordered=True)
            prob_escape['site'] = prob_escape['site'].astype(str)
            
        prob_escape['serum'] = serum
        prob_escape['age_group'] = age_group
        
        escape_df_list.append(prob_escape)
        
    full_escape_df = pd.concat(escape_df_list)
    return full_escape_df

In [76]:
# initialize list of key sites, plus samples in each age cohort
site_list = [50, 82, 103, 121, 122, 124, 131, 135, 137, 138, 145, 156, 157, 
              159, 160, 186, 188, 189, 193, 220, 224, 244, 276]

peds = [2367, 3944, 2462, 2389, 2323, 2388, 2463, 3973, 4299, 4584]
teens = [2343, 2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862, 3895]
adults = ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C']
ferrets = ['ferret_1', 'ferret_2', 'ferret_3']

sample_lists = [peds, teens, adults, ferrets]
cohorts = ['0-5', '15-18', '40-45', 'ferrets']
summed_escapes_filtered = []
mean_escapes_filtered = []

i=0 # for looping through age cohort definitions

# start by getting full escape df filtered to key sites
for list in sample_lists:
    escape_sum = get_escapes(list, cohorts[i], 'sum', site_list)
    escape_mean = get_escapes(list, cohorts[i], 'mean', site_list)
    
    summed_escapes_filtered.append(escape_sum)
    mean_escapes_filtered.append(escape_mean)

    i+=1

filtered_df_sum = pd.concat(summed_escapes_filtered)
filtered_df_mean = pd.concat(mean_escapes_filtered)

site_dict = {'50': '050', 
             '62': '062', 
             '82': '082', 
             '94': '094'}

filtered_df_sum['site'] = filtered_df_sum['site'].apply(
    lambda x: site_dict[x] if x in site_dict else x
)
filtered_df_mean['site'] = filtered_df_mean['site'].apply(
    lambda x: site_dict[x] if x in site_dict else x
)

filtered_df_sum['serum'] = filtered_df_sum['serum'].astype(str)
filtered_df_mean['serum'] = filtered_df_mean['serum'].astype(str)

# add 'escape_mean' columns of mean escape values per site within an age group
filtered_df_sum['escape_mean'] = (
    filtered_df_sum.groupby(['site', 'age_group'])['escape']
    .transform('mean')
)

filtered_df_mean['escape_mean'] = (
    filtered_df_sum.groupby(['site', 'age_group'])['escape']
    .transform('mean')
)

# full_df_filtered['summed_escape_mean'] = (
#     full_df_filtered.groupby(['site', 'age_group'])['summed_escape']
#     .transform('mean')
# )

# full_df_filtered['mean_escape_mean'] = (
#     full_df_filtered.groupby(['site', 'age_group'])['mean_escape']
#     .transform('mean')
# )

filtered_df_mean

Unnamed: 0,site,wildtype,escape,serum,age_group,escape_mean
487,050,E,0.030716,2367,0-5,1.644210
645,082,K,-0.003329,2367,0-5,0.108860
768,103,P,-0.032294,2367,0-5,-0.376590
859,121,K,-0.034284,2367,0-5,-0.643280
879,122,N,0.000029,2367,0-5,-0.560240
...,...,...,...,...,...,...
1316,193,S,-0.038428,ferret_3,ferrets,-0.542600
1571,220,R,-0.333139,ferret_3,ferrets,-5.531733
1614,224,R,-0.357231,ferret_3,ferrets,-4.203567
1693,244,L,-0.235982,ferret_3,ferrets,-2.630167


In [84]:
# filtered sites, scatterplot
escape_scatterplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="escape score",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_circle(size=30, opacity=0.7)
    .properties(width=500, height=200)
)



x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_scatter_sum = alt.layer(
    escape_scatterplot, x_axis, data=filtered_df_sum
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=0
            # labelOrient='right',
            # labelAngle=0,
            # labelFontStyle='italic'
        )
    ),
    columns=1
)

faceted_scatter_mean = alt.layer(
    escape_scatterplot, x_axis, data=filtered_df_mean
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='mean escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
)

# faceted_scatter_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230714_summed_escape_scatterplot.html')

(faceted_scatter_sum | faceted_scatter_mean).configure_axis(
        grid=False,
        labelFontSize=13,
        titleFontSize=15
    )

In [89]:
escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape",
            title="escape score",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape']
    )
    .mark_line(size=1.5, opacity=0.4)
    .properties(width=500, height=150)
)

mean_points = (
    alt.Chart()
    .encode(
        x=alt.X("site", title="site"),
        y=alt.Y("mean", title="escape score"),
        color=alt.Color('age_group:N', legend=None).scale(scheme='set2'),
        tooltip=['site', 'escape']
    )
    .mark_circle(size=45, opacity=0.75)
    .properties(width=500, height=150)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

line_scatter_sum = alt.layer(
    escape_lineplot, mean_points, x_axis, data=filtered_df_sum
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
)

line_scatter_mean = alt.layer(
    escape_lineplot, mean_points, x_axis, data=filtered_df_mean
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='mean escape',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
)

# faceted_scatter_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230714_summed_escape_scatterplot.html')

(line_scatter_sum | line_scatter_mean).configure_axis(
        grid=False,
        labelFontSize=13,
        titleFontSize=15
    )

ValueError: Unable to determine data type for the field "mean"; verify that the field name is not misspelled. If you are referencing a field from a transform, also confirm that the data type is specified correctly.

alt.HConcatChart(...)

In [8]:
# filtered sites, scatterplot
summed_escape_scatterplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed_escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_circle(size=30, opacity=0.7)
    .properties(width=500, height=200)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_scatter_filtered = alt.layer(
    summed_escape_scatterplot, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

# faceted_scatter_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230714_summed_escape_scatterplot.html')

faceted_scatter_filtered

In [142]:
# filtered sites, lineplot
summed_escape_lineplot = (
    alt.Chart()
    .encode(
        x=alt.X("site", 
                title="site",
               ),
        y=alt.Y(
            "escape_mean",
            title="summed escape",
        ),
        color=alt.Color('age_group:N', 
                        legend=None
                       ).scale(scheme='set2'),
        detail='serum',
        tooltip=['serum', 'site', 'escape_mean']
    )
    .mark_line(size=0.6, opacity=.6)
    .properties(width=550, height=175)
)

mean_line = (
    alt.Chart()
    .encode(
        x=alt.X("site", title="site"),
        y=alt.Y("mean_escape_mean", title="summed escape"),
        color=alt.Color('age_group:N', legend=None).scale(scheme='set2'),
        tooltip=['site', 'escape_mean']
    )
    .mark_line(size=3, opacity=0.75)
    .properties(width=550, height=175)
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(opacity=0.5).encode(y='y')

faceted_lineplot_filtered = alt.layer(
    summed_escape_lineplot, mean_line, x_axis, data=escape_df_filtered
).facet(
    facet=alt.Facet(
        'age_group:N',
        title='summed escape at significant sites',
        header=alt.Header(
            titleFontSize=21,
            titleFontWeight='normal',
            titlePadding=20,
            labelFontSize=17,
            labelOrient='right',
            # labelAngle=0,
            labelFontStyle='italic'
        )
    ),
    columns=1
).configure_axis(
    grid=False,
    labelFontSize=13,
    titleFontSize=15
)

# faceted_lineplot_filtered.save('scratch_notebooks/figure_drafts/sitewise_escape_plots/230715_summed_escape_lineplot.html')

faceted_lineplot_filtered

In [8]:
# initialize list of key sites, plus samples in each age cohort
site_list = [48, 50, 81, 82, 121, 122, 124, 131, 135, 137, 145, 156, 157, 
              159, 160, 186, 188, 189, 192, 193, 275, 276]


peds = [3944, 2389, 2323, 2388, 3973, 4299, 4584, 2367]
teens = [2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862]
adults = ['33C', '34C', '197C', '199C', '215C', 
          '210C', '74C', '68C', '150C', '18C',]
elderly = ['AUSAB-13']
infant = [2462]

sample_lists = [peds, teens, adults, elderly, infant]
cohorts = ['02-05 years', '15-20 years', '40-45 years',  '68 years', 'infant']
summed_escapes_filtered = []

i=0 # for looping through age cohort definitions

# start by getting full escape df filtered to key sites
for entry in sample_lists:
    summed_escape_selected_sites = get_summed_escapes(entry, cohorts[i], site_list)
    summed_escapes_filtered.append(summed_escape_selected_sites)

    i+=1

escape_df_filtered = pd.concat(summed_escapes_filtered)

escape_df_filtered['serum'] = escape_df_filtered['serum'].astype(str)

# add 'mean_escape_mean' column of mean escape values per site within an age group
escape_df_filtered['mean_escape_mean'] = (
    escape_df_filtered.groupby(['site', 'age_group'])['escape_mean']
    .transform('mean')
)

In [11]:
# also generate escape df with all sites included
summed_escapes = []
i=0 # for looping through age cohort definitions

for entry in sample_lists:
    summed_escape = get_summed_escapes(entry, cohorts[i])
    summed_escapes.append(summed_escape)

    i+=1

escape_df_full = pd.concat(summed_escapes)

escape_df_full['serum'] = escape_df_full['serum'].astype(str)

escape_df_full['antigenic_region'] = escape_df_full['site'].map(map_site_to_antigenic_region)