In [20]:
import altair as alt

import pandas as pd

import polyclonal

import warnings
warnings.filterwarnings('ignore')

import Bio

In [3]:
import os
os.chdir('../../')

In [76]:
def reassign_b_factor(
    input_pdbfile,
    output_pdbfile,
    df,
    metric_col,
    *,
    site_col="site",
    chain_col="chain",
    missing_metric=0,
    model_index=0,
):

    # subset `df` to needed columns and error check it
    cols = [metric_col, site_col, chain_col]
    for col in cols:
        if col not in df.columns:
            raise ValueError(f"`df` lacks column {col}")
    df = df[cols].drop_duplicates()
#     if len(df) != len(df.groupby([site_col, chain_col])):
#         raise ValueError("non-unique metric for a site in a chain")

    if df[site_col].dtype != int:
        # if we have string type, convert to int
        if df[site_col].map(type).eq(str).all():
            encodes_int = df[site_col].str.fullmatch(r"\d+")
            if encodes_int.all():
                df[site_col] = df[site_col].astype(int)
            else:
                # this may raise an error if there are sites like 214a; before fixing
                # such errors, need to check the `residue.get_id()[1]` command below
                raise ValueError(
                    f"`site_col` has non-integer entries:\n{df[site_col][~encodes_int]}"
                )
        else:
            raise ValueError(f"`site_col` is neither str nor int:\n{df[site_col]}")

    # read PDB, catch warnings about discontinuous chains
    with warnings.catch_warnings():
        warnings.simplefilter(
            "ignore", category=Bio.PDB.PDBExceptions.PDBConstructionWarning
        )
        pdb = Bio.PDB.PDBParser().get_structure("_", input_pdbfile)

    # get the model out of the PDB
    model = list(pdb.get_models())[model_index]

    # make sure all chains in PDB
    missing_chains = set(df[chain_col]) - {chain.id for chain in model.get_chains()}
    if missing_chains:
        raise ValueError(f"`df` has chains not in PDB: {missing_chains}")

    # make missing_metric a dict if it isn't already
    if not isinstance(missing_metric, dict):
        missing_metric = {chain.id: missing_metric for chain in model.get_chains()}

    # loop over all chains and do coloring
    for chain in model.get_chains():
        chain_id = chain.id
        site_to_val = (
            df.query(f"{chain_col} == @chain_id")
            .set_index(site_col)[metric_col]
            .to_dict()
        )
        for residue in chain:
            site = residue.get_id()[1]
            try:
                metric_val = site_to_val[site]
            except KeyError:
                metric_val = missing_metric[chain_id]
            # for disordered residues, get list of them
            try:
                residuelist = residue.disordered_get_list()
            except AttributeError:
                residuelist = [residue]
            for r in residuelist:
                for atom in r:
                    # for disordered atoms, get list of them
                    try:
                        atomlist = atom.disordered_get_list()
                    except AttributeError:
                        atomlist = [atom]
                    for a in atomlist:
                        a.bfactor = metric_val

    # write PDB
    io = Bio.PDB.PDBIO()
    io.set_structure(pdb)
    io.save(output_pdbfile)

In [77]:
def get_b_factor_sum(model, serum):
    df = model.mut_escape_site_summary_df()
    df['site'] = df['site'].astype(int)
    df['sum'] = df['total positive'] + df['total negative']
    
    
    df = model.mut_escape_site_summary_df(min_times_seen=3)
    df['site'] = df['site'].astype(int)
    df['sum'] = df['total positive'] + df['total negative']
    df = df.loc[df['site'] < 326]
    df['chain'] = 'A'
#     return df
#     chains = ['A']
    metric_col = 'sum'
    
