In [1]:
import os
import numpy as np
import torch
import torch.nn as nn

from dgllife.model import load_pretrained
from dgllife.utils import EarlyStopping, Meter, SMILESToBigraph
from torch.optim import Adam
from torch.utils.data import DataLoader

from utils import collate_molgraphs, load_model, predict
from utils import init_featurizer, mkdir_p, split_dataset, get_configure

源码作者将输入参数存于 args 字典，并使用 init_featurizer 函数对参数进行解析。
本例中仅仔细研究 GCN & canonical 方法的模型，其它模型可类比。

In [2]:
from dgllife.utils import CanonicalAtomFeaturizer

In [3]:
args = {
    'dataset': 'BBBP',
    'model': 'GCN',
    'featurizer_type': 'canonical',
    'pretrain': False,
    'split': 'scaffold',
    'split_ratio': '0.8,0.1,0.1',
    'metric': 'roc_auc_score',
    'num_epochs': 1000,
    'num_workers': 0,
    'print_every': 10,
    'result_path': 'classification_results',
    'device': torch.device('cpu'),
    # add by init_featurizer
    'node_featurizer': CanonicalAtomFeaturizer(),
    'edge_featurizer': None,
}

In [4]:
# make dir
if not os.path.exists(args['result_path']):
    os.makedirs(args['result_path'])

SMILES to Graph：对 `SMILESToBigraph` 包装后，成为一个新的函数，主要给出节点和边的编码方法

In [5]:
from dgllife.data import BBBP

smiles_to_g = SMILESToBigraph(add_self_loop=False, node_featurizer=args['node_featurizer'],
                              edge_featurizer=args['edge_featurizer'])

解析 `SMILESToBigraph`
- 将 SMILES 字符串转化为双向 `DGLGraphs` 对象并将其特征化
- 原子有 74 个特征（待查看）

In [6]:
# 函数自带案例

from rdkit import Chem

def featurize_atoms(mol):
    feats = []
    for atom in mol.GetAtoms():
        feats.append(atom.GetAtomicNum())
    return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}

def featurize_bonds(mol):
    feats = []
    bond_types = [Chem.rdchem.BondType.SINGLE,
                  Chem.rdchem.BondType.DOUBLE,
                  Chem.rdchem.BondType.TRIPLE,
                  Chem.rdchem.BondType.AROMATIC]
    for bond in mol.GetBonds():
        btype = bond_types.index(bond.GetBondType())
        feats.extend([btype, btype])
    return {'type': torch.tensor(feats).reshape(-1, 1).float()}

smi_to_g = SMILESToBigraph(
    node_featurizer=featurize_atoms,
    edge_featurizer=featurize_bonds)

g = smi_to_g('CCO')

print(g)
print(g.ndata)
print(g.edata)

Graph(num_nodes=3, num_edges=4,
      ndata_schemes={'atomic': Scheme(shape=(1,), dtype=torch.float32)}
      edata_schemes={'type': Scheme(shape=(1,), dtype=torch.float32)})
{'atomic': tensor([[6.],
        [8.],
        [6.]])}
{'type': tensor([[0.],
        [0.],
        [0.],
        [0.]])}


使用本例中的 `smiles_to_g` 函数编码 SMILES 后，结果如下：
- 返回 DGLGraph 对象，其包含 3 个节点（对应 3 个原子）和 4 个边（对应两个键，双向）
- 节点特征名为 'h'，其大小为 3 行 74 列，即每个原子由 74 个数字编码
- 边特征为空

In [7]:
g = smiles_to_g('CCO')

print(g)
print(g.ndata)
print(g.edata)

Graph(num_nodes=3, num_edges=4,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})
{'h': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
         0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.,

每个 dataset 的元素返回一个元组，元组中包含 4 个元素：
- SMILES
- DGLGraph 对象，包含节点数、边数、节点特征（名为 h 的 74 个特征）、边特征（空）
- Labels，dtype 为 float32，shape 为 task 个数
- masks，dtype 为 float32，表示在多任务学习中该图是否存在 Label

In [8]:
dataset = BBBP(smiles_to_graph=smiles_to_g, n_jobs=1)

Processing dgl graphs from scratch...
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Processing molecule 1000/2050
Processing molecule 2000/2050


In [9]:
dataset.n_tasks

1

In [10]:
dataset[0]

('[Cl].CC(C)NCC(O)COc1cccc2ccccc12',
 Graph(num_nodes=20, num_edges=40,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={}),
 tensor([1.]),
 tensor([1.]))

In [11]:
len(dataset)

2039

## 划分训练集与测试集

In [12]:
# utils

a,b,c = map(float, '0.8,0.1,0.1'.split(','))
print(a, b, c)

0.8 0.1 0.1


In [13]:
from dgllife.utils import ScaffoldSplitter

# transfer string to float
train_ratio, val_ratio, test_ratio = map(float, args['split_ratio'].split(','))

# splitting
train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
    dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio
)

Start initializing RDKit molecule instances...
Creating RDKit molecule instance 1000/2039
Creating RDKit molecule instance 2000/2039
Start computing Bemis-Murcko scaffolds.
Computing Bemis-Murcko for compound 1000/2039
Computing Bemis-Murcko for compound 2000/2039


In [14]:
train_set

<dgl.data.utils.Subset at 0x7fd3247868b0>

In [15]:
dataset

<dgllife.data.bbbp.BBBP at 0x7fd2f869c1c0>

## 载入模型参数

In [16]:
import json

with open('configures/BBBP/GCN_canonical.json') as f:
    config = json.load(f)

config

