# Correlations between DMS escape scores and validation neut assays for circulating H3 HA strains
In this notebook, we generate plots showing the correlation between predicted strain escape from summed DMS measurements, and fold-change-IC50s from neutralization assays, for five H3 HA strains circulating since 2014. These validations included 5 different strains and 4 sera. Visualization of the neutralization curves and calculation of fold-change-IC50 values can be found in `../../neut_assays/validations_hk19`. 

Final plots include a single correlation plot for all analyzed sera (`figure_6C`) and correlation plots for each individual serum (`figure_S13`). These plots include calculations of the Pearson R correlation coefficient.

In [1]:
import altair as alt

import pandas as pd

import warnings
warnings.filterwarnings('ignore')

import altair_saver
import scipy as sp

import numpy as np

import os

In [2]:
os.chdir('../../')

### Read in IC50-fold-change data for validated strains
See `neut_assays/validations_hk19` for code generating these IC50-fold-change measurements from neutralization curves.

In [3]:
ics = pd.read_csv('neut_assays/validations_hk19/ic50_fold_changes_strains.csv')

# make all sera a consistent datatype
ics['serum'] = ics['serum'].astype(str)

# strain_dict = {
#     '4642_stockholm-14': 'A/Stockholm/18/2014',
#     '4643_singapore-17': 'A/Singapore/TT0495/2017',
#     '4644_ecuador-19': 'A/Ecuador/4472/2019', 
#     '4645_south-auck-20': 'A/SouthAuckland/2/2020',
#     '4646_zhejiang-22': 'A/Zhejiang-yongkang/1312/2022',
# }

cohort_dict = {
    '2388': 'children', 
    '3862': 'teenagers',
    '2380': 'teenagers', 
    '199C': 'adults',
}

ics['cohort'] = ics['serum'].map(cohort_dict)
ics = ics.rename(columns={'variant': 'strain'})

### Merge with DMS escape scores for each mutation
`h3_strain_validations.csv` contains predicted DMS escape scores, based on summed effects of each mutation relative to A/Hong Kong/45/2019. This file was generated from the notebook https://github.com/matsengrp/seasonal-flu-dmsa/blob/master/profiles/dmsa-phenotype/notebooks/analysis_code.ipynb in a separate repository.

In [4]:
strain_escape = pd.read_csv('figures/validation_correlations/h3_strain_validations.csv')
strain_escape.head()

Unnamed: 0,strain,year,all_aa_substitutions,escape_score,serum,cohort,n_aa_substitutions
0,A/Stockholm/18/2014,2014.619178,G62E R92K K121N A128T K135T F137S S138A G142R ...,0.2894,199C,adults,14
1,A/Stockholm/18/2014,2014.619178,G62E R92K K121N A128T K135T F137S S138A G142R ...,-0.2799,74C,adults,14
2,A/Stockholm/18/2014,2014.619178,G62E R92K K121N A128T K135T F137S S138A G142R ...,-0.8299,4299,children,14
3,A/Stockholm/18/2014,2014.619178,G62E R92K K121N A128T K135T F137S S138A G142R ...,-0.5252,2388,children,14
4,A/Stockholm/18/2014,2014.619178,G62E R92K K121N A128T K135T F137S S138A G142R ...,-0.9317,3862,teenagers,14


In [5]:
corr_df = (
    ics.merge(
        strain_escape,
        how="left",
        on=["strain", "serum", "cohort"],
        validate="one_to_one",
    )
    .fillna(0)
)

corr_df.head()

