# Fit polyclonal model
Here we fit [polyclonal](https://jbloomlab.github.io/polyclonal) models to the data.

First, import Python modules:

In [1]:
import pickle

import altair as alt

import pandas as pd

import polyclonal

import yaml

In [2]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [3]:
import os
os.chdir('../../')

## Read input data

Get parameterized variable from [papermill](https://papermill.readthedocs.io/)

In [4]:
# papermill parameters cell (tagged as `parameters`)
prob_escape_csv = None
n_threads = None
pickle_file = None
antibody = None

In [5]:
# Parameters
prob_escape_csv = "results/prob_escape/libA_221021_1_1C04_1_prob_escape.csv"
pickle_file = "results/polyclonal_fits/libA_221021_1_1C04_1.pickle"
n_threads = 2


Read the probabilities of escape, and filter for those with sufficient no-antibody counts:

In [6]:
print(f"\nReading probabilities of escape from {prob_escape_csv}")

prob_escape = pd.read_csv(
    prob_escape_csv, keep_default_na=False, na_values="nan"
).query("`no-antibody_count` >= no_antibody_count_threshold")
assert prob_escape.notnull().all().all()


Reading probabilities of escape from results/prob_escape/libA_221021_1_1C04_1_prob_escape.csv


Read the rest of the configuration and input data:

In [7]:
# get information from config
with open("config.yaml") as f:
    config = yaml.safe_load(f)

antibody = prob_escape["antibody"].unique()
assert len(antibody) == 1, antibody
antibody = antibody[0]

# get site numbering map and the reference sites in order
site_numbering_map = pd.read_csv(config["site_numbering_map"])
reference_sites = site_numbering_map.sort_values("sequential_site")[
    "reference_site"
].tolist()

# get the polyclonal configuration for this antibody
with open(config["polyclonal_config"]) as f:
    polyclonal_config = yaml.safe_load(f)
if antibody not in polyclonal_config:
    raise ValueError(f"`polyclonal_config` lacks configuration for {antibody=}")
antibody_config = polyclonal_config[antibody]

# print names of variables and settings
print(f"{antibody=}")
print(f"{n_threads=}")
print(f"{pickle_file=}")
print(f"{antibody_config=}")

antibody='1C04'
n_threads=2
pickle_file='results/polyclonal_fits/libA_221021_1_1C04_1.pickle'
antibody_config={'min_epitope_activity_to_include': 0.2, 'plot_kwargs': {'addtl_slider_stats': {'times_seen': 3, 'functional effect': -1.38}, 'slider_binding_range_kwargs': {'n_models': {'step': 1}, 'times_seen': {'step': 1, 'min': 1, 'max': 25}}, 'heatmap_max_at_least': 2, 'heatmap_min_at_least': -2}, 'max_epitopes': 1, 'fit_kwargs': {'reg_escape_weight': 0.1, 'reg_spread_weight': 0.25, 'reg_activity_weight': 1.0}}


## Some summary statistics
Note that these statistics are only for the variants that passed upstream filtering in the pipeline.

Number of variants per concentration:

In [13]:
display(
    prob_escape.groupby("antibody_concentration").aggregate(
        n_variants=pd.NamedAgg("barcode", "nunique")
    )
)

Unnamed: 0_level_0,n_variants
antibody_concentration,Unnamed: 1_level_1
0.2,26662
0.4,26662
0.8,26662


In [12]:
prob_escape = prob_escape.loc[prob_escape['antibody_concentration'] != 0.10]

Plot mean probability of escape across all variants with the indicated number of mutations.
Note that this plot weights each variant the same in the means regardless of how many barcode counts it has.
We plot means for both censored (set to between 0 and 1) and uncensored probabilities of escape.
Also, note it uses a symlog scale for the y-axis.
Mouseover points for values:

In [16]:
max_aa_subs = 4  # group if >= this many substitutions

mean_prob_escape = (
    prob_escape.assign(
        n_subs=lambda x: (
            x["aa_substitutions_reference"]
            .str.split()
            .map(len)
            .clip(upper=max_aa_subs)
            .map(lambda n: str(n) if n < max_aa_subs else f">{max_aa_subs - 1}")
        )
    )
    .groupby(["antibody_concentration", "n_subs"], as_index=False)
    .aggregate({"prob_escape": "mean", "prob_escape_uncensored": "mean"})
    .rename(
        columns={
            "prob_escape": "censored to [0, 1]",
            "prob_escape_uncensored": "not censored",
        }
    )
    .melt(
        id_vars=["antibody_concentration", "n_subs"],
        var_name="censored",
        value_name="probability escape",
    )
)

mean_prob_escape_chart = (
    alt.Chart(mean_prob_escape)
    .encode(
        x=alt.X("antibody_concentration"),
        y=alt.Y(
            "probability escape",
            scale=alt.Scale(type="symlog", constant=0.05),
        ),
        column=alt.Column("censored", title=None),
        color=alt.Color("n_subs", title="n substitutions"),
        tooltip=[
            alt.Tooltip(c, format=".3g") if mean_prob_escape[c].dtype == float else c
            for c in mean_prob_escape.columns
        ],
    )
    .mark_line(point=True, size=0.5)
    .properties(width=200, height=125)
    .configure_axis(grid=False)
)

mean_prob_escape_chart

  for col_name, dtype in df.dtypes.iteritems():


## Fit `polyclonal` model
First, get the fitting related keyword arguments from the configuration passed by `snakemake`:

In [15]:
model = polyclonal.Polyclonal(
    n_epitopes=1,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.2,  # regularize epitope similarity
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 507 parameters at Wed Nov 30 12:51:39 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.025213       30655       30655           0           0           0              0               0     0.030412
         178      4.5892      496.92      469.04      20.222           0           0              0               0       7.6658
# Successfully finished at Wed Nov 30 12:51:43 2022.
# Starting optimization of 3254 parameters at Wed Nov 30 12:51:43 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.026712      728.83      533.93      187.23  5.7179e-31           0              0               0       7.6658
         168      4.5928      496.38       454.8      30.455      5.2278           0              0               0       5.9

  for col_name, dtype in df.dtypes.iteritems():


In [18]:
model.mut_escape_plot()

  for col_name, dtype in df.dtypes.iteritems():


In [17]:
model = polyclonal.Polyclonal(
    n_epitopes=1,
    data_to_fit=prob_escape.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_spread_weight=0.25,
    reg_activity_weight=1.0,
    reg_uniqueness_weight=0.2,  # regularize epitope similarity
)
model.mut_escape_plot()

# First fitting site-level model.
# Starting optimization of 507 parameters at Wed Nov 30 12:38:23 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.026844       39495       39494           0           0           0              0               0      0.51759
          90      2.6103      2414.1      2399.3      11.917           0           0              0               0       2.7982
# Successfully finished at Wed Nov 30 12:38:26 2022.
# Starting optimization of 3254 parameters at Wed Nov 30 12:38:26 2022.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.029281      2931.3      2854.8      73.692  9.2364e-32           0              0               0       2.7982
         162      5.0235        2633      2556.8      63.463      9.3513           0              0               0       3.3

  for col_name, dtype in df.dtypes.iteritems():


In [17]:
def reassign_b_factor(
    input_pdbfile,
    output_pdbfile,
    df,
    metric_col,
    *,
    site_col="site",
    chain_col="chain",
    missing_metric=0,
    model_index=0,
):

    # subset `df` to needed columns and error check it
    cols = [metric_col, site_col, chain_col]
    for col in cols:
        if col not in df.columns:
            raise ValueError(f"`df` lacks column {col}")
    df = df[cols].drop_duplicates()
    if len(df) != len(df.groupby([site_col, chain_col])):
        raise ValueError("non-unique metric for a site in a chain")

#     if df[site_col].dtype != int:
#         raise ValueError("function currently requires `site_col` to be int")

    # read PDB, catch warnings about discontinuous chains
    with warnings.catch_warnings():
        warnings.simplefilter(
            "ignore", category=Bio.PDB.PDBExceptions.PDBConstructionWarning
        )
        pdb = Bio.PDB.PDBParser().get_structure("_", input_pdbfile)

    # get the model out of the PDB
    model = list(pdb.get_models())[model_index]

    # make sure all chains in PDB
    missing_chains = set(df[chain_col]) - {chain.id for chain in model.get_chains()}
    if missing_chains:
        raise ValueError(f"`df` has chains not in PDB: {missing_chains}")

    # make missing_metric a dict if it isn't already
    if not isinstance(missing_metric, dict):
        missing_metric = {chain.id: missing_metric for chain in model.get_chains()}

    # loop over all chains and do coloring
    for chain in model.get_chains():
        chain_id = chain.id
        site_to_val = (
            df.query(f"{chain_col} == @chain_id")
            .set_index(site_col)[metric_col]
            .to_dict()
        )
        for residue in chain:
            site = residue.get_id()[1]
            try:
                metric_val = site_to_val[site]
            except KeyError:
                metric_val = missing_metric[chain_id]
            # for disordered residues, get list of them
            try:
                residuelist = residue.disordered_get_list()
            except AttributeError:
                residuelist = [residue]
            for r in residuelist:
                for atom in r:
                    # for disordered atoms, get list of them
                    try:
                        atomlist = atom.disordered_get_list()
                    except AttributeError:
                        atomlist = [atom]
                    for a in atomlist:
                        a.bfactor = metric_val

    # write PDB
    io = Bio.PDB.PDBIO()
    io.set_structure(pdb)
    io.save(output_pdbfile)

In [19]:
site_summary_df = model.mut_escape_site_summary_df()
site_summary_df['site'] = site_summary_df['site'].astype(float)
                                      
site_summary_df = site_summary_df.loc[(site_summary_df['site'] < 326) &
                                      (site_summary_df['site'] > 0)
                                     ]
                                      
site_summary_df['site'] = site_summary_df['site'].astype(int)
site_summary_df['chain'] = 'A'

site_summary_df['sum'] = site_summary_df['total positive'] + site_summary_df['total negative']

In [20]:
import warnings
import Bio

reassign_b_factor(input_pdbfile='scratch_notebooks/221111_model-fitting/4o5n.pdb',
                  output_pdbfile='scratch_notebooks/221130_summary_figures/libA_1c04_b-factors_sum.pdb',
                  df=site_summary_df,
                  metric_col='sum',
                  site_col="site",
                  chain_col="chain",
                  missing_metric=0,
                  model_index=0,)