In [1]:
import os
import numpy as np
import torch
import torch.nn as nn

from dgllife.model import load_pretrained
from dgllife.utils import EarlyStopping, Meter, SMILESToBigraph
from torch.optim import Adam
from torch.utils.data import DataLoader

from utils import collate_molgraphs, load_model, predict
from utils import init_featurizer, mkdir_p, split_dataset, get_configure

源码作者将输入参数存于 args 字典，并使用 init_featurizer 函数对参数进行解析。
本例中仅仔细研究 GCN & canonical 方法的模型，其它模型可类比。

In [2]:
from dgllife.utils import CanonicalAtomFeaturizer

In [3]:
args = {
    'dataset': 'BBBP',
    'model': 'GCN',
    'featurizer_type': 'canonical',
    'pretrain': False,
    'split': 'scaffold',
    'split_ratio': '0.8,0.1,0.1',
    'metric': 'roc_auc_score',
    'num_epochs': 1000,
    'num_workers': 0,
    'print_every': 10,
    'result_path': 'classification_results',
    'device': torch.device('cpu'),
    # add by init_featurizer
    'node_featurizer': CanonicalAtomFeaturizer(),
    'edge_featurizer': None,
}

In [4]:
# make dir
if not os.path.exists(args['result_path']):
    os.makedirs(args['result_path'])

SMILES to Graph：对 `SMILESToBigraph` 包装后，成为一个新的函数，主要给出节点和边的编码方法

In [5]:
from dgllife.data import BBBP

smiles_to_g = SMILESToBigraph(add_self_loop=True, node_featurizer=args['node_featurizer'],
                              edge_featurizer=args['edge_featurizer'])

解析 `SMILESToBigraph`
- 将 SMILES 字符串转化为双向 `DGLGraphs` 对象并将其特征化
- 原子有 74 个特征（待查看）
- **需注意**，源码案例中 `add_self_loop` 设置为 True，如果改为False，则会在训练中出错

In [6]:
# 函数自带案例

from rdkit import Chem

def featurize_atoms(mol):
    feats = []
    for atom in mol.GetAtoms():
        feats.append(atom.GetAtomicNum())
    return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}

def featurize_bonds(mol):
    feats = []
    bond_types = [Chem.rdchem.BondType.SINGLE,
                  Chem.rdchem.BondType.DOUBLE,
                  Chem.rdchem.BondType.TRIPLE,
                  Chem.rdchem.BondType.AROMATIC]
    for bond in mol.GetBonds():
        btype = bond_types.index(bond.GetBondType())
        feats.extend([btype, btype])
    return {'type': torch.tensor(feats).reshape(-1, 1).float()}

smi_to_g = SMILESToBigraph(
    node_featurizer=featurize_atoms,
    edge_featurizer=featurize_bonds)

g = smi_to_g('CCO')

print(g)
print(g.ndata)
print(g.edata)

Graph(num_nodes=3, num_edges=4,
      ndata_schemes={'atomic': Scheme(shape=(1,), dtype=torch.float32)}
      edata_schemes={'type': Scheme(shape=(1,), dtype=torch.float32)})
{'atomic': tensor([[6.],
        [8.],
        [6.]])}
{'type': tensor([[0.],
        [0.],
        [0.],
        [0.]])}


使用本例中的 `smiles_to_g` 函数编码 SMILES 后，结果如下：
- 返回 DGLGraph 对象，其包含 3 个节点（对应 3 个原子）和 4 个边（对应两个键，双向）
- 节点特征名为 'h'，其大小为 3 行 74 列，即每个原子由 74 个数字编码
- 边特征为空

In [7]:
g = smiles_to_g('CCO')

print(g)
print(g.ndata)
print(g.edata)

