In [1]:
import altair as alt

import pandas as pd

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../../')

### Get averaged escape scores for main HK/19 cohorts

In [3]:
# define samples in each age cohort
ped_sera = [3944, 2389, 2323, 2388, 3973, 4299, 4584, 2367]
teen_sera = [2350, 2365, 2380, 2382, 3866, 3856, 3857, 3862]
adult_sera = ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C']

# get list of lists for samples divided by age group
serum_lists = [ped_sera, teen_sera, adult_sera]
age_cohorts = ['2-5 years', '15-20 years', '40-45 years']

# adjust this if we want more stringent filtering
min_times_seen = 3
min_func_score = -1.38

df_list = []

i = 0 # for looping across age cohort definitions

for list in serum_lists:
    for serum in list:
        # reading in values from averaged libA and libB models
        avg_df = pd.read_csv(f'results/antibody_escape/{serum}_avg.csv'
                            ).query(f"`times_seen` >= {min_times_seen}")
        
        avg_df = avg_df[['site', 'wildtype', 'mutant', 'escape_mean']]

        avg_df = avg_df.rename(columns={'escape_mean': 'escape'})
        
        serum = str(serum) # ped / teen sera automatically read as ints
        avg_df['serum'] = serum
        avg_df['cohort'] = age_cohorts[i]

        # also get summed and mean site scores to check AA-level vs site-level metrics
        avg_df['site_escape_sum'] = avg_df['escape'].groupby(avg_df['site']).transform('sum')
        avg_df['site_escape_mean'] = avg_df['escape'].groupby(avg_df['site']).transform('mean')

        df_list.append(avg_df)

    i+=1

# concat to final df
escape_df = pd.concat(df_list).reset_index(drop=True)

# add functional effects
muteffects_csv = "results/muteffects_functional/muteffects_observed.csv"

muteffects = pd.read_csv(muteffects_csv).rename(
    columns={"reference_site": "site", "effect": "functional effect"}
)[["site", "mutant", "functional effect"]]

escape_df = escape_df.merge(
    muteffects[['site', 'mutant', 'functional effect']], 
    on=['site', 'mutant'], 
    how='left')

# filter for minimum functional effect of -1.38
escape_df = escape_df.loc[escape_df['functional effect'] > min_func_score]

escape_df.to_csv('scratch_notebooks/figure_drafts/fst_analysis/hk19_escape_df_full.csv')

escape_df.head()

Unnamed: 0,site,wildtype,mutant,escape,serum,cohort,site_escape_sum,site_escape_mean,functional effect
0,-2,D,G,0.1278,3944,2-5 years,0.1616,0.0808,-0.6583
1,-2,D,Y,0.0338,3944,2-5 years,0.1616,0.0808,-0.644
2,1,Q,H,0.0069,3944,2-5 years,-0.0166,-0.0083,-0.1601
3,1,Q,R,-0.0235,3944,2-5 years,-0.0166,-0.0083,-0.6362
4,2,K,N,-0.0178,3944,2-5 years,-0.0178,-0.0178,-0.1545


### Generate vectors for escape from each serum, normalized to one

In [4]:
def generate_vectors(df, site_or_aa, site_metric='sum'):
    # Get sera names
    sera = df['serum'].unique()
    
    # Get column names based on site_or_aa
    if site_or_aa == 'aa':
        metric_column = 'escape'
        pivot_column = ['site', 'mutant']
    elif site_or_aa == 'site':
        metric_column = 'site_escape_' + site_metric
        pivot_column = 'site'
    else:
        raise ValueError("Invalid value for site_or_aa. Use 'aa' or 'site'.")

    # Pivot the DataFrame to have escape as rows and sera as columns
    vector_df = df.pivot_table(index=pivot_column, 
                              columns='serum', 
                              values=metric_column, 
                              fill_value=0)
    
    # Reset index and drop index column name
    vector_df = vector_df.reset_index().rename_axis(None, axis=1)
    
    # Normalize specified column to one for each serum
    def normalize_escape(column):
        return column / column.sum()
    
    for serum in sera:
        vector_df[serum] = vector_df[serum].transform(normalize_escape)

    return vector_df

