# Extracting PyTorch Geometric Datasets from IAM Graph Dataset Repository's Graph Data Formats (in .txt forms)
###### Run the below cell to generate pytorch datasets for training
###### All raw dataset forms from IAM Graph Datasets Repository must be inside root/Datasets/Raw/*
###### Functions in the cells reads and processes the raw datasets into geometric datasets

In [11]:
import os
import pickle
import statistics
import torch
from torch_geometric.data import Data

"""
A graph is used to model pairwise relations (edges) between objects (nodes). A single graph in PyG is described by an instance of torch_geometric.data.Data, which holds the following attributes by default:
Ref: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html
data.x: Node feature matrix with shape [num_nodes, num_node_features]
data.edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long
data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
data.pos: Node position matrix with shape [num_nodes, num_dimensions]
"""

""" Extracts Graph Properties from .txt files and creates a Raw Dict Dataset structure"""
def extract_geometric_dataset(dataset_name: str, rel_path: str, pde: bool = False):
    print("Dataset Name: ", dataset_name)
    all_data_files = [ str(("_".join(data.split("_")[1:])).split(".")[0]) for data in os.listdir(rel_path)]
    print("Dataset Attributes: ")
    for x in all_data_files:
        print(f"\t{x}")

    edges_path = os.path.join(rel_path, f'{dataset_name}_A.txt') if 'A' in all_data_files else None
    edge_attrib_path = os.path.join(rel_path, f'{dataset_name}_edge_attributes.txt') if 'edge_attributes' in all_data_files else None 
    node_attrib_path = os.path.join(rel_path, f'{dataset_name}_node_attributes.txt') if 'node_attributes' in all_data_files else None 
    node_labels_path = os.path.join(rel_path, f'{dataset_name}_node_labels.txt') if 'node_labels' in all_data_files else None
    graph_indicator_path = os.path.join(rel_path, f'{dataset_name}_graph_indicator.txt') if 'graph_indicator' in all_data_files else None
    graph_classes_path = os.path.join(rel_path, f'{dataset_name}_graph_labels.txt') if 'graph_labels' in all_data_files else None
    node_pos_descr_exists_path = os.path.join(rel_path, f'{dataset_name}_pos_exists.txt') if 'pos_exists' in all_data_files else None

    # Read to determine whether positional descriptors exist
    assert not isinstance(node_pos_descr_exists_path, type(None)), "Every data to be processed must define *_pos_exists.txt to define whether positional descriptors exist"
    node_pos_descr_exists = False
    with open(node_pos_descr_exists_path, "r") as data_file:
        node_pos_descr_exists = int(str(data_file.readlines()[0]).split(",")[-1].strip()) == 1
        print(f"Node positional descr exists: {node_pos_descr_exists}")

    # Loop and read in data files
    # Graph Indicator
    assert not isinstance(graph_indicator_path, type(None)), "Graph Indicator must not be None"
    graph_indicator_dict = { "config": None, "indicators": None, "num_graphs": -1 }
    with open(graph_indicator_path, "r") as data_file:
        graph_indicator_config = {} # e.g. 1: {start:<int>, end:<int>}
        indicator_ordered_list = [] # used to run through the other lists for read
        for i, graph_indicator in enumerate(data_file.readlines()):
            if int(graph_indicator) not in graph_indicator_config:
                # create obj for this
                indicator_obj = {"start": i, "end": i }
                graph_indicator_config[int(graph_indicator)] = indicator_obj
                indicator_ordered_list.append(int(graph_indicator))
            else:
                graph_indicator_config[int(graph_indicator)]["end"] += 1
        # create graph indic. struct
        graph_indicator_dict["config"]     = graph_indicator_config
        graph_indicator_dict["indicators"] = indicator_ordered_list
        graph_indicator_dict["num_graphs"] = len(indicator_ordered_list)

    # Read the graph labels and assign labels to graphs
    assert not isinstance(graph_classes_path, type(None)), "Graph Labels must not be None"
    with open(graph_classes_path, "r") as data_file:
        all_graph_labels = [ int(graph_label_str) for graph_label_str in data_file.readlines() ]
        min_graph_label = min(all_graph_labels)
        # re-normalized to zero-indexed
        renormed_labels = [ int(label - min_graph_label) for label in all_graph_labels ]
        # Labels remapped -> contiguous array
        remapped_labels = { label: idx for idx, label in enumerate(set(renormed_labels)) }
        # save graph labels in graph configs
        for graph_indicator, graph_label in zip(graph_indicator_dict["indicators"], renormed_labels):
            graph_indicator_dict["config"][graph_indicator]['label'] = int(remapped_labels[graph_label])

    # Use config to extract out the node attrib details
    """
    Node Attributes: E.g. below
        - AIDS: 
            Node labels:		[symbol]
            Node attributes:	[chem, charge, x, y]
          Fingerprints:
            Node attributes:	[x, y]
    Used attributes:
        - Zip node labels together with node attributes -> New node attributes
        - Remove last 2 node attributes and save that as node_pos
        - Remaining goes into node attributes list  
    """
    assert not (isinstance(node_attrib_path, type(None)) and isinstance(node_labels_path, type(None))), "Node Attributes must not be None"
    all_node_attribs = [] # holds all node attributes
    # Extract node attributes
    if node_attrib_path is not None:
        with open(node_attrib_path, "r") as data_file:
            all_node_attribs = [ [ float(x.strip()) for x in line.split(",")] for line in data_file.readlines() ]
    # Extract node labels - as node attributes
    if node_labels_path is not None:
        with open(node_labels_path, "r") as node_labels_data_file:
            all_node_labels = [ [ float(x.strip()) for x in line.split(",")] for line in node_labels_data_file.readlines() ]
            # extend the primary list with node labels
            if len(all_node_attribs) > 0:
                assert len(all_node_attribs) == len(all_node_labels), "Total num. of node features and node labels must match"
                all_node_attribs = [ [ *all_node_labels[i], *orig_node_attribs_lst ] for i, orig_node_attribs_lst in enumerate(all_node_attribs) ]
            else:
                # Treat Node labels as Node features
                all_node_attribs = all_node_labels

    # Read in edge attributes
    assert not isinstance(edges_path, type(None)), "Edge Attributes must not be None"
    with open(edges_path, "r") as edges_data_file:
        edge_connectivities = [ [int(x.strip()) for x in line.split(",")] for line in edges_data_file.readlines() ]
        edge_connectivity_adjacency_list_map = {}
        for in_node, out_node in edge_connectivities:
            if in_node not in edge_connectivity_adjacency_list_map:
                edge_connectivity_adjacency_list_map[in_node] = set( [out_node] )
            else:
                edge_connectivity_adjacency_list_map[in_node].add(out_node)

        # Add self loops
        for node_label in range(1, len(all_node_attribs)+1):
            if node_label in edge_connectivity_adjacency_list_map:
                edge_connectivity_adjacency_list_map[node_label].add(node_label)
            else:
                edge_connectivity_adjacency_list_map[node_label] = set( [ node_label ] )

        # # assertions
        assert len(edge_connectivity_adjacency_list_map.keys()) == len(all_node_attribs), "Nodes in edge connectivities must match nodes in attribute descriptors file"

        # Extract the node degree as an attribute on the nodes themselves
        for i in range(len(all_node_attribs)):
            node_label = i + 1
            node_degree = len(edge_connectivity_adjacency_list_map[node_label])
            all_node_attribs[i].insert(0, float(node_degree))

        # Build the edge features if they exist
        edge_feats_map = None if edge_attrib_path is None else {}
        if edge_attrib_path is not None:
            with open(edge_attrib_path, "r") as edges_attr_data_file:
                edge_attrs = [ [float(x.strip()) for x in line.split(",")] for line in edges_attr_data_file.readlines() ]
                assert len(edge_attrs) == len(edge_connectivities), "Edge Attributes length must match num edges in Edge Index (Edges Path) data type"
                # Create a mapping from each edge to each edge feature
                num_edge_feats = -1
                for idx, (in_node, out_node) in enumerate(edge_connectivities):
                    if num_edge_feats == -1: num_edge_feats = len(edge_attrs[idx])
                    assert num_edge_feats == len(edge_attrs[idx]), "Edge Attributes must be same dimension"
                    edge_feats_map[(in_node, out_node)] = edge_attrs[idx]
                # Handle case for self-loops [feats = [0] if self-loop doesn't exist]
                for i in range(len(all_node_attribs)):
                    node_label = i + 1
                    self_loop_edge = (node_label, node_label)
                    if self_loop_edge not in edge_feats_map:
                        edge_feats_map[self_loop_edge] = num_edge_feats * [0.0] # Null all features

    # Assert that node attributes have length >= 2
    # NOTE: idx=0 -> Node Degree
    #       idx=1 -> Node Label
    #       ...
    assert all([ len(node_attr) >= 2 for node_attr in all_node_attribs ]), "All nodes must have at least 2 features, idx=0 (Node Degree) and idx=1 (Node label)"

    # Separate node position indicators from other attributes
    # node_position_descr = [ node_attrib[-2:] if node_pos_descr_exists else node_attrib[:2] for node_attrib in all_node_attribs ]
    extract_pos_descr = lambda node_attr: [ statistics.mean(node_attr[:len(node_attr)//2]), statistics.mean(node_attr[len(node_attr)//2:]) ]
    node_position_descr = [ node_attrib[-2:] if node_pos_descr_exists else extract_pos_descr(node_attrib) for node_attrib in all_node_attribs ]
    node_attrib_descr   = [ node_attrib[:-2] for node_attrib in all_node_attribs ] if (node_pos_descr_exists and (not pde)) else all_node_attribs
    assert len(node_position_descr) == len(node_attrib_descr), "Node attributes and position descriptors must be same"

    # NOTE:
    #   Goal: Build the PyTorch object for a graph dataset
    # Loop through graphs config and assign nodes and edge attributes
    graphs = []
    max_node_degree = 0
    for graph_indicator in graph_indicator_dict["indicators"]:
        indicator_start = graph_indicator_dict["config"][graph_indicator]["start"]
        indicator_end = graph_indicator_dict["config"][graph_indicator]["end"]
        # base_attribs = {  }
        # pde_optional_attrib = { "pos": [] }
        # graph_attribs = { **base_attribs, **(pde_optional_attrib if pde else {}) }
        graph_attribs = { "x": [], "edge_index": [], "pos": [], "y": [] }
        if edge_feats_map is not None: graph_attribs.update({ "edge_attr": [] }) 
        edge_indices_coo_format = [[], []]
        for i in range(indicator_start, indicator_end + 1):
            node_label = i + 1 # ---> nodes file line number
            graph_attribs["x"].append(node_attrib_descr[i])
            graph_attribs["pos"].append(node_position_descr[i])
            # if "x" in graph_attribs: graph_attribs["x"].append(node_attrib_descr[i])
            # if "pos" in graph_attribs: graph_attribs["pos"].append(node_position_descr[i])
            # add the edge indices
            target_nodes = edge_connectivity_adjacency_list_map[node_label]
            max_node_degree = max_node_degree if max_node_degree >= len(target_nodes) else len(target_nodes)
            edge_indices_coo_format[0].extend( len(target_nodes) * [node_label] )
            edge_indices_coo_format[1].extend( list(target_nodes))
            # compile the edge attributes (if they exist)
            if edge_feats_map is not None:
                for target_node in target_nodes:
                    graph_attribs["edge_attr"].append(edge_feats_map[(node_label, target_node)])
        
        # re-normalize edge_indices 
        min_node_id = min(edge_indices_coo_format[0])
        edge_indices_coo_format[0] = [ (node_id - min_node_id) for node_id in edge_indices_coo_format[0] ]
        edge_indices_coo_format[1] = [ (node_id - min_node_id) for node_id in edge_indices_coo_format[1] ]

        # Add graph edge indices and label
        graph_attribs["edge_index"] = edge_indices_coo_format
        graph_attribs["y"] = [ graph_indicator_dict["config"][graph_indicator]['label'] ] # pytorch geometric expects this format
        # Add to graphs pile
        graphs.append(graph_attribs)
    
    # Assertions
    assert len(graphs) == len(graph_indicator_dict["indicators"]), "Num graphs must match Num graph indicators"

    # Print out max node degree
    print(f"{dataset_name}: max_node_degree: ", max_node_degree, "\n")

    # Return graphs
    return graphs

""" Converts Raw Dataset structured in GEO format to PyTorch type and Geometric Data Type """
def convert_to_pytorch_graph_dataset(dataset: list):
    data_attrib_typing = {
        'x': torch.FloatTensor,
        'y': torch.LongTensor,
        'pos': torch.FloatTensor,
        'edge_attr': torch.FloatTensor,
        'edge_index': torch.LongTensor
    }

    # Loop through and init data
    torch_geo_dataset = []
    for data_item in dataset:
        # Init empty Geo Graph Data
        geo_graph_data_elem = Data()
        geo_graph_data_allowed_keys = list(Data.__dict__.keys())
        assert all([ raw_data_key in geo_graph_data_allowed_keys  for raw_data_key in list(data_item.keys()) ]), "Raw Graph Data structure to be parsed into PyTorch Geometric struct contains keys not supported"
        # populate with fields
        for data_attrib_key, data_attrib_val in data_item.items():
            data_attrib_type = data_attrib_typing[data_attrib_key]
            data_attrib_val_geo_torch = torch.tensor(data_attrib_val).type(data_attrib_type)
            if hasattr(geo_graph_data_elem, data_attrib_key):
                setattr(geo_graph_data_elem, data_attrib_key, data_attrib_val_geo_torch)
        # Add to collection
        torch_geo_dataset.append(geo_graph_data_elem)

    # return new dataset
    return torch_geo_dataset

""" Saves the PyTorch Geometric Dataset to Generated/ dir """
def save_graph_datasets(positional_descriptors_enabled: bool = False):
    print(f"Generating Graph Dataset with positional descriptors enabled: { 'ON' if positional_descriptors_enabled else 'OFF'}")
    # Save pickle file objects of the datasets required
    parent_dirpath = os.path.join(".", "Dataset", "Raw")
    gen_folder_name = "Generated"
    for dir in os.listdir(parent_dirpath):
        if dir != gen_folder_name and os.path.isdir(os.path.join(parent_dirpath, dir)):
            rel_path = os.path.join(parent_dirpath, dir)
            dataset = extract_geometric_dataset(dataset_name=dir, rel_path=rel_path, pde=positional_descriptors_enabled)
            geo_graph_dataset = convert_to_pytorch_graph_dataset(dataset)
            # Save a pickle object of this file
            filename = f"{dir.strip().lower()}_graph_data_pde={ 'yes' if positional_descriptors_enabled else 'no' }.pkl"
            if not os.path.exists(os.path.join(parent_dirpath, gen_folder_name)):
                os.mkdir(os.path.join(parent_dirpath, gen_folder_name))
            # Add generated file
            store_file_path = os.path.join(parent_dirpath, gen_folder_name, filename)
            with open(store_file_path, 'wb') as handle:
                pickle.dump(geo_graph_dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print("DONE \n")

""" Important helper function to read generated Torch Dataset """
def read_dataset(dataset_name="AIDS", pde: bool = False):
    # Save pickle file objects of the datasets required
    parent_dirpath = os.path.join(".", "Dataset", "Raw", "Generated")
    filename = f"{dataset_name.strip().lower()}_graph_data_pde={ 'yes' if pde else 'no' }.pkl"
    rel_dataset_path = os.path.join(parent_dirpath, filename)
    with open(rel_dataset_path, "rb") as pickle_file_handle:
        loaded_dataset = pickle.load(pickle_file_handle)
    # inspect dataset
    return loaded_dataset

"""
Dataset Characteristics:
    - Self loops available = YES
    - Node attributes have node degree in addition to primary attributes
    - Edge labels and Edge attributes are unused
    - Node mux var: pde: positional descriptors enabled
        pde == yes:
            We separate out positional descriptors into data.pos
        pde == no:
            Positional descriptors are removed for geo graph data elem
"""
save_graph_datasets(positional_descriptors_enabled = False)
save_graph_datasets(positional_descriptors_enabled = True)
############################ EOF #########################

Generating Graph Dataset with positional descriptors enabled: ON
Dataset Name:  AIDS
Dataset Attributes: 
	A
	edge_labels
	graph_indicator
	graph_labels
	label_readme
	node_attributes
	node_labels
	pos_exists
Node positional descr exists: True
AIDS: max_node_degree:  7 

Dataset Name:  COIL-DEL
Dataset Attributes: 
	A
	edge_labels
	graph_indicator
	graph_labels
	label_readme
	node_attributes
	pos_exists
Node positional descr exists: True
COIL-DEL: max_node_degree:  15 

Dataset Name:  ENZYMES
Dataset Attributes: 
	A
	graph_indicator
	graph_labels
	node_attributes
	node_labels
	pos_exists
	
Node positional descr exists: False
ENZYMES: max_node_degree:  10 

Dataset Name:  Fingerprint
Dataset Attributes: 
	A
	edge_attributes
	graph_indicator
	graph_labels
	label_readme
	node_attributes
	pos_exists
Node positional descr exists: True
Fingerprint: max_node_degree:  4 

Dataset Name:  FRANKENSTEIN
Dataset Attributes: 
	A
	graph_indicator
	graph_labels
	node_attributes
	pos_exists
Node positi

## Creating different dataset splits for the datasets +
## Parsing into a struct that experiment.py can consume
###### experiment.py struct format:
###### struct = { "raw": {"x_train_data": [], "y_train_data": [], "x_test_data" : [], "y_test_data" : []}, "geometric": {"qgcn_train_data":  [], "qgcn_test_data" : [], "sgcn_train_data": [], "sgcn_test_data" :  []}}

In [12]:
import random
import math
from copy import copy, deepcopy
# TODO: Create test and train datasets: Example showcased below
    # AIDs: [Binary]
    #               [100, 20] & [1000, 200] & [1600, 400]
    #                50s 10s     500s 100s     800s 200s
    # Fingerprints: [15 Classes]
    #               [150, 30] & [1050, 150] & [1650, 450]
    #                10s  2s     70s  10s      110s  30s
# NOTE: Make sure for small dataset split, num classes are balanced out
# NOTE: Randomize dataset for all split

def read_dataset(dataset_name="AIDS", pde: bool = False):
    # Save pickle file objects of the datasets required
    parent_dirpath = os.path.join(".", "Dataset", "Raw", "Generated")
    filename = f"{dataset_name.strip().lower()}_graph_data_pde={ 'yes' if pde else 'no' }.pkl"
    rel_dataset_path = os.path.join(parent_dirpath, filename)
    with open(rel_dataset_path, "rb") as pickle_file_handle:
        loaded_dataset = pickle.load(pickle_file_handle)
    # inspect dataset
    return loaded_dataset

def order_dataset_into_classes(loaded_dataset: list):
    classes_dict = {}
    for data in loaded_dataset:
        # classes_dict[data.y.item()] = [ *classes_dict.get(data.y.item(), []),  data ]
        if data.y.item() in classes_dict: classes_dict[data.y.item()].append(data)
        else: classes_dict[data.y.item()] = [ data ]
    return classes_dict

def find_dataset_splits(dataset_name: str = "AIDS", pde: bool = False):
    splits = {
        "aids": [
            [100, 20], 
            [1000, 200], 
            [1600, 400]
        ],
        "coil-del": [
            [500,  100], 
            [1000, 200], 
            [3200, 700]
        ],
        "enzymes": [
            [480,  120]
        ],
        "fingerprint": [
            [150, 30], 
            [1050, 150], 
            [1650, 450],
            [1650, 495]
        ],
        "frankenstein": [
            [100, 20], 
            [1000, 200], 
            [2000, 500],
            [3400, 900]
        ],
        "letter-low": [
            [150, 30], 
            [1050, 150], 
            [1650, 450],
            [1725, 525]
        ],
        "letter-med": [
            [150, 30], 
            [1050, 150], 
            [1650, 450],
            [1725, 525]
        ],
        "letter-high": [
            [150, 30], 
            [1050, 150], 
            [1650, 450],
            [1725, 525]
        ],
        "mutag": [
            [100, 20],
            [148, 40]
        ],
        "mutagenicity": [
            [100, 20], 
            [1000, 200], 
            [2000, 500],
            [3400, 900]
        ],
        "proteins": [
            [100, 20], 
            [850, 250], 
            [1000, 100]
        ],
        "proteins-full": [
            [100, 20], 
            [850, 250], 
            [1000, 100]
        ],
        "synthie": [
            [100, 20],
            [320, 80]
        ],
    }

    """"
    AIDS: 
        num_classes: 2
        max_node_degree:  7
        node_features: pde_on-6, pde_off-4
        total_dataset: 2000     
        splits:
            - train:100_test:20
            - train:1000_test:200
            - train:1600_test:400
    COIL-DEL: 
        num_classes: 100
        max_node_degree:  15
        node_features: pde_on-3, pde_off-1
        total_dataset: 3900
        splits:
            - train:500_test:100
            - train:1000_test:200
            - train:3200_test:700
    ENZYMES:
        num_classes: 6
        max_node_degree:  10
        node_features: pde_on-20, pde_off-20
        total_dataset: 600
        splits:
            - train:480_test:120
    Fingerprint:
        num_classes: 15
        max_node_degree: 4
        node_features: pde_on-3, pde_off-3
        num_edge_feats: 2
        total_dataset: 2149
        splits:
            - train:150_test:30
            - train:1050_test:150
            - train:1650_test:450
            - train:1650_test:495
    FRANKENSTEIN:
        num_classes: 2
        max_node_degree: 4
        node_features: pde_on-781, pde_off-779
        total_dataset: 4337
        splits:
            - train:100_test:20
            - train:1000_test:200
            - train:2000_test:500
            - train:3400_test:900
    Letter-high: 
        num_classes: 15
        max_node_degree:  6
        node_features: pde_on-3, pde_off-1
        total_dataset: 2250
        splits:
            - train:150_test:30
            - train:1050_test:150
            - train:1650_test:450
            - train:1725_test:525
    Letter-low:
        num_classes: 15
        max_node_degree:  5
        node_features: pde_on-3, pde_off-1
        total_dataset: 2250
        splits:
            - train:150_test:30
            - train:1050_test:150
            - train:1650_test:450
            - train:1725_test:525
    Letter-med: 
        num_classes: 15
        max_node_degree:  5
        node_features: pde_on-3, pde_off-1
        total_dataset: 2250
        splits:
            - train:150_test:30
            - train:1050_test:150
            - train:1650_test:450
            - train:1725_test:525
    MUTAG:
        num_classes: 2
        max_node_degree: 5
        node_features: pde_on-2, pde_off-2
        total_dataset: 188
        splits:
            - train:100_test:20
            - train:148_test:40
    Mutagenicity:
        num_classes: 2
        max_node_degree:  5
        node_features: pde_on-2, pde_off-2
        total_dataset: 4337
        splits:
            - train:100_test:20
            - train:1000_test:200
            - train:2000_test:500
            - train:3400_test:900
    PROTEINS:
        num_classes: 2
        max_node_degree:  26
        node_features: pde_on-3, pde_off-3
        total_dataset: 1113
        splits:
            - train:100_test:20
            - train:850_test:250
            - train:1000_test:100
    PROTEINS-Full:
        num_classes: 2
        max_node_degree:  26
        node_features: pde_on-31, pde_off-31
        total_dataset: 1113
        splits:
            - train:100_test:20
            - train:850_test:250
            - train:1000_test:100
    Synthie:
        num_classes: 4
        max_node_degree:  21
        node_features: pde_on-16, pde_off-16
        total_dataset: 400
        splits:
            - train:100_test:20
            - train:320_test:80
    """

    dataset_split = splits[dataset_name.strip().lower()]
    dataset = read_dataset(dataset_name=dataset_name, pde=pde)
    classes_dict = order_dataset_into_classes(dataset)
    classes_size = { it: len(val) for it, val in classes_dict.items() }

    print(f"{dataset_name} - Classes splits: ", { k: len(v) for k, v in classes_dict.items() }, "\n")

    # Mapping of how to partition data
    dataset_mapping_splits = {}

    # Split into the train and test size
    for train_size, test_size in dataset_split:
        # Determine if we meet split targets
        classes_size_cpy = classes_size.copy()
        per_class_train_size = math.floor(train_size / len(classes_size_cpy.keys()))
        per_class_test_size  = math.ceil(test_size / len(classes_size_cpy.keys()))
        total_per_class_contrib = per_class_test_size + per_class_train_size
        train_sf = (per_class_train_size / total_per_class_contrib)
        test_sf = (per_class_test_size / total_per_class_contrib)
        split_classes_col_dict = {  class_id: [0, 0] for class_id in classes_size_cpy }
        remaining = (train_size + test_size) - sum([ sum(split) for split in split_classes_col_dict.values() ])
        while remaining > 0:
            # For classes that have more data left, add more to train and test
            non_zero_classes = list(filter(lambda x: classes_size_cpy[x] > 0, classes_size_cpy.keys()))
            filtered_classes = { class_id: classes_size_cpy[class_id] for class_id in non_zero_classes }
            size_per_class = remaining / len(non_zero_classes)
            add_train_size = math.floor(size_per_class * train_sf)
            add_test_size  = math.ceil(size_per_class * test_sf)

            # Helper functions for determining whether a class can contribute more training / testing data
            train_data_can_accommodate = lambda x: math.floor(classes_size[x] * train_sf) >= (split_classes_col_dict[x][0] + add_train_size)
            test_data_can_accommodate  = lambda x: math.ceil(classes_size[x] * test_sf)   >= (split_classes_col_dict[x][1] + add_test_size)
            # redist = len(list(filter(lambda x: train_data_can_accommodate(x) and test_data_can_accommodate(x), filtered_classes.keys()))) != len(filtered_classes)

            # Add data from largest classes
            if math.floor(size_per_class) == 0:
                sorted_redist_list = sorted(non_zero_classes, key=lambda x: filtered_classes[x])
                filtered_redist_list = []
                rolling_sum = 0
                for class_id in sorted_redist_list:
                    if rolling_sum < remaining:
                        filtered_redist_list.append(class_id)
                        rolling_sum += filtered_classes[class_id]
                    else:
                        break
                # redist from filtered class ids
                for class_id in filtered_redist_list:
                    num_data = filtered_classes[class_id]
                    if num_data < remaining:
                        delta_train_size = math.floor(num_data * train_sf)
                        delta_test_size = math.ceil(num_data * test_sf)
                        split_classes_col_dict[class_id][0] += delta_train_size
                        split_classes_col_dict[class_id][1] += delta_test_size
                        classes_size_cpy[class_id] -= (delta_train_size + delta_test_size)
                        remaining -= (delta_train_size + delta_test_size)
                    else:
                        delta_train_size = math.floor(remaining * train_sf)
                        delta_test_size = math.ceil(remaining * test_sf)
                        split_classes_col_dict[class_id][0] += delta_train_size
                        split_classes_col_dict[class_id][1] += delta_test_size
                        remaining -= (delta_train_size + delta_test_size)
                        break
            else:
                # Redistribute data from other classes
                for class_id in filtered_classes:
                    if train_data_can_accommodate(class_id) and test_data_can_accommodate(class_id):
                        split_classes_col_dict[class_id][0] += add_train_size
                        split_classes_col_dict[class_id][1] += add_test_size
                        classes_size_cpy[class_id] -= (add_train_size + add_test_size) 
                    else:
                        delta_train_size = math.floor(filtered_classes[class_id] * train_sf)
                        delta_test_size  = math.ceil(filtered_classes[class_id] * test_sf)
                        split_classes_col_dict[class_id][0] += delta_train_size
                        split_classes_col_dict[class_id][1] += delta_test_size
                        classes_size_cpy[class_id] -= (delta_train_size + delta_test_size) 

            # compute remaining
            remaining = (train_size + test_size) - sum([ sum(split) for split in split_classes_col_dict.values() ])

        # assert that distribution meets target split requirements
        assert (train_size + test_size) == sum([ sum(split_lst) for split_lst in split_classes_col_dict.values() ])

        dataset_mapping_splits[f"train_{train_size}_test_{test_size}"] = { "splits": split_classes_col_dict, "config": { "train_size": train_size, "test_size": test_size }}

    # return the mapping for creating the dataset splits
    return classes_dict, dataset_mapping_splits

    # AIDs: [Binary]
    #               [100, 20] & [1000, 200] & [1600, 400]
    #                50s 10s     500s 100s     800s 200s
    # Fingerprints: [15 Classes]
    #                &  & 
    #                10s  2s     70s  10s      110s  30s

def create_dataset_splits(dataset_name: str = "AIDS", pde: bool = False):
    classes_dict, dataset_splits = find_dataset_splits(dataset_name=dataset_name, pde=pde)
    for dataset_split in dataset_splits:
        train_data, test_data = [], []
        for class_id in dataset_splits[dataset_split]["splits"]:
            class_train_size, class_test_size = dataset_splits[dataset_split]["splits"][class_id]
            random.shuffle(classes_dict[class_id]) # shuffle data first before distr.
            for i in range(class_train_size + class_test_size):
                if i < class_train_size:
                    train_data.append(classes_dict[class_id][i])
                else:
                    test_data.append(classes_dict[class_id][i])

        # struct to create
        struct = {
            "raw": {
                "x_train_data": [],
                "y_train_data": [],
                "x_test_data" : [],
                "y_test_data" : []
            },
            "geometric": {
                "qgcn_train_data": deepcopy(train_data),
                "qgcn_test_data" : deepcopy(test_data),
                "sgcn_train_data": deepcopy(train_data),
                "sgcn_test_data" : deepcopy(test_data),
            }
        }

        # save file
        parent_dirpath = os.path.join(".", "Dataset", "Raw", "Generated")
        filename = f"{dataset_split}_struct_{dataset_name.strip().lower()}_graph_data_pde={ 'yes' if pde else 'no' }.pkl"
        rel_full_path = os.path.join(parent_dirpath, filename)
        with open(rel_full_path, "wb") as pickle_file_handle:
            pickle.dump(struct, pickle_file_handle, protocol=pickle.HIGHEST_PROTOCOL)



struct = {
    "raw": {
        "x_train_data": [],
        "y_train_data": [],
        "x_test_data" : [],
        "y_test_data" : []
    },
    "geometric": {
        "qgcn_train_data":  [],
        "qgcn_test_data" : [],
        "sgcn_train_data": [],
        "sgcn_test_data" :  [],
    }
}

# dataset_name="AIDS", pde: bool = False -> 2 classes 400, 1600
# {0: 400, 1: 1600}
# dataset_name="Fingerprints", pde: bool = False -> 2 classes 400, 1600
# {0: 369, 1: 136, 2: 496, 3: 75, 4: 396, 5: 368, 6: 134, 7: 4, 8: 105, 9: 18, 10: 20, 11: 20, 12: 2, 13: 4, 14: 2}
# order_dataset_into_classes(read_dataset())

# # dataset with pde=OFF
# create_dataset_splits(dataset_name="AIDS", pde=False)
# create_dataset_splits(dataset_name="COIL-DEL", pde=False)
# create_dataset_splits(dataset_name="ENZYMES", pde=False)
# create_dataset_splits(dataset_name="Fingerprint", pde=False)
# create_dataset_splits(dataset_name="FRANKENSTEIN", pde=False)
# create_dataset_splits(dataset_name="Letter-high", pde=False)
# create_dataset_splits(dataset_name="Letter-low", pde=False)
# create_dataset_splits(dataset_name="Letter-med", pde=False)
# create_dataset_splits(dataset_name="MUTAG", pde=False)
# create_dataset_splits(dataset_name="Mutagenicity", pde=False)
# create_dataset_splits(dataset_name="PROTEINS", pde=False)
# create_dataset_splits(dataset_name="PROTEINS-Full", pde=False)
# create_dataset_splits(dataset_name="Synthie", pde=False)

# dataset with pde=ON
create_dataset_splits(dataset_name="AIDS", pde=True)
create_dataset_splits(dataset_name="COIL-DEL", pde=True)
create_dataset_splits(dataset_name="ENZYMES", pde=True)
create_dataset_splits(dataset_name="Fingerprint", pde=True)
create_dataset_splits(dataset_name="FRANKENSTEIN", pde=True)
create_dataset_splits(dataset_name="Letter-high", pde=True)
create_dataset_splits(dataset_name="Letter-low", pde=True)
create_dataset_splits(dataset_name="Letter-med", pde=True)
create_dataset_splits(dataset_name="MUTAG", pde=True)
create_dataset_splits(dataset_name="Mutagenicity", pde=True)
create_dataset_splits(dataset_name="PROTEINS", pde=True)
create_dataset_splits(dataset_name="PROTEINS-Full", pde=True)
create_dataset_splits(dataset_name="Synthie", pde=True)

AIDS - Classes splits:  {0: 400, 1: 1600} 

COIL-DEL - Classes splits:  {14: 39, 99: 39, 18: 39, 3: 39, 57: 39, 87: 39, 38: 39, 24: 39, 63: 39, 98: 39, 46: 39, 52: 39, 2: 39, 89: 39, 78: 39, 15: 39, 73: 39, 79: 39, 29: 39, 16: 39, 6: 39, 12: 39, 51: 39, 47: 39, 39: 39, 86: 39, 31: 39, 88: 39, 10: 39, 61: 39, 22: 39, 35: 39, 1: 39, 42: 39, 27: 39, 93: 39, 23: 39, 36: 39, 75: 39, 58: 39, 43: 39, 91: 39, 26: 39, 60: 39, 83: 39, 33: 39, 56: 39, 55: 39, 72: 39, 49: 39, 13: 39, 84: 39, 5: 39, 90: 39, 20: 39, 45: 39, 69: 39, 74: 39, 82: 39, 7: 39, 28: 39, 70: 39, 94: 39, 76: 39, 9: 39, 50: 39, 71: 39, 80: 39, 25: 39, 95: 39, 40: 39, 21: 39, 34: 39, 64: 39, 11: 39, 32: 39, 96: 39, 41: 39, 19: 39, 97: 39, 8: 39, 66: 39, 4: 39, 59: 39, 68: 39, 92: 39, 81: 39, 62: 39, 17: 39, 0: 39, 65: 39, 48: 39, 30: 39, 37: 39, 77: 39, 53: 39, 67: 39, 85: 39, 54: 39, 44: 39} 

ENZYMES - Classes splits:  {5: 100, 4: 100, 0: 100, 1: 100, 2: 100, 3: 100} 

Fingerprint - Classes splits:  {0: 369, 1: 136, 2: 496, 3