In [None]:
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from torch_geometric.nn import knn_graph
import h5py

from time import time
from tqdm import tqdm
from pathlib import Path

In [None]:
def gen_random_graph(N: int, k: int = 4)  -> Data:
    pos = torch.rand(N, 3)
    x = torch.rand(N, 1)
    edge_index = knn_graph(pos, k=k, loop=False)
    edge_index = to_undirected(edge_index)
    src, dst = edge_index
    edge_vec = pos[dst] - pos[src]
    edge_attr = edge_vec.norm(dim=1).unsqueeze(1)
    graph = Data(x=x, pos=pos,
                 edge_attr=edge_attr, edge_index=edge_index)
    return graph

def gen_dataset(N: int, n_mean, n_sig) -> list[Data]:
    dataset = [torch.tensor([]) for _ in range(N)]
    for i in tqdm(range(N)):
        n = int(torch.normal(n_mean, n_sig, (1,)).item())
        dataset[i] = gen_random_graph(n)
    return dataset

In [18]:
n_events = 10000
n_mean = 250
n_sig = 50

dsnb_dataset = gen_dataset(n_events, n_mean, n_sig)
atm_dataset = gen_dataset(n_events, 4000, 800)

  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:09<00:00, 1039.88it/s]
100%|██████████| 10000/10000 [02:20<00:00, 71.04it/s]


In [35]:
EVENTS_PER_FILE = 700

def save_hdf5(dataset: list[Data], path: str) -> None:
    with h5py.File(path, 'w') as f:
        for i, data in enumerate(dataset):
            grp = f.create_group(f'graph_{i}')
            for key, value in data.items():
                grp.create_dataset(
                    key,
                    data=value.numpy(),
                    compression='gzip'
                )

def load_hdf5(path: str) -> list[Data]:
    with h5py.File(path, 'r') as f:
        num_graphs = sum(isinstance(
            f[key], h5py.Group
        ) for key in f.keys())
        dataset = [Data() for _ in range(EVENTS_PER_FILE)]
        for i in range(num_graphs):
            grp = f[f'graph_{i}']
            for key in grp.keys():
                dataset[i][key] = torch.from_numpy(
                    grp[key][()]
                )
    return dataset
    
def save_graphs(dataset: list[Data], root: str):
    num_graphs = len(dataset)
    num_files = num_graphs // EVENTS_PER_FILE
    # num_files += num_graphs % EVENTS_PER_FILE # euh wow
    root.mkdir(parents=True, exist_ok=True)
    root = str(root)
    for file_id in range(num_files):
        print(f'Saving graphs in file: {file_id}')
        to_save = dataset[
            file_id*EVENTS_PER_FILE:
            min(num_graphs, (file_id+1)*EVENTS_PER_FILE)
        ]
        torch.save(to_save, root+f'/graphs_{file_id}.pt')
        save_hdf5(to_save, root+f'/graphs_{file_id}.h5')
    

In [36]:
dsnb_root = Path('/home/amaterasu/work/RootToGraph/playground/profiling_dsnb_vs_atm/dsnb_700/')
atm_root = Path('/home/amaterasu/work/RootToGraph/playground/profiling_atm_vs_atm/atm_700/')
save_graphs(dsnb_dataset, dsnb_root)
save_graphs(atm_dataset, atm_root)

Saving graphs in file: 0
Saving graphs in file: 1
Saving graphs in file: 2
Saving graphs in file: 3
Saving graphs in file: 4
Saving graphs in file: 5
Saving graphs in file: 6
Saving graphs in file: 7
Saving graphs in file: 8
Saving graphs in file: 9
Saving graphs in file: 10
Saving graphs in file: 11
Saving graphs in file: 12
Saving graphs in file: 13
Saving graphs in file: 0
Saving graphs in file: 1
Saving graphs in file: 2
Saving graphs in file: 3
Saving graphs in file: 4
Saving graphs in file: 5
Saving graphs in file: 6
Saving graphs in file: 7
Saving graphs in file: 8
Saving graphs in file: 9
Saving graphs in file: 10
Saving graphs in file: 11
Saving graphs in file: 12
Saving graphs in file: 13


In [37]:
def time_openning(root: str, num_files: int):
    pt_times = []
    h5_times = []
    for i in range(num_files):
        pt_path = root + f'/graphs_{i}.pt'
        h5_path = root + f'/graphs_{i}.h5'
        print(f'Openning file: {i}')
        t0 = time()
        gs = torch.load(pt_path)
        t1 = time()
        gs = load_hdf5(h5_path)
        t2 = time()
        pt_times.append((t1-t0)*1000)  # ms
        h5_times.append((t2-t1)*1000)  # ms
    pt_times = torch.tensor(pt_times)
    h5_times = torch.tensor(h5_times)
    print(f'Torch: {torch.mean(pt_times)} +/- {torch.std(pt_times)} ms')
    print(f'HDF5 : {torch.mean(h5_times)} +/- {torch.std(h5_times)} ms')

In [39]:
dsnb_root = '/home/amaterasu/work/RootToGraph/playground/profiling_dsnb_vs_atm/dsnb_700/'
time_openning(dsnb_root, 13)

Openning file: 0


  gs = torch.load(pt_path)


Openning file: 1
Openning file: 2
Openning file: 3
Openning file: 4
Openning file: 5
Openning file: 6
Openning file: 7
Openning file: 8
Openning file: 9
Openning file: 10
Openning file: 11
Openning file: 12
Torch: 270.6748046875 +/- 63.40267562866211 ms
HDF5 : 915.78125 +/- 24.621042251586914 ms


In [40]:
atm_root = '/home/amaterasu/work/RootToGraph/playground/profiling_atm_vs_atm/atm_700/'
time_openning(atm_root, 13)

Openning file: 0


  gs = torch.load(pt_path)


Openning file: 1
Openning file: 2
Openning file: 3
Openning file: 4
Openning file: 5
Openning file: 6
Openning file: 7
Openning file: 8
Openning file: 9
Openning file: 10
Openning file: 11
Openning file: 12
Torch: 901.8040771484375 +/- 158.5144500732422 ms
HDF5 : 2751.423095703125 +/- 201.0260772705078 ms


In [None]:
x = torch.randn(1000, 3)
edge_index = torch.randn(2, 4000)


In [None]:
%%timeit
for _ in range(1000):
    __ =  Data(x=x, pos=x, edge_index=edge_index)

8.86 ms ± 211 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%%timeit
for _ in range(500):
    __ =  Data(x=x, pos=x, edge_index=edge_index)

4.5 ms ± 133 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
