In [None]:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import torch
from pysmiles import read_smiles
import networkx as nx
import dgllife
from rdkit import Chem
import random
import dgl.nn.pytorch as dglnn
import torch.nn as nn
import dgl
from dgl import DGLGraph
from torch.nn.utils.rnn import pad_sequence



import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from torch.distributions import Categorical
import matplotlib.pyplot as plt


In [None]:
def init_weights_recursive(m):
    
    if isinstance(m,nn.Linear):
        m.apply(init_weights)
    elif isinstance(m,Linear_3):
        m.apply(init_weights)
    elif isinstance(m, Batch_Edge):
        m.NodeEmbed.apply(init_weights)
        m.Edges.apply(init_weights)
    elif isinstance(m, BaseLine):
        m.AddNode.apply(init_weights_recursive)
        m.BatchEdge.apply(init_weights_recursive)
        m.GGN.linears[0].apply(init_weights)
        m.GGN.linears[1].apply(init_weights)
        m.GGN.linears[2].apply(init_weights)
        m.GGN.linears[3].apply(init_weights)
        m.GGN.gru.apply(init_weights)
    elif isinstance(m, Spin2):
        m.Weight.apply(init_weights_recursive)
        m.AddNode.apply(init_weights_recursive)
        m.BatchEdge.apply(init_weights_recursive)
        m.GGN.linears[0].apply(init_weights)
        m.GGN.linears[1].apply(init_weights)
        m.GGN.linears[2].apply(init_weights)
        m.GGN.linears[3].apply(init_weights)
        m.GGN.gru.apply(init_weights)
    elif isinstance(m,torch.nn.modules.container.ModuleList):
        m[0].apply(init_weights)
        m[1].apply(init_weights)
        m[2].apply(init_weights)
        m[3].apply(init_weights)
        
    elif isinstance(m,dgl.nn.pytorch.conv.gatedgraphconv.GatedGraphConv):
        m.apply(init_weights)
    elif isinstance(m,torch.nn.modules.rnn.GRUCell):
        m.apply(init_weights)
    elif isinstance(m,CriticSqueeze):
        m.GGN.apply(init_weights_recursive)
        m.Dense.apply(init_weights_recursive)
    else:
        print(m,type(m))
def init_weights(m): 
    #print(m)
    try:
        nn.init.orthogonal_(m.weight.data)
    except:
        pass

In [None]:
def selfLoop(graph):
    return dgl.add_self_loop(dgl.remove_self_loop(graph))

In [None]:
def CustomAtomFeaturizer(mol):
    '''
    atom type, bond_number, formal charge, chirality, 
    number of bonded h atoms, hybridization, aromaticity,
    atomic mass scaled
    
    '''
    feats = []
    
    atom_ids = [1, 8, 6, 17, 7, 15, 35, 16, 9, 53]
    hybridization_ids = ["SP", "SP2", "SP3"]
    for atom in mol.GetAtoms():
        
        atomic_num = atom.GetAtomicNum() #9 values it seems 
        id = atom_ids.index(atomic_num)
        atomic_num_onehot = np.zeros(10)
        atomic_num_onehot[id] = 1
        
        num_bonds = atom.GetDegree() #tuple, is this the same len and degree?
        num_bonds_onehot = np.zeros(4)
        num_bonds_onehot[num_bonds-1] = 0
        
        formal_charge = np.expand_dims(np.asarray(atom.GetFormalCharge()),0) #looks like between 1 and -1
        
        hybridization = atom.GetHybridization().name # either SP, SP2, SP3
        id = hybridization.index(hybridization)
        hybridization_onehot = np.zeros(3)
        hybridization_onehot[id] = 1
        
        is_aromatic = np.expand_dims(np.asarray(atom.GetIsAromatic()).astype(int),0) # Bool
        
        valence = atom.GetTotalValence() #between 1 and 6
        valence_onehot = np.zeros(6)
        valence_onehot[valence-1] = 1
        
        mass = np.expand_dims(np.asarray(atom.GetMass())/127,0) # max val of 126.904

        
        feats.append(np.concatenate((atomic_num_onehot,num_bonds_onehot,formal_charge,
                                     hybridization_onehot,is_aromatic,valence_onehot,mass)))
        #print(np.asarray(feats).shape)
    #print(torch.tensor(feats).shape)
    return {'atomic': torch.tensor(feats).float()}


