In [None]:
import torch 
from torch.nn import (BatchNorm1d, Embedding, Linear, ModuleList, ReLU, Sequential)
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.loader import DataLoader 
import torch_geometric.transforms as T
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Dataset, Data
import h5py
from torch_geometric.nn import attention as Att
import os
from typing import List, Optional, Callable
import numpy as np


# Data

We first need to implement a data loader to get the data from the hdf5 files. 

In [None]:
class CausalDataset(Dataset):
    input: str
    dimension: int | None = None
    manifold_codes: list[int] | None = None
    manifold_names: list[str] | None = None
    boundaries: list[str] | None = None
    boundary_codes: list[int] | None = None
    _num_samples: int | None = None

    def __init__(
        self,
        input: list[str],
        output: str,
        transform: list[callable] = None,
        pre_transform: list[callable] = None,
        pre_filter: list[callable] = None,
    ):
        super().__init__(output, transform, pre_transform, pre_filter)
        self.input = input
        self._num_samples = None

        with h5py.File(input, "r") as f:
            self.dimension = f["dimension"]
            self.manifold_codes = f["manifold_codes"]
            self.manifold_names = f["manifold_names"]
            self.manifold_boundaries = f["manifold_boundaries"]
            self.boundaries = f["boundaries"]
            self.boundary_codes = f["boundary_codes"]
            self._num_samples = f["num_causal_sets"]

    @property
    def raw_paths(self):
        return self.input

    @property
    def output(self):
        return self.root

    @property
    def raw_file_names(self):
        return [os.path.basename(self.input)]

    @property
    def processed_file_names(self):
        return [f"data_{i}.pt" for i in range(self.len())]

    def download(self):
        # TODO:
        pass

    def process(self):
        # TODO:
        for data in self.raw_data:
            pass

    def len(self):
        # TODO:
        return self._num_samples

    def get(self, idx):
        # TODO:
        return None

    def num_node_features(self):
        # TODO:
        return 0

    def num_edge_features(self):
        # TODO:
        return 0

    def num_graph_features(self):
        # TODO:
        return 0

    def num_classe(self):
        # TODO:
        return 0

    def _load_graph(self, f:h5py.File, idx: int, float_dtype: torch.dtype, int_dtype: torch.dtype) -> Data:
        # Load adjacency matrix and convert to edge indices
        adj_matrix = torch.tensor(f["adjacency_matrix"][:, :, idx], dtype=float_dtype)
        edge_index, edge_weight = dense_to_sparse(adj_matrix)

        # Load node features
        node_features = []

        # Sprinkling coordinates
        sprinkling = torch.tensor(f["sprinkling"][:, :, idx], dtype=float_dtype)
        node_features.append(sprinkling)

        # Degree information
        in_degrees = torch.tensor(
            f["in_degrees"][:, idx], dtype=float_dtype
        ).unsqueeze(1)
        out_degrees = torch.tensor(
            f["out_degrees"][:, idx], dtype=float_dtype
        ).unsqueeze(1)
        node_features.extend([in_degrees, out_degrees])

        # Path lengths
        max_path_future = torch.tensor(
            f["max_path_lengths_future"][:, idx], dtype=float_dtype
        ).unsqueeze(1)
        max_path_past = torch.tensor(
            f["max_path_lengths_past"][:, idx], dtype=float_dtype
        ).unsqueeze(1)
        node_features.extend([max_path_future, max_path_past])

        # Link-based path lengths
        max_path_future_links = torch.tensor(
            f["max_path_lengths_future_links"][:, idx], dtype=float_dtype
        ).unsqueeze(1)
        max_path_past_links = torch.tensor(
            f["max_path_lengths_past_links"][:, idx], dtype=float_dtype
        ).unsqueeze(1)
        node_features.extend([max_path_future_links, max_path_past_links])

        # Topological ordering
        topo_future = torch.tensor(
            f["topological_order_future"][:, idx], dtype=float_dtype
        ).unsqueeze(1)
        topo_past = torch.tensor(
            f["topological_order_past"][:, idx], dtype=float_dtype
        ).unsqueeze(1)
        node_features.extend([topo_future, topo_past])

        # Concatenate all node features
        x = torch.cat(node_features, dim=1)

        # Load graph-level features (targets for regression)
        manifold_id = torch.tensor(f["manifold_ids"][idx], dtype=int_dtype)
        boundary_id = torch.tensor(f["boundary_ids"][idx], dtype=int_dtype)
        relation_dim = torch.tensor(f["relation_dim"][idx], dtype=float_dtype)
        atom_count = torch.tensor(f["atom_count"][idx], dtype=int_dtype)
        num_sources = torch.tensor(f["num_sources"][idx], dtype=int_dtype)
        num_sinks = torch.tensor(f["num_sinks"][idx], dtype=int_dtype)

        # Create Data object
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_weight.unsqueeze(1) if edge_weight.numel() > 0 else None,
            # Graph-level attributes
            manifold_id=manifold_id,
            boundary_id=boundary_id,
            relation_dim=relation_dim,
            atom_count=atom_count,
            num_sources=num_sources,
            num_sinks=num_sinks,
            # Additional matrices as graph attributes
            adjacency_matrix=adj_matrix,
            link_matrix=torch.tensor(f["link_matrix"][:, :, idx], dtype=float_dtype),
            past_relations=torch.tensor(f["past_relations"][:, :, idx], dtype=float_dtype),
            future_relations=torch.tensor( f["future_relations"][:, :, idx], dtype=float_dtype),
        )

        return data