In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

import altair_saver
import scipy as sp

import numpy as np

import os

In [2]:
os.chdir('../../../')

In [3]:
ic80s_files = ['231021_3856-4584_fold-change.csv',
               '3857-215C_fold-change.csv',
               '3944-2365_fold-change.csv',
               # '3944-215C_fold-change_v1.csv',
               '3857-215C_fold-change_pt2.csv',
               '197C_74C_fold-change_final.csv'
             ]

In [4]:
ic80s_list = []

for df in ic80s_files:
    df = pd.read_csv(f'experiments/validations_hk19/{df}')
    df = df.loc[(df['variant'] != 'T160K-Y159N')]
    df['serum'] = df['serum'].astype(str)
    
    ic80s_list.append(df)
    
ic80s = pd.concat(ic80s_list)
ic80s = ic80s.loc[ic80s['serum'] != '3944']

ic80s.head()

Unnamed: 0,serum,variant,log2_fold_change_ic80
0,3856,E50K,-0.315872
1,3856,K189E,0.042243
2,3856,S193D,0.291549
3,3856,T160K,-1.700928
4,3856,Y159N,-0.706739


In [90]:
df = '3944_fold-change_final.csv'
df = pd.read_csv(f'experiments/validations_hk19/{df}')
df = df.loc[(df['variant'] != 'T160K-Y159N')]
df['serum'] = df['serum'].astype(str)

ic80s = pd.concat([ic80s, df])

ic80s = ic80s[ic80s['serum'] != '215C']
ic80s = ic80s.loc[ic80s['variant'] != 'S124R']

In [5]:
ics = pd.read_csv('experiments/validations_hk19/ic50_fold_changes.csv')
ics = ics.loc[(ics['variant'] != 'T160K-Y159N')]
ics['serum'] = ics['serum'].astype(str)

In [30]:
sera = ['4584', '3857', '3856', '2365', '3944', '197C', '74C']

models_list = []
for serum in sera:
    df = pd.read_csv(f'results/antibody_escape/{serum}_avg.csv')
    df['serum'] = serum
    models_list.append(df)

models = pd.concat(models_list)

models['variant'] = (
    models["wildtype"] + 
    models["site"].astype(str) + 
    models["mutant"]
)

models = models[['variant', 'serum', 'escape_mean']]

# Merge model predictions with measured ICs
corr_df = (
    ics.merge(
        models,
        how="left",
        on=["variant", "serum"],
        validate="one_to_one",
    )
    .fillna(0)
)

corr_df = corr_df.rename(columns={'escape_mean': 'escape'})

# Add a WT entry for each serum with log2 fold change of 0 and escape of 0
wt_df = pd.DataFrame(columns=['serum', 'variant', 'log2_fold_change_ic50', 'escape'])

# Iterate over unique serum values
for serum_value in corr_df['serum'].unique():
    
    # Create a 'WT' row for the current serum value
    wt_row = pd.DataFrame({'serum': [serum_value], 'variant': ['WT'], 'log2_fold_change_ic50': [0], 'escape': [0]})

    # Append the 'WT' row to the result DataFrame
    wt_df = pd.concat([wt_df, wt_row], ignore_index=True)
    
# Append the 'WT' rows to the original DataFrame
corr_df = pd.concat([corr_df, wt_df], ignore_index=True)

corr_df.head()

Unnamed: 0,serum,variant,log2_fold_change_ic50,escape
0,2365,Y159N,1.15878,0.3106
1,2365,T160K,0.08919,0.0047
2,2365,K189E,2.199497,0.781
3,2365,S193D,1.856389,0.5509
4,2365,S145H,-0.300784,-0.0419


In [31]:
# add cohorts to df
cohort_dict = {
    '3944': '2-5 years',
    '4584': '2-5 years',
    '2365': '15-20 years',
    '3856': '15-20 years',
    '3857': '15-20 years',
    '74C': '40-45 years',
    '197C': '40-45 years', 
}

corr_df['cohort'] = corr_df['serum'].map(cohort_dict)

# add r values to df
r_dict = {}
sera = corr_df['serum'].unique().tolist()
ic_col = [col for col in corr_df.columns if 'ic' in col][0]

for serum in sera:
    serum_df = corr_df.loc[corr_df['serum'] == serum]
    x = serum_df[ic_col]
    y = serum_df["escape"]
    r, p = sp.stats.pearsonr(x, y)
    r = round(r, 3)
    
    r_dict[serum] = r

corr_df['r'] = corr_df['serum'].map(r_dict)
corr_df.head()

