In [None]:

import numpy as np

import numpy as np
import torch
from pysmiles import read_smiles
import networkx as nx
import dgllife
from rdkit import Chem
import random
import dgl.nn.pytorch as dglnn
import torch.nn as nn
import dgl
from dgl import DGLGraph
from torch.nn.utils.rnn import pad_sequence



import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import pandas
%run MoleculeGenerator2.ipynb
%run Discrim.ipynb





In [None]:
torch.Tensor([2])

In [None]:
class Env_2(object):
    '''
    Class leveraging rdkit Updates to Molecules
    '''
    def __init__(self, num_atom_types, num_node_feats_internal):
        
        self.atom_list = ['N','C','O','S','F','Cl','Na','P','Br','Si','B','Se','K', 'Aro']
        self.atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,3], 'Se' : [12,0,6],
                               'K':[13,0,7]}
        
        
        self.stateSpaceGraph = dgl.graph(([], []))
        self.stateSpaceGraph.ndata['atomic'] = torch.zeros(0, 1)
        self.stateSpaceGraph.edata['type'] = torch.zeros(0,1)
        
        
        self.num_atom_types = num_atom_types
        self.num_node_feats_internal = num_node_feats_internal

        
        
        self.Done = False
        self.last_action_node = torch.zeros((1,1))
        self.last_atom_features = torch.zeros(1,self.num_node_feats_internal)
        self.reward = 0
        self.just_added_node = False
        
        self.log = ""
        
    @property
    def stateSpaceGraph_full(self):
        return MolFromGraphsAro1(self.stateSpaceGraph)

    @property
    def n_nodes(self):
        return self.stateSpaceGraph.number_of_nodes()
    
    def __len__(self):
        return self.stateSpaceGraph.num_nodes()

    def mol_checker(self,graph):
        mol = self.MolFromGraphsAro1(graph)
        aro_atoms =  mol.GetAromaticAtoms()

        for atom in mol.GetAtoms():
            atom.UpdatePropertyCache()
            if 8-atom_bond_dict[atom.GetSymbol()][-1] < atom.GetExplicitValence():
                return False
        return True



    def featurize(self, graph):
        return mol_to_graph_full(graph)

    def reset(self):        
        self.just_added_node = False
        self.reward = 0
        self.log = ""


        self.last_action_node = torch.zeros((1,1))
        self.stateSpaceGraph = dgl.graph(([], []))
        self.stateSpaceGraph.ndata['atomic'] = torch.zeros(0, self.num_node_feats_internal)
        self.stateSpaceGraph.edata['type'] = torch.zeros(0,1)
        
        
        start_ring = random.randint(1,5) == 2
        if True:
            self.addAromaticRing()
        
        else:
            self.addNode('C')
            self.last_action_node == 0
            self.addNode('C')
            self.addEdge(1,0)
        self.reward = 0 
        
        
    def addAromaticRing(self):
        curr_num_nodes = self.n_nodes

        idx = self.atom_list.index(node_type)   
        feats = torch.Tensor(idx).float()

        for i in range(6):
            self.stateSpaceGraph.add_nodes(1)
            self.stateSpaceGraph.ndata['atomic'][-1] = feats

        for i in range(6):
            j = i+curr_num_nodes
            self.stateSpaceGraph.add_edges([curr_num_nodes+i,curr_num_nodes+((i+1)%6)],[curr_num_nodes+((i+1)%6),curr_num_nodes+i], {'type': torch.tensor([[3], [3]],dtype=torch.float)})
            self.stateSpaceGraph.edata['type'][-1]            


    def addNode(self,node_type):
        if self.last_action_node == 1:
            self.reward -=.1
        else:
            self.reward+=.1
            self.last_action_node = torch.ones((1,1))


            if node_type == 'Aro':
                self.addAromaticRing()

            else:
                idx = self.atom_list.index(node_type)
                self.stateSpaceGraph.add_nodes(1)
                self.stateSpaceGraph.ndata['atomic'][-1] = torch.Tensor([idx]).float()


    def addEdge(self, edge_type, atom_id, give_reward=True):
        failed = False
        if atom_id == self.n_nodes-1:
            self.log += ('self loop attempt \n')
            failed = True
            if give_reward:
                self.reward -=.1
        elif self.stateSpaceGraph.has_edges_between(self.n_nodes-1,atom_id):
            self.log += ('edge already present \n')
            failed = True
            if give_reward:
                self.reward -=.1

        elif (3 in self.stateSpaceGraph.edata['type'][self.stateSpaceGraph.in_edges(atom_id,'eid')]) and (3 in self.stateSpaceGraph.edata['type'][self.stateSpaceGraph.in_edges(self.__len__()-1,'eid')]):
            self.log += 'bonds between two aromatic atoms'
            failed = True
            if give_reward:
                self.reward -=.1

        if failed == False:
            graph_clone = self.stateSpaceGraph.clone()
            graph_clone.add_edges([atom_id, self.n_nodes-1],[self.n_nodes-1,atom_id], {'type': torch.tensor([[edge_type], [edge_type]],dtype=torch.float)})
            full_graph = self.fullFeaturize(graph_clone)

            if self.mol_checker(full_graph):
                self.reward += .1
                self.stateSpaceGraph = graph_clone
            else:
                self.reward -=.1


    def step(self, action, final_step = False, verbose = False):
        reward_dict_info = {'model_reward':0, 'property_reward':0, 'step_reward':0} #info for different rewards for logging

        self.reward = 0
        self.log = ""
        terminated = False

        if action == 0:
            self.log += 'terminating \n' 
            self.Done = True        
            terminated = True

        elif action > 0 and action < self.num_atom_types+1:
            self.log += ("------adding "+ self.atom_list[action-1] +" atom------ \n")
            self.addNode(self.atom_list[action-1])


        elif action < 1 + self.num_atom_types + (2*self.__len__()):           
            destination_atom_idx = (action - self.num_atom_types - 1) // 2
            edge_type = (action - self.num_atom_types - 1)%2 + 1


            self.log +=("------attempting to add " + str(edge_type) + " bond between last atom added and atom "+ str(destination_atom_idx) +"------ \n")
            self.addEdge(edge_type, destination_atom_idx)


        mol = MolFromGraphsAro(self.stateSpaceGraph)
        full_graph = self.featurize(mol)



        full_graph = selfLoop(full_graph)


        self.last_atom_features = torch.unsqueeze(full_graph.ndata['atomic'][-1], dim = 0)

        obs = full_graph.clone(), self.last_action_node, self.last_atom_features 

        return obs, self.reward, self.Done, reward_dict_info

