In [1]:
import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

import altair_saver
import scipy as sp

import numpy as np

In [2]:
import os
os.chdir('../../../')

In [3]:
# set up dict for adding cohort label
sample_dict = {
    "2-4 years": [
        "age 2.1 (Vietnam)", 
        "age 2.2 (Vietnam)",
        "age 2.4 (Vietnam)",
        "age 2.5 (Vietnam)",
        "age 2.5b (Vietnam)",
        "age 3.3 (Vietnam)", 
        "age 3.3b (Vietnam)",
        "age 3.4 (Vietnam)", 
        "age 3.5 (Vietnam)",
    ],   
    "30-33 years": [
        "age 30.5 (Vietnam)",
        "age 31.5 (Vietnam)",
        "age 33.5 (Vietnam)",
    ],
    "misc_adult": [
        "age 21 (Seattle)",
        "age 53 (Seattle)",
        "age 64 (Seattle)",
        "age 65 (Seattle)",
    ],
    "ferret": [
        "ferret 1 (Pitt)",
        "ferret 2 (Pitt)",
        "ferret 3 (Pitt)",
        "ferret (WHO)",
    ]
}

# get full dataset
escape_df = pd.read_csv(f'results/perth2009/merged_escape.csv')[['name', 'site', 'wildtype', 'mutant', 'escape']]
escape_df = escape_df.rename(columns={'name': 'serum'})

# add cohort label
def find_sample_type(sample_name):
    for sample_type, sample_list in sample_dict.items():
        if sample_name in sample_list:
            return sample_type
    return None

escape_df['cohort'] = escape_df['serum'].apply(find_sample_type)

# define samples in each age cohort
sample_rename_dict = {
    "age 2.1 (Vietnam)": "age-2.1", 
    "age 2.2 (Vietnam)": 'age-2.2',
    "age 2.4 (Vietnam)": 'age-2.4',
    "age 2.5 (Vietnam)": 'age-2.5',
    "age 2.5b (Vietnam)": 'age-2.5-b',
    "age 3.3 (Vietnam)": 'age-3.3',
    "age 3.3b (Vietnam)": 'age-3.3-b',
    "age 3.4 (Vietnam)": 'age-3.4',
    "age 3.5 (Vietnam)": 'age-3.5',
    "age 30.5 (Vietnam)": 'age-30.5',
    "age 31.5 (Vietnam)": 'age-31.5',
    "age 33.5 (Vietnam)": 'age-33.5',
    "ferret 2 (Pitt)": 'ferret-Pitt2',
}

# rename sera and drop extra samples
escape_df['serum'] = escape_df['serum'].replace(sample_rename_dict)
escape_df = escape_df.dropna(subset=['serum'])

# Function to convert '(HA2)X' to numeric
def convert_site_to_numeric(site):
    if '(HA2)' in site:
        try:
            number = int(site.replace('(HA2)', '').strip())
            return number + 329
        except ValueError:
            return site  # If there's an issue with conversion, return the original value
    else:
        return site

# Apply the function to the 'site' column
escape_df['site'] = escape_df['site'].apply(convert_site_to_numeric)

# floor at 0
escape_df['escape'] = escape_df['escape'].clip(lower=0)

escape_df.head()

Unnamed: 0,serum,site,wildtype,mutant,escape,cohort
0,age-30.5,159,F,G,3.482,30-33 years
1,age-30.5,159,F,N,2.451,30-33 years
2,age-30.5,159,F,H,2.334,30-33 years
3,age-30.5,159,F,T,1.514,30-33 years
4,age-30.5,159,F,S,1.195,30-33 years


In [4]:
escape_df['variant'] = (
    escape_df['wildtype'] +
    escape_df['site'].astype(str) + 
    escape_df['mutant']
)

escape_df = escape_df[['variant', 'serum', 'escape', 'cohort']]

In [5]:
ics = pd.read_csv('experiments/validations_perth09/ic50_fold_changes.csv')
ics['serum'] = ics['serum'].astype(str)

In [6]:
# Merge model predictions with measured ICs
corr_df = (
    ics.merge(
        escape_df,
        how="left",
        on=["variant", "serum"],
        validate="one_to_one",
    )
    .fillna(0)
)

corr_df.head()

Unnamed: 0,serum,variant,log2_fold_change_ic50,escape,cohort
0,age-2.1,F193F,-0.100906,0.0,2-4 years
1,age-2.1,K189D,0.898353,2.336,2-4 years
2,age-2.1,F193D,0.606342,1.402,2-4 years
3,age-2.1,F159G,0.653439,0.1604,2-4 years
4,age-2.2,F193F,-0.038777,0.0,2-4 years


In [7]:
# add r values to df
r_dict = {}
sera = corr_df['serum'].unique().tolist()
ic_col = [col for col in corr_df.columns if 'ic' in col][0]