Unnamed: 0,serum,variant,log2_fold_change_ic50,escape,cohort,r
0,2365,Y159N,1.15878,0.3106,15-20 years,0.987
1,2365,T160K,0.08919,0.0047,15-20 years,0.987
2,2365,K189E,2.199497,0.781,15-20 years,0.987
3,2365,S193D,1.856389,0.5509,15-20 years,0.987
4,2365,S145H,-0.300784,-0.0419,15-20 years,0.987


In [53]:
def get_corr_plot(serum, df=corr_df):
    if isinstance(serum, list):
        df = df.loc[df['serum'].isin(serum)]
     
    else:
        df = df.loc[df['serum'] == serum]
    
    # get ic column name
    ic_col = [col for col in df.columns if 'ic' in col][0]
    
    # Calculate correlation between predicted and measured
    x = df[ic_col]
    y = df["escape"]

    # Calculate Pearson correlation coefficient
    r, p = sp.stats.pearsonr(x, y)

    # Calculate the slope and y-intercept for the correlation line
    slope = r * (np.std(y) / np.std(x))
    intercept = np.mean(y) - slope * np.mean(x)

    # Create a range of x values for the line
    x_range = np.linspace(min(x), max(x), 100)

    # Calculate corresponding y values for the line
    y_range = slope * x_range + intercept
    
    r = round(r, 3)
    
    # Define the order for 'variant' and 'serum'
    variant_order = ['WT', 'E50K', 'S145H', 'S145K', 'Y159N', 'T160K', 'T160S', 'K189E', 'S193D', 'S193Y']
    serum_order = ['3944', '74C', '2365', '3856', '3857', '4584', '197C']
    
    # Define cb-friendly color scheme
    custom_color_scheme = ['#333333', '#CC6677', '#999933', '#DDCC77', '#117733', 
                       '#882255', '#88CCEE', '#44AA99', '#332288', '#AA4499']
    
    # Get plot
    corrs = (
        alt.Chart(df)
        .mark_point(filled=True, size=200)
        .encode(
            x=alt.X(ic_col,
                    title=ic_col,
                    scale=alt.Scale(
                        domain=[-4, 4],
                        # type="symlog"
                    )
                   ),
            y=alt.Y('escape',
                    title='escape score',
                    scale=alt.Scale(
                        domain=[-1, 1],
                        clamp=True
                        # type="symlog"
                    )
                   ),
            color=alt.Color('variant:N', scale=alt.Scale(range=custom_color_scheme, domain=variant_order)), 
            shape=alt.Shape('serum:N', scale=alt.Scale(domain=serum_order)),
            tooltip=['variant', 'serum', ic_col, 'escape']
        )

    )
    
    r_text = alt.Chart().mark_text(
        align='left', 
        baseline='bottom',
        fontSize=14,
        fontWeight=300
    ).encode(
        x=alt.value(10),
        y=alt.value(20),
        text=alt.value([f"R={r}"])
    )

    x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
        size=1, 
        opacity=0.5, color='gray').encode(y='y')

    y_axis = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(
        size=1, 
        opacity=0.5, color='gray').encode(x='x')


    chart = (
        (corrs + r_text + x_axis + y_axis)
        .configure_axis(
            grid=False,
            labelFontSize=14,
            titleFontSize=15,  
        )
        .configure_title(
            fontSize=21,
            fontWeight='normal',
            # padding=5,"
        )
    )

    return chart

In [58]:
# validation_vs_prediction = validation_vs_prediction.loc[validation_vs_prediction['serum'] != '74C']
full_corr = get_corr_plot(sera)

full_corr.save(
    'scratch_notebooks/figure_drafts/validation_correlations/full-serum-corr_ic50.svg',
    scale_factor=2.0
)

full_corr

In [57]:
ic_col = [col for col in corr_df.columns if 'ic' in col][0]

variant_order = ['WT', 'E50K', 'S145H', 'S145K', 'Y159N', 'T160K', 'T160S', 'K189E', 'S193D', 'S193Y']
serum_order = ['3944', '4584', '2365', '3856', '3857', '74C', '197C']
cohort_order = ['2-5 years', '15-20 years', '40-45 years']

# Define cb-friendly color scheme
custom_color_scheme = ['#404040', '#CC6677', '#999933', '#DDCC77', '#117733', 
                   '#882255', '#88CCEE', '#44AA99', '#332288', '#AA4499']