#     if isinstance(chains, str) and len(chains) == 1:
#         chains = [chains]
#     df = pd.concat([df.assign(chain=chain) for chain in chains], ignore_index=True)
    
    result_files = []

    for epitope in model.epitopes:
        output_pdbfile = f'scratch_notebooks/221227_model_fitting/{serum}_sum_epitope_{epitope}.pdb'
        output_pdbfile = output_pdbfile.format(
            epitope=epitope
        ).replace(" ", "_")
        if os.path.dirname(output_pdbfile):
            os.makedirs(os.path.dirname(output_pdbfile), exist_ok=True)
        result_files.append((epitope, output_pdbfile))
        reassign_b_factor(input_pdbfile='data/PDBs/4o5n.pdb',
                          output_pdbfile=output_pdbfile,
                          df = df.query("epitope == @epitope"),
                          metric_col='sum',
                          site_col='site',
                          chain_col='chain',
                          missing_metric=0,
                          model_index=0
                         )
    
    return pd.DataFrame(result_files, columns=["epitope", "PDB file"])

In [78]:
df = get_b_factor_sum(model, 'AUSAB13')
# df.to_csv('scratch_notebooks/221227_model_fitting/test_df_2.csv')

In [80]:
import matplotlib.colors

for epitope, hex_color in model.epitope_colors.items():
    rgb = [round(val, 3) for val in matplotlib.colors.to_rgb(hex_color)]
    print(f"{epitope}: hex color is {hex_color}; RGB tuple is {rgb}")

1: hex color is #0072B2; RGB tuple is [0.0, 0.447, 0.698]
2: hex color is #CC79A7; RGB tuple is [0.8, 0.475, 0.655]


In [85]:
rgb = [round(val, 3) for val in matplotlib.colors.to_rgb('#E69F00')]
rgb

[0.902, 0.624, 0.0]

In [8]:
spatial_distances = polyclonal.pdb_utils.inter_residue_distances(
    "data/PDBs/4o5n.pdb",
    target_chains=["A", "B"],
)

spatial_distances

Unnamed: 0,site_1,site_2,distance,chain_1,chain_2
0,9,10,1.328212,A,A
1,9,11,3.469929,B,B
2,9,12,6.336130,B,B
3,9,13,9.189821,B,B
4,9,14,8.930696,B,A
...,...,...,...,...,...
260276,497,499,15.936294,B,B
260277,497,500,16.632641,B,B
260278,498,499,23.859705,B,B
260279,498,500,13.285421,B,B