for serum in sera:
    serum_df = corr_df.loc[corr_df['serum'] == serum]
    x = serum_df[ic_col]
    y = serum_df["escape"]
    r, p = sp.stats.pearsonr(x, y)
    r = round(r, 3)
    
    r_dict[serum] = r

corr_df['r'] = corr_df['serum'].map(r_dict)
corr_df.head()

Unnamed: 0,serum,variant,log2_fold_change_ic50,escape,cohort,r
0,age-2.1,F193F,-0.100906,0.0,2-4 years,0.738
1,age-2.1,K189D,0.898353,2.336,2-4 years,0.738
2,age-2.1,F193D,0.606342,1.402,2-4 years,0.738
3,age-2.1,F159G,0.653439,0.1604,2-4 years,0.738
4,age-2.2,F193F,-0.038777,0.0,2-4 years,0.716


In [8]:
def get_corr_plot(serum, df=corr_df):
    if isinstance(serum, list):
        df = df.loc[df['serum'].isin(serum)]
     
    else:
        df = df.loc[df['serum'] == serum]
    
    # get ic column name
    ic_col = [col for col in df.columns if 'ic' in col][0]
    
    # Calculate correlation between predicted and measured
    x = df[ic_col]
    y = df["escape"]

    # Calculate Pearson correlation coefficient
    r, p = sp.stats.pearsonr(x, y)
    
    r = round(r, 3)
    
    # Define the order for 'variant' and 'serum'
    variant_order = ['F159G', 'K189D', 'I192E', 'F193D', 'F193F']
    
    # Define cb-friendly color scheme
    custom_color_scheme = ['#CC6677', '#332288', '#117733', 
                       '#88CCEE', '#882255', '#44AA99', '#DDCC77', '#999933', '#AA4499']
    
    # Get plot
    corrs = (
        alt.Chart(df)
        .mark_point(filled=True, size=200)
        .encode(
            x=alt.X(ic_col,
                    title=ic_col,
                    scale=alt.Scale(
                        # domain=[-4, 4],
                        # type="symlog"
                    )
                   ),
            y=alt.Y('escape',
                    title='DMS escape score',
                   ),
            color=alt.Color('variant:N', scale=alt.Scale(range=custom_color_scheme, domain=variant_order)), 
            shape=alt.Shape('cohort:N'),
            tooltip=['variant', 'serum', ic_col, 'escape']
        )
    )
    
    r_text = alt.Chart().mark_text(
        align='left', 
        baseline='bottom',
        fontSize=14,
        fontWeight=400
    ).encode(
        x=alt.value(10),
        y=alt.value(20),
        text=alt.value([f"R={r}"])
    )


    chart = (
        (corrs + r_text)
        .configure_axis(
            grid=False,
            labelFontSize=14,
            titleFontSize=15,  
        )
        .configure_title(
            fontSize=21,
            fontWeight='normal',
            # padding=5,"
        )
    )

    return chart

In [179]:
full_corrs = get_corr_plot(sera)

full_corrs.save(
    'scratch_notebooks/figure_drafts/perth09_analysis/full-corrs_ic50.svg',
    scale_factor=2.0
)

full_corrs

In [13]:
ic_col = [col for col in corr_df.columns if 'ic' in col][0]

variant_order = ['F159G', 'K189D', 'I192E', 'F193D', 'F193F']

custom_color_scheme = ['#CC6677', '#332288', '#117733', 
                   '#88CCEE', '#882255', '#44AA99', '#DDCC77', '#999933', '#AA4499']

corrs = (
    alt.Chart()
    .mark_point(filled=True, size=100)
    .encode(
        x=alt.X(ic_col, title=ic_col),
        y=alt.Y('escape', title='DMS escape score'),
        color=alt.Color('variant:N', scale=alt.Scale(range=custom_color_scheme, domain=variant_order)),
        shape=alt.Shape('cohort:N'),
        tooltip=['variant', 'serum', ic_col, 'escape']
    )
    .properties(
        title=serum,
        width=100,
        height=100
    )
)

r_text = alt.Chart().mark_text(
    align='right',
    baseline='bottom',
    fontSize=12,
    fontWeight=300
).encode(
    x=alt.value(95),
    y=alt.value(95),
    text='r:N'
)



# faceted_corrs = alt.layer(
#     corrs, data=corr_df
# ).facet(
#     facet='serum:N',
#     columns=4,
#     title=None
# )


faceted_corrs = alt.layer(
    corrs, r_text, data=corr_df, 
).facet(
    facet=alt.Facet(
        'serum:N',
        header=alt.Header(
            titleFontSize=20,
            labelFontSize=14,
            labelPadding=3
    )
    ),
    
    spacing=1,
    columns=5
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=15,
)

faceted_corrs.save(
    'scratch_notebooks/figure_drafts/perth09_analysis/faceted-corrs_ic50.png',
    scale_factor=2.0
)

faceted_corrs