In [8]:
import pickle

import altair as alt
import altair_saver

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

import yaml

from IPython.utils import io

In [9]:
import os
os.chdir('../../../')

In [10]:
def get_escape_lineplot(antibody, cohort):
    pickle_file=f'results/antibody_escape/{antibody}.pickle'

    # get polyclonal model
    with open(pickle_file, "rb") as f:
        model = pickle.load(f)
    f.close()
    
    # get config for plotting
    with open('data/polyclonal_config.yaml') as f:
        antibody_config = yaml.safe_load(f)[antibody] 
    f.close()
    
    # get df for plotting
    muteffects_csv = "results/muteffects_functional/muteffects_observed.csv"
    site_numbering_map = "data/site_map.csv"

    site_map = pd.read_csv(site_numbering_map).rename(columns={"reference_site": "site"})

    df_to_merge = [site_map]

    plot_kwargs = antibody_config["plot_kwargs"]

    muteffects = pd.read_csv(muteffects_csv).rename(
        columns={"reference_site": "site", "effect": "functional effect"}
    )[["site", "mutant", "functional effect"]]
    plot_kwargs["addtl_slider_stats"]["functional effect"] = -1.38

    plot_kwargs["addtl_slider_stats_hide_not_filter"] = []

    plot_kwargs["addtl_slider_stats_hide_not_filter"].append("functional effect")

    plot_kwargs["init_floor_at_zero"] = False

    df_to_merge.append(muteffects)    
    
    # generate plot
    plot = model.mut_escape_plot(
        df_to_merge=df_to_merge,
        **plot_kwargs,
        show_heatmap=False,
        show_zoombar=False,
        plot_title=str(antibody)
    )
    
    # save to dir for that cohort
    plot.save(
        f'scratch_notebooks/figure_drafts/escape_map_lineplots/{cohort}/{antibody}.png',
        scale_factor=2.0
    )   

In [11]:
sample_dict = {
    'children': [3944, 2389, 2323, 2388, 3973, 4299, 4584, 2367],
    'teenagers': [2350, 2365, 2382, 3866, 2380, 3856, 3857, 3862],
    'adults': ['33C', '34C', '197C', '199C', '215C', '210C', '74C', '68C', '150C', '18C'],
    'misc': ['AUSAB-13', 2462]
}

for cohort, serum_list in sample_dict.items():
    for serum in serum_list: 
        get_escape_lineplot(serum, cohort)