In [1]:
import random
import numpy as np
import torch

from torch.utils.data import Dataset

from rdkit import Chem
from dgllife.model.model_zoo.dgmg import MoleculeEnv
from dgl.data.utils import download, _get_dgl_url
from torch.utils.data import DataLoader

## args

In [2]:
args = {
    'seed': 0,
    'warmup_epochs': 10,
    'dataset': 'ChEMBL',
    'order': 'random',
    'train_file': None,
    'val_file': None,
    'log_dir': 'training_results',
    
    'num_processes': 8,
    'master_ip': '127.0.0.1',
    'master_port': '12345',
    
    # function configure(args)
    'node_hidden_size': 128,
    'num_propagation_rounds': 2,
    'lr': 1e-4,
    'dropout': 0.2,
    'nepochs': 400,
    'batch_size': 1,
}

In [3]:
import os
import datetime

def get_date_postfix():
    """Get a data based postfix for directory name.
    
    Returns
    -------
    post_fix : str
    """
    dt = datetime.datetime.now()
    post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(
        dt.date(), dt.hour, dt.minute, dt.second)
    
    return post_fix


def setup_log_dir(args):
    """Name and create directory fot logging.
    
    Returns
    -------
    log_dir : str
        Path for logging directory
    """
    data_postfix = get_date_postfix()
    log_dir = os.path.join(
        args['log_dir'],
        '{}_{}_{}'.format(
            args['dataset'],
            args['order'],
            data_postfix)
        )
    mkdir_p(log_dir)
    return log_dir

def mkdir_p(path, log=True):
    """Create a directory for the specified path.
    
    Parameters
    ----------
    path : str
        Path name
    log : bool
        Whether to print result for directory creation
    """
    import errno
    try:
        os.makedirs(path)
        if log:
            print('Create directory {}'.format(path))
    except OSError as exc:
        if exc.errno == errno.EEXIST and ps.path.isdir(path) and log:
            print('Directory {} already exists.'.format(path))
        else:
            raise

In [4]:
log_dir = setup_log_dir(args)
args['log_dir'] = log_dir

Create directory training_results/ChEMBL_random_2023-12-08_11-35-20


In [5]:
args['checkpoint_dir'] = os.path.join(log_dir, 'checkpoint.pth')

## set random seed

In [6]:
def set_random_seed(seed):
    '''Fix random seed for reproducible results.
    
    Paramters
    ---------
    seed: int
        Random seed to use
    '''
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

Paramters of `torch.set_num_threads()`:
- int, sets the number of threads used for intraop parallelism of CPU

In [7]:
torch.set_num_threads(1)

## Setup dataset and data loader