Unnamed: 0,serum,strain,log2_fold_change_ic50,cohort,year,all_aa_substitutions,escape_score,n_aa_substitutions
0,2388,A/Stockholm/18/2014,-2.080628,children,2014.619178,G62E R92K K121N A128T K135T F137S S138A G142R ...,-0.5252,14
1,2388,A/Singapore/TT0495/2017,-2.016404,children,2017.320548,L59I A128T K135T F137S S138A S193F N312S,-0.5908,7
2,2388,A/Ecuador/4472/2019,-2.32045,children,2019.024658,F137S S138A S193F N312S,-0.4543,4
3,2388,A/SouthAuckland/2/2020,0.189488,children,2020.989041,E50K S91N S144T N312S,0.0561,4
4,2388,A/Zhejiang-yongkang/1312/2022,1.348478,children,2022.2,I48T K83E Y94N A128T T131K K135T F137S S138A K...,0.4945,15


In [6]:
# Merge model predictions with measured ICs
corr_df = (
    ics.merge(
        strain_escape,
        how="left",
        on=["strain", "serum", "cohort"],
        validate="one_to_one",
    )
    .fillna(0)
)

corr_df = corr_df[['serum', 'log2_fold_change_ic50', 'strain', 'n_aa_substitutions', 'escape_score', 'cohort']]

corr_df = corr_df.rename(columns={'escape_score': 'escape'})

Add additional data to `corr_df`:
* A WT entry for each serum with log2 fold change of 0 and escape of 0
* Age cohort for each serum
* Pearson R correlation coefficient for each sera

In [7]:
# generate empty dataframe for WT entries with columns matching corr_df
wt_df = pd.DataFrame(columns=['serum', 'strain', 'log2_fold_change_ic50', 'n_aa_substitutions', 'escape'])

# Iterate over unique serum values
for serum_value in corr_df['serum'].unique():
    
    # Create a 'WT' row for the current serum value
    wt_row = pd.DataFrame({'serum': [serum_value], 
                           'strain': ['WT'], 
                           'log2_fold_change_ic50': [0], 
                           'n_aa_substitutions': [0],
                           'escape': [0]})

    # Append the 'WT' row to the result DataFrame
    wt_df = pd.concat([wt_df, wt_row], ignore_index=True)
    
# Append the 'WT' rows to the original DataFrame
corr_df = pd.concat([corr_df, wt_df], ignore_index=True)

# add Pearson R values to df
r_dict = {}
sera = corr_df['serum'].unique().tolist()
ic_col = [col for col in corr_df.columns if 'ic' in col][0]

for serum in sera:
    serum_df = corr_df.loc[corr_df['serum'] == serum]
    x = serum_df[ic_col]
    y = serum_df["escape"]
    r, p = sp.stats.pearsonr(x, y)
    r = round(r, 3)
    
    r_dict[serum] = r

corr_df['r'] = corr_df['serum'].map(r_dict)

corr_df.head()

Unnamed: 0,serum,log2_fold_change_ic50,strain,n_aa_substitutions,escape,cohort,r
0,2388,-2.080628,A/Stockholm/18/2014,14,-0.5252,children,0.98
1,2388,-2.016404,A/Singapore/TT0495/2017,7,-0.5908,children,0.98
2,2388,-2.32045,A/Ecuador/4472/2019,4,-0.4543,children,0.98
3,2388,0.189488,A/SouthAuckland/2/2020,4,0.0561,children,0.98
4,2388,1.348478,A/Zhejiang-yongkang/1312/2022,15,0.4945,children,0.98


### Plot correlation for all sera
All sera are overlaid on a single plot, and the Pearson R value is calculated for the complete set of sera + mutations as a whole.

In [8]:
# get ic column name
ic_col = [col for col in corr_df.columns if 'ic' in col][0]

# Calculate Pearson correlation between predicted and measured
x = corr_df[ic_col]
y = corr_df["escape"]
r, p = sp.stats.pearsonr(x, y)

r = round(r, 3)

# Define the order for 'variant' and 'serum'
strain_order = ['WT', 
                'A/Stockholm/18/2014', 
                'A/Singapore/TT0495/2017', 
                'A/Ecuador/4472/2019', 
                'A/SouthAuckland/2/2020', 
                'A/Zhejiang-yongkang/1312/2022',
               ]