In [None]:
class BatchWrapper:
    '''
    Wrapper for deepchem batch iterator
    '''
    def __init__(self,batch_size, disk_dataset,task_ids, atom_featurizer = None, edge_featurizer = None):
        
        self.disk_dataset = disk_dataset
        self.batch_size = batch_size
        self.data_iter = disk_dataset.iterbatches(batch_size)
        
        if edge_featurizer == None:
            self.edge_featurizer = self.__base_edge_featurizer
        else:
            self.edge_featurizer = edge_featurizer
        
        if atom_featurizer == None:
            self.atom_featurizer = self.__base_atom_featurizer
        else:
            self.atom_featurizer = atom_featurizer
            
            
            
    def __complement(self, graph):
        '''
        Function for finding complement of a graph
        to be used down the line for link prediction 

        complement could be disconnected 
        '''
        u = []
        v = []
        for i in range(graph.num_nodes()):
            for j in range(i+1,graph.num_nodes()):
                if not (graph.has_edges_between(i, j)):
                    u.append(i)
                    v.append(j)

        complement = dgl.graph((u+v,v+u), num_nodes=graph.num_nodes())
        return complement
    
        
    def __base_atom_featurizer(self, mol):
        feats = []
        for atom in mol.GetAtoms():
            feats.append(atom.GetAtomicNum())
            
        return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
        
    def __my_atom_featurizer(self, mol):
        atom_list = ['N','C','O','S']
        atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6]}
        feats = []
        
        def oneHot(number, max_size, index):
            OneHot = np.zeros(max_size)
            OneHot[number-index] = 1
            return OneHot
        
        for atom in mol.GetAtoms():
            atom_type, degree, valence = self.atom_bond_dict[atom.GetSymbol()]
            
            degree_onehot = self.oneHot(atom.GetExplicitValence(),12,0)
            atom_type_onehot = self.oneHot(atom_type,4,1)
            valence_onehot = self.oneHot(valence,8,1)
            
            
            
            feats.append(degree_onehot,atom_type_onehot,valence)
            
        
        return {'atomic': torch.tensor(feats).float()}
        
        
    def __base_edge_featurizer(self ,mol, add_self_loop = False):
        feats = []
        bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                    Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
        for bond in mol.GetBonds():
            btype = bond_types.index(bond.GetBondType())
            # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
            feats.extend([btype, btype])
        return {'type': torch.tensor(feats).reshape(-1, 1).float()}    
    def produce_batch_smiles(self,smiles):
        graphs = []
        for smile in smiles:
            graph = dgllife.utils.smiles_to_bigraph(smiles[i], node_featurizer=self.atom_featurizer, 
                                                            edge_featurizer=self.edge_featurizer)
            graphs.append(graph)
        return graphs
    def produce_batch(self):
        graphs = []
        complements = []
        X, labels, something, smiles = self.data_iter.__next__()
        
        for i in range(self.batch_size):
            graph = dgllife.utils.smiles_to_bigraph(smiles[i], node_featurizer=self.atom_featurizer, 
                                                            edge_featurizer=self.edge_featurizer)
            
            complements.append(self.__complement(graph))
            graphs.append(graph)
        #graphs = dgl.batch(graphs)
        return (graphs)#, torch.from_numpy(np.asarray(labels).astype('float32'))    
        

In [None]:
atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,3], 'Se' : [12,0,6],
                               'K':[13,0,7]}

In [None]:
'C' in atom_bond_dict

In [None]:
def mol_checker(mol_rep, form = 'MOL'):
    atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,3], 'Se' : [12,0,6],
                               'K':[13,0,7]}
    
    if form == "SMILES":
        if mol_rep == 'nan':
            return False
        try:
            mol = Chem.MolFromSmiles(mol_rep)
        except:
            return False
    elif form == "GRAPH":
        mol = MolFromGraphsAro(mol_rep)
    
    elif form == 'MOL':
        mol = mol_rep
        
        
    
    aro_atoms =  mol.GetAromaticAtoms()

    if len(aro_atoms)%6 != 0 or not all([aro_atom.GetSymbol() == 'C' for aro_atom in aro_atoms]):
        return False
    
    
    for bond in mol.GetBonds():
        if bond.GetBondTypeAsDouble() >= 3:
            return False
        
    for atom in mol.GetAtoms():
        try:
            atom.UpdatePropertyCache()
        except:
            return False
        if (atom.GetSymbol() not in atom_bond_dict) or (8-atom_bond_dict[atom.GetSymbol()][-1] < atom.GetExplicitValence()):
            return False
    return True
        
        

In [None]:
def featurize_atoms(mol):
    feats = []
    for atom in mol.GetAtoms():
        feat = np.concatenate()
        feats.append(atom.GetAtomicNum())
    return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}

In [None]:

def CustomAtomFeaturizer(mol):
    '''
    atom type, bond_number, formal charge, chirality, 
    number of bonded h atoms, hybridization, aromaticity,
    atomic mass scaled
    
    '''
    feats = []
    
    atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,3], 'Se' : [12,0,6],
                               'K':[13,0,7]}
    hybridization_ids = ["SP", "SP2", "SP3"]
    for atom in mol.GetAtoms():
        
        atomic_num = atom.GetAtomicNum() #9 values it seems 
        idx = atom_bond_dict[atom.GetSymbol()][0]
        
        
        atomic_num_onehot = np.zeros(14)
        atomic_num_onehot[idx] = 1
        
        num_bonds = atom.GetDegree() #tuple, is this the same len and degree?
        num_bonds_onehot = np.zeros(4)
        num_bonds_onehot[num_bonds-1] = 0
        
        formal_charge = np.expand_dims(np.asarray(atom.GetFormalCharge()),0) #looks like between 1 and -1
        
        hybridization = atom.GetHybridization().name # either SP, SP2, SP3
        id = hybridization.index(hybridization)
        hybridization_onehot = np.zeros(3)
        hybridization_onehot[id] = 1
        
        is_aromatic = np.expand_dims(np.asarray(atom.GetIsAromatic()).astype(int),0) # Bool
        
        valence = atom.GetTotalValence() #between 1 and 6
        valence_onehot = np.zeros(6)
        valence_onehot[valence-1] = 1
        
        mass = np.expand_dims(np.asarray(atom.GetMass())/127,0) # max val of 126.904

        
        feats.append(np.concatenate((atomic_num_onehot,num_bonds_onehot,formal_charge,
                                     hybridization_onehot,is_aromatic,valence_onehot,mass)))

    return {'atomic': torch.tensor(feats).float()}
        
        

