In [1]:
import json
import networkx as nx
import copy
import os

In [2]:
def load_data_from_file(filename):
    """
    Load a dictionary of graphs from JSON file.
    """
    with open(filename, "r") as file_handle:
        string_dict = json.load(file_handle)
    return _load_data_from_string_dict(string_dict)

def _load_data_from_string_dict(string_dict):
    """
    Internal helper to parse graphs from node-link JSON format.
    """
    result_dict = {}
    for key in string_dict:
        data = copy.deepcopy(string_dict[key])
        if 'edges' in data:
            data["links"] = data.pop("edges")
        graph = nx.node_link_graph(data)
        result_dict[key] = graph
    return result_dict

def write_data_to_json_file(graph_dict, filename, **kwargs):
    """
    Write dictionary of graphs to JSON file.
    """
    with open(filename, "w") as file_handle:
        json_string = json.dumps(graph_dict, default=nx.node_link_data, **kwargs)
        file_handle.write(json_string)

In [5]:
def create_duplicates(input_file, output_file):
    # Load the input graph data
    train_data = load_data_from_file(input_file)

    # List to hold all graphs and names
    training_graphs = []
    training_graph_names = []

    # Process each graph in the training data
    for i in range(len(train_data)):  # For each graph in the training set
        mol = list(train_data.keys())[i]
        graph = train_data[mol]  # Graph to be duplicated
        
        # Get the target nodes (with orbitals, even if they are -1)
        target_node_indices = [n for n, v in graph.nodes(data=True) if v['orbitals']]

        # List to hold all duplicated graphs and names for this graph
        all_graphs = []
        all_names = []

        count = 0
        for n_i in target_node_indices:  # For each target node
            graph_i = graph.copy()

            # Print orbitals for inspection before filtering
            #print(f"Processing molecule {mol}: Node {n_i} orbitals before filtering: {graph_i.nodes[n_i]['orbitals']}")

            # Add 'NONE' bonds to all non-neighboring nodes of the target node
            for nb_i in nx.non_neighbors(graph_i, n_i):
                graph_i.add_edge(n_i, nb_i, bond_type='NONE')  # Add edge to target node
            
            # Create duplicates for each orbital of the target node, including those with -1
            for j in range(len(graph_i.nodes[n_i]['orbitals'])):
                graph_ij = graph_i.copy()

                # No skipping based on orbital value now
                graph_ij.nodes[n_i]['orbitals'] = [graph_ij.nodes[n_i]['orbitals'][j]]

                # Print orbitals after filtering for comparison
                #print(f"Processing molecule {mol}: Node {n_i} orbitals after filtering: {graph_ij.nodes[n_i]['orbitals']}")

                # Mark the target node and other nodes
                graph_ij.nodes[n_i]['target'] = True
                for n in graph_ij.nodes:
                    if n != n_i:
                        graph_ij.nodes[n]['target'] = False
                
                # Append the duplicated graph and its name
                all_graphs.append(graph_ij)
                name = f"{mol}_{n_i}_{j}"  # Name based on molecule, target node, and orbital index
                all_names.append(name)
                count += 1

        # Add the generated graphs and names to the overall lists
        training_graphs.extend(all_graphs)
        training_graph_names.extend(all_names)

        # Print how many duplicates were created for this graph
        #print(f"Graph {i} for molecule {mol} created {count} duplicates.")

    # Write the duplicates to a new JSON file
    training_data_dict = dict(zip(training_graph_names, training_graphs))
    write_data_to_json_file(training_data_dict, output_file, indent=2)

In [6]:
create_duplicates('graph_data.json', 'graph_data_duplicates.json')

Graph 0 for molecule [Ag] created 1 duplicates.
Graph 1 for molecule C/C(=C\C(=O)C(F)(F)F)/O[Al](O/C(=C\C(=O)C(F)(F)F)/C)O/C(=C\C(=O)C(F)(F)F)/C created 31 duplicates.
Graph 2 for molecule C(=C(\O[Al](O/C(=C\C(=O)C(F)(F)F)/C(F)(F)F)O/C(=C\C(=O)C(F)(F)F)/C(F)(F)F)/C(F)(F)F)\C(=O)C(F)(F)F created 40 duplicates.
Graph 3 for molecule C/C(=C/C(=O)C)/O[Al](O/C(=C\C(=O)C)/C)O/C(=C\C(=O)C)/C created 22 duplicates.
Graph 4 for molecule CC(/C(=C/C(=O)C(C)(C)C)/O[Al](O/C(=C\C(=O)C(C)(C)C)/C(C)(C)C)O/C(=C\C(=O)C(C)(C)C)/C(C)(C)C)(C)C created 40 duplicates.
