In [1]:
%load_ext autoreload
%autoreload 2

from collections import defaultdict
from functools import partial

import numpy as np
import pandas as pd
import os
import sklearn.metrics as sk_metrics
import time
import torch
import torch.nn as nn
import torch_geometric
# import tqdm
from atom3d.datasets import LMDBDataset
from atom3d.splits.splits import split_randomly
from atom3d.util import metrics
from torch.nn.utils.rnn import pad_sequence
from types import SimpleNamespace

import sys
sys.path.append('../')
import gvp
import gvp.atom3d
from gvp import set_seed, Logger
from egnn import egnn_clean as eg

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, classification_report, roc_auc_score, roc_curve

# import atom3d.datasets.datasets as da
# dataset = da.load_dataset('../data/PPI/DIPS-split/data/', 'pdb')

In [2]:
device = 'cuda'
# data_path = '../data/paul_PPbind/test/'
data_path = '../data/paul_DB5_lmdb/'
# dataset = LMDBDataset(data_path + 'train', transform=gvp.atom3d.PPBindingTransform(plm=0, device='cpu'))
dataset = LMDBDataset(data_path, transform=gvp.atom3d.PPBindingTransform(plm=1, device=device))
trainset, valset, testset = split_randomly(dataset)

dl = torch_geometric.loader.DataLoader(dataset=dataset,
                                            num_workers=0,
                                            batch_size=4,
                                            shuffle=False)

2023-02-26 01:50:41,588 INFO 48254: Splitting dataset with 690 entries.
2023-02-26 01:50:41,589 INFO 48254: Size of the training set: 552
2023-02-26 01:50:41,589 INFO 48254: Size of the validation set: 69
2023-02-26 01:50:41,590 INFO 48254: Size of the test set: 69


In [3]:
model = eg.PPBindingModel(plm=1, device=device)
# d = torch.load('../models/PPBind_centered_coords_egnn_plm1_epoch3_val_loss_0.14324405546033672.chkpt', map_location=device)
d = torch.load('../models/PPBind_FS_frame_centered_coords_egnn_plm1_epoch1_val_loss_0.2007694764067651.chkpt', map_location=device)


d = {k.replace('module.', ''): v for k, v in d.items()}
model.load_state_dict(d)
model = model.eval().to(device)


In [4]:
batch = next(iter(dl))
preds = []
labels = []
subunits1 = []
subunits2 = []

for bnum, batch in enumerate(dl):
    preds.extend(list(torch.sigmoid(model(batch)).detach().cpu().numpy()))
    labels.extend(list(batch[0].label))
    subunits1.extend(list(batch[0].subunit))
    subunits2.extend(list(batch[1].subunit))


    if bnum == 100:
        break

preds = pd.DataFrame({'preds': preds, 'labels': labels, 'subunit1': subunits1, 'subunit2': subunits2})
preds['call'] = preds['preds'].apply(lambda x: 1 if x > 0.5 else 0)
preds['prot1'] = preds['subunit1'].apply(lambda x: x.split('_')[0])
preds['prot2'] = preds['subunit2'].apply(lambda x: x.split('_')[0])
call_counts = preds.groupby(['prot1']).agg({'call': 'sum'}).reset_index().sort_values('call', ascending=False)


# print(classification_report(preds['labels'], preds['call']))
# print(roc_auc_score(preds['labels'], preds['preds']))

In [5]:
preds.groupby(['labels']).agg({'preds': 'mean', 'call': 'sum'}).reset_index()

Unnamed: 0,labels,preds,call
0,bound,0.395077,53
1,random,0.456879,64
2,unbound,0.384142,54


In [15]:
preds[preds['labels'] == 'random']['call'].value_counts()

0    166
1     64
Name: call, dtype: int64

In [33]:
preds[preds['call'] == 0]['preds'].drop_duplicates().mean()

0.09359451

In [29]:
singles = call_counts[call_counts['call'] == 1]['prot1'].values
preds[preds['prot1'].isin(singles)]