In [None]:
def smiles_to_graph(smiles):
    graphs = []
    for smile in smiles:
        try:
            graph = dgllife.utils.smiles_to_bigraph(smile, node_featurizer=my_atom_featurizer, edge_featurizer=base_edge_featurizer)
            graph = dgl.add_self_loop(dgl.remove_self_loop(graph))
            graphs.append(graph)
            
        except:
            pass
    return graphs

In [None]:
def mol_to_graph(mol):
    graph = dgllife.utils.mol_to_bigraph(mol, node_featurizer=my_atom_featurizer, edge_featurizer=base_edge_featurizer,
                                        canonical_atom_order=False)
    graph = dgl.add_self_loop(dgl.remove_self_loop(graph))
    return graph


In [None]:
def update_graph(graph):
    try:
        Mol = MolFromGraphsAro(graph)
        graph = mol_to_graph(Mol)
        selfLoop(graph)
        return graph
    except:
        return smiles_to_graph(['C-C'])[0]

In [None]:
def base_edge_featurizer(mol, add_self_loop = False):
    feats = []
    bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                 Chem.rdchem.BondType.AROMATIC,Chem. rdchem.BondType.TRIPLE]
    for bond in mol.GetBonds():
        btype = bond_types.index(bond.GetBondType())+1
        # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
        feats.extend([btype, btype])
    return {'type': torch.tensor(feats).reshape(-1, 1).float()} 

In [None]:
def my_atom_featurizer2(mol):
    atom_list = ['N','C','O','S','F','Cl','Na','P','Br','Si','B','Se','K']
    atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,3], 'Se' : [12,0,6],
                               'K':[13,0,7]}
    feats = []

    def oneHot(number, max_size, index):
        OneHot = np.zeros(max_size)
        OneHot[number-index] = 1
        return OneHot

    for atom in mol.GetAtoms():
        
        atom_type, degree, valence = atom_bond_dict[atom.GetSymbol()]
        degree_onehot = oneHot(atom.GetExplicitValence(),12,0)
        atom_type_onehot = oneHot(atom_type,14,1)
        valence_onehot = oneHot(valence,8,1)


        
        feats.append(np.concatenate((degree_onehot,atom_type_onehot,valence_onehot)))


    return {'atomic': torch.tensor(feats).float()}

In [None]:
def my_atom_featurizer(mol):
    atom_list = ['N','C','O','S','F','Cl','Na','P','Br','Si','B','Se','K']
    atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,3], 'Se' : [12,0,6],
                               'K':[13,0,7]}
    feats = []

    def oneHot(number, max_size, index):
        OneHot = np.zeros(max_size)
        OneHot[number-index] = 1
        return OneHot

    for atom in mol.GetAtoms():
        atom.UpdatePropertyCache()
        atom_type, degree, valence = atom_bond_dict[atom.GetSymbol()]
        #atom.calcExplicitValence()
        degree_onehot = oneHot(atom.GetExplicitValence(),12,0)
        atom_type_onehot = oneHot(atom_type,14,1)
        valence_onehot = oneHot(valence,8,1)


        
        feats.append(np.concatenate((degree_onehot,atom_type_onehot,valence_onehot)))


    return {'atomic': torch.tensor(feats).float()}

In [None]:
def test_atom_featurizer(mol):
    feats = []
    for atom in mol.GetAtoms():
        feats.append([1,3])
    return {'atomic': torch.tensor(feats).float()}

In [None]:
def atom_type_from_tensor(tensor):
    atom_list =['N','C','O','S','F','Cl','Na','P','Br','Si','B','Se','K']
    

In [None]:
def CustomAtomFeaturizer_full(mol):
    '''
    atom type, bond_number, formal charge, chirality, 
    number of bonded h atoms, hybridization, aromaticity,
    atomic mass scaled
    
    '''
    feats = []
    
    atom_bond_dict = {'N':[1,0,5], 'C':[2,0,4], 'O':[3,0,6], 'S':[4,0,6],
                               'F':[5,0,7], 'Cl' : [6,0,7],'Na':[7,0,7], 'P' : [8,0,5],
                               'Br':[9,0,7], 'Si' : [10,0,4],'B':[11,0,5], 'Se' : [12,0,6],
                               'K':[13,0,7]}
    
    hybridization_ids = ["SP", "SP2", "SP3"]
    
    Chem.SetHybridization(mol)
    
    for atom in mol.GetAtoms():
        atom.UpdatePropertyCache()
        atom_symbol = atom.GetSymbol()
        atom_idx = atom_bond_dict[atom_symbol][0]-1
        
        atom_oh = np.zeros(15)
        atom_oh[atom_idx] = 1
        
        max_valence = atom_bond_dict[atom_symbol][-1]
        max_valence_oh = np.zeros(8)
        max_valence_oh[max_valence] = 1
        
        degree = dgllife.utils.atom_degree_one_hot(atom)
                
        hybridization = dgllife.utils.atom_hybridization_one_hot(atom)
        
        is_aromatic = np.expand_dims(np.asarray(atom.GetIsAromatic()).astype(int),0) # Bool
        
        exp_valence = dgllife.utils.atom_explicit_valence_one_hot(atom)
        imp_valence = dgllife.utils.atom_implicit_valence_one_hot(atom)
        
        mass = np.expand_dims(np.asarray(atom.GetMass())/127,0) # max val of 126.904
        
        feat = np.concatenate((atom_oh,max_valence_oh,degree,
                                     hybridization,is_aromatic,exp_valence,imp_valence,mass))
        
        feats.append(feat)
    return {'atomic': torch.tensor(feats).float()}
        

