In [1]:
from rdkit import Chem

In [5]:
from torch.utils.data import Dataset
from dgllife.model.model_zoo.dgmg import MoleculeEnv

class MoleculeDataset(object):
    """Initialize and split the dataset
    
    Parameters
    ----------
    dataset : str
        Dataset name
    order : None (default) or str
        Order to extract a decision sequence for generating a molecule.
    modes : None (default) or list
        List of subsets to use, which can contain 'train', 'val'.
    subset_id : int
        With multiprocess training, we partition the training set into multiple subsets
        and each process will use only one subset only.
        This subset_id corresponds to subprocess id.
    n_subsets : int
        With multiprocess training, this corresponds to the number of total subprocesses.
    """
    def __init__(self, dataset, order=None, modes=None, subset_id=0, n_subsets=1):
        super(MoleculeDataset, self).__init__()
        
        if modes == None:
            modes = []
        else:
            assert order is not None, 'An order should be specified for extracting ' \
                                      'decision sequences.'
        
        assert order in ['random', 'canonical', None], 'Unexpected order option to get sequences of graph generation decisions'
        assert len(set(modes) - {'train', 'val'}) == 0, "modes should be a list, representing a subset of ['train', 'val']"
        
        self.dataset = dataset
        self.order = order
        self.modes = modes
        self.subset_id = subset_id
        self.n_subsets = n_subsets
        self._setup()
        
    def collate(self, samples):
        """PyTorch's approach to batch multiple samples.
        
        For auto-regressive generative models, we process one sample at a time.
        
        Parameters
        ----------
        samples : list
            A list of length 1 that consists of decision sequence to generate a molecule.
        
        Returns
        -------
        list
            List of 2-tuples, a decision sequence to generate a molecule.
        """
        assert len(samples) == 1
        return samples[0]
    
    def _create_a_subset(self, smiles):
        """Create a dataset from a subset of smiles.
        
        Parameters
        ----------
        smiles : list or str
            Lists of molecules in SMILES format
        """
        # divide the smiles into multiple subsets with multiprocess
        subset_size = len(smiles) // self.n_subsets
        return Subset(smiles[self.subset_id * subset_size: (self.subset_id + 1) * subset_size], self.order, self.env)
    
    def _setup(self):
        """
        1. Instantiate an MDP environment for molecule generation
        2. Download the dataset, which is a file of SMILES
        3. Create subsets for training and validation
        """
        if self.dataset == 'ChEMBL':
            # For new datasets, get_atom_and_bond_types can be used to
            # identify the atom and bond types in them.
            self.atom_types = ['O', 'Cl', 'C', 'S', 'F', 'Br', 'N']
            self.bond_types = [Chem.rdchem.BondType.SINGLE,
                               Chem.rdchem.BondType.DOUBLE,
                               Chem.rdchem.BondType.TRIPLE]
            
        elif self.dataset == 'ZINC':
            self.atom_types = ['C', 'N', 'O', 'P', 'S', 'F', 'Cl', 'Br', 'I']
            self.bond_types = [Chem.rdchem.BondType.SINGLE,
                               Chem.rdchem.BondType.DOUBLE,
                               Chem.rdchem.BondType.TRIPLE]
        
        else:
            path_to_atom_and_bond_type = '_'.join([self.dataset, 'atom_and_bond_types.pkl'])
            with open(path_to_atom_and_bond_type, 'rb') as f:
                type_info = pickle.load(f)
            self.atom_types = type_info['atom_types']
            self.bond_types = type_info['bond_types']
        
        self.env = MoleculeEnv(self.atom_types, self.bond_types)
        
        dataset_prefix = self._dataset_prefix()
        
        if 'train' in self.modes:
            fname = '_'.join([dataset_prefix, 'train.txt'])
            download_data(self.dataset, fname)
            smiles = load_smiles_from_file(fname)
            self.train_set = self._create_a_subset(smiles)
        
        if 'val' in self.modes:
            fname = '_'.join([dataset_prefix, 'val.txt'])
            download_data(self.dataset, fname)
            smiles = load_smiles_from_file(fname)
            self.val_set = self._create_a_subset(smiles)
        
    def _dataset_prefix(self):
        """Get the prefix for the data files of supported datasets.
        
        Returns
        -------
        str
            Prefix for dataset file name
        """
        return '_'.join([self.dataset, 'DGMG'])

class Subset(Dataset):
    """A set of molecules which can be used for training, validation, test.
    
    Parameters
    ----------
    smiles : list
        List of SMILES for the dataset
    order : str
        Specifies how decision sequences for molecule generation
        are obtained, can be either 'random' or 'canonical'
    env : MoleculeEnv object
        MDP environment for generating molecules
    """
    def __init__(self, smiles, order, env):
        super(Subset, self).__init__()
        self.smiles = smiles
        self.order = order
        self.env = env
        self._setup()
    
    def _setup(self):
        """Convert SMILES into rdkit molecule objects.
        
        Decision sequences are extracted if we use a fixed order.
        """
        smiles = []
        mols = []
        for s in self.smiles:
            m = smiles_to_standard_mol(s)
            if m is None:
                continue
            smiles_.append(s)
            mols.append(m)
        self.smiles = smiles_
        self.mols = mols
        
        if self.order == 'random':
            return
        
        self.decisions = []
        for m in self.mols:
            self.decisions.append(
                self.env.get_decision_sequence(m, list(range(m.GetNumAtoms())))
            )
    
    def __len__(self):
        return len(self.mols)
    
    def __getitem__(self, item):
        """Get the decision sequence for generating the molecule indexed by item.
        """
        if self.order == 'canonical':
            return self.decisions[item]
        else:
            m = self.mols[item]
            nodes = list(range(m.GetNumAtoms()))
            random.shuffle(nodes)
            return self.env.get_decision_sequence(m, nodes)