In [None]:
import torch 
from torch_geometric.loader import DataLoader 
import torch_geometric.transforms as tgt
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Dataset, Data
from torch_geometric.utils import scatter
from functools import lru_cache

import h5py
import os


# Data

We first need to implement a data loader to get the data from the hdf5 files. I can do a lot of stuff that I until now have done in Julia with pytorch-geometric, e.g., graph laplacians and shit like this

In [None]:
# TODO: this function is too long, break it up into smaller ones

def load_graph(
    f: h5py.File, idx: int, float_dtype: torch.dtype, int_dtype: torch.dtype, validate:bool = False
) -> Data:
    # Load adjacency matrix and convert to edge indices
    adj_matrix = torch.tensor(f["adjacency_matrix"][:, :, idx], dtype=float_dtype)
    edge_index, edge_weight = dense_to_sparse(adj_matrix)
    adj_matrix = adj_matrix.to_sparse()
    # Load node features
    node_features = []

    # Sprinkling coordinates
    # sprinkling = torch.tensor(f["sprinkling"][:, :, idx], dtype=float_dtype)

    # node_features.append(sprinkling)

    # Degree information
    in_degrees = torch.tensor(f["in_degrees"][:, idx], dtype=float_dtype).unsqueeze(1)
    out_degrees = torch.tensor(f["out_degrees"][:, idx], dtype=float_dtype).unsqueeze(1)
    node_features.extend([in_degrees, out_degrees])

    # Path lengths
    max_path_future = torch.tensor(
        f["max_path_lengths_future"][:, idx], dtype=float_dtype
    ).unsqueeze(1)
    max_path_past = torch.tensor(
        f["max_path_lengths_past"][:, idx], dtype=float_dtype
    ).unsqueeze(1)
    node_features.extend([max_path_future, max_path_past])

    # Link-based path lengths
    max_path_future_links = torch.tensor(
        f["max_path_lengths_future_links"][:, idx], dtype=float_dtype
    ).unsqueeze(1)
    max_path_past_links = torch.tensor(
        f["max_path_lengths_past_links"][:, idx], dtype=float_dtype
    ).unsqueeze(1)
    node_features.extend([max_path_future_links, max_path_past_links])

    # Topological ordering --> TODO: check again what this does
    topo_future = torch.tensor(
        f["topological_order_future"][:, idx], dtype=float_dtype
    ).unsqueeze(1)
    topo_past = torch.tensor(
        f["topological_order_past"][:, idx], dtype=float_dtype
    ).unsqueeze(1)
    node_features.extend([topo_future, topo_past])

    # Concatenate all node features
    x = torch.cat(node_features, dim=1)

    # Load graph-level features (targets for regression)
    manifold_id = torch.tensor(f["manifold_ids"][idx], dtype=int_dtype)
    boundary_id = torch.tensor(f["boundary_ids"][idx], dtype=int_dtype)
    relation_dim = torch.tensor(f["relation_dim"][idx], dtype=float_dtype)
    atom_count = torch.tensor(f["atom_count"][idx], dtype=int_dtype)
    num_sources = torch.tensor(f["num_sources"][idx], dtype=int_dtype)
    num_sinks = torch.tensor(f["num_sinks"][idx], dtype=int_dtype)
    dimension = torch.tensor(f["dimension"][idx], dtype=int_dtype)

    # matrices
    link_matrix = torch.tensor(f["link_matrix"][:, :, idx], dtype=float_dtype).to_sparse()
    past_relations = torch.tensor(f["past_relations"][:, :, idx], dtype=float_dtype).to_sparse()
    future_relations = torch.tensor(f["future_relations"][:, :, idx], dtype=float_dtype).to_sparse()

    # Create Data object
    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_weight.unsqueeze(1)
        if edge_weight.numel() > 0
        else None,  # Not sure if this is a good idea need to add edge attributes if possible

        # node positions as positional attributes as well
        # pos=sprinkling,

        y = torch.tensor([manifold_id[0], boundary_id[0], dimension[0]]),

        # Graph-level attributes
        manifold_id=manifold_id,
        boundary_id=boundary_id,
        relation_dim=relation_dim,
        dimension=dimension,
        atom_count=atom_count,
        num_sources=num_sources,
        num_sinks=num_sinks,
        # sprinkling=sprinkling,
        
        # Additional matrices as graph attributes. make the shitty past and future relations and all that into node attributes!
        adjacency_matrix=adj_matrix,
        link_matrix=link_matrix,
        past_relations=past_relations,
        future_relations=future_relations,
    )

    if validate:
        data.validate()

    return data

In [None]:
class CsDataset(Dataset):

    def __init__(
        self,
        input: list[str],
        output: str,
        transform: callable | None = None,
        pre_transform: callable | None = None,
        pre_filter: callable | None = None,
        validate_data: bool = False,
    ):
        super().__init__(output, transform, pre_transform, pre_filter)
        self.input = input
        self._num_samples = None
        self.validate_data = validate_data
        
        with h5py.File(input, "r") as f:
            self.dimension = f["dimension"]
            self.manifold_codes = f["manifold_codes"]
            self.manifold_names = f["manifold_names"]
            self.boundaries = f["boundaries"]
            self.boundary_codes = f["boundary_codes"]
            self._num_samples = f["num_causal_sets"]

    @property
    def raw_paths(self):
        return self.input

    @property
    def output(self):
        return self.root

    @property
    def raw_file_names(self):
        return [os.path.basename(self.input)]

    @property
    def processed_file_names(self):
        return [f for f in os.listdir(self.root) if f.startswith("data_") and f.endswith(".pt")]

    def process(self):
        for file in self.raw_paths:
            if not os.path.exists(file):
                raise FileNotFoundError(f"Input file {file} does not exist.")
            with h5py.File(file, "r") as f:
                for idx in range(self.len()): 
                    data = load_graph(f, idx, float_dtype=torch.float32, int_dtype=torch.int64, validate=self.validate_data)

                    if self.pre_filter is not None: 
                        if not self.pre_filter(data):
                            continue
                    if self.pre_transform is not None:
                        data = self.pre_transform(data)
                    torch.save(data, os.path.join(self.processed_dir, f"data_{idx}.pt"))

    def len(self):
        return self._num_samples

    @lru_cache(maxsize=100) # cache some results for faster access
    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f"data_{idx}.pt"))
        if self.transform is not None:
            data = self.transform(data)
        return data
        