In [8]:
class MoleculeDataset(object):
    """Initialize and split the dataset.

    Parameters
    ----------
    dataset : str
        Dataset name
    order : None or str
        Order to extract a decision sequence for generating a molecule. Default to be None.
    modes : None or list
        List of subsets to use, which can contain 'train', 'val', corresponding to
        training and validation. Default to be None.
    subset_id : int
        With multiprocess training, we partition the training set into multiple subsets and
        each process will use one subset only. This subset_id corresponds to subprocess id.
    n_subsets : int
        With multiprocess training, this corresponds to the number of total subprocesses.
    """
    def __init__(self, dataset, order=None, modes=None, subset_id=0, n_subsets=1):
        super(MoleculeDataset, self).__init__()

        if modes is None:
            modes = []
        else:
            assert order is not None, 'An order should be specified for extracting ' \
                                      'decision sequences.'

        assert order in ['random', 'canonical', None], \
            "Unexpected order option to get sequences of graph generation decisions"
        assert len(set(modes) - {'train', 'val'}) == 0, \
            "modes should be a list, representing a subset of ['train', 'val']"

        self.dataset = dataset
        self.order = order
        self.modes = modes
        self.subset_id = subset_id
        self.n_subsets = n_subsets
        self._setup()

    def collate(self, samples):
        """PyTorch's approach to batch multiple samples.

        For auto-regressive generative models, we process one sample at a time.

        Parameters
        ----------
        samples : list
            A list of length 1 that consists of decision sequence to generate a molecule.

        Returns
        -------
        list
            List of 2-tuples, a decision sequence to generate a molecule
        """
        assert len(samples) == 1
        return samples[0]

    def _create_a_subset(self, smiles):
        """Create a dataset from a subset of smiles.

        Parameters
        ----------
        smiles : list of str
            List of molecules in SMILES format
        """
        # We evenly divide the smiles into multiple susbets with multiprocess
        subset_size = len(smiles) // self.n_subsets
        return Subset(smiles[self.subset_id * subset_size: (self.subset_id + 1) * subset_size],
                      self.order, self.env)

    def _setup(self):
        """
        1. Instantiate an MDP environment for molecule generation
        2. Download the dataset, which is a file of SMILES
        3. Create subsets for training and validation
        """
        if self.dataset == 'ChEMBL':
            # For new datasets, get_atom_and_bond_types can be used to
            # identify the atom and bond types in them.
            self.atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
            self.bond_types = [Chem.rdchem.BondType.SINGLE,
                               Chem.rdchem.BondType.DOUBLE,
                               Chem.rdchem.BondType.TRIPLE]

        elif self.dataset == 'ZINC':
            self.atom_types = ['Br', 'S', 'C', 'P', 'N', 'O', 'F', 'Cl', 'I']
            self.bond_types = [Chem.rdchem.BondType.SINGLE,
                               Chem.rdchem.BondType.DOUBLE,
                               Chem.rdchem.BondType.TRIPLE]

        else:
            path_to_atom_and_bond_types = '_'.join([self.dataset, 'atom_and_bond_types.pkl'])
            with open(path_to_atom_and_bond_types, 'rb') as f:
                type_info = pickle.load(f)
            self.atom_types = type_info['atom_types']
            self.bond_types = type_info['bond_types']
        self.env = MoleculeEnv(self.atom_types, self.bond_types)

        dataset_prefix = self._dataset_prefix()

        if 'train' in self.modes:
            fname = '_'.join([dataset_prefix, 'train.txt'])
            download_data(self.dataset, fname)
            smiles = load_smiles_from_file(fname)
            self.train_set = self._create_a_subset(smiles)

        if 'val' in self.modes:
            fname = '_'.join([dataset_prefix, 'val.txt'])
            download_data(self.dataset, fname)
            smiles = load_smiles_from_file(fname)
            # We evenly divide the smiles into multiple susbets with multiprocess
            self.val_set = self._create_a_subset(smiles)

    def _dataset_prefix(self):
        """Get the prefix for the data files of supported datasets.

        Returns
        -------
        str
            Prefix for dataset file name
        """
        return '_'.join([self.dataset, 'DGMG'])

class Subset(Dataset):
    """A set of molecules which can be used for training, validation, test.

    Parameters
    ----------
    smiles : list
        List of SMILES for the dataset
    order : str
        Specifies how decision sequences for molecule generation
        are obtained, can be either "random" or "canonical"
    env : MoleculeEnv object
        MDP environment for generating molecules
    """
    def __init__(self, smiles, order, env):
        super(Subset, self).__init__()
        self.smiles = smiles
        self.order = order
        self.env = env
        self._setup()

    def _setup(self):
        """Convert SMILES into rdkit molecule objects.

        Decision sequences are extracted if we use a fixed order.
        """
        smiles_ = []
        mols = []
        for s in self.smiles:
            m = smiles_to_standard_mol(s)
            if m is None:
                continue
            smiles_.append(s)
            mols.append(m)
        self.smiles = smiles_
        self.mols = mols

        if self.order is 'random':
            return

        self.decisions = []
        for m in self.mols:
            self.decisions.append(
                self.env.get_decision_sequence(m, list(range(m.GetNumAtoms())))
            )

    def __len__(self):
        """Get number of molecules in the dataset."""
        return len(self.mols)

    def __getitem__(self, item):
        """Get the decision sequence for generating the molecule indexed by item."""
        if self.order == 'canonical':
            return self.decisions[item]
        else:
            m = self.mols[item]
            nodes = list(range(m.GetNumAtoms()))
            random.shuffle(nodes)
            return self.env.get_decision_sequence(m, nodes)

def download_data(dataset, fname):
    """Download dataset if built-in support exists
    
    Parameters
    ----------
    dataset : str
        Dataset name
    fname : str
        Name of dataset file
    """
    if dataset not in ['ChEMBL', 'ZINC']:
        return
    
    data_path = fname
    download(_get_dgl_url(os.path.join('dataset', fname)), path=data_path)

    
def load_smiles_from_file(f_name):
    """Load dataset into a list of SMILES

    Parameters
    ----------
    f_name : str
        Path to a file of molecules, where each line of the file
        is a molecule in SMILES format.

    Returns
    -------
    smiles : list of str
        List of molecules as SMILES
    """
    with open(f_name, 'r') as f:
        smiles = f.read().splitlines()
    return smiles