In [None]:
def edge_featurizer_full(mol, add_self_loop = False):
        feats = []
        bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                    Chem.rdchem.BondType.AROMATIC]
        for bond in mol.GetBonds():
            btype = bond_types.index(bond.GetBondType())+1
            # One bond between atom u and v corresponds to two edges (u, v) and (v, u)
            feats.extend([btype, btype])
        return {'type': torch.tensor(feats).reshape(-1, 1).float()}

In [None]:
def mol_to_graph_full(mol):
    graph = dgllife.utils.mol_to_bigraph(mol, node_featurizer=my_atom_featurizer_full, edge_featurizer=base_edge_featurizer_full,
                                        canonical_atom_order=False)
    graph = dgl.add_self_loop(dgl.remove_self_loop(graph))
    return graph

In [None]:
class AddEdge(nn.Module):
    '''
    Add edge between last generated node and 
    a molecule in the already made graph
    
    returns a matrix which is 2 * num_nodes
    
    
    arohfarflsdjk have to fix
    '''
    
    def __init__(self, in_dim, hidden_dim):
        super(AddEdge, self).__init__()
        self.GGC_1 = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        self.GGC_2 = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        self.NodeEmbed = nn.Linear(in_dim,hidden_dim)
            
    def forward(self,g,h,e, new_node):
        #expects new_node.shape = (1, node_feats)
        new_node_embed = self.NodeEmbed(new_node)
        new_node_embed = torch.transpose(new_node_embed,-1,0)
        
        
        h1 = torch.tanh(self.GGC_1(g,h,e))
        h2 = torch.tanh(self.GGC_2(g,h,e))
        
        score_1 = torch.softmax(torch.matmul(h1,new_node_embed),dim = 0)
        score_2 = torch.softmax(torch.matmul(h2,new_node_embed),dim = 0)
        
        
        score_1 = torch.transpose(score_1,1,0) 
        score_2 = torch.transpose(score_2,1,0) 
        
        
        
        return torch.cat((score_1,score_2),dim=1)

In [None]:
class AddNode(nn.Module):
    '''
    class for adding a node and connecting it
    first pick new node, then figure out where it goes
    Rewording: GraphClassification, NodeClassificiationish
    Could just condition node classification on some embedding of the chosen atom
    '''
    def __init__(self,in_dim, hidden_dim, num_atom):
        super(AddNode, self).__init__()
        self.GGC_NewNode = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        self.Pool = dgl.nn.pytorch.glob.GlobalAttentionPooling(nn.Linear(hidden_dim,20))
        self.Dense = nn.Linear(hidden_dim, 50)
        self.NodeProposal = nn.Linear(50,num_atom)
        
        
    def forward(self,g,h,e):
        h = F.relu(self.GGC_NewNode(g,h,e))
        with g.local_scope():
            g.ndata['h'] = h
            hg = dgl.mean_nodes(g, 'h')
        out = torch.relu(self.Dense(hg))
        out = self.NodeProposal(out)
        return hg


In [None]:
class Linear_3(nn.Module):
    def __init__(self, in_dim, hidden_dim,out_dim):
        super(Linear_3, self).__init__()
        self.Dense1 = nn.Linear(in_dim,hidden_dim)
        self.Dense2 = nn.Linear(hidden_dim,hidden_dim)
        self.Dense3 = nn.Linear(hidden_dim,out_dim)
        
    def forward(self, inputs):
        out = (torch.tanh(self.Dense1(inputs)))
        out = (torch.tanh(self.Dense2(out)))
        out = self.Dense3(out)
        return out
        

In [None]:
class Linear_3_bn(nn.Module):
    def __init__(self, in_dim, hidden_dim,out_dim):
        super(Linear_3_bn, self).__init__()
        self.Dense1 = nn.Linear(in_dim,hidden_dim)
        self.Dense2 = nn.Linear(hidden_dim,hidden_dim)
        self.Dense3 = nn.Linear(hidden_dim,out_dim)
        
        self.norm1 = nn.BatchNorm1d(hidden_dim)
        self.norm2 = nn.BatchNorm1d(hidden_dim)
    def forward(self, inputs):
        out = self.norm1(torch.relu(self.Dense1(inputs)))
        out = self.norm2(torch.relu(self.Dense2(out)))
        out = self.Dense3(out)
        return out
        