Unnamed: 0,preds,labels,subunit1,subunit2,call,prot1,prot2
20,0.64842,1,1XD3_l_b_cleaned.pdb,1XD3_r_b_cleaned.pdb,1,1XD3,1XD3
21,0.423775,0,1XD3_l_u_cleaned.pdb,1XD3_r_u_cleaned.pdb,0,1XD3,1XD3
52,0.496018,1,1H9D_l_b_cleaned.pdb,1H9D_r_b_cleaned.pdb,0,1H9D,1H9D
53,0.674025,0,1H9D_l_u_cleaned.pdb,1H9D_r_u_cleaned.pdb,1,1H9D,1H9D
60,0.305477,1,2HLE_l_b_cleaned.pdb,2HLE_r_b_cleaned.pdb,0,2HLE,2HLE
61,0.670287,0,2HLE_l_u_cleaned.pdb,2HLE_r_u_cleaned.pdb,1,2HLE,2HLE
140,0.349596,1,2AYO_l_b_cleaned.pdb,2AYO_r_b_cleaned.pdb,0,2AYO,2AYO
141,0.553598,0,2AYO_l_u_cleaned.pdb,2AYO_r_u_cleaned.pdb,1,2AYO,2AYO
158,0.441152,1,1DQJ_l_b_cleaned.pdb,1DQJ_r_b_cleaned.pdb,0,1DQJ,1DQJ
159,0.508031,0,1DQJ_l_u_cleaned.pdb,1DQJ_r_u_cleaned.pdb,1,1DQJ,1DQJ


Unnamed: 0,preds,labels,subunit1,subunit2,call,prot1,prot2
0,0.034413,1,1SBB_l_b_cleaned.pdb,1SBB_r_b_cleaned.pdb,0,1SBB,1SBB
1,0.037818,0,1SBB_l_u_cleaned.pdb,1SBB_r_u_cleaned.pdb,0,1SBB,1SBB
2,0.782478,1,1JPS_l_b_cleaned.pdb,1JPS_r_b_cleaned.pdb,1,1JPS,1JPS
3,0.554409,0,1JPS_l_u_cleaned.pdb,1JPS_r_u_cleaned.pdb,1,1JPS,1JPS
4,0.038708,1,1EXB_l_b_cleaned.pdb,1EXB_r_b_cleaned.pdb,0,1EXB,1EXB
...,...,...,...,...,...,...,...
455,0.094366,0,3MXW_l_u_cleaned.pdb,3MXW_r_u_cleaned.pdb,0,3MXW,3MXW
456,0.030631,1,BAAD_l_b_cleaned.pdb,BAAD_r_b_cleaned.pdb,0,BAAD,BAAD
457,0.021579,0,BAAD_l_u_cleaned.pdb,BAAD_r_u_cleaned.pdb,0,BAAD,BAAD
458,0.004770,1,2YVJ_l_b_cleaned.pdb,2YVJ_r_b_cleaned.pdb,0,2YVJ,2YVJ


In [5]:
batch[0]

DataBatch(x=[522, 3], edge_index=[2, 5090], atoms=[522], edge_s=[5090, 16], edge_v=[5090, 1, 3], label=[4], plm=[522, 1280], batch=[522], ptr=[5])

In [51]:
print(roc_auc_score(preds['labels'], preds['preds']))

0.49357277882797734


In [45]:
predsfrom sklearn.metrics import roc_auc_score

Unnamed: 0,preds,labels,call
0,0.034413,1,0
1,0.037818,0,0
2,0.782478,1,1
3,0.554409,0,1
4,0.038708,1,0
...,...,...,...
455,0.094366,0,0
456,0.030631,1,0
457,0.021579,0,0
458,0.004770,1,0


In [6]:
# summarize the fit of the model, print the confusion matrix



              precision    recall  f1-score   support

           0       0.95      0.95      0.95       393
           1       0.95      0.95      0.95       415

    accuracy                           0.95       808
   macro avg       0.95      0.95      0.95       808
