In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from patch_gnn.data import load_ghesquire
import pandas as pd
from pyprojroot import here
import pickle as pkl
from patch_gnn.splitting import train_test_split
from jax import random
from patch_gnn.seqops import one_hot
from patch_gnn.unirep import unirep_reps
from patch_gnn.graph import graph_tensors
from patch_gnn.models import MPNN, DeepMPNN
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import explained_variance_score as evs
import matplotlib.pyplot as plt 
from sklearn.metrics import mean_squared_error as mse
import pickle as pkl
from patch_gnn.graph import met_position
import seaborn as sns



In [3]:
data = load_ghesquire()

graph_pickle_path = here() / "data/ghesquire_2011/graphs.pkl"

with open(graph_pickle_path, "rb") as f:
    graphs = pkl.load(f)

key = random.PRNGKey(490)

In [36]:
graphs['P40121-VSDATGQMNLTK'].nodes(data=True)

NodeDataView({'261MET': {'chain_id': '', 'residue_number': 261, 'residue_name': 'MET', 'x_coord': 11.016, 'y_coord': 41.385, 'z_coord': 83.63, 'features': None, 'log_Phob/A^2': 3.1166215908294443, 'log_Phil/A^2': 2.3243465847755584, 'log_SASA/A^2': 3.4901235908565567, 'log_N(overl)': 6.35088571671474}, '262ASN': {'chain_id': '', 'residue_number': 262, 'residue_name': 'ASN', 'x_coord': 13.867, 'y_coord': 42.722, 'z_coord': 81.462, 'features': None, 'log_Phob/A^2': 2.917770732084279, 'log_Phil/A^2': 3.733135545368474, 'log_SASA/A^2': 4.099497927462895, 'log_N(overl)': 6.030685260261263}, '312VAL': {'chain_id': '', 'residue_number': 312, 'residue_name': 'VAL', 'x_coord': 8.135, 'y_coord': 35.62, 'z_coord': 81.989, 'features': None, 'log_Phob/A^2': 2.9317269435780786, 'log_Phil/A^2': 1.0784095813505903, 'log_SASA/A^2': 3.077312260546414, 'log_N(overl)': 6.159095388491933}, '316PHE': {'chain_id': '', 'residue_number': 316, 'residue_name': 'PHE', 'x_coord': 8.551, 'y_coord': 37.946, 'z_coord

In [37]:
print(graphs['P40121-VSDATGQMNLTK']['261MET'])

{'260GLN': {'kind': ['backbone']}, '262ASN': {'kind': ['backbone']}, '312VAL': {'kind': ['hydrophobic']}, '257ALA': {'kind': ['hydrophobic']}, '316PHE': {'kind': ['hydrophobic']}, '263LEU': {'kind': ['hydrophobic']}, '254VAL': {'kind': ['hydrophobic']}}


In [5]:
# eric has cleaned the data to matcch the structured model to the sequence, to include the Met that was captured by the 3D structures
filtered = (
    data
    .query("`accession-sequence` in @graphs.keys()")
    .query("ox_fwd_logit < 2.0") # very few cases are with higher values of 2.0
    .join_apply(met_position, "met_position")
)

In [17]:
data.head

<bound method NDFrame.head of      accession                                       Description  \
0       A0AVT1      Ubiquitin-like modifier-activating enzyme 6    
1       A0AVT1                                               NaN   
2       A2RRP1             Neuroblastoma-amplified gene protein    
3       A2RRP1                                               NaN   
4       A3KN83               Protein strawberry notch homolog 1    
...        ...                                               ...   
2621    Q9Y6G9  Cytoplasmic dynein 1 light intermediate chain 1    
2622    Q9Y6I3                                          Epsin-1    
2623    Q9Y6I3                                               NaN   
2624    Q9Y6J0              Calcineurin-binding protein cabin-1    
2625    Q9Y6X4                          UPF0611 protein FAM169A    

               sequence isoforms     end  score  threshold       m/z    z  \
0         GMITVTDPDLIEK      NaN   503.0   79.0       32.0  732.3690  2.0   

In [6]:
train_df, test_df = train_test_split(key, filtered) # 70% training, 30% testing

In [7]:
(train_df.shape), (test_df.shape) #369 total raw data points

((258, 18), (111, 18))

In [8]:
train_oh = one_hot(train_df, 50) #pad sequence to 50 length and one hot encode it
test_oh = one_hot(test_df, 50)
print(train_oh.shape, test_oh.shape)

train_unirep = unirep_reps(train_df) # unirep representation of the sequence
test_unirep = unirep_reps(test_df)
print(train_unirep.shape, test_unirep.shape)

train_graph = graph_tensors(train_df, graphs) # use graph + sasa features
test_graph = graph_tensors(test_df, graphs)


(258, 1050) (111, 1050)
(258, 1900) (111, 1900)


In [9]:
type(train_graph)


tuple

In [None]:
train_graph

In [10]:
with open(here() / "data/ghesquire_2011/sasa.pkl", "rb") as f:
    sasa_dfs = pkl.load(f)

In [16]:
len(sasa_dfs)
sasa_dfs["O15305"] #sasa_dfs are associated with a unipot ID, there are 810 structures, for input for machine learning is the uniprot id + sequence, and then look up a particular met position


Unnamed: 0,ResidNe,Chain,ResidNr,iCode,Phob/A^2,Phil/A^2,SASA/A^2,Q(SASA),N(overl),Surf/A^2
0,PRO,-,4,-,99.79,7.32,107.12,0.6476,348,165.4
1,GLY,-,5,-,37.79,27.27,65.06,0.5368,192,121.2
2,PRO,-,6,-,96.03,4.12,100.15,0.6055,400,165.4
3,ALA,-,7,-,10.52,6.56,17.09,0.1184,452,144.3
4,LEU,-,8,-,7.74,2.43,10.17,0.0647,898,157.1
...,...,...,...,...,...,...,...,...,...,...
238,GLU,-,242,-,33.31,36.27,69.59,0.3455,479,201.4
239,LEU,-,243,-,112.63,16.03,128.66,0.8190,328,157.1
240,LEU,-,244,-,19.86,13.01,32.87,0.2092,440,157.1
241,PHE,-,245,-,27.41,7.35,34.76,0.2029,588,171.3


In [18]:
def linear_model_data(df, sasa_dfs):
    linear = []
    for acc, pos in zip(df["accession"], df["met_position"]):
        feats = sasa_dfs[acc].query("ResidNr == @pos")[["SASA/A^2", "N(overl)"]] #using only two features from sasa_dfs, want to compare with older models
        linear.append(feats)

    return pd.concat(linear)

train_linear = linear_model_data(train_df, sasa_dfs)
test_linear = linear_model_data(test_df, sasa_dfs)

In [19]:
train_target = train_df['ox_fwd_logit'].values
test_target = test_df['ox_fwd_logit'].values

In [20]:
len(filtered)

369

In [21]:
num_training_steps = 5000

models = {
    "mpnn": MPNN(
        node_feature_shape=(20, 65),# (num_nodes is number of unique animo acid, num_feats from data/amino_acid_properties.csv + 4 sasa_dfs features )
        num_adjacency=1,
        num_training_steps=num_training_steps
    ),
    "deep_mpnn": DeepMPNN(
        node_feature_shape=(20, 65),
        num_adjacency=1,
        num_training_steps=num_training_steps
    ),
    "rf_oh": RandomForestRegressor(n_estimators=300), #300 is set as default
    "rf_unirep": RandomForestRegressor(n_estimators=300),
}

In [38]:
from sklearn.linear_model import LinearRegression

In [39]:
model_linear = LinearRegression()
model_linear.fit(train_linear, train_target)

LinearRegression()

In [40]:
model_mpnn = MPNN(
    node_feature_shape=(20, 65),
    num_adjacency=1,
    num_training_steps=num_training_steps
)
model_mpnn.fit(train_graph, train_target)

  0%|          | 0/5000 [00:00<?, ?it/s]

<patch_gnn.models.MPNN at 0x2aacfe091040>

In [None]:
train_graph

In [41]:
model_deepmpnn = DeepMPNN(
    node_feature_shape=(20, 65),
    num_adjacency=1,
    num_training_steps=num_training_steps
)
model_deepmpnn.fit(train_graph, train_target)

  0%|          | 0/5000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
model_rfoh = RandomForestRegressor(oob_score=True, n_jobs=-1)
model_rfoh.fit(train_oh, train_target)

model_rf_unirep = RandomForestRegressor(oob_score=True, n_jobs=-1)
model_rf_unirep.fit(train_unirep, train_target)

In [None]:
def plot_y_eq_x(ax):
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()

    minval = min(xmin, ymin)
    maxval = max(xmax, ymax)

    ax.plot([minval, maxval], [minval, maxval])

def plot_performance(
    model,
    trainX, trainY, testX, testY,
    model_name: str,
    ev_func,
    checkpoint: int = None,
):
    fig, ax = plt.subplots(
        figsize=(10, 5), nrows=1, ncols=2, sharex=True, sharey=True,
    )

    if isinstance(model, MPNN):
        train_preds = model.predict(trainX, checkpoint=checkpoint)
    else:
        train_preds = model.predict(trainX)
    ax[0].scatter(trainY, train_preds.squeeze())
    ax[0].set_title(f"Model: {model_name}, Training Perf: {ev_func(trainY, train_preds.squeeze()):.3f}")
    plot_y_eq_x(ax[0])

    if isinstance(model, MPNN):
        test_preds = model.predict(testX, checkpoint=checkpoint)
    else:
        test_preds = model.predict(testX)
    ax[1].scatter(testY, test_preds.squeeze())
    ax[1].set_title(f"Model: {model_name}, Testing Perf: {ev_func(testY, test_preds.squeeze()):.3f}")
    plot_y_eq_x(ax[1])

## Baseline Model Performance

In [None]:
plot_performance(model_linear, train_linear, train_target, test_linear, test_target, 'Linear, evs', evs)


In [None]:
plot_performance(model_rfoh, train_oh, train_target, test_oh, test_target, 'One Hot rf, evs', evs)
plot_performance(model_rf_unirep, train_unirep, train_target, test_unirep, test_target, 'Unirep rf, evs', evs)

## Which checkpoint for NN models?

In [None]:
plt.plot(model_mpnn.loss_history)
plt.yscale("log")

In [None]:
plt.plot(model_deepmpnn.loss_history)
plt.yscale("log")

## Checkpoint at 400 steps

In [None]:
checkpoint = 400
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


## Checkpoint at 600 steps

In [None]:
checkpoint = 600
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


In [None]:
checkpoint = 1000
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, f'MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, f'Deep MPNN, evs, {checkpoint}', evs, checkpoint=checkpoint)


## Checkpoint at Final steps

In [None]:
plot_performance(model_mpnn, train_graph, train_target, test_graph, test_target, 'MPNN, evs', evs, checkpoint=-1)
plot_performance(model_deepmpnn, train_graph, train_target, test_graph, test_target, 'Deep MPNN, evs', evs, checkpoint=-1)

## plot descriptor, target relation scatter plot

In [None]:
descriptor = pd.concat([train_linear.reset_index(drop=True), test_linear.reset_index(drop=True)], axis=0)
target = pd.concat([pd.Series(train_target), pd.Series(test_target)], axis=0)
train_test = pd.Series(["train"]*train_linear.shape[0] + ['test']* test_linear.shape[0])
scatter_df = pd.DataFrame({'fluc': descriptor['fluc'].tolist(), 'target': target.tolist(), 'train_test': train_test.tolist()})
scatter_df.shape

In [None]:
scatter_df.head()

In [None]:
sns.relplot(x="fluc", y="target",hue="train_test", data=scatter_df)
plt.title('ANM fluc vs. target in train and test set')

In [None]:
np.where(scatter_df['fluc']==100000000)

In [None]:
scatter_df.iloc[[ 56,  71,  90, 142, 190, 207, 265, 279, 337, 341],]

In [None]:
# next we will try dropping them and see the performance of the model