In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from patch_gnn.data import load_ghesquire
import pandas as pd
from pyprojroot import here
import pickle as pkl
from patch_gnn.splitting import train_test_split
from jax import random
from patch_gnn.seqops import one_hot
from patch_gnn.unirep import unirep_reps
from patch_gnn.graph import graph_tensors
from patch_gnn.models import MPNN, DeepMPNN
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import explained_variance_score as evs
import matplotlib.pyplot as plt 
from sklearn.metrics import mean_squared_error as mse
import pickle as pkl
from patch_gnn.graph import met_position


In [None]:
data = load_ghesquire()

graph_pickle_path = here() / "data/ghesquire_2011/graphs.pkl"

with open(graph_pickle_path, "rb") as f:
    graphs = pkl.load(f)

key = random.PRNGKey(490)

In [None]:
filtered = (
    data
    .query("`accession-sequence` in @graphs.keys()")
    .query("ox_fwd_logit < 2.0")
    .join_apply(met_position, "met_position")
)

In [None]:
train_df, test_df = train_test_split(key, filtered)

In [None]:
len(train_df), len(test_df)

In [None]:
train_oh = one_hot(train_df, 50)
test_oh = one_hot(test_df, 50)

train_unirep = unirep_reps(train_df)
test_unirep = unirep_reps(test_df)

train_graph = graph_tensors(train_df, graphs)
test_graph = graph_tensors(test_df, graphs)

In [None]:
with open(here() / "data/ghesquire_2011/sasa.pkl", "rb") as f:
    sasa_dfs = pkl.load(f)

In [None]:
sasa_dfs["O15305"]

In [None]:
def linear_model_data(df, sasa_dfs):
    linear = []
    for acc, pos in zip(df["accession"], df["met_position"]):
        feats = sasa_dfs[acc].query("ResidNr == @pos")[["SASA/A^2", "N(overl)"]]
        linear.append(feats)

    return pd.concat(linear)

train_linear = linear_model_data(train_df, sasa_dfs)
test_linear = linear_model_data(test_df, sasa_dfs)

In [None]:
train_target = train_df['ox_fwd_logit'].values
test_target = test_df['ox_fwd_logit'].values

In [None]:
len(filtered)

In [None]:
num_training_steps = 5000

models = {
    "mpnn": MPNN(
        node_feature_shape=(20, 65),
        num_adjacency=1,
        num_training_steps=num_training_steps
    ),
    "deep_mpnn": DeepMPNN(
        node_feature_shape=(20, 65),
        num_adjacency=1,
        num_training_steps=num_training_steps
    ),
    "rf_oh": RandomForestRegressor(n_estimators=300),
    "rf_unirep": RandomForestRegressor(n_estimators=300),
}

In [None]:
from sklearn.linear_model import LinearRegression

In [None]:
model_linear = LinearRegression()
model_linear.fit(train_linear, train_target)

In [None]:
model_mpnn = MPNN(
    node_feature_shape=(20, 65),
    num_adjacency=1,
    num_training_steps=num_training_steps
)
model_mpnn.fit(train_graph, train_target)

In [None]:
model_deepmpnn = DeepMPNN(
    node_feature_shape=(20, 65),
    num_adjacency=1,
    num_training_steps=num_training_steps
)
model_deepmpnn.fit(train_graph, train_target)

In [None]:
model_rfoh = RandomForestRegressor(oob_score=True, n_jobs=-1)
model_rfoh.fit(train_oh, train_target)

model_rf_unirep = RandomForestRegressor(oob_score=True, n_jobs=-1)
model_rf_unirep.fit(train_unirep, train_target)

In [None]:
def plot_y_eq_x(ax):
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()

    minval = min(xmin, ymin)
    maxval = max(xmax, ymax)

    ax.plot([minval, maxval], [minval, maxval])

def plot_performance(
    model,
    trainX, trainY, testX, testY,
    model_name: str,
    ev_func,
    checkpoint: int = None,
):
    fig, ax = plt.subplots(
        figsize=(10, 5), nrows=1, ncols=2, sharex=True, sharey=True,
    )

    if isinstance(model, MPNN):
        train_preds = model.predict(trainX, checkpoint=checkpoint)
    else:
        train_preds = model.predict(trainX)
    ax[0].scatter(trainY, train_preds.squeeze())
    ax[0].set_title(f"Model: {model_name}, Training Perf: {ev_func(trainY, train_preds.squeeze()):.3f}")
    plot_y_eq_x(ax[0])

    if isinstance(model, MPNN):
        test_preds = model.predict(testX, checkpoint=checkpoint)
    else:
        test_preds = model.predict(testX)
    ax[1].scatter(testY, test_preds.squeeze())
    ax[1].set_title(f"Model: {model_name}, Testing Perf: {ev_func(testY, test_preds.squeeze()):.3f}")
    plot_y_eq_x(ax[1])

## Baseline Model Performance

In [None]:
plot_performance(model_linear, train_linear, train_target, test_linear, test_target, 'Linear, evs', evs)


In [None]:
plot_performance(model_rfoh, train_oh, train_target, test_oh, test_target, 'One Hot rf, evs', evs)
plot_performance(model_rf_unirep, train_unirep, train_target, test_unirep, test_target, 'Unirep rf, evs', evs)

## Which checkpoint for NN models?

In [None]:
plt.plot(model_mpnn.loss_history)
plt.yscale("log")

In [None]:
plt.plot(model_deepmpnn.loss_history)
plt.yscale("log")

## Checkpoint at 400 steps

In [None]:
checkpoint = 400
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


## Checkpoint at 600 steps

In [None]:
checkpoint = 600
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


In [None]:
checkpoint = 1000
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


## Checkpoint at Final steps

In [None]:
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, 'MPNN, evs', evs, checkpoint=-1)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, 'Deep MPNN, evs', evs, checkpoint=-1)