In [1]:
import os
import numpy as np
import torch
import torch.nn as nn

from dgllife.model import load_pretrained
from dgllife.utils import EarlyStopping, Meter, SMILESToBigraph
from torch.optim import Adam
from torch.utils.data import DataLoader

from utils import collate_molgraphs, load_model, predict
from utils import init_featurizer, mkdir_p, split_dataset, get_configure

源码作者将输入参数存于 args 字典，并使用 init_featurizer 函数对参数进行解析。
本例中仅仔细研究 GCN & canonical 方法的模型，其它模型可类比。

In [2]:
from dgllife.utils import CanonicalAtomFeaturizer

In [3]:
args = {
    'dataset': 'BBBP',
    'model': 'GCN',
    'featurizer_type': 'canonical',
    'pretrain': False,
    'split': 'scaffold',
    'split_ratio': '0.8,0.1,0.1',
    'metric': 'roc_auc_score',
    'num_epochs': 1000,
    'num_workers': 0,
    'print_every': 10,
    'result_path': 'classification_results',
    'device': torch.device('cpu'),
    # add by init_featurizer
    'node_featurizer': CanonicalAtomFeaturizer(),
    'edge_featurizer': None,
}

In [4]:
# make dir
if not os.path.exists(args['result_path']):
    os.makedirs(args['result_path'])

SMILES to Graph



In [5]:
from dgllife.data import BBBP

smiles_to_g = SMILESToBigraph(add_self_loop=False, node_featurizer=args['node_featurizer'],
                              edge_featurizer=args['edge_featurizer'])

解析 `SMILESToBigraph`
- 将 SMILES 字符串转化为双向 `DGLGraphs` 对象并将其特征化
- 原子有 74 个特征（待查看）

In [6]:
from rdkit import Chem

def featurize_atoms(mol):
    feats = []
    for atom in mol.GetAtoms():
        feats.append(atom.GetAtomicNum())
    return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}

def featurize_bonds(mol):
    feats = []
    bond_types = [Chem.rdchem.BondType.SINGLE,
                  Chem.rdchem.BondType.DOUBLE,
                  Chem.rdchem.BondType.TRIPLE,
                  Chem.rdchem.BondType.AROMATIC]
    for bond in mol.GetBonds():
        btype = bond_types.index(bond.GetBondType())
        feats.extend([btype, btype])
    return {'type': torch.tensor(feats).reshape(-1, 1).float()}

smi_to_g = SMILESToBigraph(
    node_featurizer=featurize_atoms,
    edge_featurizer=featurize_bonds)

g = smi_to_g('CCO')

print(g)
print(g.ndata)
print(g.edata)

Graph(num_nodes=3, num_edges=4,
      ndata_schemes={'atomic': Scheme(shape=(1,), dtype=torch.float32)}
      edata_schemes={'type': Scheme(shape=(1,), dtype=torch.float32)})
{'atomic': tensor([[6.],
        [8.],
        [6.]])}
{'type': tensor([[0.],
        [0.],
        [0.],
        [0.]])}


In [7]:
g = smiles_to_g('CCO')

print(g)
print(g.ndata)
print(g.edata)

Graph(num_nodes=3, num_edges=4,
      ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
      edata_schemes={})
{'h': tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         1., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
         0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.,

In [9]:
dataset = BBBP(smiles_to_graph=smiles_to_g, n_jobs=1)

Processing dgl graphs from scratch...
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Invalid mol found
Processing molecule 1000/2050
Processing molecule 2000/2050


In [12]:
dataset.n_tasks

1

每个 dataset 的元素返回一个元组，元组中包含 4 个元素：
- SMILES
- DGLGraph 对象，包含节点数、边数、节点特征（名为 h 的 74 个特征）、边特征（空）
- Labels，dtype 为 float32，shape 为 task 个数
- masks，dtype 为 float32，表示在多任务学习中该图是否存在 Label

In [20]:
dataset[0]

('[Cl].CC(C)NCC(O)COc1cccc2ccccc12',
 Graph(num_nodes=20, num_edges=40,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={}),
 tensor([1.]),
 tensor([1.]))

In [21]:
len(dataset)

2039

## 划分训练集与测试集

In [26]:
# utils

a,b,c = map(float, '0.8,0.1,0.1'.split(','))
print(a, b, c)

0.8 0.1 0.1


In [27]:
from dgllife.utils import ScaffoldSplitter

# transfer string to float
train_ratio, val_ratio, test_ratio = map(float, args['split_ratio'].split(','))

# splitting
train_set, val_set, test_set = ScaffoldSplitter.train_val_test_split(
    dataset, frac_train=train_ratio, frac_val=val_ratio, frac_test=test_ratio
)

Start initializing RDKit molecule instances...
Creating RDKit molecule instance 1000/2039
Creating RDKit molecule instance 2000/2039
Start computing Bemis-Murcko scaffolds.
Computing Bemis-Murcko for compound 1000/2039
Computing Bemis-Murcko for compound 2000/2039


In [32]:
train_set

<dgl.data.utils.Subset at 0x7fca2e6a0c10>

In [33]:
dataset

<dgllife.data.bbbp.BBBP at 0x7fca2c0d00a0>

## 载入模型参数

In [36]:
import json

with open('configures/BBBP/GCN_canonical.json') as f:
    config = json.load(f)

config

{'batch_size': 64,
 'batchnorm': False,
 'dropout': 0.0272564399565973,
 'gnn_hidden_feats': 256,
 'lr': 0.02020086171843634,
 'num_gnn_layers': 4,
 'patience': 30,
 'predictor_hidden_feats': 32,
 'residual': True,
 'weight_decay': 0.001168051063650801}

## 训练前的准备工作

In [38]:
args['node_featurizer'].feat_size()

74

In [40]:
config['in_node_feats'] = args['node_featurizer'].feat_size()

In [41]:
def collate_molgraphs(data):
    if len(data[0]) == 3:
        smiles, graphs, labels = map(list, zip(*data))
    else:
        smiles, graphs, labels, masks = map(list, zip(*data))
    
    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    labels = torch.stack(labels, dim=0)
    
    if len(data[0]) == 3:
        masks = torch.ones(labels.shape)
    else:
        masks = torch.stack(masks, dim=0)

In [42]:
# dataloader

train_loader = DataLoader(dataset=train_set, batch_size=config['batch_size'], shuffle=True,
                          collate_fn=collate_molgraphs, num_workers=0)