In [None]:
env = Env_2(13,34)
    

In [None]:
env.reset()

In [None]:
env.stateSpaceGraph

In [None]:
def CustomAtomFeaturizer_full(mol):
    '''
    atom type, bond_number, formal charge, chirality, 
    number of bonded h atoms, hybridization, aromaticity,
    atomic mass scaled
    
    '''
    feats = []
    
    atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,3], 'Se' : [12,0,6],
                               'K':[13,0,7]}
    
    hybridization_ids = ["SP", "SP2", "SP3"]
    
    Chem.SetHybridization(mol)
    
    for atom in mol.GetAtoms():
        atom.UpdatePropertyCache()
        atom_symbol = atom.GetSymbol()
        atom_idx = atom_bond_dict[atom_symbol][0]-1
        
        atom_oh = np.zeros(15)
        atom_oh[atom_idx] = 1
        
        max_valence = atom_bond_dict[atom_symbol][-1]
        max_valence_oh = np.zeros(8)
        max_valence_oh[max_valence] = 1
        
        degree = dgllife.utils.atom_degree_one_hot(atom)
                
        hybridization = dgllife.utils.atom_hybridization_one_hot(atom)
        
        is_aromatic = np.expand_dims(np.asarray(atom.GetIsAromatic()).astype(int),0) # Bool
        
        exp_valence = dgllife.utils.atom_explicit_valence_one_hot(atom)
        imp_valence = dgllife.utils.atom_implicit_valence_one_hot(atom)
        
        mass = np.expand_dims(np.asarray(atom.GetMass())/127,0) # max val of 126.904
        
        feat = np.concatenate((atom_oh,max_valence_oh,degree,
                                     hybridization,is_aromatic,exp_valence,imp_valence,mass))
        
        print(feat.shape)
        feats.append(feat)
    return {'atomic': torch.tensor(feats).float()}
        
        

In [None]:
def edge_featurizer(mol, add_self_loop = False):
        feats = []
        bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                    Chem.rdchem.BondType.AROMATIC]
        for bond in mol.GetBonds():
            btype = bond_types.index(bond.GetBondType())+1
            # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
            feats.extend([btype, btype])
        return {'type': torch.tensor(feats).reshape(-1, 1).float()}

In [None]:
graph = dgllife.utils.mol_to_bigraph(mol,node_featurizer=CustomAtomFeaturizer_full,edge_featurizer=edge_featurizer)

In [None]:
mol = Chem.MolFromSmiles(smiles_values[1])

In [None]:
def MolFromGraphsAro1(graph):
    feat_list = graph.ndata['atomic']
    node_list = []
    atom_list = atom_list = ['N','C','O','S','F','Cl','Na','P','Br','Si','B','Se','K', 'Aro']
    for feat_idx in range(feat_list.shape[0]):
        node_list.append(atom_list[feat_list[feat_idx]])

    mol = Chem.RWMol()
    node_to_idx = {}
    for i in range(len(node_list)):
        a = Chem.Atom(node_list[i])
        molIdx = mol.AddAtom(a)
        node_to_idx[i] = molIdx

    for u in range(len(node_list)-1):
        for v in range(u+1, len(node_list)):
            if graph.has_edges_between(u,v):
                bond_type = int(graph.edges[u, v][0]['type'].numpy()[0][0])
                if bond_type == 1:
                    bond = Chem.rdchem.BondType.SINGLE
                elif bond_type == 2:
                    bond = Chem.rdchem.BondType.DOUBLE
                elif bond_type == 3:
                    bond = Chem.rdchem.BondType.AROMATIC
                else:
                    print("graelfsdjhkarg")
                mol.AddBond(u,v,bond)
    mol = mol.GetMol()
    return Mol

In [None]:
MolFromGraphsAro1(graph)

In [None]:
CustomAtomFeaturizer(mol)

In [None]:
def mol_to_graph_full(mol):
    graph = dgllife.utils.mol_to_bigraph(mol, node_featurizer=my_atom_featurizer, edge_featurizer=base_edge_featurizer,
                                        canonical_atom_order=False)
    graph = dgl.add_self_loop(dgl.remove_self_loop(graph))
    return graph
    

In [None]:
from dgl import data

In [None]:
from dgllife import data

In [None]:
data.alchemy.