In [9]:
prob_escape = pd.read_csv(
    "results/prob_escape/libA_221223_1_AUSAB-05_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape.notnull().all().all()

prob_escape_filtered_05 = prob_escape.loc[(prob_escape['antibody_concentration'] == 0.0074) |
                                          (prob_escape['antibody_concentration'] == 0.0111)
                                         ]

In [10]:
model = polyclonal.Polyclonal(
    n_epitopes=1,
    data_to_fit=prob_escape_filtered_05.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
)

# display results
display(model.activity_wt_barplot())
display(model.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 503 parameters at Thu Jan  5 14:47:35 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0     0.01971       44868       44865           0           0           0              0               0       3.6049
          42      1.2909      612.49      606.83      2.1069           0           0              0               0       3.5517
# Successfully finished at Thu Jan  5 14:47:36 2023.
# Starting optimization of 3244 parameters at Thu Jan  5 14:47:36 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.021703      774.35      748.62      22.181  8.0512e-33           0              0               0       3.5517
          64      1.6005      746.34      739.99      2.1956    0.088411           0              0               0       4.0

In [21]:
get_b_factor_sum(model, 'AUSAB05')

Unnamed: 0,epitope,site,wildtype,mean,total positive,max,min,total negative,n mutations,sum,chain
0,1,-2,D,-0.015089,0.000000,-0.015089,-0.015089,-0.015089,1,-0.015089,A
1,1,1,Q,0.003310,0.012187,0.012187,-0.005568,-0.005568,2,0.006619,A
2,1,2,K,0.006951,0.006951,0.006951,0.006951,0.000000,1,0.006951,A
3,1,3,I,-0.003990,0.150278,0.039109,-0.077150,-0.222093,18,-0.071815,A
4,1,4,P,0.003146,0.219551,0.080116,-0.029464,-0.159781,19,0.059770,A
...,...,...,...,...,...,...,...,...,...,...,...
823,1,533,G,-0.009041,0.000000,-0.009041,-0.009041,-0.009041,1,-0.009041,B
824,1,536,M,0.009438,0.009438,0.009438,0.009438,0.000000,1,0.009438,B
825,1,537,W,-0.027538,0.000000,-0.009128,-0.045949,-0.055077,2,-0.055077,B
826,1,538,A,-0.001626,0.022949,0.022949,-0.020690,-0.029454,4,-0.006505,B


In [22]:
prob_escape_13 = pd.read_csv(
    "results/prob_escape/libA_221027_1_AUSAB-13_1_prob_escape.csv", keep_default_na=False, na_values="nan"
).query(
    "`no-antibody_count` >= no_antibody_count_threshold"
)  # filter for those with sufficient no-antibody counts
assert prob_escape_13.notnull().all().all()

In [24]:
reference_sites = pd.read_csv("data/site_map.csv")["reference_site"].tolist()
model = polyclonal.Polyclonal(
    n_epitopes=2,
    data_to_fit=prob_escape_13.rename(
        columns={
            "antibody_concentration": "concentration",
            "aa_substitutions_reference": "aa_substitutions",
        }
    ),
    alphabet=polyclonal.AAS_WITHSTOP_WITHGAP,
    sites=reference_sites,
    spatial_distances=spatial_distances,
)

# fit model
opt_res = model.fit(
    logfreq=200,
    reg_escape_weight=0.1,
    reg_uniqueness_weight=0,
    reg_uniqueness2_weight=1,
    reg_spatial_weight=0.0,
    reg_spatial2_weight=0.0005,
)

# display results
display(model.activity_wt_barplot())
display(model.mut_escape_plot(addtl_slider_stats={"times_seen": 3}, init_floor_at_zero=False))

# First fitting site-level model.
# Starting optimization of 1006 parameters at Thu Jan  5 15:06:39 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0    0.087781  1.6416e+05  1.6415e+05           0           0           0              0               0       11.001
          82      9.5339      848.47      837.86      1.7832           0      3.5513              0         0.15505        5.127
# Successfully finished at Thu Jan  5 15:06:48 2023.
# Starting optimization of 6474 parameters at Thu Jan  5 15:06:48 2023.
        step    time_sec        loss    fit_loss  reg_escape  reg_spread reg_spatial reg_uniqueness reg_uniqueness2 reg_activity
           0     0.12317      1093.1      1006.5      19.327  2.2396e-32      3.5513              0          58.611        5.127
         162      18.598      949.25      920.66      14.826        1.79      6.2682              0         0.58849       5.

In [29]:
get_b_factor_sum(model, 'AUSAB13')

ValueError: non-unique metric for a site in a chain

In [13]:
model.mut_escape_pdb_b_factor(
    input_pdbfile="data/PDBs/4o5n.pdb",
    chains=["A", "B"],
    metric="mean",
    outfile="scratch_notebooks/221227_model_fitting/AUSAB-05_mean_epitope1.pdb"
)

Unnamed: 0,epitope,PDB file
0,1,scratch_notebooks/221227_model_fitting/AUSAB-0...


In [12]:
df = model.mut_escape_site_summary_df()
df

Unnamed: 0,epitope,site,wildtype,mean,total positive,max,min,total negative,n mutations
0,1,-9,A,-0.003663,0.000000,-0.003663,-0.003663,-0.003663,1
1,1,-4,D,-0.002608,0.000000,-0.002608,-0.002608,-0.002608,1
2,1,-3,A,-0.002419,0.000000,-0.000600,-0.004238,-0.004839,2
3,1,-2,D,-0.004222,0.002369,0.002369,-0.015089,-0.023477,5
4,1,-1,T,-0.001964,0.000000,-0.001329,-0.002599,-0.003929,2
...,...,...,...,...,...,...,...,...,...
488,1,537,W,-0.027538,0.000000,-0.009128,-0.045949,-0.055077,2
489,1,538,A,-0.001626,0.022949,0.022949,-0.020690,-0.029454,4
490,1,539,C,-0.002923,0.000000,-0.002923,-0.002923,-0.002923,1
491,1,540,Q,-0.006883,0.000000,-0.000407,-0.013782,-0.020649,3


In [None]:
poly_abs.mut_escape_pdb_b_factor(
    input_pdbfile="6M0J.pdb",
    chains="E",
    metric="mean",
    outfile="RBD_{metric}_{epitope}.pdb",
)