In [1]:
import altair as alt

import pandas as pd

import itertools

import numpy

# import sklearn.manifold
# from sklearn.preprocessing import MinMaxScaler

import warnings
warnings.filterwarnings('ignore')

from IPython.utils import io

import glob

In [2]:
import os
os.chdir('../../')

In [3]:
# define samples in each age cohort
ped_sera = [2367, 3944, 2462, 2389, 2323, 2388, 2463, 3973, 4299, 4584]
teen_sera = [2343, 2350, 2365, 2380, 2382, 3866, 3856, 3857, 3862, 3895]
adult_sera = ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C']

# get list of lists for samples divided by age group
serum_lists = [ped_sera, teen_sera, adult_sera]
age_cohorts = ['0-5', '15-18', '40-45']

# adjust this if we want more stringent filtering
min_times_seen = 3

df_list = []

i = 0 # for looping across age cohort definitions

for list in serum_lists:
    for serum in list:
        # reading in values from both libA and libB models
        lib_dfs= []
        for lib in ['libA', 'libB']:
            file_pattern = f'results/prob_escape/{lib}_*_{serum}_*_prob_escape.csv'
            file_list = glob.glob(file_pattern)
            for file_path in file_list:
                df=pd.read_csv(file_path).query("`retain` == True")
                lib_dfs.append(df)
                
        full_df = pd.concat(lib_dfs).reset_index(drop=True)

        # full_df = lib_dfs[0].merge(lib_dfs[1],
        #                            how='left',
        #                            on=['barcode', 'aa_substitutions_reference', 'n_aa_substitutions',
        #                                'antibody', 'antibody_concentration'],
        #                            suffixes = ('_libA', '_libB')
        #                           )[['barcode', 'prob_escape_libA', 'prob_escape_libB',
        #                              'prob_escape_uncensored_libA', 'prob_escape_uncensored_libB', 
        #                              'aa_substitutions_reference', 'n_aa_substitutions', 
        #                              'antibody', 'antibody_concentration']]

        serum = str(serum) # ped / teen sera automatically read as ints
        full_df['serum'] = serum
        full_df['age_cohort'] = age_cohorts[i]

        df_list.append(full_df)

    i+=1

# concat to final df
# escape_df = pd.concat(df_list).reset_index(drop=True)

df_list[1].head()

Unnamed: 0,library,antibody_sample,no-antibody_sample,aa_substitutions_sequential,n_aa_substitutions,barcode,prob_escape,prob_escape_uncensored,antibody_count,no-antibody_count,antibody_neut_standard_count,no-antibody_neut_standard_count,no_antibody_count_threshold,antibody_count_threshold,aa_substitutions_reference,retain,antibody,antibody_concentration,serum,age_cohort
0,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,K297I,1,ATAACACAAAAAAGTA,0.0025,0.0025,64222,369330,5953821,86955,23,,K278I,True,3944,0.0043,3944,0-5
1,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,D123H K208E,2,AAGCCACAAGGTACTA,0.1891,0.1891,62174,4803,5953821,86955,23,,D104H K189E,True,3944,0.0043,3944,0-5
2,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,T49Y K140N K154T F156S S163K Y178S T179K K208H...,11,ACAATCACGGTACTCC,0.1503,0.1503,44440,4317,5953821,86955,23,,T30Y K121N K135T F137S S144K Y159S T160K K189H...,True,3944,0.0043,3944,0-5
3,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,P122V S212D Q216K G405D,4,AGACCGGGACTCCTCA,0.1968,0.1968,42644,3164,5953821,86955,23,,P103V S193D Q197K G386D,True,3944,0.0043,3944,0-5
4,libA,230221_1_antibody_3944_0.004272_1,230221_1_no-antibody_control_1,S143D K208E I261T R280P,4,CTTTTGTTAATTGATA,0.1436,0.1436,42460,4318,5953821,86955,23,,S124D K189E I242T R261P,True,3944,0.0043,3944,0-5


In [4]:
def plot_wt_prob_escape(serum_list):
    
    mean_prob_escape_list = []

    for prob_escape in serum_list:
        serum = prob_escape['antibody'][0]
        max_aa_subs = 4
        
        for lib in ['libA', 'libB']:
            prob_escape_lib = prob_escape.loc[prob_escape['library'] == lib]
            mean_prob_escape = (
                prob_escape_lib.assign(
                    n_subs=lambda x: (
                        x["n_aa_substitutions"]
                        # .str.split()
                        # .map(len)
                        .clip(upper=max_aa_subs)
                        .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
                    )
                )
                .groupby(["antibody_concentration", "n_subs", "library", "age_cohort"], as_index=False)
                .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
                .rename(
                    columns={
                        "prob_escape": "censored to [0, 1]",
                        "prob_escape_uncensored": "not censored",
                    }
                )
                .melt(
                    id_vars=["antibody_concentration", "n_subs", "library", "age_cohort"],
                    var_name="censored",
                    value_name="probability escape",
                )
                .assign(serum=serum)
            )

            mean_prob_escape_wt = mean_prob_escape.loc[(mean_prob_escape['n_subs'] == '0') &
                                                       (mean_prob_escape['censored'] != 
                                                        'not censored')
                                                      ]
            mean_prob_escape_wt['concentration'] = range(1, 1+len(mean_prob_escape_wt))
            mean_prob_escape_wt['concentration'] = mean_prob_escape_wt['concentration'].astype(str)
            mean_prob_escape_wt['serum'] = mean_prob_escape_wt['serum'].astype(str)

            mean_prob_escape_list.append(mean_prob_escape_wt)

    mean_prob_escape_full = pd.concat(mean_prob_escape_list, axis=0, ignore_index=True)
    
    mean_prob_escape_lineplot = (
        alt.Chart()
        .encode(
            x=alt.X("concentration"),
            y=alt.Y(
                "probability escape",
                scale=alt.Scale(type="symlog", constant=0.05),
            ),
            color=alt.Color("serum", title="serum"),
            tooltip=[
                alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
                for c in mean_prob_escape.columns
            ],
        )
        .mark_line(point=True, size=0.5)
        .properties(width=200, height=125, title='avg prob escape of WT variants')
    )
    
    chart = alt.layer(
        mean_prob_escape_lineplot, data=mean_prob_escape_full
    ).facet(
        # facet=alt.Facet(
            row='age_cohort:N',
            column='library:N',
            # title='summed escape by age group',
            # header=alt.Header(
            #     titleFontSize=23,
            #     # titleFontWeight='normal',
            #     labelFontSize=17,
            # )
        # ),
        # columns=2
    ).configure_title(
        dy=-5,
        fontWeight=500
    ).resolve_axis(
        x='independent'
    )

    return chart

In [5]:
plot_wt_prob_escape(df_list)