In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
import pickle as pkl

In [None]:
from pyprojroot import here

graph_pickle_path = here() / "data/ghesquire_2011/graphs.pkl"


with open(graph_pickle_path, "rb") as f:
    graphs = pkl.load(f)

In [None]:
import networkx as nx

nx.draw(graphs["P15374-FLEESVSMSPEER"], with_labels=True)

In [None]:
from patch_gnn.data import load_ghesquire

data = load_ghesquire()

In [None]:
data["accession"] = data["accession"].fillna(method="ffill")

In [None]:
filtered = (
    data
    .concatenate_columns(["accession", "sequence"], "accession-sequence")
    .query("`accession-sequence` in @graphs.keys()").drop_duplicates(subset=["accession", "end"])
    .query("ox_fwd_logit < 0.0")
)

In [None]:
filtered

In [None]:
import jax.numpy as np
import matplotlib.pyplot as plt


def ecdf_scatter(data):
    x, y = np.sort(data), np.arange(1, len(data)+1) / len(data)
    plt.scatter(x, y)
    plt.show()

ecdf_scatter(filtered["ox_fwd_logit"].values)

In [None]:
graphs

In [None]:
filtered

In [None]:
from patch_gnn.graph import graph_tensors

F, A = graph_tensors(filtered, graphs)
F.shape, A.shape

In [None]:
import pandas as pd

aa_props = pd.read_csv(here() / "data/amino_acid_properties.csv", index_col=0)
aa_props

In [None]:
from patch_gnn.graph import generate_feature_dataframe

def featurize_aa_props(n, d, aa_props):
    return pd.Series(aa_props[d["residue_name"]], name=n)


funcs = [
    lambda n, d: featurize_aa_props(n, d, aa_props)
]

generate_feature_dataframe(graphs["P15374-FLEESVSMSPEER"], funcs=funcs)


In [None]:
max(len(g) for _, g in graphs.items())

In [None]:
import jax.numpy as np

from tqdm.auto import tqdm
from patch_gnn.graph import prep_features
feats = dict()
for acc, g in tqdm(graphs.items()):
    feat = np.array(generate_feature_dataframe(g, funcs).values)
    feat = prep_features(feat, 20)
    feats[acc] = feat


In [None]:
from patch_gnn.graph import prep_adjacency_matrix

In [None]:
import networkx as nx

adjs = dict()
for acc, g in tqdm(graphs.items()):
    a = np.expand_dims(np.array(nx.adjacency_matrix(g).todense()), 2)
    a = prep_adjacency_matrix(a, 20)
    adjs[acc] = a

In [None]:
np.stack(list(adjs.values())).shape

In [None]:
np.stack(list(feats.values())).shape

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
train_acc, test_acc = train_test_split(filtered["accession-sequence"], train_size=0.7, random_state=49049)

In [None]:
train_feats = dict()
train_adjs = dict()
for acc in tqdm(train_acc):
    train_feats[acc] = feats[acc]
    train_adjs[acc] = adjs[acc]

In [None]:
test_feats = dict()
test_adjs = dict()
for acc in tqdm(test_acc):
    test_feats[acc] = feats[acc]
    test_adjs[acc] = adjs[acc]

## Code up models

In [None]:
from patch_gnn.layers import MessagePassing, GraphAverage, GraphSummation, CustomGraphEmbedding
from jax.experimental import stax

embedding_init_fun, embedding_apply_fun = CustomGraphEmbedding(256)

In [None]:
from patch_gnn.layers import LinearRegression

model_init_fun, model_apply_fun = stax.serial(
    CustomGraphEmbedding(1024),
    LinearRegression(1),
)

model2_init_fun, model2_apply_fun = stax.serial(
    CustomGraphEmbedding(1024),
    # One hidden layer
    stax.Dense(512),
    stax.Relu,
    stax.Dense(1),
)

In [None]:
train_Fs = np.stack(list(train_feats.values()))
train_As = np.stack(list(train_adjs.values()))

