In [3]:
from collections import defaultdict
from functools import partial

import numpy as np
import os
import sklearn.metrics as sk_metrics
import time
import torch
import torch.nn as nn
import torch_geometric
# import tqdm
from atom3d.datasets import LMDBDataset
from atom3d.splits.splits import split_randomly
from atom3d.util import metrics
from torch.nn.utils.rnn import pad_sequence
from types import SimpleNamespace

sys.path.append('../')
import gvp
import gvp.atom3d
from gvp import set_seed, Logger
from egnn import egnn_clean as eg

In [4]:
def get_datasets(task, gnn_model, lba_split=30, device='cpu'):
    data_path = {'RES': 'atom3d-data/RES/raw/RES/data/', 'PPI': 'data/PPI/DIPS-split/data/', 'PSR': 'data/PSR/split-by-year/data/',
                 'MSP': 'atom3d-data/MSP/splits/split-by-sequence-identity-30/data/', 'LEP': 'atom3d-data/LEP/splits/split-by-protein/data/',
                 'LBA': f'data/LBA/split-by-sequence-identity-{lba_split}/data/', 'TOY': 'data/TOY/split-by-cath-topology/data/'}[task]      # TOY use the test dataset of RES

    if task == 'RES':
        split_path = 'atom3d-data/RES/splits/split-by-cath-topology/indices/'
        dataset = partial(gvp.atom3d.RESDataset, data_path)
        trainset = dataset(split_path=split_path + 'train_indices.txt')
        valset = dataset(split_path=split_path + 'val_indices.txt')
        testset = dataset(split_path=split_path + 'test_indices.txt')
    elif task == 'PPI':
        if args.model == 'molformer':
            train_dataset, val_dataset, test_dataset = split_randomly(LMDBDataset(data_path + 'test'))
            trainset = gvp.atom3d.PPIDataset(train_dataset, plm=args.plm)
            valset = gvp.atom3d.PPIDataset(val_dataset, plm=args.plm)
            testset = gvp.atom3d.PPIDataset(test_dataset, plm=args.plm)
        else:
            dataset = LMDBDataset(data_path + 'test', transform=gvp.atom3d.PPITransform(plm=args.plm, device=device))
            trainset, valset, testset = split_randomly(dataset)
    elif task == 'TOY':
        train_dataset, val_dataset, test_dataset = split_randomly(LMDBDataset(data_path + 'test'))
        if args.model == 'molformer':
            trainset = gvp.atom3d.TOYDataset2(train_dataset, label=args.toy)
            valset = gvp.atom3d.TOYDataset2(val_dataset, label=args.toy)
            testset = gvp.atom3d.TOYDataset2(test_dataset, label=args.toy)
        else:
            trainset = gvp.atom3d.TOYDataset(train_dataset, label=args.toy, connection=args.connect)
            valset = gvp.atom3d.TOYDataset(val_dataset, label=args.toy, connection=args.connect)
            testset = gvp.atom3d.TOYDataset(test_dataset, label=args.toy, connection=args.connect)
    else:
        if task == 'PSR':
            if args.model == 'molformer':
                trainset = gvp.atom3d.PSRDataset(LMDBDataset(data_path + 'train'), plm=args.plm)
                valset = gvp.atom3d.PSRDataset(LMDBDataset(data_path + 'val'), plm=args.plm)
                testset = gvp.atom3d.PSRDataset(LMDBDataset(data_path + 'test'), plm=args.plm)
                return trainset, valset, testset
            transform = gvp.atom3d.PSRTransform(plm=args.plm)
        elif task == 'LBA':
            if args.model == 'molformer':
                trainset = gvp.atom3d.LBADataset(LMDBDataset(data_path + 'train'), plm=args.plm)
                valset = gvp.atom3d.LBADataset(LMDBDataset(data_path + 'val'), plm=args.plm)
                testset = gvp.atom3d.LBADataset(LMDBDataset(data_path + 'test'), plm=args.plm)
                return trainset, valset, testset
            transform = gvp.atom3d.LBATransform(plm=args.plm)
        else:
            transform = {'MSP': gvp.atom3d.MSPTransform, 'LEP': gvp.atom3d.LEPTransform}[task]()
        trainset = LMDBDataset(data_path + 'train', transform=transform)
        valset = LMDBDataset(data_path + 'val', transform=transform)
        testset = LMDBDataset(data_path + 'test', transform=transform)
        print(len(trainset), len(valset), len(testset))
    return trainset, valset, testset

In [8]:
data_path = {'RES': 'atom3d-data/RES/raw/RES/data/', 'PPI': '../data/PPI/DIPS-split/data/', 'PSR': 'data/PSR/split-by-year/data/',
                'MSP': 'atom3d-data/MSP/splits/split-by-sequence-identity-30/data/', 'LEP': 'atom3d-data/LEP/splits/split-by-protein/data/',
                'LBA': f'data/LBA/split-by-sequence-identity-/data/', 'TOY': 'data/TOY/split-by-cath-topology/data/'} 
data_path = data_path['PPI']
dataset = LMDBDataset(data_path + 'test', transform=gvp.atom3d.PPITransform(plm=1, device='cpu'))
trainset, valset, testset = split_randomly(dataset)

model = gvp.atom3d.PPIModel(plm=1)

2023-01-28 16:06:11,361 INFO 32059: Splitting dataset with 15268 entries.
2023-01-28 16:06:11,362 INFO 32059: Size of the training set: 12216
2023-01-28 16:06:11,363 INFO 32059: Size of the validation set: 1526
2023-01-28 16:06:11,363 INFO 32059: Size of the test set: 1526


In [21]:
egnn = gvp.atom3d.BaseModel(plm=1)

In [18]:
t = trainset[1]
t

(Data(x=[313, 3], edge_index=[2, 3228], atoms=[313], edge_s=[3228, 16], edge_v=[3228, 1, 3], label=[313], plm=[313, 1280]),
 Data(x=[313, 3], edge_index=[2, 3192], atoms=[313], edge_s=[3192, 16], edge_v=[3192, 1, 3], label=[313], plm=[313, 1280]))

In [23]:
t[0]

Data(x=[313, 3], edge_index=[2, 3228], atoms=[313], edge_s=[3228, 16], edge_v=[3228, 1, 3], label=[313], plm=[313, 1280])

In [22]:
egnn(t[0])

AttributeError: 'NoneType' object has no attribute 'dim'

In [19]:
model(t).shape

torch.Size([626])