# Reassign B factors for visualizing escape on the H3 HA structure
In order to visualize escape, we reassign the B-factors to escape scores for each site (i.e. residue), then color by B-factors in PyMol. This notebook generates PDB files with B-factors assigned to the maximum normalized escape score for each cohort analyzed in this paper. 

We chose to use the normalized escape scores because these values are more directly comparable between the Perth/2009 and HongKong/2019 data. We chose the maximum normalized escape scores because our goal was to visualize the differences in regions targeted by **any** serum in a cohort, rather than the average escape. Using the maximum accounts for the heterogeneity in targeting within certain cohorts (e.g. the average escape is relatively low because only one or two individuals have appreciable neutralization escape at that site). These structures are shown in Figure 5.

These pdb files are output in the folder `escape_score_pdbs/`, and are used as inputs for the script `cohort_escape_viz.py`. Final recolored structures are saved in `figure_5/`.

In [1]:
import pickle

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

import Bio

from IPython.utils import io

In [2]:
import os
os.chdir('../../')

### Get HongKong/19 escape data

In [3]:
def get_summed_escapes(sera_list, age_group, site_list=None):
    summed_escape_list = []
    
    for serum in sera_list:
        prob_escape = pd.read_csv(
            f'results/antibody_escape/{serum}_avg.csv'
        ).query(
            "`times_seen` >= 5"
        )
        
        prob_escape_sum = (
            prob_escape.groupby(['site', 'wildtype'], as_index=False)
            .aggregate({'escape_mean': 'sum'})
            .rename(columns={'escape_mean': 'escape'})
        )

        if site_list:
            prob_escape_final = prob_escape_sum[prob_escape_sum['site'].isin(site_list)]
            prob_escape_final['site'] = pd.Categorical(prob_escape_final['site'], ordered=True)
            # prob_escape_final['site'] = prob_escape_final['site'].astype(str)

        else:
            prob_escape_final = prob_escape_sum.copy()
            
        prob_escape_final['serum'] = serum
        prob_escape_final['cohort'] = age_group
        
        summed_escape_list.append(prob_escape_final)
        
    summed_escape = pd.concat(summed_escape_list)
    return summed_escape

In [4]:
peds = [3944, 2389, 2323, 2388, 3973, 4299, 4584, 2367]
teens = [2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862]
adults = ['33C', '34C', '197C', '199C', '215C', 
          '210C', '74C', '68C', '150C', '18C',]

sample_lists = [peds, teens, adults]
cohorts = ['2-5_years', '15-20_years', '40-45_years']

summed_escapes = []
i=0 # for looping through age cohort definitions

for entry in sample_lists:
    summed_escape = get_summed_escapes(entry, cohorts[i])
    summed_escapes.append(summed_escape)

    i+=1

escape_df_hk19 = pd.concat(summed_escapes)

escape_df_hk19['serum'] = escape_df_hk19['serum'].astype(str)
escape_df_hk19['ha_strain'] = 'hk19'

### Get Perth09 escape data

In [5]:
# define samples in each age cohort
sample_dict = {
    "2-4_years": [
        "age 2.1 (Vietnam)", 
        "age 2.2 (Vietnam)",
        "age 2.4 (Vietnam)",
        "age 2.5 (Vietnam)",
        "age 2.5b (Vietnam)",
        "age 3.3 (Vietnam)", 
        "age 3.3b (Vietnam)",
        "age 3.4 (Vietnam)", 
        "age 3.5 (Vietnam)",
    ],   
    "30-34_years": [
        "age 30.5 (Vietnam)",
        "age 31.5 (Vietnam)",
        "age 33.5 (Vietnam)",
    ],
    "misc_adult": [
        "age 21 (Seattle)",
        "age 53 (Seattle)",
        "age 64 (Seattle)",
        "age 65 (Seattle)",
    ],
    "ferret": [
        "ferret 1 (Pitt)",
        "ferret 2 (Pitt)",
        "ferret 3 (Pitt)",
        "ferret (WHO)",
    ]
}

# get full dataset
escape_df = pd.read_csv(f'results/perth2009/merged_escape.csv')[['name', 'site', 'wildtype', 'mutant', 'escape']]
escape_df = escape_df.rename(columns={'name': 'serum'})

# Function to convert '(HA2)X' to numeric
def convert_site_to_numeric(site):
    if '(HA2)' in site:
        try:
            number = int(site.replace('(HA2)', '').strip())
            return number + 329
        except ValueError:
            return site  # If there's an issue with conversion, return the original value
    else:
        return site

# Apply the function to the 'site' column
escape_df['site'] = escape_df['site'].apply(convert_site_to_numeric)

# get summed escape at each site
escape_df = escape_df.groupby(['serum', 'site', 'wildtype'], as_index=False).aggregate({'escape': 'sum'})

# floor at 0
escape_df['escape'] = escape_df['escape'].clip(lower=0)

# add cohort label
def find_sample_type(sample_name):
    for sample_type, sample_list in sample_dict.items():
        if sample_name in sample_list:
            return sample_type
    return None

escape_df['cohort'] = escape_df['serum'].apply(find_sample_type)

