In [12]:
if __name__ == "__main__":
    %run Discrim.ipynb
    %run ChemEnv.ipynb
    %run MolUtils.ipynb
    


In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [41]:
class ChemEnvTest():
    def __init__(self):
        chem_env_kwargs = {'max_nodes' : 12, 
                   'num_atom_types' : 17, 
                   'num_node_feats' : 54,
                   'num_edge_types' : 3, 
                   'bond_padding' : 12, 
                   'mol_featurizer': mol_to_graph_full, 
                   'RewardModule' : None, 
                   'writer' : None,
                   'num_chunks':1}
        
        self.__env = _chemEnvTest(**chem_env_kwargs)
        
    def Clear(self):
        self.__env.clear()
        
    def Step(self,action, verbose = True):
        self.__env.step(action,verbose)
        
    def AssignMol(self,mol):
        self.Clear()
        mol = Chem.RWMol(mol)
        self.__env.last_action_node == 0
        self.__env.StateSpace = mol
    
    def GetMol(self):
        return self.__env.StateSpace

In [32]:
class _chemEnvTest(object):
    '''
    Stripped ChemEnv for testing
    
    
    '''
    def __init__(self, num_chunks, max_nodes, num_atom_types, num_node_feats, num_edge_types, bond_padding, RewardModule, mol_featurizer, writer):
        
        self.num_chunks = num_chunks
        self.curr_chunk = random.randint(0,num_chunks)
        self.path = './graph_decomp/chunk_'
        self.reset_state_graphs = dgl.load_graphs(self.path + str(self.curr_chunk))[0]
        
        '''
        ENV_Atoms
        '''
        self.mol_featurizer = mol_featurizer
        self.atom_list = ['N','C','O','S','F','Cl','Na','P','Br','Si','B','Se','K', 'Benz','Pyri','Pyrr']
        self.atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,5], 'Se' : [12,0,6],
                               'K':[13,0,7]}
        
        '''
        ENV_Attributes
        '''
        self.max_nodes = max_nodes
        self.bond_padding = bond_padding 
        self.num_atom_types = self.atom_list.__len__()
        self.batch_dim = 1
        
        self.StateSpace = Chem.RWMol()
        
        
        
        '''ENV_State'''
        self.Done = False
        self.last_action_node = torch.zeros((1,1)).to(device)
        self.num_node_feats = num_node_feats
        self.last_atom_features = torch.zeros(1,self.num_node_feats).to(device)
        self.just_added_node = False
        self.reward = 0

        
    def __len__(self):
        return self.StateSpace.GetNumAtoms()
    
    @property
    def n_nodes(self):
        return self.StateSpace.GetNumAtoms()
             
            
    def clear(self):
        self.StateSpace = Chem.RWMol()
        self.last_atom_features = torch.zeros(1,self.num_node_feats).to(device)
        
    def addStructure(self,mol2):
        mol1 = self.StateSpace
        add_dif = mol1.GetNumAtoms()
        for atom in mol2.GetAtoms():
            new_atom = Chem.Atom(atom.GetSymbol())
            mol1.AddAtom(new_atom)
        for bond in mol2.GetBonds():
            a1 = bond.GetBeginAtom().GetIdx()
            a2 = bond.GetEndAtom().GetIdx()
            bt = bond.GetBondType()
            mol1.AddBond(add_dif + a1,add_dif+ a2, bt)
            mol1.UpdatePropertyCache()
            
            
    def addBenzine(self):
        mol = Chem.MolFromSmiles('c1ccccc1')
        self.addStructure(mol)
        
        
    def addPyridine(self):
        mol = Chem.MolFromSmiles('N1=CC=CC=C1')
        mol = permute_mol(mol,permute_rot(mol.GetNumAtoms()))
        SanitizeNoKEKU(mol)
        self.addStructure(mol)
        
    def addPyrrole(self):
        mol = Chem.MolFromSmiles('N1C=CC=C1')
        mol = permuteAtomToEnd(mol,0)
        self.addStructure(mol)
        
    def addNaptholene(self):
        mol = Chem.MolFromSmiles('C1=CC=C2C=CC=CC2=C1')
        self.addStructure(mol)
        
    def getResetMol(self):    
        if random.randint(0,9) == 1:
            self.curr_chunk = random.randint(0,self.num_chunks)
            self.reset_state_graphs = dgl.load_graphs(self.path + str(self.curr_chunk))[0]
        
        i = 0
        found = False
        while (not found) or i > 10:
            idx = random.randint(0,len(self.reset_state_graphs))
            temp_mol = Chem.RWMol(MolFromGraphsFULL(self.reset_state_graphs[idx]))
            self.removeUnconnected(temp_mol,sanitize = False)
            
            try:
                Chem.SanitizeMol(temp_mol)
                found = True
            except:
                found = False
            
            i += 1
            
        if found:
            new_mol = Chem.RWMol(MolFromGraphsFULL(self.reset_state_graphs[idx]))
            SanitizeNoKEKU(new_mol)
            return new_mol
        else:
            return(Chem.SanitizeMol(Chem.RWMol(Chem.MolToSmiles("CC"))))
                
        
    def reset(self): 
        self.StateSpace = self.getResetMol()
        self.just_added_node = False
        self.last_action_node = torch.zeros((self.batch_dim,1)).to(device)
        self.last_atom_features = torch.zeros(1, self.num_node_feats).to(device)
        
    

    
    def addNode(self, node_choice, give_reward = True):  
        #####figure out last features 
        if self.last_action_node == 1:
            if give_reward:
                self.reward -= .1
            return
        
        
        self.last_action_node = torch.ones((1,1)).to(device)
        if give_reward:
            self.reward+=.1
        if node_choice == 'Benz':
            self.addBenzine()
        elif node_choice == 'Pyri':
            self.addPyridine()
        elif node_choice == 'Pyrr':
            self.addPyrrole()
        else:
            self.StateSpace.AddAtom(Chem.Atom(node_choice))
            
            
        
    def addEdge(self, edge_type, atom_id, give_reward = True):
        '''
        Method for calculating new graph after adding an edge between the last node added and nodes[atom_id]
        returns nothing as we mutate in place
        '''
     
        try:
            atom_id = (atom_id).item()
        except:
            pass
            
        if edge_type == 1:
            bond = Chem.rdchem.BondType.SINGLE
        elif edge_type == 2:
            bond = Chem.rdchem.BondType.DOUBLE

        mol_copy = permute_mol(self.StateSpace, lambda x: x)
        mol_copy.UpdatePropertyCache()
        SanitizeNoKEKU(mol_copy)
        
        addable = True
        
        connected = False
        good_keku = True 
        good_valence = True
        unknown_pass = True
        
        #perform checks

        #add bond to complete the rest of the checks
        try:
            mol_copy.AddBond(atom_id,self.StateSpace.GetNumAtoms()-1,bond)
        except:
            addable = False
            
            
        #check is connected
        try:
            if nx.is_connected(mol_to_graph(mol_copy).to_networkx().to_undirected()):
                connected = True
        except:
            unknown_pass = False
            

        #check kekulization    
        try:
            Chem.Kekulize(mol_copy)
        except Chem.rdchem.KekulizeException:
            good_keku = False

        #atom valence
        try:
            SanitizeNoKEKU(mol_copy)
        except Chem.rdchem.AtomValenceException:
            self.log += 'valence overload \n' 
            good_valence = False   


        if all([addable, connected,good_keku,good_valence,unknown_pass]):
            success = True
        else:
            success = False
        
        
        if success:
            self.StateSpace.AddBond(atom_id,self.StateSpace.GetNumAtoms()-1,bond)
            self.StateSpace.UpdatePropertyCache()
            Chem.SanitizeMol(self.StateSpace)
            
            self.reward+=.1
            
            self.last_action_node = torch.zeros((self.batch_dim,1))
            self.log += ('edge added \n')
        else:
            self.reward-=.1
     
    def removeUnconnected(self,mol, sanitize = True):
        if mol.GetAtomWithIdx(mol.GetNumAtoms()-1).GetDegree() == 0:
            mol.RemoveAtom(mol.GetNumAtoms()-1)
            
        else:
            if mol.GetNumAtoms() > 6:
                if all([mol.GetAtomWithIdx(i).GetDegree() == 2 for i in range(mol.GetNumAtoms()-6,mol.GetNumAtoms())]):
                    for i in range(self.n_nodes-6,self.n_nodes):
                        mol.RemoveAtom(self.n_nodes-1)
                        
                elif all([mol.GetAtomWithIdx(i).GetDegree() == 2 for i in range(mol.GetNumAtoms()-5,mol.GetNumAtoms())]):
                    for i in range(self.n_nodes-5,self.n_nodes):
                        mol.RemoveAtom(self.n_nodes-1)
            
        self.StateSpace.UpdatePropertyCache()
        if sanitize:
            Chem.SanitizeMol(self.StateSpace)
    
    def checkValence(self, atom_id, edge_type):
        atom = self.StateSpace.GetAtomWithIdx(atom_id)
        currValence = atom.GetExplicitValence()
        maxValence = 8 - self.atom_bond_dict[atom.GetSymbol()][-1]      
        return currValence + edge_type > maxValence                
    
    def modelRewards(self, mol): 
        return self.RewardModule.GiveReward(mol)
    
    def graphObs(self):
        self.StateSpace.UpdatePropertyCache()
        return dgl.add_self_loop(dgl.remove_self_loop(self.mol_featurizer(self.StateSpace))).to(device)
    
    
    def step(self, action, final_step = False, verbose = False):
        '''
        Function for a single step in our trajectory
        Expect action to be an int indexing
        [terminate, add_atom1,...,add_atomN, node1_edge, ... ,nodeN_edge]
        '''
        self.TempSmiles = Chem.MolToSmiles(self.StateSpace)
        
        self.log = ""
        terminated = False
        #print(action)
                
        
        #case for termination
        if action == 0:
            self.log += 'terminating \n' 
            self.Done = True        
            terminated = True
            '''final rewards '''
            
            
                
        #case for adding a node
        elif action > 0 and action < self.num_atom_types+1:
            self.log += ("------adding "+ self.atom_list[action-1] +" atom------ \n")
            self.addNode(self.atom_list[action-1])
            SanitizeNoKEKU(self.StateSpace)
            
                
                
        #case for edge addition
        elif action < 1 + self.num_atom_types + (2*self.__len__()):
                       
            destination_atom_idx = (action - len(self.atom_list) - 1) // 2
            edge_type = (action - self.num_atom_types - 1)%2 + 1
            
            self.log +=("------attempting to add " + str(edge_type) + " bond between last atom added and atom "+ str(destination_atom_idx) +"------ \n")
            self.addEdge(edge_type,destination_atom_idx)
        else:
            self.log += "------action id is too large for state space------ \n"


            
            