In [None]:
test_Fs = np.stack(list(test_feats.values()))
test_As = np.stack(list(test_adjs.values()))

In [None]:
train_Fs.shape  # (num_graphs, num_nodes, num_feats)

In [None]:
from jax.random import PRNGKey

In [None]:
from jax import vmap
from functools import partial

train_inputs = (train_As, train_Fs)
test_inputs = (test_As, test_Fs)

In [None]:
train_output = data.loc[train_acc.index]['ox_fwd_logit'].values.reshape(-1, 1)
test_output = data.loc[test_acc.index]['ox_fwd_logit'].values.reshape(-1, 1)


In [None]:
train_output.shape, test_output.shape

In [None]:
(*train_Fs[0].shape, train_As[0].shape[-1])

In [None]:
(*train_Fs[0].shape, train_As[0].shape[-1])

In [None]:
from patch_gnn.training import mseloss
from jax import grad
output_shape, params = model_init_fun(PRNGKey(42), input_shape=(*train_Fs[0].shape, train_As[0].shape[-1]))


dloss = grad(mseloss)
train_loss = mseloss(params, model_apply_fun, train_inputs, train_output)

test_loss = mseloss(params, model_apply_fun, test_inputs, test_output)

In [None]:
train_loss, test_loss

## Training loop

In [None]:
from jax.experimental.optimizers import adam, sgd
from patch_gnn.training import mseloss
from jax import grad

dmseloss = grad(mseloss)

In [None]:
import jax
from typing import Tuple
from jax import jit
from patch_gnn.training import step

init, update, get_params = adam(step_size=1e-5)
get_params = jit(get_params)

training_step = partial(step, loss_fun=mseloss, apply_fun=model_apply_fun, update_fun=update, get_params=get_params, inputs=train_inputs, outputs=train_output)
training_step = jit(training_step)

In [None]:
from tqdm.autonotebook import tqdm
state = init(params)
states = []
losses = []
for i in tqdm(range(100)):
    state, loss = training_step(i, state)
    states.append(state)
    losses.append(loss)
    

In [None]:
import matplotlib.pyplot as plt

plt.plot(losses)

In [None]:
params_final = get_params(state)

In [None]:
from sklearn.metrics import explained_variance_score as evs

In [None]:
train_preds = vmap(partial(model_apply_fun, params_final))(train_inputs)


import matplotlib.pyplot as plt


plt.scatter(train_preds.squeeze(), train_output.squeeze())
plt.plot([-4, 1], [-4, 1])
plt.title(evs(train_output.squeeze(), train_preds.squeeze()))
plt.xlabel("pred")
plt.ylabel("true")

In [None]:
mseloss(params_final, model_apply_fun, train_inputs, train_output)

In [None]:
params_final = get_params(state)
mseloss(params_final, model_apply_fun, test_inputs, test_output)

In [None]:
test_preds = vmap(partial(model_apply_fun, params_final))(test_inputs)
plt.scatter(test_preds.squeeze(), test_output.squeeze())
plt.plot([-4, 1], [-4, 1])
plt.title(evs(test_output.squeeze(), test_preds.squeeze(), ))
plt.xlabel("pred")
plt.ylabel("true")

## Try model2

Has a hidden NN layer in between graph embedding and output.

In [None]:
training2_step = partial(step, loss_fun=mseloss, apply_fun=model2_apply_fun, update_fun=update, get_params=get_params, inputs=train_inputs, outputs=train_output)
training2_step = jit(training2_step)

In [None]:
from tqdm.autonotebook import tqdm

output_shape, params = model2_init_fun(PRNGKey(42), input_shape=(*train_Fs[0].shape, train_As[0].shape[-1]))

init, update, get_params = adam(step_size=1e-5)
get_params = jit(get_params)
state = init(params)

states = []
losses = []
for i in tqdm(range(100)):
    state, loss = training2_step(i, state)
    losses.append(loss)
    states.append(state)