In [None]:
class SingleGraphEdge(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super(SingleGraphEdge, self).__init__()
        self.NodeEmbed = nn.Linear(in_dim, hidden_dim)
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        self.Edges = Linear_3(hidden_dim*2,hidden_dim*2,2)

        
    def forward(self, graph, last_node):
        h = F.relu(self.GGN(graph,graph.ndata['atomic'],graph.edata['type'].squeeze()))
        nodeEmbed = self.NodeEmbed(last_node)
        new_node_array = nodeEmbed.repeat(graph.num_nodes(),1)
        
        
        
        stack = torch.cat((h,new_node_array),dim = 1)
        edges = self.Edges(stack) 
        edges = torch.reshape(edges, (1,graph.num_nodes()*2))


        return edges

In [None]:
class Spin(nn.Module):
    '''
    Sharing network for message passing, we can then bulk up the dense layers
    '''
    
    
    def __init__(self, in_dim, hidden_dim, num_nodes):
        super(Spin, self).__init__()
        self.hidden_dim = hidden_dim
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        
        self.Weight = Linear_3(hidden_dim*2, hidden_dim*2, 3)
        self.AddNode = Linear_3(hidden_dim, hidden_dim*2,num_nodes)
        self.NodeEmbed = nn.Linear(in_dim, hidden_dim)
        self.Edges = Linear_3(hidden_dim*2,hidden_dim*2,2)
        
        
    def forward(self, graph, last_action_node, last_node):
        
        h = F.relu(self.GGN(graph,graph.ndata['atomic'],graph.edata['type']))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        
        last_action_node = last_action_node.repeat(1,self.hidden_dim)
        #print(last_action_node.shape)
        # yay forgot to include infor about last_action_node
        catted = torch.cat((hg,last_action_node),dim =1 )
        weights = self.Weight(catted)
        addNode = self.AddNode(hg)
        
        nodeEmbed = self.NodeEmbed(last_node)
        new_node_array = nodeEmbed.repeat(graph.num_nodes(),1)
        
        
        stack = torch.cat((h,new_node_array),dim = 1)
        edges = self.Edges(stack) 
        edges = torch.reshape(edges, (1,graph.num_nodes()*2))
        
        addNode = addNode * weights[0][0]
        addEdge = edges * weights[0][1]
        terminate = torch.unsqueeze(torch.unsqueeze(weights[0][2], dim=0), dim = 0)
        
        
        return torch.softmax((torch.cat((terminate,addNode,addEdge),dim = 1)),dim = 1)
        
        
        
        
from torch.distributions import Categorical



In [None]:
class EdgeTesting(nn.Module):
    '''
    '''
    
    
    def __init__(self, in_dim, hidden_dim):
        super(EdgeTesting, self).__init__()
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)

        self.NodeEmbed = nn.Linear(in_dim, hidden_dim)
        self.Edges = Linear_3(hidden_dim*2,hidden_dim*2,2)
        
        
    def forward(self, graph, last_node_batch):
        '''
        graph is actually a batch of graphs where len(graph.batch_num_nodes()) = len(last_node_batch)
        returns just a list of edges so its unsplit up rn, bbut     
        '''
        
        
        
        h = F.relu(self.GGN(graph,graph.ndata['atomic'],graph.edata['type']))
        
        
        node_embed_batch = self.NodeEmbed(last_node_batch)
        batch_node_stacks = [] #holds number of graphs of each node embedding stacked to match node number per graph
                               #so its first dimension is the batch size, and the second 'dimension' is the number of nodes per grah
        
        '''looping over graphs'''
        for i in range(last_node_batch.shape[0]):
            #curr_stack = torch.cat(last_node)
            batch_node_stacks.append(node_embed_batch[i].repeat(graph.batch_num_nodes()[i],1))
        
                                    
        batch_node_stacks = torch.cat(batch_node_stacks, dim = 0)
        
        stack = torch.cat((h,batch_node_stacks),dim = 1)
        edges = self.Edges(stack)
        
        return edges
        return torch.softmax((torch.cat((terminate,addNode,addEdge),dim = 1)),dim = 1)
        
        

In [None]:
class Weighting(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super(Weighting, self).__init__()
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        self.Dense = nn.Linear(hidden_dim, 50)
        self.Prediction = nn.Linear(50,3)
    def forward(self, graph, last_action_node):
        g, h, e = graph, graph.ndata['atomic'], graph.edata['type']
        h = F.relu(self.GGN(graph, graph.ndata['atomic'], graph.edata['type']))        
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        out = F.relu(self.Dense(hg))
        return torch.softmax(self.Prediction(out), dim = 1)
    


In [None]:
class MoleculeGenerator(nn.Module):
    '''
    Molecular Generation Agent
    Observation s = (current graph, node_added)
    
    Returns (terminate, node, edges)   
    '''
    def __init__(self, in_dim, hidden_dim, num_atoms):
        super(MoleculeGenerator, self).__init__()
        self.policyNetwork = Weighting(in_dim,hidden_dim)
        self.AddNode = AddNode(in_dim,hidden_dim,num_atoms) 
        self.AddEdge = AddEdge(in_dim,hidden_dim)    
        
    '''
    Produces action probability distribution as well as distribution over that action
    '''
    def forward(self, graph, last_action_node, last_node):
        
        g, h, e = graph, graph.ndata['atomic'], graph.edata['type']
        
        weights = self.policyNetwork(graph,last_action_node)
        addNode = self.AddNode(g,h,e) * weights[0][0]
        addEdge = self.AddEdge(g,h,e, last_node) * weights[0][1]
        terminate = torch.unsqueeze(torch.unsqueeze(weights[0][2], dim=0), dim = 0)
        
        
        
        return torch.sigmoid(torch.cat((terminate,addNode,addEdge),dim = 1))
        


In [None]:
class GraphDiscriminator(nn.Module):
    '''
    Critic class for advantage estimation 
    '''
    def __init__(self, in_dim, hidden_dim):
        super(GraphDiscriminator, self).__init__()
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,5)
        self.Dense = Linear_3(hidden_dim,hidden_dim,1)
    def forward(self, graph):
        #print(graph.adj())
        h = F.relu(self.GGN(graph, graph.ndata['atomic'], graph.edata['type'].squeeze()))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        torch.nn.Dropout()(hg)
        out = torch.sigmoid(self.Dense(hg))
        return out
    

