In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from patch_gnn.data import load_ghesquire
import pandas as pd
from pyprojroot import here
import pickle as pkl
from patch_gnn.splitting import train_test_split
from jax import random
from patch_gnn.seqops import one_hot
from patch_gnn.unirep import unirep_reps
from patch_gnn.graph import graph_tensors
from patch_gnn.models import MPNN, DeepMPNN
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import explained_variance_score as evs
import matplotlib.pyplot as plt 
from sklearn.metrics import mean_squared_error as mse
import pickle as pkl
from patch_gnn.graph import met_position




In [3]:
data = load_ghesquire()

graph_pickle_path = here() / "data/ghesquire_2011/graphs.pkl"

with open(graph_pickle_path, "rb") as f:
    graphs = pkl.load(f)

key = random.PRNGKey(490)

In [4]:
filtered = (
    data
    .query("`accession-sequence` in @graphs.keys()")
    .query("ox_fwd_logit < 2.0")
    .join_apply(met_position, "met_position")
)

In [5]:
train_df, test_df = train_test_split(key, filtered)

In [6]:
len(train_df), len(test_df)

(258, 111)

In [7]:
train_oh = one_hot(train_df, 50)
test_oh = one_hot(test_df, 50)

train_unirep = unirep_reps(train_df)
test_unirep = unirep_reps(test_df)

train_graph = graph_tensors(train_df, graphs)
test_graph = graph_tensors(test_df, graphs)

In [8]:
with open(here() / "data/ghesquire_2011/sasa.pkl", "rb") as f:
    sasa_dfs = pkl.load(f)

In [9]:
sasa_dfs["O15305"]

Unnamed: 0,ResidNe,Chain,ResidNr,iCode,Phob/A^2,Phil/A^2,SASA/A^2,Q(SASA),N(overl),Surf/A^2
0,PRO,-,4,-,99.79,7.32,107.12,0.6476,348,165.4
1,GLY,-,5,-,37.79,27.27,65.06,0.5368,192,121.2
2,PRO,-,6,-,96.03,4.12,100.15,0.6055,400,165.4
3,ALA,-,7,-,10.52,6.56,17.09,0.1184,452,144.3
4,LEU,-,8,-,7.74,2.43,10.17,0.0647,898,157.1
...,...,...,...,...,...,...,...,...,...,...
238,GLU,-,242,-,33.31,36.27,69.59,0.3455,479,201.4
239,LEU,-,243,-,112.63,16.03,128.66,0.8190,328,157.1
240,LEU,-,244,-,19.86,13.01,32.87,0.2092,440,157.1
241,PHE,-,245,-,27.41,7.35,34.76,0.2029,588,171.3


In [10]:
def linear_model_data(df, sasa_dfs):
    linear = []
    for acc, pos in zip(df["accession"], df["met_position"]):
        feats = sasa_dfs[acc].query("ResidNr == @pos")[["SASA/A^2", "N(overl)"]]
        linear.append(feats)

    return pd.concat(linear)

train_linear = linear_model_data(train_df, sasa_dfs)
test_linear = linear_model_data(test_df, sasa_dfs)

In [11]:
train_target = train_df['ox_fwd_logit'].values
test_target = test_df['ox_fwd_logit'].values

In [12]:
### try new fluc descriptor anm and nma

In [65]:
with open(here() / "data/ghesquire_2011/ANM.pkl", "rb") as f:
    sasa_dfs = pkl.load(f)

In [66]:
len(sasa_dfs)

810

In [67]:
sasa_dfs['P05386']

Unnamed: 0,chains,resnos,resids,fluc
0,,1,MET,inf
1,,2,ALA,inf
2,,3,SER,inf
3,,4,VAL,inf
4,,5,SER,inf
...,...,...,...,...
109,,110,PHE,inf
110,,111,GLY,inf
111,,112,LEU,inf
112,,113,PHE,inf


In [68]:
import numpy as np
import math
def linear_model_data(df, sasa_dfs):
    linear = []
    for acc, pos in zip(df["accession"], df["met_position"]):
        #print(acc)
        sasa_dfs[acc] = sasa_dfs[acc].replace(np.inf, 100000000)
        feats = sasa_dfs[acc].query("resnos == @pos")[["fluc"]]
        #sasa_dfs[acc]['fluc'] = sasa_dfs[acc]['fluc'].apply(lambda x: x if not math.isinf(x) else 1000000 )
        
        #sasa_dfs[acc] = sasa_dfs[acc].replace(to_replace=r'inf', value='10000000', regex=True)
        #sasa_dfs[acc]['fluc'] = sasa_dfs[acc]['fluc'].apply(lambda x: x if type(x) != float64 else 1000000 )
        #sasa_dfs[acc] = sasa_dfs[acc].replace({'fluc': 'inf'}, {'fluc': '1000000000'}, regex=True)
        #sasa_dfs[acc]['fluc'] = np.where('fluc' > 1000000000, 10000000, 'fluc') 
        linear.append(feats)

    return pd.concat(linear)