weighted avg       0.95      0.95      0.95       808



In [17]:

p1 = batch[1]
sample_idx = 0
chain1 = p1.x[p1.ptr[sample_idx]:p1.ptr[sample_idx+1]].detach().cpu().numpy()
chain1.mean(axis=0)

array([ 7.0540375e-07, -9.0638468e-07,  4.9654119e-07], dtype=float32)

In [8]:
preds[preds['labels'] == preds['call']].shape[0] / preds.shape[0]

0.9492574257425742

In [55]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.6, patience=5, min_lr=5e-7)
losses = []
for bnum, batch in enumerate(dl):
    out = model(batch)
    labels = batch[0].label.float()

    # Calculate loss for binary labels
    loss = nn.BCEWithLogitsLoss()(out, labels)
    losses.append(loss.item())
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


    # for i in range(len(batch)):
    #     out.append(model(batch[i]))

    if bnum % 10 == 0:
        print(bnum, np.mean(losses[-10:]))
    if bnum == 1000:
        break
    

0 0.717462420463562
10 0.6931776583194733
20 0.69017853140831
30 0.6965782046318054
40 0.6903312802314758
50 0.6791554868221283
60 0.6595579624176026
70 0.6180838644504547
80 0.6023252129554748
90 0.5775288462638855
100 0.587931090593338
110 0.5414257317781448
120 0.4844168394804001
130 0.5293464481830596
140 0.5114886581897735
150 0.46706070601940153
160 0.43244494795799254
170 0.4818073779344559
180 0.5525013148784638
190 0.4985406190156937
200 0.4412555664777756
210 0.4781623274087906
220 0.4804361343383789
230 0.415347820520401
240 0.37652755379676817
250 0.3916472226381302
260 0.4876827001571655
270 0.44699072241783144
280 0.42157927006483076
290 0.45795753300189973
300 0.40810640454292296
310 0.3996733009815216
320 0.42279467433691026
330 0.4264607042074203
340 0.4135534703731537
350 0.41993205845355985
360 0.4567837119102478
370 0.37821575701236726
380 0.4016067236661911


In [45]:
loss

tensor(0.7317, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [39]:
# cast out as torch float


out.squeeze().shape


torch.Size([8])

In [31]:
out

tensor([[0.5231],
        [0.5392],
        [0.4919]], grad_fn=<SigmoidBackward0>)

In [11]:
batch

[DataBatch(x=[937, 3], edge_index=[2, 9072], atoms=[937], edge_s=[9072, 16], edge_v=[9072, 1, 3], label=[3], batch=[937], ptr=[4]),
 DataBatch(x=[836, 3], edge_index=[2, 8454], atoms=[836], edge_s=[8454, 16], edge_v=[8454, 1, 3], label=[3], batch=[836], ptr=[4])]

In [17]:
batch = next(iter(dl))
batch

[DataBatch(x=[669, 3], edge_index=[2, 6368], atoms=[669], edge_s=[6368, 16], edge_v=[6368, 1, 3], label=[669], batch=[669], ptr=[4]),
 DataBatch(x=[531, 3], edge_index=[2, 5080], atoms=[531], edge_s=[5080, 16], edge_v=[5080, 1, 3], label=[531], batch=[531], ptr=[4])]

In [9]:
batch[0]

DataBatch(x=[378, 3], edge_index=[2, 3550], atoms=[378], edge_s=[3550, 16], edge_v=[3550, 1, 3], label=[378], batch=[378], ptr=[3])

In [None]:
t = trainset[105]
output = model(t).detach().numpy()
cutoff = 0.3

pred_1 = (output > cutoff).astype(int)[:t[0].x.shape[0]]
pred_2 = (output > cutoff).astype(int)[t[0].x.shape[0]:]

label_1 = t[0].label.numpy().astype(int)
label_2 = t[1].label.numpy().astype(int)

pred = pd.concat([pd.DataFrame({'pred': pred_1, 'label': label_1, 'protein': 1}),
                 pd.DataFrame({'pred': pred_2, 'label': label_2, 'protein': 2})])

precision, recall, f1, support = precision_recall_fscore_support(pred['label'], pred['pred'], average='binary')
summary = pred.groupby(['label', 'pred'])['protein'].count().reset_index()
summary['percent'] = summary['protein'] / summary.groupby('label')['protein'].transform('sum')

print(f'Precision: {precision:.3f}, Recall: {recall:.3f}, F1: {f1:.3f}, Support: {support}')
print(summary)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm)