In [None]:
class CriticSqueeze(nn.Module):
    '''
    Critic class for advantage estimation 
    '''
    def __init__(self, in_dim, hidden_dim):
        super(CriticSqueeze, self).__init__()
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        self.Dense = Linear_3(hidden_dim+1+in_dim,hidden_dim,1)
    def forward(self, graph ,last_action_node,last_node):
        #print(graph.adj())
        h = F.relu(self.GGN(graph, graph.ndata['atomic'], graph.edata['type'].squeeze()))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        cat = torch.cat((hg,last_action_node,last_node),dim = 1)
        out = self.Dense(cat)
        return out
    

In [None]:
class Critic(nn.Module):
    '''
    Critic class for advantage estimation 
    '''
    def __init__(self, in_dim, hidden_dim):
        super(Critic, self).__init__()
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        self.Dense = Linear_3(hidden_dim+1+in_dim,hidden_dim,1)
    def forward(self, graph ,last_action_node,last_node):
        #print(graph.adj())
        h = F.relu(self.GGN(graph, graph.ndata['atomic'], graph.edata['type']))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        cat = torch.cat((hg,last_action_node,last_node),dim = 1)
        out = self.Dense(cat)
        return out
    
                   

In [None]:
class CriticTest(nn.Module):
    '''
    Critic class for advantage estimation 
    '''
    def __init__(self, in_dim, hidden_dim):
        super(CriticTest, self).__init__()
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,2)
        self.Dense = nn.Linear(hidden_dim,1)
    def forward(self, graph):
        #print(graph.adj())
        h = F.relu(self.GGN(graph, graph.ndata['atomic'], graph.edata['type']))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        out = self.Dense(hg)
        return out

In [None]:
class CriticBN(nn.Module):
    '''
    Critic class for advantage estimation 
    '''
    def __init__(self, in_dim, hidden_dim):
        super(CriticBN, self).__init__()
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        self.Dense = Linear_3_bn(hidden_dim+1+in_dim,hidden_dim,1)
    def forward(self, graph,last_action_node,last_node):
        h = F.relu(self.GGN(graph, graph.ndata['atomic'], graph.edata['type']))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        cat = torch.cat((hg,last_action_node,last_node),dim = 1)
        out = self.Dense(cat)
        return out
                   

In [None]:
def FeatToAtom(feat):
    atom_list = ['N','C','O','S']
    atom_type_slice = feat[12:12+4]
        
    atom_type_idx = np.where(atom_type_slice.cpu()==1)
    atom_type_idx = atom_type_idx[0][0]
    atom_type = atom_list[atom_type_idx]
    
    return atom_type

def MolFromGraphs(graph):

    # create empty editable mol object
    feat_list = graph.ndata['atomic']
    
    
    
    node_list = []
    for feat in range(feat_list.shape[0]):
        node_list.append(FeatToAtom(feat_list[feat]))

    mol = Chem.RWMol()

    # add atoms to mol and keep track of index
    node_to_idx = {}
    for i in range(len(node_list)):
        a = Chem.Atom(node_list[i])
        molIdx = mol.AddAtom(a)
        node_to_idx[i] = molIdx

    for u in range(len(node_list)-1):
        for v in range(u+1, len(node_list)):
            if graph.has_edges_between(u,v):
                bond_type = int(graph.edges[u, v][0]['type'].cpu().numpy()[0][0])
                if bond_type == 1:
                    bond = Chem.rdchem.BondType.SINGLE
                elif bond_type == 2:
                    bond = Chem.rdchem.BondType.DOUBLE
                else:
                    print("graelfsdjhkarg")
                mol.AddBond(u,v,bond)
    mol = mol.GetMol()            

    return mol

In [None]:
def FeatToAtomAro(feat):
    atom_list = ['N','C','O','S','F','Cl','Na','P','Br','Si','B','Se','K']
    atom_type_slice = feat[12:12+13]
        
    atom_type_idx = np.where(atom_type_slice.cpu()==1)
    atom_type_idx = atom_type_idx[0][0]
    atom_type = atom_list[atom_type_idx]
    
    return atom_type