train_linear = linear_model_data(train_df, sasa_dfs)
test_linear = linear_model_data(test_df, sasa_dfs)

In [69]:
test_linear

Unnamed: 0,fluc
391,0.498642
89,0.592851
68,0.646654
202,0.785049
74,7.825681
...,...
124,0.610817
115,0.301053
154,0.526359
21,8.188917


In [71]:
#sasa_dfs['P05386'].replace([np.inf, -np.inf], 100000000)

In [72]:
train_target = train_df['ox_fwd_logit'].values
test_target = test_df['ox_fwd_logit'].values

In [73]:
len(filtered)

369

In [74]:
num_training_steps = 5000

models = {
    "mpnn": MPNN(
        node_feature_shape=(20, 65),
        num_adjacency=1,
        num_training_steps=num_training_steps
    ),
    "deep_mpnn": DeepMPNN(
        node_feature_shape=(20, 65),
        num_adjacency=1,
        num_training_steps=num_training_steps
    ),
    "rf_oh": RandomForestRegressor(n_estimators=300),
    "rf_unirep": RandomForestRegressor(n_estimators=300),
}

In [75]:
from sklearn.linear_model import LinearRegression

In [76]:
model_linear = LinearRegression()
model_linear.fit(train_linear, train_target)

LinearRegression()

In [None]:
model_mpnn = MPNN(
    node_feature_shape=(20, 65),
    num_adjacency=1,
    num_training_steps=num_training_steps
)
model_mpnn.fit(train_graph, train_target)

  0%|          | 0/5000 [00:00<?, ?it/s]

In [None]:
model_deepmpnn = DeepMPNN(
    node_feature_shape=(20, 65),
    num_adjacency=1,
    num_training_steps=num_training_steps
)
model_deepmpnn.fit(train_graph, train_target)

In [None]:
model_rfoh = RandomForestRegressor(oob_score=True, n_jobs=-1)
model_rfoh.fit(train_oh, train_target)

model_rf_unirep = RandomForestRegressor(oob_score=True, n_jobs=-1)
model_rf_unirep.fit(train_unirep, train_target)

In [None]:
def plot_y_eq_x(ax):
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()

    minval = min(xmin, ymin)
    maxval = max(xmax, ymax)

    ax.plot([minval, maxval], [minval, maxval])

def plot_performance(
    model,
    trainX, trainY, testX, testY,
    model_name: str,
    ev_func,
    checkpoint: int = None,
):
    fig, ax = plt.subplots(
        figsize=(10, 5), nrows=1, ncols=2, sharex=True, sharey=True,
    )

    if isinstance(model, MPNN):
        train_preds = model.predict(trainX, checkpoint=checkpoint)
    else:
        train_preds = model.predict(trainX)
    ax[0].scatter(trainY, train_preds.squeeze())
    ax[0].set_title(f"Model: {model_name}, Training Perf: {ev_func(trainY, train_preds.squeeze()):.3f}")
    plot_y_eq_x(ax[0])

    if isinstance(model, MPNN):
        test_preds = model.predict(testX, checkpoint=checkpoint)
    else:
        test_preds = model.predict(testX)
    ax[1].scatter(testY, test_preds.squeeze())
    ax[1].set_title(f"Model: {model_name}, Testing Perf: {ev_func(testY, test_preds.squeeze()):.3f}")
    plot_y_eq_x(ax[1])

## Baseline Model Performance

In [None]:
plot_performance(model_linear, train_linear, train_target, test_linear, test_target, 'Linear, evs', evs)


In [None]:
plot_performance(model_rfoh, train_oh, train_target, test_oh, test_target, 'One Hot rf, evs', evs)
plot_performance(model_rf_unirep, train_unirep, train_target, test_unirep, test_target, 'Unirep rf, evs', evs)

## Which checkpoint for NN models?

In [None]:
plt.plot(model_mpnn.loss_history)
plt.yscale("log")

In [None]:
plt.plot(model_deepmpnn.loss_history)
plt.yscale("log")

## Checkpoint at 400 steps

In [None]:
checkpoint = 400
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


## Checkpoint at 600 steps

In [None]:
checkpoint = 600
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


In [None]:
checkpoint = 1000
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


## Checkpoint at Final steps

In [None]:
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, 'MPNN, evs', evs, checkpoint=-1)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, 'Deep MPNN, evs', evs, checkpoint=-1)