In [5]:
aa_vectors = generate_vectors(escape_df, 'aa')
aa_vectors.to_csv('scratch_notebooks/figure_drafts/fst_analysis/aa_escape_vectors.csv')

aa_vectors.head()

Unnamed: 0,site,mutant,150C,18C,197C,199C,210C,215C,2323,2350,...,3856,3857,3862,3866,3944,3973,4299,4584,68C,74C
0,-2,G,0.000885,-0.000318,-0.002787,0.000245,-2e-06,-0.000829,-0.00056,0.000233,...,0.000309,-0.00029,0.00035,0.002887,-0.002642,0.000301,0.000111,0.001102,-0.000286,-5.2e-05
1,-2,Y,0.000648,-0.000932,-0.000113,-0.003608,0.000147,-0.000997,-0.004531,-4.5e-05,...,-0.000613,-6.4e-05,-1.5e-05,0.002124,-0.000699,0.000367,0.000597,0.000706,0.000621,0.000279
2,1,H,-0.000264,-0.0,-0.002979,2.3e-05,-0.000851,-0.001712,-0.00165,0.000471,...,0.00019,0.000148,-0.001666,-0.001011,-0.000143,-0.000891,-0.000587,-0.001276,-0.002525,0.000185
3,1,R,-5.2e-05,-0.00094,0.000481,6.2e-05,0.00023,0.000235,-0.001364,0.000424,...,0.000708,0.000734,0.001207,0.001622,0.000486,-0.001319,-0.002206,-0.001041,0.00129,0.00118
4,2,N,0.000248,3.9e-05,0.000169,0.000395,4.3e-05,0.000546,-0.000542,0.000606,...,-0.000248,-0.001776,-0.00073,-5e-06,0.000368,0.000114,-0.000234,0.000255,-0.001378,8.2e-05


In [6]:
site_sum_vectors = generate_vectors(escape_df, 'site')
site_sum_vectors.to_csv('scratch_notebooks/figure_drafts/fst_analysis/site_sum_vectors.csv')

site_sum_vectors.head()

Unnamed: 0,site,150C,18C,197C,199C,210C,215C,2323,2350,2365,...,3856,3857,3862,3866,3944,3973,4299,4584,68C,74C
0,-2,0.001206,-0.00097,-0.002321,-0.002833,0.000114,-0.001336,-0.003409,0.000148,-0.002193,...,-0.000227,-0.000265,0.000255,0.003781,-0.002354,0.000509,0.000526,0.001166,0.000252,0.000185
1,1,-0.000249,-0.00073,-0.002,7.2e-05,-0.000489,-0.00108,-0.002018,0.000704,0.001131,...,0.000671,0.000659,-0.00035,0.000461,0.000242,-0.001686,-0.002073,-0.001495,-0.000931,0.001113
2,2,0.000195,3e-05,0.000136,0.000333,3.4e-05,0.000399,-0.000363,0.000476,-0.00036,...,-0.000185,-0.001328,-0.000556,-4e-06,0.000259,8.7e-05,-0.000173,0.000164,-0.00104,6.7e-05
3,3,0.006979,-0.003193,0.002574,-1.5e-05,0.009416,0.005233,0.008152,0.002681,-0.004403,...,0.005321,-0.002766,0.001064,0.003783,0.008479,0.003843,0.007689,0.002981,0.000221,-0.000548
4,4,0.003432,-0.005548,0.002594,0.002815,0.004248,0.005427,0.002937,0.000952,0.004182,...,-7.1e-05,0.004713,0.000685,0.003159,0.012934,0.002218,0.006326,0.007063,0.009117,0.001367


In [7]:
site_sum_vectors['199C'].sum()

1.0