Graph(num_nodes=3, num_edges=7,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})
{'h': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
         0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.,

每个 dataset 的元素返回一个元组，元组中包含 4 个元素：
- SMILES
- DGLGraph 对象，包含节点数、边数、节点特征（名为 h 的 74 个特征）、边特征（空）
- Labels，dtype 为 float32，shape 为 task 个数
- masks，dtype 为 float32，表示在多任务学习中该图是否存在 Label

In [8]:
dataset = BBBP(smiles_to_graph=smiles_to_g, n_jobs=1)

Processing dgl graphs from scratch...
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Processing molecule 1000/2050
Processing molecule 2000/2050


In [9]:
dataset.n_tasks

1

In [10]:
dataset[0]

('[Cl].CC(C)NCC(O)COc1cccc2ccccc12',
 Graph(num_nodes=20, num_edges=60,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={}),
 tensor([1.]),
 tensor([1.]))

In [11]:
len(dataset)

2039

## 划分训练集与测试集

In [12]:
# utils

a,b,c = map(float, '0.8,0.1,0.1'.split(','))
print(a, b, c)

0.8 0.1 0.1


In [13]:
from dgllife.utils import ScaffoldSplitter

# transfer string to float
train_ratio, val_ratio, test_ratio = map(float, args['split_ratio'].split(','))

# splitting
train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
    dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio
)

Start initializing RDKit molecule instances...
Creating RDKit molecule instance 1000/2039
Creating RDKit molecule instance 2000/2039
Start computing Bemis-Murcko scaffolds.
Computing Bemis-Murcko for compound 1000/2039
Computing Bemis-Murcko for compound 2000/2039


In [14]:
train_set

<dgl.data.utils.Subset at 0x7fa950169f40>

In [15]:
dataset

<dgllife.data.bbbp.BBBP at 0x7fa970abebe0>

## 载入模型参数

In [16]:
import json

with open('configures/BBBP/GCN_canonical.json') as f:
    config = json.load(f)

config

{'batch_size': 64,
 'batchnorm': False,
 'dropout': 0.0272564399565973,
 'gnn_hidden_feats': 256,
 'lr': 0.02020086171843634,
 'num_gnn_layers': 4,
 'patience': 30,
 'predictor_hidden_feats': 32,
 'residual': True,
 'weight_decay': 0.001168051063650801}

## 训练前的准备工作

- `CanonicalAtomFeaturizer() 具有一个 feat_size() 方法，返回它编码的原子特征个数`
- 创建一个整理函数 `collate_molgraphs()`，其输入参数为 DGL 数据集，输出我们期待的一系列数据

In [17]:
print(args['node_featurizer'].feat_size())
print(CanonicalAtomFeaturizer().feat_size())

74
74


In [18]:
config['in_node_feats'] = args['node_featurizer'].feat_size()

In [19]:
import dgl

def collate_molgraphs(data):
    if len(data[0]) == 3:
        smiles, graphs, labels = map(list, zip(*data))
    else:
        smiles, graphs, labels, masks = map(list, zip(*data))
    
    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    labels = torch.stack(labels, dim=0)
    
    if len(data[0]) == 3:
        masks = torch.ones(labels.shape)
    else:
        masks = torch.stack(masks, dim=0)
    
    return smiles, bg, labels, masks

In [20]:
_ = collate_molgraphs(train_set)
len(_)

4

### 对 collate_molgraphs 拆解

In [21]:
smiles, graphs, labels, masks = map(list, zip(*train_set))
print(len(smiles))
print(labels[0])
print(len(graphs))
print(graphs[0])

bg = dgl.batch(graphs)
print(bg)

bg.set_n_initializer(dgl.init.zero_initializer)
bg.set_e_initializer(dgl.init.zero_initializer)
print(bg)

print(labels[:5])
labels = torch.stack(labels, dim=0)
print(labels[:5])

1631
tensor([1.])
1631
Graph(num_nodes=23, num_edges=69,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})
Graph(num_nodes=36682, num_edges=115678,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})
Graph(num_nodes=36682, num_edges=115678,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})
[tensor([1.]), tensor([1.]), tensor([1.]), tensor([1.]), tensor([1.])]
tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.]])


In [22]:
# dataloader

train_loader = DataLoader(dataset=train_set, batch_size=config['batch_size'], shuffle=True,
                          collate_fn=collate_molgraphs, num_workers=0)
