In [1]:
import torch
from torchdrug import datasets
import pickle

# dataset = datasets.ZINC250k("~/molecule-datasets/", kekulize=True,
#                             atom_feature="symbol")
# with open("torchdrug_data/zinc250k.pkl", "wb") as fout:
#     pickle.dump(dataset, fout)
    
with open("torchdrug_data/zinc250k.pkl", "rb") as fin:
    dataset = pickle.load(fin)

In [2]:
from torchdrug import core, models, tasks

model = models.RGCN(input_dim=dataset.node_feature_dim,
                    num_relation=dataset.num_bond_type,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)
task = tasks.GCPNGeneration(model, dataset.atom_types, max_edge_unroll=12,
                            max_node=38, criterion="nll")

In [3]:
from torch import nn, optim
optimizer = optim.Adam(task.parameters(), lr = 1e-3)
solver = core.Engine(task, dataset, None, None, optimizer,
                     gpus=(0,), batch_size=128, log_interval=10)

solver.train(num_epoch=1)
solver.save("torchdrug_data/graphgeneration/gcpn_zinc250k_1epoch.pkl")

20:09:40   Preprocess training set
20:09:41   {'batch_size': 128,
 'class': 'core.Engine',
 'gpus': (0,),
 'gradient_interval': 1,
 'log_interval': 10,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'differentiable': False,
               'eps': 1e-08,
               'foreach': None,
               'fused': None,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'agent_update_interval': 10,
          'atom_types': [6, 7, 8, 9, 15, 16, 17, 35, 53],
          'baseline_momentum': 0.9,
          'class': 'tasks.GCPNGeneration',
          'criterion': 'nll',
          'gamma': 0.9,
          'hidden_dim_mlp': 128,
          'max_edge_unroll': 12,
          'max_node': 38,
          'model': {'activation': 'relu',
                    'batch_norm': False,
          



20:09:45   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:09:45   edge acc: 0.588088
20:09:45   edge loss: 1.09507
20:09:45   node1 acc: 0.168966
20:09:45   node1 loss: 2.27437
20:09:45   node2 acc: 0
20:09:45   node2 loss: 2.9571
20:09:45   stop acc: 0.897227
20:09:45   stop bce loss: 0.686563
20:09:45   total loss: 7.0131




20:09:47   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:09:47   edge acc: 0.740025
20:09:47   edge loss: 0.643667
20:09:47   node1 acc: 0.259663
20:09:47   node1 loss: 2.20115
20:09:47   node2 acc: 0.662407
20:09:47   node2 loss: 1.90984
20:09:47   stop acc: 0.742806
20:09:47   stop bce loss: 0.60739
20:09:47   total loss: 5.36204
20:09:48   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:09:48   edge acc: 0.75852
20:09:48   edge loss: 0.540533
20:09:48   node1 acc: 0.24927
20:09:48   node1 loss: 2.0794
20:09:48   node2 acc: 0.657254
20:09:48   node2 loss: 1.42557
20:09:48   stop acc: 0.685883
20:09:48   stop bce loss: 0.486943
20:09:48   total loss: 4.53244
20:09:50   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
20:09:50   edge acc: 0.769637
20:09:50   edge loss: 0.463028
20:09:50   node1 acc: 0.330953
20:09:50   node1 loss: 1.92287
20:09:50   node2 acc: 0.650109
20:09:50   node2 loss: 1.34563
20:09:50   stop acc: 0.895193
20:09:50   stop bce loss: 0.489291
20:09:50   total loss: 4.22082
20:09:52   >>>>>>>>>>>>>>>>>>>>>>>>

In [4]:
%matplotlib inline

solver.load("torchdrug_data/graphgeneration/gcpn_zinc250k_1epoch.pkl")
results = task.generate(num_sample=32, max_resample=5)
print(results)
print(results.to_smiles())

20:15:26   Load checkpoint from torchdrug_data/graphgeneration/gcpn_zinc250k_1epoch.pkl
20:15:27   1 / 30 molecules are invalid even after 5 resampling
PackedMolecule(batch_size=31, num_atoms=[6, 11, 13, ..., 22, 25, 24], num_bonds=[10, 22, 24, ..., 48, 54, 54], device='cuda:0')
['C=CCC(C)C', 'CCC(=O)C1=CC=CC=C1C', 'CCCC(CNC(=O)NC)C(C)=O', 'C=[SH]NCC(CC)NC(=O)CCC', 'C=[SH]NCC(C)CN1CCC1CC', 'CCCC1=NOC(C(C)CC)=C1', 'CCCC(CC)(CC)C(=O)NC(C)C', 'CCCC(=CN)C(=O)CN1C=CN=C1', 'COC1=CC=CC=C1C(C)=NNC=O', 'CCCC1=NC2=CC=C(N)C=C2S1', 'NC1=C(C=C2C=N2)C(F)=CC=C1CO', 'COC1=CC=C2C=C1C=C(C#N)CO2', 'CCCCC(C)OC(=O)C1=CC=CC=C1', 'O=C(C1=CC=C(Cl)C=C1)N1CCCCC1', 'C=C(C)C(=O)C1=C(CC)C=CC(OCC)=C1', 'CCCC(C=C1C=C1)=CC(CN)=C(N)N(C)N', 'CCCC(=O)NC1=CC=C(F)C(C(C)NC)=C1', 'CCCN(CC)C(CC)CNC(=O)C1=CC=CS1', 'CCSC1=NC2=NC=CC=C3C=CC=C2N31', 'CCN1CCCC(CNCCC2=CC=CC=C2C)C1', 'CC(O)C(O)C(N)(C1=CC=CC=C1)C1=CC=CC=C1', 'COC1=CC=C(CN(C(C)=O)C2=CC=CC=C2)C=C1', 'CCCSC1=NC=CN1C(=O)C1=CC=CC(C)=C1C', 'C=C1NC(C2=CC=CC=C2O)CCC1(C)N(CC)

In [5]:
import torch
from torchdrug import core, datasets, models, tasks
from torch import nn, optim
from collections import defaultdict

dataset = datasets.ZINC250k("~/molecule-datasets/", kekulize=True,
                            atom_feature="symbol")

model = models.RGCN(input_dim=dataset.node_feature_dim,
                    num_relation=dataset.num_bond_type,
                    hidden_dims=[256, 256, 256, 256], batch_norm=False)
task = tasks.GCPNGeneration(model, dataset.atom_types,
                            max_edge_unroll=12, max_node=38,
                            task="qed", criterion=("ppo", "nll"),
                            reward_temperature=1,
                            agent_update_interval=3, gamma=0.9)


optimizer = optim.Adam(task.parameters(), lr=1e-5)
solver = core.Engine(task, dataset, None, None, optimizer,
                     gpus=(0,), batch_size=16, log_interval=10)

solver.load("torchdrug_data/graphgeneration/gcpn_zinc250k_1epoch.pkl",
            load_optimizer=False)

# RL finetuning
solver.train(num_epoch=10)
solver.save("torchdrug_data/graphgeneration/gcpn_zinc250k_1epoch_finetune.pkl")

Loading /data/yulai/molecule-datasets/250k_rndm_zinc_drugs_clean_3.csv:  50%|█████     | 249456/498911 [00:02<00:02, 102490.04it/s]
Constructing molecules from SMILES: 100%|██████████| 249455/249455 [05:59<00:00, 693.93it/s] 


20:21:32   Preprocess training set
20:21:33   {'batch_size': 16,
 'class': 'core.Engine',
 'gpus': (0,),
 'gradient_interval': 1,
 'log_interval': 10,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'differentiable': False,
               'eps': 1e-08,
               'foreach': None,
               'fused': None,
               'lr': 1e-05,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'agent_update_interval': 3,
          'atom_types': [6, 7, 8, 9, 15, 16, 17, 35, 53],
          'baseline_momentum': 0.9,
          'class': 'tasks.GCPNGeneration',
          'criterion': ('ppo', 'nll'),
          'gamma': 0.9,
          'hidden_dim_mlp': 128,
          'max_edge_unroll': 12,
          'max_node': 38,
          'model': {'activation': 'relu',
                    'batch_norm': False,
   

: 