# Merge Heterogeneous graph datasets

Merge the created Heterogeneous graph datasets using a binary tree merge approach

In [2]:
# Torch
import torch
from torch_geometric.data import HeteroData, DataLoader
from torch_geometric.data import Dataset, Data

# Data manipulation
import pandas as pd
import numpy as np

# Other
from math import floor, ceil
from joblib import dump, load
import time
import os

# Pandas display options
pd.set_option('display.max_columns', 1300)
pd.set_option('display.max_rows', 100)

In [3]:
# Graph related path constants
PATH_HETERO_GRAPH_DATASET_FOLDER = '../../data/graph-dataset/semester-2/hetero-graph/'

# Graph dataset class
class InfernoDataset(Dataset):
    def __init__(self, data_list):
        super(InfernoDataset, self).__init__()
        self.data_list = data_list

    def len(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data_list[idx]
    
    def concat(self, dataset):
        self.data_list = self.data_list + dataset.data_list

### 1. Merge single datasets in pairs of two

In [10]:
file_list = os.listdir(PATH_HETERO_GRAPH_DATASET_FOLDER)
file_list = [f for f in file_list if f.endswith('.pt')]

for i in range(0, len(file_list), 2):
    try:
        dataset = InfernoDataset(
            torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i]).data_list +
            torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i+1]).data_list
        )
        file_name = file_list[i].split('.')[0][-4:] + '-' + file_list[i+1].split('.')[0][-4:]
        torch.save(dataset, PATH_HETERO_GRAPH_DATASET_FOLDER + 'merged-2/' + file_name + '.pt')
        print('Merged', file_list[i], file_list[i+1])
    except:
        print('Error in merging', file_list[i], file_list[i+1], 'copying the first file.')
        dataset = torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i])
        torch.save(dataset, PATH_HETERO_GRAPH_DATASET_FOLDER + 'merged-2/' + file_list[i].split('.')[0][-4:] + '.pt')

Merged inferno_graph_data_00-1000.pt inferno_graph_data_01-1001.pt
Merged inferno_graph_data_01-1002.pt inferno_graph_data_01-1003.pt
Merged inferno_graph_data_01-1004.pt inferno_graph_data_01-1005.pt
Merged inferno_graph_data_02-1006.pt inferno_graph_data_02-1007.pt
Merged inferno_graph_data_02-1008.pt inferno_graph_data_02-1009.pt
Merged inferno_graph_data_02-1010.pt inferno_graph_data_03-1011.pt
Merged inferno_graph_data_03-1012.pt inferno_graph_data_03-1013.pt
Merged inferno_graph_data_03-1014.pt inferno_graph_data_04-1015.pt
Merged inferno_graph_data_04-1016.pt inferno_graph_data_04-1017.pt
Merged inferno_graph_data_04-1018.pt inferno_graph_data_04-1019.pt
Merged inferno_graph_data_05-1020.pt inferno_graph_data_05-1021.pt
Merged inferno_graph_data_05-1022.pt inferno_graph_data_05-1023.pt
Merged inferno_graph_data_05-1024.pt inferno_graph_data_06-1025.pt
Merged inferno_graph_data_06-1026.pt inferno_graph_data_06-1027.pt
Merged inferno_graph_data_06-1028.pt inferno_graph_data_06-102

### 2. Merge double datasets into pairs of four

In [None]:
file_list = os.listdir(PATH_HETERO_GRAPH_DATASET_FOLDER + 'merged-2/')
file_list = [f for f in file_list if f.endswith('.pt')]

for i in range(0, len(file_list), 2):
    try:
        dataset = InfernoDataset(
            torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i]).data_list +
            torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i+1]).data_list
        )
        file_name = file_list[i].split('.')[0] + '-' + file_list[i+1].split('.')[0]
        torch.save(dataset, PATH_HETERO_GRAPH_DATASET_FOLDER + 'merged-4/' + file_name + '.pt')
        print('Merged', file_list[i], file_list[i+1])
    except:
        print('Error in merging', file_list[i], file_list[i+1], 'copying the first file.')
        dataset = torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i])
        torch.save(dataset, PATH_HETERO_GRAPH_DATASET_FOLDER + 'merged-4/' + file_list[i].split('.')[0] + '.pt')

### 3. Merge quadruple datasets into pairs of eight

In [None]:
file_list = os.listdir(PATH_HETERO_GRAPH_DATASET_FOLDER + 'merged-4/')
file_list = [f for f in file_list if f.endswith('.pt')]

for i in range(0, len(file_list), 2):
    try:
        dataset = InfernoDataset(
            torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i]).data_list +
            torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i+1]).data_list
        )
        file_name = file_list[i].split('.')[0] + '-' + file_list[i+1].split('.')[0]
        torch.save(dataset, PATH_HETERO_GRAPH_DATASET_FOLDER + 'merged-8/' + file_name + '.pt')
        print('Merged', file_list[i], file_list[i+1])
    except:
        print('Error in merging', file_list[i], file_list[i+1], 'copying the first file.')
        dataset = torch.load(PATH_HETERO_GRAPH_DATASET_FOLDER + file_list[i])
        torch.save(dataset, PATH_HETERO_GRAPH_DATASET_FOLDER + 'merged-8/' + file_list[i].split('.')[0] + '.pt')