def smiles_to_standard_mol(s):
    """Convert SMILES to a standard molecule.

    Parameters
    ----------
    s : str
        SMILES

    Returns
    -------
    Chem.rdchem.Mol
        Standardized molecule
    """
    mol = Chem.MolFromSmiles(s)
    return standardize_mol(mol)

def standardize_mol(mol):
    """Standardize molecule to avoid false novel molecule.

    Kekulize and deprotonate molecules to avoid false novel molecules.

    In addition to deprotonation, we also kekulize molecules to avoid
    explicit Hs in the SMILES. Otherwise we will get false novel molecules
    as well. For example, DGMG can only generate
    O=S(=O)(NC1=CC=CC(C(F)(F)F)=C1)C1=CNC=N1
    from
    O=S(=O)(Nc1cccc(C(F)(F)F)c1)c1c[nH]cn1.

    One downside is that we remove all explicit aromatic rings and to
    explicitly predict aromatic bond might make the learning easier for
    the model.
    """
    reactions = initialize_neuralization_reactions()
    Chem.Kekulize(mol, clearAromaticFlags=True)
    mol = neutralize_charges(mol, reactions)
    return mol

def initialize_neuralization_reactions():
    """Reference neuralization reactions

    Code adapted from RDKit Cookbook, by Hans de Winter.
    """
    patts = (
        # Imidazoles
        ('[n+;H]', 'n'),
        # Amines
        ('[N+;!H0]', 'N'),
        # Carboxylic acids and alcohols
        ('[$([O-]);!$([O-][#7])]', 'O'),
        # Thiols
        ('[S-;X1]', 'S'),
        # Sulfonamides
        ('[$([N-;X2]S(=O)=O)]', 'N'),
        # Enamines
        ('[$([N-;X2][C,N]=C)]', 'N'),
        # Tetrazoles
        ('[n-]', '[n]'),
        # Sulfoxides
        ('[$([S-]=O)]', 'S'),
        # Amides
        ('[$([N-]C=O)]', 'N'),
    )
    return [(Chem.MolFromSmarts(x), Chem.MolFromSmiles(y, False)) for x, y in patts]

def neutralize_charges(mol, reactions=None):
    """Deprotonation for molecules.

    Code adapted from RDKit Cookbook, by Hans de Winter.

    DGMG currently cannot generate protonated molecules.
    For example, it can only generate
    CC(C)(C)CC1CCC[NH+]1Cc1nnc(-c2ccccc2F)o1
    from
    CC(C)(C)CC1CCCN1Cc1nnc(-c2ccccc2F)o1
    even with correct decisions.

    Deprotonation is therefore an important step to avoid
    false novel molecules.

    Parameters
    ----------
    mol : Chem.rdchem.Mol
    reactions : list of 2-tuples
        Rules for deprotonation

    Returns
    -------
    mol : Chem.rdchem.Mol
        Deprotonated molecule
    """
    if reactions is None:
        reactions = initialize_neuralization_reactions()
    for i, (reactant, product) in enumerate(reactions):
        while mol.HasSubstructMatch(reactant):
            rms = AllChem.ReplaceSubstructs(mol, reactant, product)
            mol = rms[0]
    return mol

  if self.order is 'random':


In [9]:
rank = 0

dataset = MoleculeDataset(args['dataset'], args['order'], ['train', 'val'],
                          subset_id=rank, n_subsets=args['num_processes'])

Downloading ChEMBL_DGMG_train.txt from https://data.dgl.ai/dataset/ChEMBL_DGMG_train.txt...
Downloading ChEMBL_DGMG_val.txt from https://data.dgl.ai/dataset/ChEMBL_DGMG_val.txt...


In [10]:
train_loader = DataLoader(dataset.train_set, batch_size=args['batch_size'],
                              shuffle=True, collate_fn=dataset.collate)
val_loader = DataLoader(dataset.val_set, batch_size=args['batch_size'],
                            shuffle=True, collate_fn=dataset.collate)

## Initialize model

In [11]:
from dgllife.model import DGMG

model = DGMG(atom_types = dataset.atom_types,
             bond_types = dataset.bond_types,
             node_hidden_size=args['node_hidden_size'],
             num_prop_rounds=args['num_propagation_rounds'],
             dropout=args['dropout'])

In [12]:
model

