## MBP Protein NMR Example

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import maxent

sns.set_context("paper")
sns.set_style(
    "whitegrid",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#333333",
        "ytick.color": "#333333",
    },
)
# plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"
colors = ["#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e"]
import pynmrstar
from functools import partialmethod
from tqdm import tqdm

tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

In [None]:
# load data from brmrb
bmrb = pynmrstar.Entry.from_database(20062, convert_data_types=True)
cs_result_sets = []
for chemical_shift_loop in bmrb.get_loops_by_category("Atom_chem_shift"):
    cs_result_sets.append(
        chemical_shift_loop.get_tag(
            ["Comp_index_ID", "Comp_ID", "Atom_ID", "Atom_type", "Val", "Val_err"]
        )
    )
ref_data = pd.DataFrame(
    cs_result_sets[0], columns=["id", "res", "atom", "type", "shift", "error"]
)

ref_resids = ref_data[ref_data.atom == "H"].id.values
ref_data[ref_data.atom == "H"].head(25)

In [None]:
HAVE_MD_FILE = False

ref_hdata = ref_data[ref_data.atom == "H"]
# cut GLU because proton type mismatch
ref_hdata = ref_hdata["shift"].values[1:].astype(float)
resnames = ref_data[ref_data.atom == "H"].res[1:]
if HAVE_MD_FILE:
    data = pd.read_csv("./cs.csv")
    data.head(10)
    # only need weights, so we extract only shifts that will be biased
    hdata_df = data[data.names == "HN"]
    hdata_df = hdata_df[hdata_df["resids"].isin(ref_resids)]
    hdata_c = hdata_df.confident.values.reshape(len(data.frame.unique()), -1)
    hdata = hdata_df.peaks.values.reshape(len(data.frame.unique()), -1)
    assert hdata.shape[-1] == ref_hdata.shape[0]
    np.savez("mbp_files/mbp_cs.npz", hdata=hdata, hdata_c=hdata_c)
data = np.load("mbp_files/mbp_cs.npz")
hdata, hdata_c = data["hdata"], data["hdata_c"]

In [None]:
plt.plot(np.mean(hdata, axis=0), "o-")
plt.plot(ref_hdata, "o-")
plt.show()

In [None]:
# fill in unconfident peaks with mean
hdata_m = np.sum(hdata * hdata_c, axis=0) / np.sum(hdata_c, axis=0)
total_fill = 0
for i in range(hdata.shape[1]):
    hdata[:, i][~hdata_c[:, i]] = hdata_m[i]
    total_fill += np.sum(~hdata_c[:, i])
print("Filled", total_fill)

In [None]:
plt.plot(np.mean(hdata, axis=0), "o-")
plt.plot(ref_hdata, "o-")
plt.show()

In [None]:
# make restraints
restraints = []
do_restrain = range(len(ref_hdata) // 2)
for i in do_restrain:
    restraints.append(
        maxent.Restraint(lambda h, i=i: h[i], ref_hdata[i], prior=maxent.Laplace(0.05))
    )

In [None]:
model = maxent.MaxentModel(restraints)
model.compile(tf.keras.optimizers.Adam(0.1), "mean_squared_error")
history = model.fit(hdata, epochs=500, verbose=0)

In [None]:
plt.plot(history.history["loss"])
print(history.history["loss"][-1])

In [None]:
np.mean(np.abs(np.sum(hdata * model.traj_weights[..., np.newaxis], axis=0) - ref_hdata))

In [None]:
model.lambdas

In [None]:
plt.plot(model.traj_weights)

In [None]:
plt.figure(figsize=(3, 2), dpi=300)
seq_dict = {
    "CYS": "C",
    "ASP": "D",
    "SER": "S",
    "GLN": "Q",
    "LYS": "K",
    "ILE": "I",
    "PRO": "P",
    "THR": "T",
    "PHE": "F",
    "ASN": "N",
    "GLY": "G",
    "HIS": "H",
    "LEU": "L",
    "ARG": "R",
    "TRP": "W",
    "ALA": "A",
    "VAL": "V",
    "GLU": "E",
    "TYR": "Y",
    "MET": "M",
}
plt.plot(
    np.sum(hdata * model.traj_weights[..., np.newaxis], axis=0), "o-", label="Posterior"
)
plt.plot(np.mean(hdata, axis=0), "o-", label="Prior")
plt.plot(ref_hdata, "*", label="Experiment")
plt.axvline(x=len(ref_hdata) // 2 - 0.5, color="gray", linestyle="--")
plt.xticks(range(len(ref_hdata)), [seq_dict[r] for r in resnames])
plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.8))
plt.text(len(ref_hdata) // 5, 8.55, "Biased")
plt.text(len(ref_hdata) // 2, 8.55, "Unbiased")
plt.xlabel("Sequence")
plt.ylabel("Chemical Shift [ppm]")
plt.savefig("protein.pdf")

In [None]:
print("most favored clusters", np.argsort(model.traj_weights)[-3:])