val_loader = DataLoader(dataset=val_set, batch_size=config['batch_size'], shuffle=True,
                        collate_fn=collate_molgraphs, num_workers=0)
test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], shuffle=True,
                         collate_fn=collate_molgraphs, num_workers=0)

In [23]:
print(len(train_loader), len(val_loader), len(test_loader))

26 4 4


In [24]:
for data in train_loader:
    
    _ = data
    
    print(len(_))
    print(_[0][:3])
    print(_[1])
    print(_[2][:3])
    print(_[3][:3])
    
    break

4
['C1=C(OC)C(=CC2=C1C(=C(C)[NH]2)CCN4CCN(C3=CC=CC=C3OC)CC4)OC', 'C2=C(OCC1OC(NCC1)=S)C=CC=C2', 'CCC(C)(CC)OC(N)=O']
Graph(num_nodes=1514, num_edges=4756,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})
tensor([[1.],
        [1.],
        [1.]])
tensor([[1.],
        [1.],
        [1.]])


## 载入模型

In [25]:
config.update({'n_tasks': 1})
config

{'batch_size': 64,
 'batchnorm': False,
 'dropout': 0.0272564399565973,
 'gnn_hidden_feats': 256,
 'lr': 0.02020086171843634,
 'num_gnn_layers': 4,
 'patience': 30,
 'predictor_hidden_feats': 32,
 'residual': True,
 'weight_decay': 0.001168051063650801,
 'in_node_feats': 74,
 'n_tasks': 1}

In [26]:
from dgllife.model import GCNPredictor
import torch.nn.functional as F

def load_model(config):
    model = GCNPredictor(in_feats=config['in_node_feats'],
                         hidden_feats=[config['gnn_hidden_feats']] * config['num_gnn_layers'],
                         activation=[F.relu] * config['num_gnn_layers'],
                         residual=[config['residual']] * config['num_gnn_layers'],
                         batchnorm=[config['batchnorm']] * config['num_gnn_layers'],
                         dropout=[config['dropout']] * config['num_gnn_layers'],
                         predictor_hidden_feats=config['predictor_hidden_feats'],
                         predictor_dropout=config['dropout'],
                         n_tasks=config['n_tasks']
                        )
    return model

In [27]:
model = load_model(config).to(args['device'])
model