DGMG(
  (graph_embed): GraphEmbed(
    (node_gating): Sequential(
      (0): Linear(in_features=128, out_features=1, bias=True)
      (1): Sigmoid()
    )
    (node_to_graph): Linear(in_features=128, out_features=256, bias=True)
  )
  (graph_prop): GraphProp(
    (message_funcs): ModuleList(
      (0): Linear(in_features=259, out_features=256, bias=True)
      (1): Linear(in_features=259, out_features=256, bias=True)
    )
    (node_update_funcs): ModuleList(
      (0): GRUCell(256, 128)
      (1): GRUCell(256, 128)
    )
  )
  (add_node_agent): AddNode(
    (add_node): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): Dropout(p=0.2, inplace=False)
      (2): Linear(in_features=256, out_features=8, bias=True)
    )
    (node_type_embed): Embedding(7, 128)
    (initialize_hv): Linear(in_features=384, out_features=128, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (add_edge_agent): AddEdge(
    (add_edge): Sequential(
      (0): Linear

In [13]:
import torch.nn as nn
from torch.optim import Adam

class Optimizer(nn.Module):
    """Wrapper for optimization

    Parameters
    ----------
    lr : float
        Initial learning rate
    optimizer
        model optimizer
    """
    def __init__(self, lr, optimizer):
        super(Optimizer, self).__init__()
        self.lr = lr
        self.optimizer = optimizer
        self._reset()

    def _reset(self):
        self.optimizer.zero_grad()

    def backward_and_step(self, loss):
        """Backward and update model.

        Parameters
        ----------
        loss : torch.tensor consisting of a float only
        """
        loss.backward()
        self.optimizer.step()
        self._reset()

    def decay_lr(self, decay_rate=0.99):
        """Decay learning rate.

        Parameters
        ----------
        decay_rate : float
            Multiply the current learning rate by the decay_rate
        """
        self.lr *= decay_rate
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

In [14]:
optimizer = Optimizer(args['lr'], Adam(model.parameters(), lr=args['lr']))

In [15]:
best_val_prob = 0


## Training

In [16]:
import torch

def evaluate(epoch, model, data_loader, printer):
    model.eval()
    batch_size = data_loader.batch_size
    total_log_prob = 0
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            log_prob = model(actions=data, compute_log_prob=True).detach()
            total_log_prob -= log_prob
            if printer is not None:
                prob = log_prob.detach().exp()
                printer.update(epoch + 1, - log_prob / batch_size, prob / batch_size)
    return total_log_prob / len(data_loader)

In [18]:
for epoch in range(1):
    model.train()
    
    for i, data in enumerate(train_loader):
        log_prob = model(actions=data, compute_log_prob=True)
        prob = log_prob.detach().exp()
        
        loss_averaged = - log_prob
        prob_averaged = prob
        optimizer.backward_and_step(loss_averaged)
        
        if (i % 100) == 0:
            print('process steps {} / {}'.format(i+1, len(train_loader)))

process steps 1 / 16353
process steps 101 / 16353
process steps 201 / 16353
process steps 301 / 16353
process steps 401 / 16353
process steps 501 / 16353
process steps 601 / 16353
process steps 701 / 16353
process steps 801 / 16353
process steps 901 / 16353
process steps 1001 / 16353
process steps 1101 / 16353
process steps 1201 / 16353
process steps 1301 / 16353
process steps 1401 / 16353
process steps 1501 / 16353
process steps 1601 / 16353
process steps 1701 / 16353
process steps 1801 / 16353
process steps 1901 / 16353
process steps 2001 / 16353
process steps 2101 / 16353
process steps 2201 / 16353
process steps 2301 / 16353
process steps 2401 / 16353
process steps 2501 / 16353
process steps 2601 / 16353
process steps 2701 / 16353
process steps 2801 / 16353
process steps 2901 / 16353
process steps 3001 / 16353
process steps 3101 / 16353
process steps 3201 / 16353
process steps 3301 / 16353
process steps 3401 / 16353
process steps 3501 / 16353
process steps 3601 / 16353
process steps

In [20]:
loss_averaged

tensor(99.3213, grad_fn=<NegBackward0>)

In [21]:
prob_averaged

tensor(7.2868e-44)

In [22]:
val_log_prob = evaluate(0, model, val_loader, None)

In [23]:
val_log_prob

tensor(90.0992)

In [25]:
(- val_log_prob).exp().item()

7.419861355615263e-40

In [26]:
torch.save({'model_state_dict': model.state_dict()}, args['checkpoint_dir'])

## multiprocessing