In [None]:
batch = next(iter(dl))

[DataBatch(x=[1512, 3], edge_index=[2, 14058], atoms=[1512], edge_s=[14058, 16], edge_v=[14058, 1, 3], label=[1512], plm=[1512, 1280], batch=[1512], ptr=[9]),
 DataBatch(x=[1433, 3], edge_index=[2, 13486], atoms=[1433], edge_s=[13486, 16], edge_v=[13486, 1, 3], label=[1433], plm=[1433, 1280], batch=[1433], ptr=[9])]

In [None]:
x = trainset[110]
y = trainset[2000]
mix = (x[0], y[0])

output = model(x).detach().numpy()
cutoff = 0.3

pred_1 = (output > cutoff).astype(int)[:x[0].x.shape[0]]
pred_2 = (output > cutoff).astype(int)[x[0].x.shape[0]:]

label_1 = x[0].label.numpy().astype(int)
label_2 = x[1].label.numpy().astype(int)

x = pd.concat([pd.DataFrame({'pred': pred_1, 'label': label_1, 'protein': 1}),
                 pd.DataFrame({'pred': pred_2, 'label': label_2, 'protein': 2})])


output = model(mix).detach().numpy()
cutoff = 0.3

pred_1 = (output > cutoff).astype(int)[:mix[0].x.shape[0]]
pred_2 = (output > cutoff).astype(int)[mix[0].x.shape[0]:]

label_1 = mix[0].label.numpy().astype(int)
label_2 = mix[1].label.numpy().astype(int)

mix = pd.concat([pd.DataFrame({'pred': pred_1, 'label': label_1, 'protein': 1}),
                 pd.DataFrame({'pred': pred_2, 'label': label_2, 'protein': 2})])

In [None]:
(x[x['protein'] == 1]['pred'] == mix[mix['protein'] == 1]['pred']).all()

True

In [None]:
def hook_fn(module, input, output):
    hook_fn.output = output

model.dense[3].register_forward_hook(hook_fn)

output = model(t)
second_to_last_output = hook_fn.output

In [None]:
out = model(t)

In [None]:
second_to_last_output.shape

torch.Size([112, 128])

In [None]:
t.atoms

tensor([11, 15, 19, 13, 19,  7,  5, 10, 16, 17, 11,  5, 18,  5,  0,  1, 19,  0,
         0,  7,  3,  4, 19, 10, 12, 10, 14, 19,  7,  0, 10,  5,  6,  8,  7,  8,
         8, 12,  4, 12,  2, 19,  3, 19, 10, 10, 14, 16,  0, 19,  4, 11,  1, 19,
         0,  5,  1,  9,  7,  0, 10, 19, 12, 14,  7, 10,  6, 18,  7, 18, 11, 15,
         6,  6, 11, 15,  7,  7,  7,  2,  8, 13, 14,  7, 16, 16, 15, 10,  3,  7,
         0, 16, 10, 16,  7, 16, 19,  6,  3,  9,  9,  1,  5, 10,  0,  1,  8,  7,
         0,  1,  1, 10, 19, 10, 12,  2,  7,  8, 18,  6,  2, 15, 12, 13,  9, 19,
         5,  7,  9,  3, 10,  0, 10,  1,  5, 10,  1, 18,  0,  7,  9,  6,  3, 13,
        11, 19, 19, 19, 10, 15, 18, 17,  3, 13, 19, 11,  3, 14,  0, 19,  9,  6,
         6, 10, 18, 14,  5,  7, 13, 10,  7, 17,  3,  9,  5,  8,  7,  7, 19, 13,
         5, 16, 15, 10, 12, 10,  0, 10, 18, 14,  3, 10, 19,  3, 10,  3,  1, 19,
        19,  3,  8, 14, 14,  0, 16, 13, 14, 14, 18,  3, 19, 13, 14, 19,  3, 14,
         0,  1, 16, 14,  0, 14,  7, 16, 

In [None]:
t['edge_s']

tensor([[5.9416e-26, 2.1925e-19, 8.3118e-14,  ..., 1.1505e-17, 6.0327e-24,
         3.2499e-31],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 8.6126e-01, 6.2961e-01,
         4.7287e-02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 8.0241e-02, 7.6177e-01,
         7.4300e-01],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.9982e-02, 5.2217e-01,
         9.3435e-01],
        [3.7651e-26, 1.4808e-19, 5.9831e-14,  ..., 1.6693e-17, 9.3289e-24,
         5.3565e-31],
        [0.0000e+00, 0.0000e+00, 1.4013e-45,  ..., 8.7249e-02, 9.9931e-04,
         1.1759e-06]])