escape_df = escape_df.loc[(escape_df['cohort'] != 'misc_adult') & (escape_df['cohort'] != 'ferret')]
escape_df['site'] = escape_df['site'].astype(int)

escape_df['ha_strain'] = 'perth09'

### Merge datasets and calculate max normalized escape value for each cohort

In [6]:
escape_df_full = pd.concat([escape_df, escape_df_hk19])

# Group the DataFrame by the 'serum' column
grouped = escape_df_full.groupby('serum')

# Define a function to normalize the 'escape_mean' column within each group
def normalize(group):
    group['escape'] = group['escape'] / group['escape'].max()
    return group

# Apply the normalization function to each group
normalized_df = grouped.apply(normalize)

# Reset the index of the resulting DataFrame
normalized_df.reset_index(drop=True, inplace=True)

# Clip to just the positive values
normalized_df['escape'] = normalized_df['escape'].clip(lower=0)

# add mean cohort escape column
normalized_df['mean_cohort_escape'] = (
    normalized_df.groupby(['site', 'cohort'])['escape']
    .transform('mean')
)

# add max cohort escape column
normalized_df['max_cohort_escape'] = (
    normalized_df.groupby(['site', 'cohort'])['escape']
    .transform('max')
)

In [7]:
# Create a new DataFrame by grouping 'normalized_df' by the 'cohort' column
grouped = normalized_df.groupby('cohort')

# Initialize an empty list to store DataFrames for each cohort
cohort_dfs = []

# Loop through each cohort group
for cohort, cohort_data in grouped:
    # Create a DataFrame containing unique sites and their mean_cohort_escape values
    cohort_escape = cohort_data[['site', 'mean_cohort_escape', 'max_cohort_escape', 'ha_strain', 'cohort']].drop_duplicates()

    # Add chain identifiers, which will duplicate sites 3x
    chain_dfs = []
    chains = ['A', 'C', 'E']
    for chain in chains:
        chain_df = cohort_escape.copy()
        chain_df['chain'] = chain
        # cohort_escape['chain'] = chain
        chain_dfs.append(chain_df)

    cohort_escape = pd.concat(chain_dfs, ignore_index=True)
    
    # Append this cohort's data to the list
    cohort_dfs.append(cohort_escape)

### Reassign B-factors in PDB file for each cohort

In [8]:
def reassign_b_factor(
    input_pdbfile,
    output_pdbfile,
    df,
    metric_col,
    *,
    site_col="site",
    chain_col="chain",
    missing_metric=0,
    model_index=0,
):

    # subset `df` to needed columns and error check it
    cols = [metric_col, site_col, chain_col]
    for col in cols:
        if col not in df.columns:
            raise ValueError(f"`df` lacks column {col}")
    df = df[cols].drop_duplicates()
    if len(df) != len(df.groupby([site_col, chain_col])):
        raise ValueError("non-unique metric for a site in a chain")

    # read PDB, catch warnings about discontinuous chains
    with warnings.catch_warnings():
        warnings.simplefilter(
            "ignore", category=Bio.PDB.PDBExceptions.PDBConstructionWarning
        )
        pdb = Bio.PDB.PDBParser().get_structure("_", input_pdbfile)

    # get the model out of the PDB
    model = list(pdb.get_models())[model_index]

    # make sure all chains in PDB
    missing_chains = set(df[chain_col]) - {chain.id for chain in model.get_chains()}
    if missing_chains:
        raise ValueError(f"`df` has chains not in PDB: {missing_chains}")

    # make missing_metric a dict if it isn't already
    if not isinstance(missing_metric, dict):
        missing_metric = {chain.id: missing_metric for chain in model.get_chains()}

    # loop over all chains and do coloring
    for chain in model.get_chains():
        chain_id = chain.id
        site_to_val = (
            df.query(f"{chain_col} == @chain_id")
            .set_index(site_col)[metric_col]
            .to_dict()
        )
        for residue in chain:
            site = residue.get_id()[1]
            try:
                metric_val = site_to_val[site]
            except KeyError:
                metric_val = missing_metric[chain_id]
            # for disordered residues, get list of them
            try:
                residuelist = residue.disordered_get_list()
            except AttributeError:
                residuelist = [residue]
            for r in residuelist:
                for atom in r:
                    # for disordered atoms, get list of them
                    try:
                        atomlist = atom.disordered_get_list()
                    except AttributeError:
                        atomlist = [atom]
                    for a in atomlist:
                        a.bfactor = metric_val

    # write PDB
    io = Bio.PDB.PDBIO()
    io.set_structure(pdb)
    io.save(output_pdbfile)

In [9]:
for df in cohort_dfs:
    cohort_name = df['cohort'].iloc[0]
    ha_strain = df['ha_strain'].iloc[0]
    reassign_b_factor(input_pdbfile='data/PDBs/4o5n.pdb',
                      output_pdbfile=f'figures/structure_plots/escape_score_pdbs/{ha_strain}_{cohort_name}_normalized_max.pdb',
                      df=df,
                      metric_col='max_cohort_escape',
                      site_col="site",
                      chain_col="chain",
                      missing_metric=0,
                      model_index=0,)