def MolFromGraphsAro(graph):

    # create empty editable mol object
    feat_list = graph.ndata['atomic']
    
    
    
    node_list = []
    for feat in range(feat_list.shape[0]):
        node_list.append(FeatToAtomAro(feat_list[feat]))

    mol = Chem.RWMol()

    # add atoms to mol and keep track of index
    node_to_idx = {}
    for i in range(len(node_list)):
        a = Chem.Atom(node_list[i])
        molIdx = mol.AddAtom(a)
        node_to_idx[i] = molIdx

    for u in range(len(node_list)-1):
        for v in range(u+1, len(node_list)):
            if graph.has_edges_between(u,v):
                bond_type = int(graph.edges[u, v][0]['type'].cpu().numpy()[0][0])
                if bond_type == 1:
                    bond = Chem.rdchem.BondType.SINGLE
                elif bond_type == 2:
                    bond = Chem.rdchem.BondType.DOUBLE
                elif bond_type == 3:
                    bond = Chem.rdchem.BondType.AROMATIC
                else:
                    print("graelfsdjhkarg",bond_type,"asdf")
                mol.AddBond(u,v,bond)
    mol = mol.GetMol()            

    return mol

In [None]:
class Any_Edge(nn.Module):
    
    def __init__(self, in_dim, hidden_dim):
        super(Any_Edge, self).__init__()
        self.Edges = Linear_3(hidden_dim,hidden_dim*2,2)
        self.in_dim = in_dim
        
    def forward(self, graphs, h):
        '''
        graph is actually a batch of graphs where len(graph.batch_num_nodes()) = len(last_node_batch)
        returns just a list of edges so its unsplit up rn, bbut     
        '''
        edges_per_graph = []
        
        batch_node_stacks = [] #holds number of graphs of each node embedding stacked to match node number per graph
                               #so its first dimension is the batch size, and the second 'dimension' is the number of nodes per grah
        
        nodes_repeat = []
        nodes_stack = []
        
        graphs_batch_num_nodes = graphs.batch_num_nodes()
        
        tick = 0
        for graph_idx in range(len(graphs_batch_num_nodes)):
            
            graph_nodes = h[tick:tick+graphs_batch_num_nodes[graph_idx]]
            
            graph_nodes_repeat = graph_nodes.repeat(graphs_batch_num_nodes[graph_idx],1)
            graph_node_stack = graph_nodes.repeat_interleave(graphs_batch_num_nodes[graph_idx],0)
            
            
            nodes_repeat.append(graph_nodes_repeat)
            nodes_stack.append(graph_node_stack)
            
            tick += graphs_batch_num_nodes[graph_idx]
        

        
                                    
        nodes_repeat = torch.cat(nodes_repeat, dim = 0)
        nodes_stack = torch.cat(nodes_stack, dim = 0)
        
        
        dots = nodes_repeat*nodes_stack
        edges = self.Edges(dots)
        
        pp_graph = []
        tick = 0
        
        for graph_num_nodes in (graphs_batch_num_nodes):
            graph_e_pred = edges[tick:tick+(graph_num_nodes**2)]

            pp_graph.append(graph_e_pred.flatten())
            tick += (graph_num_nodes**2)
            
        return pad_sequence(pp_graph, batch_first = True, padding_value=-10000).flatten(1,-1)
    
    

In [None]:
class Spin3(nn.Module):
    '''
    Sharing network for message passing, we can then bulk up the dense layers
    '''
    
    
    def __init__(self, in_dim, hidden_dim, num_nodes):
        super(Spin3, self).__init__()
        self.hidden_dim = hidden_dim
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,4,4)
        
        self.Weight = Linear_3(hidden_dim*2, hidden_dim*2, 3)
        self.AddNode = Linear_3(hidden_dim, hidden_dim*2,num_nodes)
        self.AnyEdge = Any_Edge(in_dim,hidden_dim)
        
    def forward(self, graph, last_action_node, softmax = True):
#         print(type(graph),"type)graph")
        out = []
        h = F.relu(self.GGN(graph,graph.ndata['atomic'],graph.edata['type'].squeeze()))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        
        
                
        
        #might need fixing here      
        last_action_node = last_action_node.repeat(1,self.hidden_dim)
        catted = torch.cat((hg,last_action_node),dim =1 )
        #weights = torch.softmax(self.Weight(catted),dim = 1)
        #weights = (self.Weight(catted))
        #weights = torch.sigmoid(self.Weight(catted))
        #print(weights[0:6])
        #weights = [[.3],[.3],[.3]]
        
        addNode = self.AddNode(hg)
        edges = self.AnyEdge(graph,h)
        #print(edges[0])
        
        
        addNode = addNode * weights[0][0]
        edges = edges * weights[0][1]
        #print(edges[0])
        terminate = weights[:,2:3]
        
        
        out = torch.cat((terminate,addNode,edges),dim = 1)
        if softmax == False:
            return out
        
        return torch.softmax((torch.cat((terminate,addNode,edges),dim = 1)),dim = 1)

