In [1]:
import networkx as nx
import json
import copy

In [2]:
def load_data_from_file(filename):
    with open(filename, "r") as file_handle:
        string_dict = json.load(file_handle)
    return _load_data_from_string_dict(string_dict)

def _load_data_from_string_dict(string_dict):
    result_dict = {}
    for key in string_dict:
        data = copy.deepcopy(string_dict[key])
        if 'edges' in data:
            data["links"] = data.pop("edges")
        graph = nx.node_link_graph(data)
        result_dict[key] = graph
    return result_dict

In [3]:
# dictionary of SMILES and values are graphs
loaddir = "../data/graphs/"
train_data = load_data_from_file(loaddir+"cleaned_graph_data_10June.json") 

In [4]:
training_graphs = [] # list of graphs along with duplicates
training_graph_names = [] # list of names of graphs

for i in range(len(train_data)): # for each graph in the training set
    
    mol = list(train_data.keys())[i]
    graph = train_data[mol] # we want to fully connect all target nodes in this graph so message passing works
    
    # add NONE bonds to all target nodes
    all_graphs = []
    all_names = []
    
    count = 0
    for n_i in graph.nodes: # for each node in the graph
        
        graph_i = graph.copy()
        
        for nb_i in nx.non_neighbors(graph_i, n_i): # get all neighbors for target node
            graph_i.add_edge(n_i, nb_i, bond_type='NONE') # add edge to target node
            
        for j in range(len(graph_i.nodes[n_i]['orbitals'])): # For each orbital specified for the target node
            
            graph_ij = graph_i.copy()
            
            if str(graph_ij.nodes[n_i]['orbitals'][j]) != '-1': # If the orbital has a specified binding energy
                
                graph_ij.nodes[n_i]['orbitals'] = [graph_ij.nodes[n_i]['orbitals'][j]]
                graph_ij.nodes[n_i]['binding_energies'] = [graph_ij.nodes[n_i]['binding_energies'][j]]
                graph_ij.nodes[n_i]['e_neg_score'] = [graph_ij.nodes[n_i]['e_neg_score'][0]] # Only depends on atom
 
                # Specify this node as the prediction target
                graph_ij.nodes[n_i]['pred'] = True 

                for n in graph_ij.nodes:
                    if n != n_i:
                        graph_ij.nodes[n]['pred'] = False
                all_graphs.append(graph_ij) # add graph to list of graphs

                name = f"{mol}_{n_i}_{j}" # SMILES + index of target node + index of orbital

                all_names.append(name)
                count += 1
        
    training_graphs = training_graphs + all_graphs # add all graphs to the list of graphs
    training_graph_names = training_graph_names + all_names # add all names to the list of names
    #print(f"Graph {i} with {len(all_graphs)} graphs added to training set")

In [5]:
training_data_dict = dict(zip(training_graph_names, training_graphs))

In [6]:
def write_data_to_json_file(graph_dict, filename, **kwargs):
    with open(filename, "w") as file_handle:
        json_string = json.dumps(graph_dict, default=nx.node_link_data, **kwargs)
        file_handle.write(json_string)

In [7]:
write_data_to_json_file(training_data_dict, loaddir+"graph_data_duplicates.json", indent=2) # write the training data to a json file