In [1]:
import json
import networkx as nx
import copy
import os

In [2]:
def load_data_from_file(filename):
    """
    Load a dictionary of graphs from JSON file.
    """
    with open(filename, "r") as file_handle:
        string_dict = json.load(file_handle)
    return _load_data_from_string_dict(string_dict)

def _load_data_from_string_dict(string_dict):
    """
    Internal helper to parse graphs from node-link JSON format.
    """
    result_dict = {}
    for key in string_dict:
        data = copy.deepcopy(string_dict[key])
        if 'edges' in data:
            data["links"] = data.pop("edges")
        graph = nx.node_link_graph(data)
        result_dict[key] = graph
    return result_dict

def write_data_to_json_file(graph_dict, filename, **kwargs):
    """
    Write dictionary of graphs to JSON file.
    """
    with open(filename, "w") as file_handle:
        json_string = json.dumps(graph_dict, default=nx.node_link_data, **kwargs)
        file_handle.write(json_string)

In [7]:
def create_duplicates(input_file, output_file):
    # Load the input graph data
    train_data = load_data_from_file(input_file)

    # List to hold all graphs and names
    training_graphs = []
    training_graph_names = []

    # Process each graph in the training data
    for i in range(len(train_data)):  # For each graph in the training set
        mol = list(train_data.keys())[i]
        graph = train_data[mol]  # Graph to be duplicated
        
        # Get the target nodes (with orbitals != -1)
        target_node_indices = [n for n, v in graph.nodes(data=True) if v['orbitals'][0] != -1]

        # List to hold all duplicated graphs and names for this graph
        all_graphs = []
        all_names = []

        count = 0
        for n_i in target_node_indices:  # For each target node
            graph_i = graph.copy()

            # Add 'NONE' bonds to all non-neighboring nodes of the target node
            for nb_i in nx.non_neighbors(graph_i, n_i):
                graph_i.add_edge(n_i, nb_i, bond_type='NONE')  # Add edge to target node
            
            # Create duplicates for each orbital of the target node
            for j in range(len(graph_i.nodes[n_i]['orbitals'])):
                graph_ij = graph_i.copy()

                # Skip if orbital is -1 (no need to modify)
                if graph_ij.nodes[n_i]['orbitals'][j] == -1:
                    continue

                # Set the orbital and remove others
                graph_ij.nodes[n_i]['orbitals'] = [graph_ij.nodes[n_i]['orbitals'][j]]

                # Mark the target node and other nodes
                graph_ij.nodes[n_i]['target'] = True
                for n in graph_ij.nodes:
                    if n != n_i:
                        graph_ij.nodes[n]['target'] = False
                
                # Append the duplicated graph and its name
                all_graphs.append(graph_ij)
                name = f"{mol}_{n_i}_{j}"  # Name based on molecule, target node, and orbital index
                all_names.append(name)
                count += 1

        # Add the generated graphs and names to the overall lists
        training_graphs.extend(all_graphs)
        training_graph_names.extend(all_names)

    # Write the duplicates to a new JSON file
    training_data_dict = dict(zip(training_graph_names, training_graphs))
    write_data_to_json_file(training_data_dict, output_file, indent=2)

In [8]:
create_duplicates('graph_data.json', 'graph_data_duplicates.json')