In [None]:
class BaseLine(nn.Module):
    '''Crappy base line to check improvements against'''
    def __init__(self,in_dim,hidden_dim, num_nodes):
        super(BaseLine, self).__init__()
        self.in_dim = in_dim
        self.num_nodes = num_nodes
        self.GGN = dgl.nn.GatedGraphConv(in_dim, hidden_dim,5,4)
        self.AddNode = Linear_3(hidden_dim,hidden_dim,num_nodes)
        self.BatchEdge = Batch_Edge(in_dim,hidden_dim)
    
    def forward(self,graph,last_action_node, last_node, mask = False, softmax = True):
        out = []
        h = F.relu(self.GGN(graph,graph.ndata['atomic'],graph.edata['type'].squeeze()))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h') 
            
        
        addNode = self.AddNode(hg)

        if mask:
            node_mask = torch.unsqueeze(torch.cat((torch.zeros(1), torch.ones(self.num_nodes-1))),dim=0).to(device)
            mask = last_action_node*(node_mask*-100000)
            addNode += mask
        
        
        edges = self.BatchEdge(graph,last_node,h)     
        out = torch.cat((addNode,edges), dim = 1)
        if softmax:
            return torch.softmax(out,dim = 1)
        else:
            return out
        
        

In [None]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# #batch = dgl.load_graphs('./graph_decomp/chunk_3')
# bl = BaseLine(54,300,18)
# spin = Spin2(54,300,17)
# graphs = batch[0]
# last_action_nodes = batch[1]['last_action']
# last_atom_feats = batch[1]['last_atom_feats']
# actions = batch[1]['actions']
# i = 64
# # print(graphs[i],last_action_nodes[i],last_atom_feats[i])
# # print(actions[i])

# p = bl(graphs[i],torch.unsqueeze(last_action_nodes[i],dim=0),torch.unsqueeze(last_atom_feats[i],dim=0))
# s = spin(graphs[i],torch.unsqueeze(last_action_nodes[i],dim=0),torch.unsqueeze(last_atom_feats[i],dim=0))
# p.shape, s.shape


In [None]:
class Spin2(torch.nn.Module):
    '''
    Sharing network for message passing, we can then bulk up the dense layers
    '''
    
    
    def __init__(self, in_dim, hidden_dim, num_nodes):
        super(Spin2, self).__init__()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        
        self.hidden_dim = hidden_dim
        self.GGN = dgl.nn.GatedGraphConv(in_dim,hidden_dim,5,4)
        
        self.Weight = Linear_3(hidden_dim*2, hidden_dim*2, 3)
        self.AddNode = Linear_3(hidden_dim, hidden_dim*2,num_nodes)
        self.BatchEdge = Batch_Edge(in_dim,hidden_dim)
        
    def forward(self, graph, last_action_node, last_node,softmax = True):
        out = []
        h = F.relu(self.GGN(graph,graph.ndata['atomic'],graph.edata['type'].squeeze()))
        with graph.local_scope():
            graph.ndata['h'] = h
            hg = dgl.mean_nodes(graph, 'h')
        
                        
        last_action_node = last_action_node.repeat(1,self.hidden_dim)
        catted = torch.cat((hg,last_action_node),dim =1 )
        weights = self.Weight(catted)
        weights = torch.sigmoid(weights)
        
        addNode = self.AddNode(hg)
        edges = self.BatchEdge(graph,last_node,h)
        
        addNode = addNode * weights[0][0]
        edges = edges * weights[0][1]
        terminate = weights[:,2:3]/2
        
        
        out = torch.cat((terminate,addNode,edges),dim = 1)
        if softmax:
            return torch.softmax(torch.cat((terminate,addNode,edges),dim = 1),dim = 1)
        else:
            return torch.cat((terminate,addNode,edges),dim = 1)

In [None]:
class Batch_Edge(nn.Module):
    
    def __init__(self, in_dim, hidden_dim):
        super(Batch_Edge, self).__init__()
        self.NodeEmbed = nn.Linear(in_dim, hidden_dim)
        self.Edges = Linear_3(hidden_dim*2,hidden_dim*2,2)
        
    def forward(self, graphs, last_node_batch, h):
        '''
        graph is actually a batch of graphs where len(graph.batch_num_nodes()) = len(last_node_batch)
        returns just a list of edges so its unsplit up rn, bbut     
        '''
        edges_per_graph = []
        
        node_embed_batch = self.NodeEmbed(last_node_batch)
        batch_node_stacks = [] #holds number of graphs of each node embedding stacked to match node number per graph
                               #so its first dimension is the batch size, and the second 'dimension' is the number of nodes per grah
        
        
        '''looping over graphs'''
        for i in range(last_node_batch.shape[0]):
            batch_node_stacks.append(node_embed_batch[i].repeat(graphs.batch_num_nodes()[i],1))
        
                                    
        batch_node_stacks = torch.cat(batch_node_stacks, dim = 0)
        
        stack = torch.cat((h,batch_node_stacks),dim = 1)
        edges = self.Edges(stack)
        with graphs.local_scope():
            graphs.ndata['bond_pred'] = edges
            graphs = dgl.unbatch(graphs)
            for graph in graphs:
                edges_per_graph.append(graph.ndata['bond_pred'])
                
        return pad_sequence(edges_per_graph, batch_first = True, padding_value=-10000).flatten(1,-1)

In [None]:
torch.full((4,1),-10000)*0