In [52]:
import pandas as pd
import numpy as np
sources =  [
    'XBB infection',
    'XBB BTI',
    'BA.5 + XBB infection',
    'BA.5 + JN.1 infection',
    'BA.5 BTI + HK.3 infection',
    'BA.5 BTI + XBB infection',
    'BA.5 BTI + JN.1 infection'
]

data_A1 = pd.read_csv("../data/DMS/antibody/_clustering.csv").set_index('id').query(
    "new_group == 'A1' and source in @sources and XBB1_5_IC50 < 1 and paper_reactivity == 'cross'")
# data_A1 = pd.read_csv("../data/DMS/antibody/_clustering.csv").set_index('id').query("new_group == 'A1' and source in @sources")

use_variants = [
    'XBB1_5_IC50', 'HK3_1_IC50', 'JN1_IC50','JN1_F456L_IC50', 'JN1_R346T_F456L_IC50','KP3_IC50',
    'JN1_F456L_A475V_IC50','KP3_A475V_IC50'
]

x = np.max(data_A1[use_variants], axis=1).to_dict()

data_A1 = data_A1.assign(breadth = lambda df: ["broad" if x[i] < 1 else ("escaped" if x[i] >= 1 else "unknown") for i in df.index]).merge(
    pd.read_csv("../data/A1_HCDR.csv"), on='id', how='left').set_index('id')


data_A1.to_csv("../data/_A1_breadth.csv")


In [53]:
def calc_mAbs_by_source(res, group, group_col="new_group", src_col="source"):
    _use_res = res.query(f"{group_col} in @group")
    _cnt = _use_res[["antibody", src_col, group_col]].drop_duplicates().groupby([src_col, group_col])['antibody'].count().reset_index().rename(columns={'antibody':'count'})
    return _use_res.groupby([src_col, group_col, "site", "mutation"])['mut_escape'].sum().reset_index().merge(
        _cnt, on=[src_col, group_col]).assign(**{
        'mut_escape': lambda x: x['mut_escape'] / x['count'],
        src_col: lambda x: x[src_col]+' ('+x['count'].astype(str)+')'
    })
    return _use_res.query('mut_escape > 0.01')


import logomaker
from matplotlib import rcParams
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

def site_to_pos(sites, split='+'):
    sites = sorted(sites, key=lambda x: [int(y) for y in x.split(split)])
    site2pos = {}
    for i in range(len(sites)):
        site2pos[sites[i]] = i
    
    return sites, site2pos

def plot_res_logo(res, prefix, by='name', site_thres=0.1, width=26, shownames={}, num_per_page = 10, force_plot_sites = None, force_ylim = None, highlight_res = {}):
    rcParams['pdf.fonttype'] = 42

    res["site"] = res["site"].astype(str)
    flat_res = res.rename(columns={by:'antibody'}).pivot(index=['antibody', 'site'], columns='mutation', values='mut_escape').fillna(0)
    sites_total_score = flat_res.sum(axis=1)

    strong_sites = list(pd.unique(sites_total_score[sites_total_score > site_thres].reset_index()['site']))
    plot_sites = strong_sites
    
    if force_plot_sites is not None:
        plot_sites = force_plot_sites
    
    flat_res = flat_res.query('site in @plot_sites')
    Abs = flat_res.index.get_level_values('antibody').unique()
    Npages = len(Abs) // num_per_page + 1
    
    plot_sites, site2pos = site_to_pos(plot_sites)
    
    with PdfPages(prefix+'_aa_logo.pdf') as pdf:
        for p in range(Npages):
            Abs_p = Abs[p*10:min(len(Abs),(p+1)*10)]
            fig = plt.figure(figsize=(width,len(Abs_p)*4.6)).subplots_adjust(wspace=0.2,hspace=0.5)

            for i in range(len(Abs_p)):
                ab = Abs_p[i]
                _ = flat_res.loc[ab, :]
                add_sites = np.setdiff1d(plot_sites, _.index)
                for _site in add_sites:
                    _.loc[_site,:] = 0.0
                _.index = [site2pos[i] for i in _.index]
                ax = plt.subplot(len(Abs_p), 1, i+1)
                logo = logomaker.Logo(_,
                               ax=ax, 
                               color_scheme='dmslogo_funcgroup', 
                               vpad=.1, 
                               width=.8)
                logo.style_xticks(anchor=0, spacing=1, rotation=90, fontsize=16)
                _max = np.sum(_.to_numpy(), axis=1).max()
                ax.yaxis.set_tick_params(labelsize=20)
                if force_ylim is not None:
                    ax.set_ylim(0, force_ylim)
                elif _max < 3:
                    ax.set_ylim(0,3)
                    ax.set_yticks(range(0, 3, 1))
                elif _max < 5:
                    ax.set_yticks(range(0, int(_max)+1, 1))
                elif _max < 8:
                    ax.set_yticks(range(0, int(_max)+1, 2))
                else:
                    ax.set_yticks(range(0, int(_max)+1, 3))

                for color, sites in highlight_res.items():
                    if ifsite in plot_sites:
                        logo.highlight_position(p=site2pos[ifsite], color=color, alpha=.2)

                ax.set_xticklabels(plot_sites)

                if ab in shownames:
                    ax.set_title(shownames[ab], fontsize=24, fontweight="bold")
                else:
                    ax.set_title(ab, fontsize=24, fontweight="bold")
            pdf.savefig()
            plt.close()



In [54]:
dms_scores = pd.read_csv("../data/DMS/antibody/dms_antibodies_XBB15_JN1_agg.csv").query('antibody in @data_A1.index').merge(
    data_A1[['breadth']], left_on='antibody', right_index=True, how='left'
).assign(new_group = 'A1')

avg_JN1 = calc_mAbs_by_source(dms_scores.query('antigen == "JN.1_RBD"'), 'A1', src_col='breadth')
avg_XBB15 = calc_mAbs_by_source(dms_scores.query('antigen == "XBB.1.5_RBD"'), 'A1', src_col='breadth')
plot_res_logo(avg_JN1, '../plots/Extended/DMS_scores_A1_JN1_breadth', 'breadth', 0.5, 24)
plot_res_logo(avg_XBB15, '../plots/Extended/DMS_scores_A1_XBB15_breadth', 'breadth', 0.5, 24)

avg_JN1.pivot(index=['site','mutation'], columns='breadth', values='mut_escape').reset_index().assign(site_mut=lambda x: x['site'].astype(str)+x['mutation']).to_csv(
    "../data/DMS/antibody/_A1_breadth_DMS_scores_compare_avg_JN1.csv", index=None)

avg_XBB15.pivot(index=['site','mutation'], columns='breadth', values='mut_escape').reset_index().assign(site_mut=lambda x: x['site'].astype(str)+x['mutation']).to_csv(
    "../data/DMS/antibody/_A1_breadth_DMS_scores_compare_avg_XBB15.csv", index=None)
