In [1]:
from torch_geometric.loader import DataLoader
from rdkit import Chem
import torch.nn.functional as F
import torch
import pandas as pd
from math import sqrt
import sys
sys.path.insert(0, '/home/shenwanxiang/Research/bidd-clsar/')

In [9]:
from clsar.dataset import LSSNS
from clsar.feature import Gen39AtomFeatures
from clsar.model.model import ACANet_GCN, ACANet_GIN, ACANet_GAT, ACANet_PNA, get_deg  # model

In [7]:
for dataset_name in LSSNS.names.keys():
    print(dataset_name)

ido1
plk1
rip2
braf
usp7
phgdh
pkci
rxfp1
mglur2


In [11]:
dataset_name = 'phgdh'
Dataset = LSSNS  # MoleculeNet
epochs = 800
batch_size = 128
lr=10**-4

In [15]:
pre_transform = Gen39AtomFeatures()
in_channels = pre_transform.in_channels
path = '/tmp/data1'

# use the attentiveFP node and edge features during the mol-2-graph transoformation
#dataset = LSSInhibitor(path, name=dataset_name, pre_transform=GenAttentiveFeatures()).shuffle()
dataset = Dataset(path, name=dataset_name,
                  pre_transform=pre_transform).shuffle()

# train, valid, test splitting
N = len(dataset) // 5
val_dataset = dataset[:N]
test_dataset = dataset[N:2 * N]
train_dataset = dataset[2 * N:]


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ACANet_GCN(in_channels=in_channels, out_channels=1,
                    edge_dim=10).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr = lr,
                             weight_decay=10**-5)


def train():
    total_loss = total_examples = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out, embed = model(data.x.float(), data.edge_index,
                    data.edge_attr, data.batch)
        loss = F.mse_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
    return sqrt(total_loss / total_examples)


@torch.no_grad()
def test(loader):
    mse = []
    for data in loader:
        data = data.to(device)
        out, embed = model(data.x.float(), data.edge_index,
                    data.edge_attr, data.batch)
        mse.append(F.mse_loss(out, data.y, reduction='none').cpu())
    return float(torch.cat(mse, dim=0).mean().sqrt())


history1 = []
for epoch in range(1, epochs):
    train_rmse = train()
    val_rmse = test(val_loader)
    test_rmse = test(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '
          f'Test: {test_rmse:.4f}')

    history1.append({'Epoch': epoch, 'train_rmse': train_rmse,
                    'val_rmse': val_rmse, 'test_rmse': test_rmse})

#pd.DataFrame(history1).to_csv('./test/%s_%s.csv' % (dataset_name, in_channels))

Epoch: 001, Loss: 5.9247 Val: 6.8136 Test: 5.9953
Epoch: 002, Loss: 5.8854 Val: 6.7684 Test: 5.9551
Epoch: 003, Loss: 5.8477 Val: 6.7252 Test: 5.9177
Epoch: 004, Loss: 5.8136 Val: 6.6893 Test: 5.8865
Epoch: 005, Loss: 5.7839 Val: 6.6579 Test: 5.8596
Epoch: 006, Loss: 5.7583 Val: 6.6396 Test: 5.8415
Epoch: 007, Loss: 5.7402 Val: 6.6221 Test: 5.8252
Epoch: 008, Loss: 5.7232 Val: 6.6031 Test: 5.8076
Epoch: 009, Loss: 5.7046 Val: 6.5825 Test: 5.7879
Epoch: 010, Loss: 5.6844 Val: 6.5613 Test: 5.7677
Epoch: 011, Loss: 5.6637 Val: 6.5396 Test: 5.7471
Epoch: 012, Loss: 5.6424 Val: 6.5168 Test: 5.7253
Epoch: 013, Loss: 5.6203 Val: 6.4937 Test: 5.7032
Epoch: 014, Loss: 5.5979 Val: 6.4711 Test: 5.6813
Epoch: 015, Loss: 5.5759 Val: 6.4475 Test: 5.6585
Epoch: 016, Loss: 5.5531 Val: 6.4232 Test: 5.6348
Epoch: 017, Loss: 5.5291 Val: 6.3986 Test: 5.6099
Epoch: 018, Loss: 5.5038 Val: 6.3728 Test: 5.5835
Epoch: 019, Loss: 5.4765 Val: 6.3464 Test: 5.5562
Epoch: 020, Loss: 5.4483 Val: 6.3202 Test: 5.5289


In [16]:
from torchviz import make_dot
for data in train_loader:
    pass

model = ACANet_GCN(in_channels=in_channels, out_channels=1,
                    edge_dim=10, )
out = model(data.x.float(), data.edge_index, data.edge_attr, data.batch)
make_dot(out, params=dict(list(model.named_parameters()))
         ).render("./test/model_torchviz", format="png")

'test/model_torchviz.png'