In [None]:
egnn


EGNN(
  (embed): Embedding(22, 32)
  (gcl_0): E_GCL(
    (edge_mlp): Sequential(
      (0): Linear(in_features=81, out_features=32, bias=True)
      (1): SiLU()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): SiLU()
    )
    (node_mlp): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): SiLU()
      (2): Linear(in_features=32, out_features=32, bias=True)
    )
    (coord_mlp): Sequential(
      (0): Linear(in_features=32, out_features=32, bias=True)
      (1): SiLU()
      (2): Linear(in_features=32, out_features=1, bias=False)
    )
  )
  (gcl_1): E_GCL(
    (edge_mlp): Sequential(
      (0): Linear(in_features=81, out_features=32, bias=True)
      (1): SiLU()
      (2): Linear(in_features=32, out_features=32, bias=True)
      (3): SiLU()
    )
    (node_mlp): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1): SiLU()
      (2): Linear(in_features=32, out_features=32, bias=True)
    )
    (coo

In [None]:
318**2

101124

In [None]:
layer = nn.Linear(in_features=8, out_features=1, bias=True)
x = torch.randn(10, 8)

In [None]:
for p in layer.parameters():
    print(p)

Parameter containing:
tensor([[-0.2156,  0.2200,  0.2242,  0.0592,  0.2251, -0.2823,  0.0648,  0.0478]],
       requires_grad=True)
Parameter containing:
tensor([0.2128], requires_grad=True)


In [None]:
x

tensor([[ 1.4885, -0.6366,  0.1303, -0.2946,  0.9986,  0.2229,  0.1274,  0.6775],
        [ 0.1450, -0.8104, -2.1237, -0.2748, -0.8929, -0.2312, -2.8565, -0.5362],
        [ 0.9685,  0.1583, -0.3575,  1.1180,  0.0391,  1.5113, -1.5070, -0.7822],
        [ 0.5667, -0.3549,  0.6392, -0.0496, -0.7520,  0.5488, -0.2220, -0.2424],
        [ 0.1282,  0.6665, -0.3754,  0.9832, -1.1055,  1.5733, -0.4250,  0.5120],
        [-0.1649, -1.4688, -0.9590,  0.7493, -2.1332, -0.8609, -1.3448,  0.6251],
        [ 2.2488, -0.2345, -0.1216, -1.0074,  0.4047,  0.3989, -0.2835, -0.1022],
        [-1.2178, -0.3435, -0.0983, -1.4357,  0.7948, -0.6118,  1.0365,  0.6706],
        [ 0.6781, -0.0499, -1.5981,  0.6171,  1.2137,  0.5906, -1.4453, -1.8942],
        [-0.4125,  0.2209, -0.4835,  0.5741,  2.2615, -0.8910,  1.7489,  0.4334]])

In [None]:
layer(x)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x9 and 8x1)