corrs = (
    alt.Chart()
    .mark_point(filled=True, size=100)
    .encode(
        x=alt.X(ic_col, 
                title=ic_col,
                scale=alt.Scale(
                        domain=[-4, 4],
                ),
               ),
        y=alt.Y('escape', 
                title='escape score',
                scale=alt.Scale(
                    domain=[-1, 1],
                    clamp=True
                )
               ),
        color=alt.Color('variant:N', scale=alt.Scale(range=custom_color_scheme, domain=variant_order)),
        shape=alt.Shape('cohort:N', scale=alt.Scale(domain=cohort_order)),
        tooltip=['variant', 'serum', ic_col, 'escape']
    )
    .properties(
        title=serum,
        width=120,
        height=120
    )
)

r_text = alt.Chart().mark_text(
    align='right',
    baseline='bottom',
    fontSize=12,
    fontWeight=300
).encode(
    x=alt.value(115),
    y=alt.value(115),
    text='r:N'
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(y='y')

y_axis = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(x='x')

faceted_corrs = alt.layer(
    corrs, r_text, x_axis, y_axis, data=corr_df, 
).facet(
    facet=alt.Facet(
        'serum:N',
        sort=serum_order,
        header=alt.Header(
            titleFontSize=20,
            labelFontSize=14,
            labelPadding=2
    )
    ), 
    spacing=3,
    columns=4
).configure_axis(
    grid=False,
    labelFontSize=14,
    titleFontSize=10,
)

faceted_corrs.save(
    'scratch_notebooks/figure_drafts/validation_correlations/hk19_faceted-corrs_ic50.svg',
    scale_factor=2.0
)

faceted_corrs

In [12]:
corr_charts = []

for serum in sera:
    chart = get_corr_plot(serum)
    corr_charts.append(chart)

In [13]:
corr_charts[0]

In [14]:
corr_charts[1]

In [15]:
corr_charts[2]

In [16]:
corr_charts[3]

In [17]:
corr_charts[4]

In [18]:
corr_charts[5]

In [19]:
corr_charts[6]

In [129]:
# Calculate correlation between predicted and measured
x = validation_vs_prediction["log2_fold_change_ic80"]
y = validation_vs_prediction["escape_median"]

# Calculate Pearson correlation coefficient
r, p = sp.stats.pearsonr(x, y)

# Calculate the slope and y-intercept for the correlation line
slope = r * (np.std(y) / np.std(x))
intercept = np.mean(y) - slope * np.mean(x)

# Create a range of x values for the line
x_range = np.linspace(min(x), max(x), 100)

# Calculate corresponding y values for the line
y_range = slope * x_range + intercept

print(f"R={r}")
print(f"R^2={r**2}")

R=0.9124791960397256
R^2=0.832618283205304


In [130]:
r = round(r, 3)

corrs = (
    alt.Chart(validation_vs_prediction)
    .mark_point(filled=True, size=200)
    .encode(
        x=alt.X('log2_fold_change_ic80',
                title='log2_fold_change_ic80',
                # scale=alt.Scale(type="symlog"),
               ),
        y=alt.Y('escape_median',
                title='predicted escape score'
               ),
        color=alt.Color('variant:N'), 
        shape=alt.Shape('serum:N'),
        tooltip=['variant', 'serum']
    )

)

# # Calculate the correlation line data with log-transformed x values
line_data = pd.DataFrame({'x': x_range, 'y': y_range})

# Create an Altair chart for the correlation line with log scale
# correlation_line = alt.Chart(line_data).mark_line(
#     color='red',
#     strokeWidth=2
# ).encode(
#     x=alt.X('x'),  # Specify log scale for x
#     y=alt.Y('y')
# )

correlation_line = alt.Chart(line_data).mark_line(
    color='red',  # Set the line color to red
    strokeWidth=2  # Set the line width
).encode(
    x=alt.X('x'),
    y=alt.Y('y')
)

text_2 = alt.Chart().mark_text(
    align='left', 
    baseline='bottom',
    fontSize=14,
    fontWeight=300
).encode(
    x=alt.value(10),
    y=alt.value(20),
    text=alt.value(["R-value: ", f'{r}'])
)

x_axis = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(y='y')

y_axis = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(
    size=1, 
    opacity=0.5, color='gray').encode(x='x')
    

chart = (
    (corrs + text_2 + x_axis + y_axis + correlation_line)
    .configure_axis(
        grid=False,
        labelFontSize=14,
        titleFontSize=15,  
    )
    .configure_title(
        fontSize=21,
        fontWeight='normal',
        # padding=5,
    )
)

# chart.save(
#     'scratch_notebooks/figure_drafts/validation_correlations/231021_full-serum-corr.png',
#     scale_factor=2.0
# )

chart