GCNPredictor(
  (gnn): GCN(
    (gnn_layers): ModuleList(
      (0): GCNLayer(
        (graph_conv): GraphConv(in=74, out=256, normalization=none, activation=<function relu at 0x7fa9a131baf0>)
        (dropout): Dropout(p=0.0272564399565973, inplace=False)
        (res_connection): Linear(in_features=74, out_features=256, bias=True)
      )
      (1): GCNLayer(
        (graph_conv): GraphConv(in=256, out=256, normalization=none, activation=<function relu at 0x7fa9a131baf0>)
        (dropout): Dropout(p=0.0272564399565973, inplace=False)
        (res_connection): Linear(in_features=256, out_features=256, bias=True)
      )
      (2): GCNLayer(
        (graph_conv): GraphConv(in=256, out=256, normalization=none, activation=<function relu at 0x7fa9a131baf0>)
        (dropout): Dropout(p=0.0272564399565973, inplace=False)
        (res_connection): Linear(in_features=256, out_features=256, bias=True)
      )
      (3): GCNLayer(
        (graph_conv): GraphConv(in=256, out=256, normalization

In [28]:
import torch.nn as nn

loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
optimizer = Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
stopper = EarlyStopping(patience=config['patience'], filename=args['result_path'] + '/model.pth',
                       metric=args['metric'])

For metric roc_auc_score, the higher the better


In [29]:
def predict(args, model, bg):
    bg = bg.to(args['device'])
    if args['edge_featurizer'] is None:
        node_feats = bg.ndata.pop('h').to(args['device'])
        return model(bg, node_feats)
    else:
        node_feats = bg.ndata.pop('h').to(args['device'])
        edge_feats = bg.edata.pop('h').to(args['device'])
        return model(bg, node_feats, edge_feats)

In [30]:
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
    model.train()
    train_meter = Meter()
    for batch_id, batch_data in enumerate(data_loader):
        smiles, bg, labels, masks = batch_data
        if len(smiles) == 1:
            continue
        
        labels, masks = labels.to(args['device']), masks.to(args['device'])
        logits = predict(args, model, bg)
        # mask non-existing labels
        loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_meter.update(logits, labels, masks)
        
        if batch_id % args['print_every'] == 0:
            print('epoch {:d}/{:d}, batch {}/{}, loss {:.3f}'.format(
                epoch+1, args['num_epochs'], batch_id+1, len(data_loader), loss.item()
            ))
        
    train_score = np.mean(train_meter.compute_metric(args['metric']))
    print('epoch {:d}/{:d}, training {} {:.3f}'.format(
        epoch+1, args['num_epochs'], args['metric'], train_score)
    )
        
def run_an_eval_epoch(args, model, data_loader):
    model.eval()
    eval_meter = Meter()
    with torch.no_grad():
        for batch_id, batch_data in enumerate(data_loader):
            smiles, bg, labels, masks = batch_data
            labels = labels.to(args['device'])
            logits = predict(args, model, bg)
            eval_meter.update(logits, labels, masks)
    return np.mean(eval_meter.compute_metric(args['metric']))

In [31]:
Meter()

<dgllife.utils.eval.Meter at 0x7fa970b10790>

In [32]:
test_bg = _[1]
test_bg.ndata.pop('h')

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])

## 开始训练

In [33]:
for epoch in range(args['num_epochs']):
    # Train
    run_a_train_epoch(args, epoch, model, train_loader, loss_criterion, optimizer)
    
    # Validation and Early stop
    val_score = run_an_eval_epoch(args, model, val_loader)
    early_stop = stopper.step(val_score, model)
    
    print('epoch {:d}/{:d}, validation {} {:.3f}, best validation {} {:.3f}'.format(
        epoch+1, args['num_epochs'], args['metric'], val_score, args['metric'], stopper.best_score
    ))
    
    if early_stop: break

epoch 1/1000, batch 1/26, loss 0.718
epoch 1/1000, batch 11/26, loss 0.483
epoch 1/1000, batch 21/26, loss 0.382
epoch 1/1000, training roc_auc_score 0.522
epoch 1/1000, validation roc_auc_score 0.727, best validation roc_auc_score 0.727
epoch 2/1000, batch 1/26, loss 0.417
epoch 2/1000, batch 11/26, loss 0.335
epoch 2/1000, batch 21/26, loss 0.543
epoch 2/1000, training roc_auc_score 0.668
epoch 2/1000, validation roc_auc_score 0.776, best validation roc_auc_score 0.776
epoch 3/1000, batch 1/26, loss 0.511
epoch 3/1000, batch 11/26, loss 0.466
epoch 3/1000, batch 21/26, loss 0.326
epoch 3/1000, training roc_auc_score 0.679
epoch 3/1000, validation roc_auc_score 0.782, best validation roc_auc_score 0.782
epoch 4/1000, batch 1/26, loss 0.421
epoch 4/1000, batch 11/26, loss 0.327
epoch 4/1000, batch 21/26, loss 0.620
epoch 4/1000, training roc_auc_score 0.684
EarlyStopping counter: 1 out of 30
epoch 4/1000, validation roc_auc_score 0.725, best validation roc_auc_score 0.782
epoch 5/1000,

epoch 32/1000, batch 11/26, loss 0.294
epoch 32/1000, batch 21/26, loss 0.254
epoch 32/1000, training roc_auc_score 0.775
EarlyStopping counter: 10 out of 30
epoch 32/1000, validation roc_auc_score 0.686, best validation roc_auc_score 0.924
epoch 33/1000, batch 1/26, loss 0.494
epoch 33/1000, batch 11/26, loss 0.465
epoch 33/1000, batch 21/26, loss 0.353
epoch 33/1000, training roc_auc_score 0.643
EarlyStopping counter: 11 out of 30
epoch 33/1000, validation roc_auc_score 0.913, best validation roc_auc_score 0.924
epoch 34/1000, batch 1/26, loss 0.251
epoch 34/1000, batch 11/26, loss 0.280
epoch 34/1000, batch 21/26, loss 0.219
epoch 34/1000, training roc_auc_score 0.794
EarlyStopping counter: 12 out of 30
epoch 34/1000, validation roc_auc_score 0.683, best validation roc_auc_score 0.924
epoch 35/1000, batch 1/26, loss 0.302
epoch 35/1000, batch 11/26, loss 0.409
epoch 35/1000, batch 21/26, loss 0.355
epoch 35/1000, training roc_auc_score 0.780
EarlyStopping counter: 13 out of 30
epoch

epoch 62/1000, batch 11/26, loss 0.285
epoch 62/1000, batch 21/26, loss 0.236
epoch 62/1000, training roc_auc_score 0.838
EarlyStopping counter: 3 out of 30
epoch 62/1000, validation roc_auc_score 0.939, best validation roc_auc_score 0.952
epoch 63/1000, batch 1/26, loss 0.408
epoch 63/1000, batch 11/26, loss 0.243
epoch 63/1000, batch 21/26, loss 0.334
epoch 63/1000, training roc_auc_score 0.838
EarlyStopping counter: 4 out of 30
epoch 63/1000, validation roc_auc_score 0.952, best validation roc_auc_score 0.952
epoch 64/1000, batch 1/26, loss 0.206
epoch 64/1000, batch 11/26, loss 0.224
epoch 64/1000, batch 21/26, loss 0.226
epoch 64/1000, training roc_auc_score 0.848
EarlyStopping counter: 5 out of 30
epoch 64/1000, validation roc_auc_score 0.952, best validation roc_auc_score 0.952
epoch 65/1000, batch 1/26, loss 0.217
epoch 65/1000, batch 11/26, loss 0.382
epoch 65/1000, batch 21/26, loss 0.366
epoch 65/1000, training roc_auc_score 0.838
EarlyStopping counter: 6 out of 30
epoch 65/

epoch 92/1000, batch 11/26, loss 0.247
epoch 92/1000, batch 21/26, loss 0.271
epoch 92/1000, training roc_auc_score 0.817
EarlyStopping counter: 15 out of 30
epoch 92/1000, validation roc_auc_score 0.944, best validation roc_auc_score 0.958
epoch 93/1000, batch 1/26, loss 0.264
epoch 93/1000, batch 11/26, loss 0.297
epoch 93/1000, batch 21/26, loss 0.463
epoch 93/1000, training roc_auc_score 0.829
EarlyStopping counter: 16 out of 30
epoch 93/1000, validation roc_auc_score 0.936, best validation roc_auc_score 0.958
epoch 94/1000, batch 1/26, loss 0.271
epoch 94/1000, batch 11/26, loss 0.412
epoch 94/1000, batch 21/26, loss 0.238
epoch 94/1000, training roc_auc_score 0.812
EarlyStopping counter: 17 out of 30
epoch 94/1000, validation roc_auc_score 0.950, best validation roc_auc_score 0.958
epoch 95/1000, batch 1/26, loss 0.199
epoch 95/1000, batch 11/26, loss 0.411
epoch 95/1000, batch 21/26, loss 0.215
epoch 95/1000, training roc_auc_score 0.833
EarlyStopping counter: 18 out of 30
epoch

In [34]:
val_score = run_an_eval_epoch(args, model, val_loader)
test_score = run_an_eval_epoch(args, model, test_loader)

In [35]:
print(val_score, test_score)

0.9375684556407449 0.5981385030864197


## 拆解 `run_a_train_epoch`

In [36]:
model = GCNPredictor(in_feats = 74,
                     hidden_feats = [256, 128, 64, 32],
                     activation = [F.relu] * 4,
                     residual = [True] * 4,
                     batchnorm = [False] * 4,
                     dropout = [0.0, 0.5, 0.25, 0.0],
                     predictor_hidden_feats = 32,
                     predictor_dropout = 0.25,
                     n_tasks = 1
                    )

model = model.to(args['device'])
loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
optimizer = Adam(model.parameters(), lr=0.01, weight_decay=0.001)
stopper = EarlyStopping(patience=20, filename=args['result_path'] + '/model_custom.pth',
                       metric='roc_auc_score')

For metric roc_auc_score, the higher the better


In [37]:
demo_loader = DataLoader(train_set, batch_size=4, shuffle=True, drop_last=True, collate_fn=collate_molgraphs,
                         num_workers=0)

In [38]:
model.train()
train_meter = Meter()

for batch_id, batch_data in enumerate(demo_loader):
    smiles, bg, labels, masks = batch_data
    labels, masks = labels.to(args['device']), masks.to(args['device'])
    print(smiles)
    print(bg)
    print(labels)
    print(masks)
    
    # predict function
    bg = bg.to(args['device'])
    node_feats = bg.ndata.pop('h').to(args['device'])
    print(bg)
    logits = model(bg, node_feats)
    print(bg)
    print(logits)
    
    # mask non-existing labels
    print(loss_criterion(logits, labels))
    print(masks != 0)
    loss = (loss_criterion(logits, labels) * (masks != 0).float()).mean()
    print(loss)
    print(loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    
    train_meter.update(logits, labels, masks)
    train_score = train_meter.compute_metric('roc_auc_score')
    print(train_score)
    train_score = np.mean(train_score)
    print(train_score)

    break

['NC(=N)NCCCOc1ccccc1', 'c1(CC(N2[C@H](CN(CC2)C(=O)C)C[N@]2CC[C@H](O)C2)=O)ccc(N(=O)=O)cc1', 'C1=C(Cl)C=CC3=C1N(C2=CC=CC=C2)C(CCN3)=O', 'CC1CC2C3CC(F)(F)C4=CC(=O)C=C[C@]4(C)C3(F)C(O)CC2(C)C1(O)C(=O)CO']
Graph(num_nodes=91, num_edges=287,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})
tensor([[1.],
        [0.],
        [1.],
        [1.]])
tensor([[1.],
        [1.],
        [1.],
        [1.]])
Graph(num_nodes=91, num_edges=287,
      ndata_schemes={}
      edata_schemes={})
Graph(num_nodes=91, num_edges=287,
      ndata_schemes={}
      edata_schemes={})
tensor([[-0.1110],
        [-0.3321],
        [ 0.3468],
        [-0.4268]], grad_fn=<AddmmBackward0>)
tensor([[0.7502],
        [0.5408],
        [0.5347],
        [0.9292]], grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor([[True],
        [True],
        [True],
        [True]])
tensor(0.6887, grad_fn=<MeanBackward0>)
0.6887178421020508
[0.6666666666666666]
0.6666666666666666


In [39]:
model.eval()
eval_meter = Meter()

with torch.no_grad():
    for batch_id, batch_data in enumerate(demo_loader):
        smiles, bg, labels, masks = batch_data
        labels = labels.to(args['device'])
        
        bg = bg.to(args['device'])
        node_feats = bg.ndata.pop('h').to(args['device'])
        logits = model(bg, node_feats)
        
        eval_meter.update(logits, labels, masks)
        if (batch_id+1) % 50 == 0:
            print('batch id {}, auc = {:.3f}'.format(batch_id+1, np.mean(eval_meter.compute_metric('roc_auc_score'))))
        
    print(np.mean(eval_meter.compute_metric('roc_auc_score')))

batch id 50, auc = 0.501
batch id 100, auc = 0.543
batch id 150, auc = 0.581
batch id 200, auc = 0.601
batch id 250, auc = 0.611
batch id 300, auc = 0.611
batch id 350, auc = 0.611
batch id 400, auc = 0.614
0.6148307310585317
