In [1]:
import itertools

import altair as alt

import pandas as pd

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

In [2]:
import os
os.chdir('../../../')

### Get averaged escape scores for main HK/19 cohorts

In [3]:
# define samples in each age cohort
# ped_sera_160_esc = [3944, 2323, 2367]
# ped_sera_160_sens = [2389, 2388, 3973, 4299, 4584]
ped_sera = [3944, 2389, 2323, 2388, 3973, 4299, 4584, 2367]
teen_sera = [2350, 2365, 2380, 2382, 3866, 3856, 3857, 3862]
adult_sera = ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C']

# get list of lists for samples divided by age group
serum_lists = [ped_sera, teen_sera, adult_sera]
age_cohorts = ['2-5 years', '15-20 years', '40-45 years']
# serum_lists = [ped_sera_160_esc, ped_sera_160_sens, teen_sera, adult_sera]
# age_cohorts = ['2-5 years (160 esc)', '2-5 years (160 sens)', '15-20 years', '40-45 years']

# adjust this if we want more stringent filtering
min_times_seen = 3
min_func_score = -1.38

df_list = []

i = 0 # for looping across age cohort definitions

for list in serum_lists:
    for serum in list:
        # reading in values from averaged libA and libB models
        avg_df = pd.read_csv(f'results/antibody_escape/{serum}_avg.csv'
                            ).query(f"`times_seen` >= {min_times_seen}")
        
        avg_df = avg_df[['site', 'wildtype', 'mutant', 'escape_mean']]

        avg_df = avg_df.rename(columns={'escape_mean': 'escape'})
        
        serum = str(serum) # ped / teen sera automatically read as ints
        avg_df['serum'] = serum
        avg_df['cohort'] = age_cohorts[i]
        
        # filter out stop codons
        avg_df = avg_df.loc[avg_df['mutant'] != '*']

        # also get summed and mean site scores to check AA-level vs site-level metrics
        avg_df['site_escape_sum'] = avg_df['escape'].groupby(avg_df['site']).transform('sum')
        avg_df['site_escape_mean'] = avg_df['escape'].groupby(avg_df['site']).transform('mean')

        df_list.append(avg_df)

    i+=1

# concat to final df
escape_df = pd.concat(df_list).reset_index(drop=True)

# add functional effects
muteffects_csv = "results/muteffects_functional/muteffects_observed.csv"

muteffects = pd.read_csv(muteffects_csv).rename(
    columns={"reference_site": "site", "effect": "functional effect"}
)[["site", "mutant", "functional effect"]]

escape_df = escape_df.merge(
    muteffects[['site', 'mutant', 'functional effect']], 
    on=['site', 'mutant'], 
    how='left'
).assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])

# filter for minimum functional effect of -1.38
escape_df = escape_df.loc[escape_df['functional effect'] > min_func_score]

escape_df.to_csv('scratch_notebooks/figure_drafts/fst_analysis/hk19_escape_df_full.csv')

escape_df.head()

Unnamed: 0,site,wildtype,mutant,escape,serum,cohort,site_escape_sum,site_escape_mean,functional effect,mutation
0,-2,D,G,0.1278,3944,2-5 years,0.1616,0.0808,-0.6583,D-2G
1,-2,D,Y,0.0338,3944,2-5 years,0.1616,0.0808,-0.644,D-2Y
2,1,Q,H,0.0069,3944,2-5 years,-0.0166,-0.0083,-0.1601,Q1H
3,1,Q,R,-0.0235,3944,2-5 years,-0.0166,-0.0083,-0.6362,Q1R
4,2,K,N,-0.0178,3944,2-5 years,-0.0178,-0.0178,-0.1545,K2N


### Filter to key antigenic sites

In [4]:
site_list = [48, 50, 82, 121, 122, 124, 131, 135, 137, 144, 145, 156, 157, 
              159, 160, 189, 193, 275, 276]

# filter dataframe
escape_df_filtered = escape_df[escape_df['site'].isin(site_list)]
escape_df_filtered.head()

Unnamed: 0,site,wildtype,mutant,escape,serum,cohort,site_escape_sum,site_escape_mean,functional effect,mutation
399,48,I,A,0.0463,3944,2-5 years,0.2196,0.0122,-0.0545,I48A
400,48,I,D,0.0179,3944,2-5 years,0.2196,0.0122,-0.0362,I48D
401,48,I,E,0.014,3944,2-5 years,0.2196,0.0122,0.0566,I48E
402,48,I,F,0.0284,3944,2-5 years,0.2196,0.0122,0.0732,I48F
403,48,I,G,0.016,3944,2-5 years,0.2196,0.0122,-0.0311,I48G


### Compute distances in escape for each serum pair
These are the squared Euclidean distances between the escape vectors:

In [8]:
def generate_dist2(df, site_or_aa, site_metric='sum'):
    
    # Get column names based on site_or_aa
    if site_or_aa == 'aa':
        metric_column = 'escape'
        pivot_column = "mutation"
    elif site_or_aa == 'site':
        metric_column = 'site_escape_' + site_metric
        pivot_column = "site"
    else:
        raise ValueError("Invalid value for site_or_aa. Use 'aa' or 'site'.")
        
    # # try clipping to 0
    # df[metric_column] = df[metric_column].clip(lower=0)

    # Pivot the DataFrame to have sera as rows and escape as columns, then L2 normalize escape
    vector_df = (
        df
        [["serum", "cohort", metric_column, pivot_column]]
        .drop_duplicates()
        .pivot_table(
            index=["serum", "cohort"], 
            columns=pivot_column, 
            values=metric_column, 
            fill_value=0,
        )
    )

    # L2-normalize https://stackoverflow.com/a/35679163
    vector_df = vector_df.div((vector_df**2).sum(axis=1), axis=0)
    
#     # L-2 normalize by column
#     def normalize_escape(column):
#         return column / column.sum()
    
#     for column in vector_df:
#         vector_df[column] = vector_df[column].transform(normalize_escape)
    


    # make a single entry that is the vector of values
    vector_df = (
        vector_df
        .apply(lambda r: r.values, axis=1)
        .rename("escape_vector")
        .reset_index()
    )

    # get all pairwise distances squared
    dist2_records = []
    for (serum1, cohort1, escape_vector1), (serum2, cohort2, escape_vector2) in itertools.combinations(
        vector_df.itertuples(index=False), 2,
    ):
        dist2 = ((escape_vector1 - escape_vector2)**2).sum()
        dist2_records.append((serum1, cohort1, serum2, cohort2, dist2))
    dist2_df = (
        pd.DataFrame(
            dist2_records,
            columns=["serum1", "cohort1", "serum2", "cohort2", "distance2"],
        )
        .assign(
            same_cohort=lambda x: x["cohort1"] == x["cohort2"],
            cohort_pair=lambda x: x.apply(lambda r: " vs ".join(sorted([r["cohort1"], r["cohort2"]])), axis=1),
        )
    )

    return dist2_df

i=0
for df_type in [escape_df_filtered, escape_df]:
    if i==0:
        print('\n\nFILTERED SITES')
    else:
        print('\n\nALL SITES')
        
    for site_or_aa in ["aa", "site"]:

        dist2 = generate_dist2(df_type, site_or_aa, site_metric='sum').sort_values("distance2")

        dist2_chart = (
            alt.Chart(dist2)
            .encode(
                y="cohort_pair",
                x="distance2",
                tooltip=dist2.columns.tolist(),
            )
            .mark_circle(opacity=0.5)
        )

        print(f"\n\nMean {site_or_aa} distances betweeen cohorts:")
        mean_dist2 = dist2.groupby(["same_cohort", "cohort_pair"]).aggregate({"distance2": "mean"})
        display(mean_dist2)

        display(dist2_chart)
        
    i+=1



FILTERED SITES


Mean aa distances betweeen cohorts:


Unnamed: 0_level_0,Unnamed: 1_level_0,distance2
same_cohort,cohort_pair,Unnamed: 2_level_1
False,15-20 years vs 2-5 years,0.273744
False,15-20 years vs 40-45 years,0.389969
False,2-5 years vs 40-45 years,0.405698
True,15-20 years vs 15-20 years,0.261785
True,2-5 years vs 2-5 years,0.249281
True,40-45 years vs 40-45 years,0.456086




Mean site distances betweeen cohorts:


Unnamed: 0_level_0,Unnamed: 1_level_0,distance2
same_cohort,cohort_pair,Unnamed: 2_level_1
False,15-20 years vs 2-5 years,0.065916
False,15-20 years vs 40-45 years,0.169485
False,2-5 years vs 40-45 years,0.160716
True,15-20 years vs 15-20 years,0.061729
True,2-5 years vs 2-5 years,0.064499
True,40-45 years vs 40-45 years,0.2576




ALL SITES


Mean aa distances betweeen cohorts:


Unnamed: 0_level_0,Unnamed: 1_level_0,distance2
same_cohort,cohort_pair,Unnamed: 2_level_1
False,15-20 years vs 2-5 years,0.14206
False,15-20 years vs 40-45 years,0.143201
False,2-5 years vs 40-45 years,0.150822
True,15-20 years vs 15-20 years,0.129726
True,2-5 years vs 2-5 years,0.135918
True,40-45 years vs 40-45 years,0.145786




Mean site distances betweeen cohorts:


Unnamed: 0_level_0,Unnamed: 1_level_0,distance2
same_cohort,cohort_pair,Unnamed: 2_level_1
False,15-20 years vs 2-5 years,0.046187
False,15-20 years vs 40-45 years,0.068009
False,2-5 years vs 40-45 years,0.060358
True,15-20 years vs 15-20 years,0.050323
True,2-5 years vs 2-5 years,0.039533
True,40-45 years vs 40-45 years,0.079102