In [None]:
plt.plot(losses)
plt.yscale("log")

In [None]:
params_final = get_params(state)
train_preds = vmap(partial(model2_apply_fun, params_final))(train_inputs)

plt.scatter(train_preds.squeeze(), train_output.squeeze())
plt.plot([-4, 1], [-4, 1])
plt.title(evs(train_output.squeeze(), train_preds.squeeze()))
plt.xlabel("pred")
plt.ylabel("true")

In [None]:
test_preds = vmap(partial(model2_apply_fun, params_final))(test_inputs)
plt.scatter(test_preds.squeeze(), test_output.squeeze())
plt.plot([-4, 1], [-4, 1])
plt.title(evs(test_output.squeeze(), test_preds.squeeze()))
plt.xlabel("pred")
plt.ylabel("true")

### Benchmarking code


In [None]:
def acc2oh(accession):
    """Given accession-sequence pair, return sequence."""
    pass

def acc2unirep(accession):
    """Given accession-sequence pair, return UniRep."""
    pass

def acc2graph(accession):
    """Given accession-sequence pair, return F and A matrices."""
    pass


In [None]:
class OneHotRF:
    def __init__(self):
        pass 
    
    def fit(self, X, y):
        pass
    
    def predict(self, X):
        pass 


class UniRepRF:
    def __init__(self):
        pass
    def fit(self, X, y):
        pass
    def predict(self, X):
        pass


In [None]:
from patch_gnn.seqops import encoder, padding, one_hot

In [None]:
one_hot(data, 50).shape

In [None]:
from jax import random

key = random.PRNGKey(490)

a = np.arange(30)
random.permutation(key, a)

In [None]:
from patch_gnn.splitting import train_test_split

In [None]:
train_df, test_df = train_test_split(key, filtered)
train_df


In [None]:
test_df.shape

In [None]:
from patch_gnn.unirep import unirep_reps

unirep_reps(test_df)

In [None]:
one_hot(test_df, 50)

In [None]:
test_As.shape

In [None]:
test_tensors = graph_tensors(test_df, graphs)
test_tensors[0].shape, test_tensors[1].shape

In [None]:
train_tensors = graph_tensors(train_df, graphs)

In [None]:
test_target = test_df["ox_fwd_logit"].values
train_target = train_df["ox_fwd_logit"].values

In [None]:
from sklearn.ensemble import RandomForestRegressor
from functools import partial
from patch_gnn.training import step, mseloss



In [None]:
model = MPNN(node_feature_shape=(20, 61), num_adjacency=1, num_training_steps=200)
model.fit(train_tensors, train_target)

fig, ax = plt.subplots(figsize=(12, 6), nrows=1, ncols=2)
train_preds = model.predict(train_tensors)
ax[0].scatter(train_preds.squeeze(), train_df["ox_fwd_logit"].values)
ax[0].set_title(evs(train_df["ox_fwd_logit"].values, train_preds.squeeze()))

test_preds = model.predict(test_tensors)
ax[1].scatter(test_preds.squeeze(), test_df["ox_fwd_logit"].values)
ax[1].set_title(evs(test_df["ox_fwd_logit"].values, test_preds.squeeze()))

In [None]:
model = DeepMPNN(node_feature_shape=(20, 61), num_adjacency=1, num_training_steps=200)
model.fit(train_tensors, train_target)

fig, ax = plt.subplots(figsize=(12, 6), nrows=1, ncols=2)
train_preds = model.predict(train_tensors)
ax[0].scatter(train_preds.squeeze(), train_df["ox_fwd_logit"].values)
ax[0].set_title(evs(train_df["ox_fwd_logit"].values, train_preds.squeeze()))

test_preds = model.predict(test_tensors)
ax[1].scatter(test_preds.squeeze(), test_df["ox_fwd_logit"].values)
ax[1].set_title(evs(test_df["ox_fwd_logit"].values, test_preds.squeeze()))