serum_order = ['2388', '3862', '2380', '199C']

# Define cb-friendly color scheme
custom_color_scheme = ['#333333', 
                       '#CC6677', 
                       '#332288',  
                       '#117733', 
                       '#882255', 
                       '#88CCEE', 
                      ]

# Get plot
corrs = (
    alt.Chart(corr_df)
    .mark_point(filled=True, size=200)
    .encode(
        x=alt.X(ic_col,
                title=ic_col,
                scale=alt.Scale(
                    domain=[-4, 4],
                )
               ),
        y=alt.Y('escape',
                title='escape score',
                scale=alt.Scale(
                    domain=[-1, 1],
                    clamp=True
                )
               ),
        color=alt.Color('strain:N', scale=alt.Scale(range=custom_color_scheme, domain=strain_order)), 
        shape=alt.Shape('serum:N', scale=alt.Scale(domain=serum_order)),
        tooltip=['strain', 'serum', ic_col, 'escape']
    )

)

# Get R value to add to plot
r_text = alt.Chart().mark_text(
    align='left', 
    baseline='bottom',
    fontSize=14,
    fontWeight=300
).encode(
    x=alt.value(10),
    y=alt.value(20),
    text=alt.value([f"R={r}"])
)

# Add defined x and y axes
x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(y='y')

y_axis = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(x='x')

# plot final layered chart
chart = (
    (corrs + r_text + x_axis + y_axis)
    .configure_axis(
        grid=False,
        labelFontSize=14,
        titleFontSize=15,  
    )
    .configure_title(
        fontSize=21,
        fontWeight='normal',
    )
)

chart.save(
    'figures/validation_correlations/figure_6/hk19_strain-corr_ic50.png',
    scale_factor=2.0
)

chart

In [9]:
# define ordering for plot
strain_order = ['WT', 
                'A/Stockholm/18/2014', 
                'A/Singapore/TT0495/2017', 
                'A/Ecuador/4472/2019', 
                'A/SouthAuckland/2/2020', 
                'A/Zhejiang-yongkang/1312/2022',
               ]

serum_order = ['2388', '3862', '2380', '199C']

cohort_order = ['children', 'teenagers', 'adults']

# re-define `corrs` chart to adjust plot sizing for facet
corrs = (
    alt.Chart()
    .mark_point(filled=True, size=100)
    .encode(
        x=alt.X(ic_col, 
                title=ic_col,
                scale=alt.Scale(
                        domain=[-4, 4],
                ),
               ),
        y=alt.Y('escape', 
                title='escape score',
                scale=alt.Scale(
                    domain=[-1, 1],
                    clamp=True
                )
               ),
        color=alt.Color('strain:N', scale=alt.Scale(range=custom_color_scheme, domain=strain_order)),
        shape=alt.Shape('cohort:N', scale=alt.Scale(domain=cohort_order)),
        tooltip=['strain', 'serum', ic_col, 'escape']
    )
    .properties(
        title=serum,
        width=120,
        height=120
    )
)

# re-define `r-text` using the R value defined for each serum in corr_df
# and adjust font size and positioning for faceted chart
r_text = alt.Chart().mark_text(
    align='right',
    baseline='bottom',
    fontSize=12,
    fontWeight=300
).encode(
    x=alt.value(115),
    y=alt.value(115),
    text='r:N'
)

# plot faceted chart
faceted_corrs = alt.layer(
    corrs, r_text, x_axis, y_axis, data=corr_df, 
).facet(
    facet=alt.Facet(
        'serum:N',
        sort=serum_order,
        header=alt.Header(
            titleFontSize=20,
            labelFontSize=14,
            labelPadding=2
    )
    ), 
    spacing=3,
    columns=4
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=10,
)

faceted_corrs.save(
    'figures/validation_correlations/figure_S13/hk19_faceted-corrs_strains.png',
    scale_factor=2.0
)

faceted_corrs