{'batch_size': 64,
 'batchnorm': False,
 'dropout': 0.0272564399565973,
 'gnn_hidden_feats': 256,
 'lr': 0.02020086171843634,
 'num_gnn_layers': 4,
 'patience': 30,
 'predictor_hidden_feats': 32,
 'residual': True,
 'weight_decay': 0.001168051063650801}

## 训练前的准备工作

- `CanonicalAtomFeaturizer() 具有一个 feat_size() 方法，返回它编码的原子特征个数`
- 创建一个整理函数 `collate_molgraphs()`，其输入参数为 DGL 数据集，输出我们期待的一系列数据

In [21]:
print(args['node_featurizer'].feat_size())
print(CanonicalAtomFeaturizer().feat_size())

74
74


In [22]:
config['in_node_feats'] = args['node_featurizer'].feat_size()

In [25]:
import dgl

def collate_molgraphs(data):
    if len(data[0]) == 3:
        smiles, graphs, labels = map(list, zip(*data))
    else:
        smiles, graphs, labels, masks = map(list, zip(*data))
    
    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    labels = torch.stack(labels, dim=0)
    
    if len(data[0]) == 3:
        masks = torch.ones(labels.shape)
    else:
        masks = torch.stack(masks, dim=0)
    
    return smiles, bg, labels, masks

In [32]:
_ = collate_molgraphs(train_set)
len(_)

4

In [37]:
# dataloader

train_loader = DataLoader(dataset=train_set, batch_size=config['batch_size'], shuffle=True,
                          collate_fn=collate_molgraphs, num_workers=0)
val_loader = DataLoader(dataset=val_set, batch_size=config['batch_size'], shuffle=True,
                        collate_fn=collate_molgraphs, num_workers=0)
test_loader = DataLoader(dataset=test_set, batch_size=config['batch_size'], shuffle=True,
                         collate_fn=collate_molgraphs, num_workers=0)

In [38]:
print(len(train_loader), len(val_loader), len(test_loader))

26 4 4


## 载入模型

In [50]:
config.update({'n_tasks': 1})
config

{'batch_size': 64,
 'batchnorm': False,
 'dropout': 0.0272564399565973,
 'gnn_hidden_feats': 256,
 'lr': 0.02020086171843634,
 'num_gnn_layers': 4,
 'patience': 30,
 'predictor_hidden_feats': 32,
 'residual': True,
 'weight_decay': 0.001168051063650801,
 'in_node_feats': 74,
 'n_task': 1,
 'n_tasks': 1}

In [51]:
from dgllife.model import GCNPredictor
import torch.nn.functional as F

def load_model(config):
    model = GCNPredictor(in_feats=config['in_node_feats'],
                         hidden_feats=[config['gnn_hidden_feats']] * config['num_gnn_layers'],
                         activation=[F.relu] * config['num_gnn_layers'],
                         residual=[config['residual']] * config['num_gnn_layers'],
                         batchnorm=[config['batchnorm']] * config['num_gnn_layers'],
                         dropout=[config['dropout']] * config['num_gnn_layers'],
                         predictor_hidden_feats=config['predictor_hidden_feats'],
                         predictor_dropout=config['dropout'],
                         n_tasks=config['n_tasks']
                        )
    return model

In [52]:
model = load_model(config).to(args['device'])
model

GCNPredictor(
  (gnn): GCN(
    (gnn_layers): ModuleList(
      (0): GCNLayer(
        (graph_conv): GraphConv(in=74, out=256, normalization=none, activation=<function relu at 0x7fd320f2daf0>)
        (dropout): Dropout(p=0.0272564399565973, inplace=False)
        (res_connection): Linear(in_features=74, out_features=256, bias=True)
      )
      (1): GCNLayer(
        (graph_conv): GraphConv(in=256, out=256, normalization=none, activation=<function relu at 0x7fd320f2daf0>)
        (dropout): Dropout(p=0.0272564399565973, inplace=False)
        (res_connection): Linear(in_features=256, out_features=256, bias=True)
      )
      (2): GCNLayer(
        (graph_conv): GraphConv(in=256, out=256, normalization=none, activation=<function relu at 0x7fd320f2daf0>)
        (dropout): Dropout(p=0.0272564399565973, inplace=False)
        (res_connection): Linear(in_features=256, out_features=256, bias=True)
      )
      (3): GCNLayer(
        (graph_conv): GraphConv(in=256, out=256, normalization

In [53]:
import torch.nn as nn

loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
optimizer = Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
stopper = EarlyStopping(patience=config['patience'], filename=args['result_path'] + '/model.pth',
                       metric=args['metric'])

For metric roc_auc_score, the higher the better


In [67]:
def predict(args, model, bg):
    bg = bg.to(args['device'])
    if args['edge_featurizer'] is None:
        node_feats = bg.ndata.pop('h').to(args['device'])
        return model(bg, node_feats)
    else:
        node_feats = bg.ndata.pop('h').to(args['device'])
        edge_feats = bg.edata.pop('h').to(args['device'])
        return model(bg, node_feats, edge_feats)

In [None]:
def run_a_train_epoch(args, epoch, model, data_loader, loss_criterion, optimizer):
    model.train()
    train_meter = Meter()
    for batch_id, batch_data in enumerate(data_loader):
        smiles, bg, labels, masks = batch_data
        if len(smiles) == 1:
            continue
        
        labels, masks = labels.to(args['device']), masks.to(args['device'])
        logits = predict(args, model, bg)

In [54]:
Meter()

<dgllife.utils.eval.Meter at 0x7fd2c155db80>

In [64]:
bg = _[1]
bg.ndata.pop('h')

tensor([[1., 0., 0.,  ..., 1., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 1., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])