In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from patch_gnn.data import load_ghesquire
import pandas as pd
from pyprojroot import here
import pickle as pkl
from patch_gnn.splitting import train_test_split
from jax import random
from patch_gnn.seqops import one_hot
from patch_gnn.unirep import unirep_reps
from patch_gnn.graph import graph_tensors
from patch_gnn.models import MPNN, DeepMPNN
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import explained_variance_score as evs
import matplotlib.pyplot as plt 
from sklearn.metrics import mean_squared_error as mse

In [None]:
data = load_ghesquire()

graph_pickle_path = here() / "data/ghesquire_2011/graphs.pkl"

with open(graph_pickle_path, "rb") as f:
    graphs = pkl.load(f)
    
key = random.PRNGKey(490)

In [None]:
filtered = (
    data
    .query("`accession-sequence` in @graphs.keys()")
    .query("ox_fwd_logit < 0.0")
)

In [None]:
train_df, test_df = train_test_split(key, filtered)

train_oh = one_hot(train_df, 50)
test_oh = one_hot(test_df, 50)

train_unirep = unirep_reps(train_df)
test_unirep = unirep_reps(test_df)

train_graph = graph_tensors(train_df, graphs)
test_graph = graph_tensors(test_df, graphs)

In [None]:
train_target = train_df['ox_fwd_logit'].values
test_target = test_df['ox_fwd_logit'].values

In [None]:
models = {
    "mpnn": MPNN(node_feature_shape=(20, 61), num_adjacency=1, num_training_steps=200),
    "deep_mpnn": DeepMPNN(node_feature_shape=(20, 61), num_adjacency=1, num_training_steps=200),
    "rf_oh": RandomForestRegresor(n_estimators=300),
    "rf_unirep": RandomForestRegressor(n_estimators=300),
}

In [None]:
model_mpnn = MPNN(node_feature_shape=(20, 61), num_adjacency=1, num_training_steps=200)
model_mpnn.fit(train_graph, train_target)

model_deepmpnn = DeepMPNN(node_feature_shape=(20, 61), num_adjacency=1, num_training_steps=200)
model_deepmpnn.fit(train_graph, train_target)

model_rfoh = RandomForestRegressor(oob_score = True)
model_rfoh.fit(train_oh, train_target)

model_rf_unirep = RandomForestRegressor(oob_score = True)
model_rf_unirep.fit(train_unirep, train_target)

In [None]:
model_rfoh.oob_score_

In [None]:
model_rf_unirep.oob_score_

In [None]:
def plot_y_eq_x(ax):
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()
    
    minval = min(xmin, ymin)
    maxval = max(xmax, ymax)
    
    ax.plot([minval, maxval], [minval, maxval])

def plot_performance(model, trainX, trainY, testX, testY, model_name: str, ev_func):
    fig, ax = plt.subplots(figsize=(10, 5), nrows=1, ncols=2, sharex = True, sharey = True)
    train_preds = model.predict(trainX)
    ax[0].scatter(train_preds.squeeze(), trainY)
    ax[0].set_title(f"Model: {model_name}, Training Perf: {ev_func(trainY, train_preds.squeeze()):.3f}")
    plot_y_eq_x(ax[0])

    test_preds = model.predict(testX)
    ax[1].scatter(test_preds.squeeze(), testY)
    ax[1].set_title(f"Model: {model_name}, Testing Perf: {ev_func(testY, test_preds.squeeze()):.3f}")
    plot_y_eq_x(ax[1])

In [None]:
plot_performance(model_rfoh, train_oh, train_target, test_oh, test_target, 'One Hot rf, mse', mse)
plot_performance(model_rf_unirep, train_unirep, train_target, test_unirep, test_target, 'Unirep rf, mse', mse)
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, 'MPNN, mse', mse)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, 'Deep MPNN, mse', mse)

In [None]:
plot_performance(model_rfoh, train_oh, train_target, test_oh, test_target, 'One Hot rf, evs', evs)
plot_performance(model_rf_unirep, train_unirep, train_target, test_unirep, test_target, 'Unirep rf, evs', evs)
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, 'MPNN, evs', evs)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, 'Deep MPNN, evs', evs)