# Data loading

We are leveraging pytorch's data loading capabilities. This is currently only done on a single process because this is a prototype. 

In [None]:
import torch 
torch.multiprocessing.set_start_method('spawn', force=True)  # For multiprocessing support
from torch_geometric.loader import DataLoader 
import torch_geometric.transforms as T
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Dataset, Data
from collections.abc import Callable
import torch.nn.functional as F
from tqdm import tqdm
import pandas as pd
import h5py
import os
import numpy as np
import math
import json
import seaborn as sns
import matplotlib

The function `load_graph` constructs a `pytorch_geometric.Data` object from the data in a hdf5 file and selects some generated data in it to include as node features. This is still very basic. We will use this function to create build a custom Dataset for our graphs.

In [None]:
def load_graph(
    f: h5py.File,
    idx: int,
    float_dtype: torch.dtype,
    int_dtype: torch.dtype,
    validate: bool = False,
) -> Data:
    """Load a single graph from an hdf5 file into a pytorch_geometric Data object.

    Args:
        f (h5py.File): The hdf5 file containing the graph data.
        idx (int): The index of the graph to load from the hdf5 file. The index is shared across all datasets in the hdf5 file, i.e., the same index corresponds to the same graph property across all datasets.
        float_dtype (torch.dtype): The data type to use for floating point tensors.
        int_dtype (torch.dtype): The data type to use for integer tensors.
        validate (bool, optional): Validate the created `Data` object, i.e., making sure that it represents a valid graph. Defaults to False.

    Returns:
        Data: create a single pytorch_geometric Data object from the hdf5 file `f` at index `idx`.
    """
    # In general: these data are julia generated, which is column major and has the 
    # data element index at the last dimension, while python is row major and has the
    # data element index at the first dimension, consequently.

    # pull out the adjacency matrix for graph index idx,
    # make a torch tensor from it, and convert it to edge indices and edge weights which
    # pytorch_geometric uses to define the graph structure.
    adj_raw = f["adjacency_matrix"][idx, :, :]
    adj_matrix = torch.tensor(adj_raw, dtype=float_dtype)
    edge_index, edge_weight = dense_to_sparse(adj_matrix)
    adj_matrix = adj_matrix.to_sparse()

    # Load node features. We are only using degree information and path lengths for now.
    node_features = []

    # Degree information
    in_degrees = torch.tensor(f["in_degrees"][idx, :], dtype=float_dtype).unsqueeze(1)
    out_degrees = torch.tensor(f["out_degrees"][idx, :], dtype=float_dtype).unsqueeze(1)
    node_features.extend([in_degrees, out_degrees])

    # Path lengths
    max_path_future = torch.tensor(
        f["max_path_lengths_future"][idx, :], dtype=float_dtype
    ).unsqueeze(1) # make this a (num_nodes, 1) tensor

    max_path_past = torch.tensor(
        f["max_path_lengths_past"][idx, :], dtype=float_dtype
    ).unsqueeze(1) # make this a (num_nodes, 1) tensor
    node_features.extend([max_path_future, max_path_past])

    # Concatenate all node features
    x = torch.cat(node_features, dim=1)

    # Load graph-level features (targets for classification)
    manifold_id = int(f["manifold_ids"][idx])
    boundary_id = int(f["boundary_ids"][idx])
    relation_dim = torch.tensor(f["relation_dim"][idx], dtype=float_dtype)
    atom_count = torch.tensor(f["atom_count"][idx], dtype=int_dtype)
    num_sources = torch.tensor(f["num_sources"][idx], dtype=int_dtype)
    num_sinks = torch.tensor(f["num_sinks"][idx], dtype=int_dtype)
    dimension = int(f["dimension"][()])

    # matrices as graph attributes
    link_matrix = torch.tensor(
        f["link_matrix"][idx, :, :], dtype=float_dtype
    ).to_sparse()

    past_relations = torch.tensor(
        f["past_relations"][idx, :, :], dtype=float_dtype
    ).to_sparse()

    future_relations = torch.tensor(
        f["future_relations"][idx, :, :], dtype=float_dtype
    ).to_sparse()

    # Create Data object
    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_weight.unsqueeze(1)
        if edge_weight.numel() > 0
        else None,  # Not sure if this is a good idea need to add edge attributes if possible
        
        # the classification target. This concatenates everything into a single 1D tensor. 
        y=torch.tensor([manifold_id, boundary_id, dimension], dtype=torch.long),
        # Graph-level attributes which might be useful later to the model
        manifold_id=manifold_id,
        boundary_id=boundary_id,
        relation_dim=relation_dim,
        dimension=dimension,
        atom_count=atom_count,
        num_sources=num_sources,
        num_sinks=num_sinks,
        # Additional matrices as graph attributes
        adjacency_matrix=adj_matrix,
        link_matrix=link_matrix,
        past_relations=past_relations,
        future_relations=future_relations,
    )

    if validate:  # validate the data object if desired.
        data.validate()

    return data

The function `target_shift` will be used as a simple `pre_transform` for the pytorch dataset. This will trigger the processing of the entire dataset,i.e., each graph therein with this function. We can do a lot more complex things here in the future, if we want. 

In [None]:
def target_shift(data, manifold_classes=6, boundary_classes=3, dimension_classes=3) -> Data:
    """Shift the target values in the data object to a 0-based index and combine them into a single tensor.

    Args:
        data (Data): The input data object.
        manifold_classes (int, optional): The number of manifold classes. Defaults to 6.
        boundary_classes (int, optional): The number of boundary classes. Defaults to 3.
        dimension_classes (int, optional): The number of dimension classes. Defaults to 3.

    Returns:
        Data: The modified data object with shifted targets.
    """
    manifold_id = data.y[0].long() - 1  # convert to 0 based from 1 based: 1,2,3,4,5,6 -> 0,1,2,3,4,5
    boundary_id = data.y[1].long() - 1  # convert to 0 based from 1 based: 1,2,3 -> 0,1,2
    dimension = data.y[2].long() - 2  # convert to 0 based -> 2D is the lowest we can have, so 2, 3, 4 -> 0, 1, 2
    
    # put the manifold_id, boundary_id, and dimension into a single tensor used as a target
    data.y_original = data.y  # Keep original for reference
    data.y = torch.tensor([[dimension, boundary_id, manifold_id],], dtype=torch.long)

    # add some more metadata
    data.target_info={
        'manifold_classes': manifold_classes,
        'boundary_classes': boundary_classes,
        'dimension_classes': dimension_classes,
        'dimension_offset': 0,   
        'boundary_offset': dimension_classes,
        'manifold_offset': dimension_classes + boundary_classes, 
        'total_classes': manifold_classes + boundary_classes + dimension_classes,
    }
    
    return data

### Implement a custom Dataset class 

This implements a custom Dataset as described [here](https://pytorch-geometric.readthedocs.io/en/2.5.3/tutorial/create_dataset.html). We are creating this as a dataset that lives on disk and loads examples dynamically instead of building an `InMemoryDataset` that lives only in RAM.  A Dataset becomes a dataset by implementing the `get` and `len` functions. Here, we overwrite also the `process ` function because this is particular to our data storage system. Here, because this is a prototype, we are also processing every data example into a single pytorch `.pt` file, which essentially python-pickles the `Data` objects created by `load_graph`. This is not optimal performance-wise, but works in a very simple way without having to handle perpetually open files accessed from multiple places, and we don´t have to immediatelly take care of caching and everything. For the current prototype, this is good enough. 

In [None]:
class CsDataset(Dataset):
    """Custom dataset implementation for loading and processing causal set data from hdf5 files.

    Args:
        Dataset (pytorch_geometric.data.Dataset): Base class for PyTorch Geometric datasets.
    """
    def __init__(
        self,
        input: list[str],
        output: str,  # = root directory for processed data
        transform: Callable[[Data], Data] | None = None,
        pre_transform: Callable[[Data], Data] | None = None,
        pre_filter: Callable[[Data], Data] | None = None,
        validate_data: bool = False,
        loader: Callable[[h5py.File, torch.dtype, torch.dtype, bool], Data] = load_graph,
    ):
        """Create a new CsDataset instance for loading and processing causal set data from hdf5 files.

        Args:
            input (list[str]): List of input file paths. We can have an arbitrary number of input files, so this is a list of file paths.
            output (str): Root directory for processed data.
            pre_transform (Callable[[Data], Data] | None, optional): Optional transform to be applied to the data before loading. Defaults to None. These are only applied once to the whole dataset. Afterwards the data is stored in processed form and is only loaded from disk.
            pre_filter (Callable[[Data], Data] | None, optional): Optional filter to be applied to the data before loading. Defaults to None. These are only applied once to the whole dataset.
            validate_data (bool, optional): Whether to validate the data after loading. Defaults to False.
            loader (Callable[[h5py.File, torch.dtype, torch.dtype, bool], Data], optional): Function to load the data from the hdf5 file. Defaults to load_graph.

        Raises:
            ValueError: If the input files are not valid.
            ValueError: If the output directory is not valid.
            FileNotFoundError: If the input files do not exist.
        """
        self.input = input
        self._num_samples = None
        self.validate_data = validate_data
        self.root = output
        self.loader = loader
        self.root = output

        if len(self.processed_file_names) > 0:
            self.num_samples = len(self.processed_file_names)

            if len(self.processed_file_names) == 0:
                raise ValueError("No processed data found in the output directory.")
            
            self._load_metadata()

            self._num_samples = len(self.processed_file_names)
        else:
            if input is None or len(input) == 0:
                raise ValueError("Input files must be provided for processing.")
            
            with h5py.File(input[0], "r") as f:
                self.manifold_codes = f["manifold_codes"][()]
                self.manifold_names = f["manifolds"][()]
                self.boundaries = f["boundaries"][()]
                self.boundary_codes = f["boundary_codes"][()]

            self._num_samples = 0
            for file in self.input:
                if not os.path.exists(file):
                    raise FileNotFoundError(f"Input file {file} does not exist.")
                with h5py.File(file, "r") as f:
                    self._num_samples += f["num_causal_sets"][()]
                    print(f"Processing file: {file}, current number of samples: {self._num_samples}")

        super().__init__(output, transform, pre_transform, pre_filter)


    def _load_metadata(self): 
        """Get the metadata from the processed directory, which contains manifold codes, manifold names, boundaries, and boundary codes.

        Raises:
            FileNotFoundError: If the metadata file does not exist.
        """
        metadata_path = os.path.join(self.processed_dir, "metadata.json")
        if not os.path.exists(metadata_path):
            raise FileNotFoundError(f"Metadata file {metadata_path} does not exist.")
        
        with open(metadata_path, "r") as f:
            metadata = json.load(f)
            self.manifold_codes = metadata["manifold_codes"]
            self.manifold_names = metadata["manifolds"]
            self.boundaries = metadata["boundaries"]
            self.boundary_codes = metadata["boundary_codes"]

    @property
    def raw_paths(self) -> list[str]:
        """Get the raw paths of the input files. We can have an arbitrary number of input files, so this is a list of file paths.

        Returns:
            list[str]: A list of raw file paths.
        """
        return self.input

    @property
    def output(self) -> str:
        """Get the output directory for processed files.

        Returns:
            str: The output directory path.
        """
        return self.root

    @property
    def raw_file_names(self) -> list[str]:
        """Get the raw file names of the input files.

        Returns:
            list[str]: A list of raw file names.
        """
        return [os.path.basename(f) for f in self.input]

    @property
    def processed_file_names(self) -> list[str]:
        """Get the processed file names in the output directory.

        Returns:
            list[str]: A list of processed file names.
        """
        if os.path.isdir(self.processed_dir) is False: 
            return []

        all_files = os.listdir(self.processed_dir)
        return [
            f
            for f in all_files
            if f.startswith("data_") and f.endswith(".pt")
        ]

    def process(self) -> None:
        """Process the input files and save the processed data to the output directory as PyTorch .pt files, one file per sample.

        Raises:
            FileNotFoundError: If an input file does not exist.
        """
        # Convert NumPy arrays to Python lists for JSON serialization
        if not os.path.exists(self.processed_dir) or len(self.processed_file_names) == 0: 
            d = {
                "manifold_codes": [v.item() for v in self.manifold_codes],
                "manifolds": [str(m) for m in self.manifold_names],
                "boundaries": [str(m) for m in self.boundaries],
                "boundary_codes": [v.item() for v in self.boundary_codes],
            }

            with open(os.path.join(self.processed_dir, "metadata.json"), "w") as f:
                json.dump(d, f)

            file_index = 0
            for file in self.raw_paths:
                if not os.path.exists(file):
                    raise FileNotFoundError(f"Input file {file} does not exist.")
                with h5py.File(file, "r") as f:
                    for idx in tqdm(range(f["num_causal_sets"][()])):
                        data = self.loader(
                            f,
                            idx,
                            float_dtype=torch.float32,
                            int_dtype=torch.int64,
                            validate=self.validate_data,
                        )
                        if self.pre_filter is not None:
                            if not self.pre_filter(data):
                                continue
                        if self.pre_transform is not None:
                            data = self.pre_transform(data)
                        torch.save(data.to('cpu'), os.path.join(self.processed_dir, f"data_{file_index}.pt"))
                        file_index += 1

    def len(self)-> int:
        """Get the length of the dataset as the number of samples.

        Returns:
            int: The number of samples in the dataset. 
        """
        return self._num_samples

    def get(self, idx: int) -> Data:
        """ Retrieve a single sample from the dataset.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            pytorch_geometric.data.Data: A single data object corresponding to the index `idx`, i.e., a single sample from the dataset.
        """
        data = torch.load(os.path.join(self.processed_dir, f"data_{idx}.pt"), weights_only=False)
        if self.transform is not None:
            data = self.transform(data)
        return data


# Model definition

With the data code taken care of, we can go on to build the actual model code. 
The operators beside `GCNConv` are not used in the code, but are imported for potential use in the future with experimentations. The same holds for the pooling operations besides `global_mean_pool` 

In [None]:
import torch 
import torch_geometric
import os
from torch.nn import Linear 
from torch_geometric.nn.conv import GCNConv, GATConv, SAGEConv, GraphConv, GATv2Conv  
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, SAGPooling, Set2Set


## Architecture

We are using graph convolutional network operators here, i.e., we are not dealing with anything transformer like or so. For the simple experiments we are doing here at the moment, the architecture is very simple, but might have to become more complex later when we are dealing with larger, more complex graphs. See the diagram below for how it works: 

![gcnblock](img/gcnblock.png)

This comes a little out of the blue, especially since it's not clear how many layers we should use here. The idea is the following, however: 
- the graph convolutional layers build new presentations based on local neighborhoods that are successively aggregated until the entire graph is covered, layer by layer. 
- this creates feature squashing for deep systems, so we cannot make them too deep 
- deep nets become difficult to train at some point, so it's a good idea to put in some gradient regularization like batch-norm. This comes before the activation to not have strongly fluctuating values amplified further (although for relu that shouldn't be much of a problem). 
- the skip connection counteracts the feature squashing by reinjecting unprocessed neighborhood information. This helps with avoiding vanishing gradients
- dropout helps with generalization and counteracts overfitting for overparameterized models 

As for how many of these we should have - 
- one or two is too little probably 
- 4 or 5 are too many because squashing becomes stronger with growing depth
- 3 seems to be a good starting point


The skip connection is placed a little oddly perhaps, it might make more sense to put it around the whole chain of blocks? 



In [None]:
class GCNBlock(torch.nn.Module):
    """Graph Convolutional Network (GCN) Block.

    Args:
        torch (Module): PyTorch module.
    """

    def __init__(
        self,
        input_dim,
        output_dim,
        dropout=0.5,
        gcn_type=GCNConv,
        normalizer=torch.nn.Identity,
        activation=F.relu,
        gcn_kwargs=None,
    ):
        """Initialize the GCNBlock.

        Args:
            input_dim (int): Input feature dimension.
            output_dim (int): Output feature dimension.
            dropout (float, optional): Dropout rate. Defaults to 0.5.
            gcn_type (torch.nn.Module, optional): Type of GCN layer to use. Defaults to GCNConv.
            normalizer (torch.nn.Module, optional): Batch normalization layer. Defaults to torch.nn.Identity. with identity given, no normalization is applied.
            activation (Callable[[Tensor], Tensor], optional): Activation function. Defaults to F.relu.
            gcn_kwargs (dict, optional): Additional arguments for the GCN layer constructor. Defaults to None.
        """
        super(GCNBlock, self).__init__()
        self.dropout = dropout
        self.gcn_type = gcn_type
        self.conv = gcn_type(
            input_dim, output_dim, **(gcn_kwargs if gcn_kwargs else {})
        )
        self.activation = activation
        self.normalizer = normalizer

        if input_dim != output_dim:
            # Use 1x1 convolution for projection into a different dimensional space to enable skip connections
            # we are using a linear layer without bias
            self.projection = Linear(input_dim, output_dim, bias=False)
        else:
            self.projection = torch.nn.Identity()

    def forward(
        self,
        x: torch.tensor,
        edge_index: torch.tensor,
        edge_weight: torch.tensor = None,
        kwargs: dict = None,
    ):
        """Forward pass for the GCNBlock.

        Args:
            x (Tensor): Input node features.
            edge_index (Tensor): Graph connectivity in COO format.
            edge_weight (Tensor, optional): Edge weights. Defaults to None.
            kwargs (dict, optional): Additional arguments for the GCN layer call. Defaults to None.

        Returns:
            Tensor: Output node features.
        """
        x_res = x
        x = self.conv(
            x, edge_index, edge_weight=edge_weight, **(kwargs if kwargs else {})
        )  # Apply the GCN layer
        x = self.normalizer(
            x,
        )  # this is a no-op if batch normalization is not used
        x = self.activation(x)
        x = x + self.projection(
            x_res
        )  # skip connection. this is a no-op if input_dim == output_dim
        x = F.dropout(
            x, p=self.dropout, training=self.training
        )  # this is only applied during training and acts as a noise regularization
        return x

We collect a chain of `GCNBlock`s into a `GCNBackbone` structure. All this does is to apply a sequence of `GCNBlock`s to the data, one after the other. We could think about adding a final skip connection here too if the backbone becomes too deep. 

In [None]:
class GCNBackbone(torch.nn.Module): 
    """Graph Convolutional Network Backbone.

    Args:
        torch (Module): PyTorch module.
    """
    def __init__(self, gcn_net: list[GCNBlock]):
        """Initialize the GCNBackbone.

        Args:
            gcn_net (list[GCNBlock]): List of GCNBlock layers.
        """
        super(GCNBackbone, self).__init__()
        self.gcn_net = torch.nn.ModuleList(gcn_net)

    def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor = None, gcn_kwargs: dict = None):
        """Forward pass for the GCNBackbone.

        Args:
            x (Tensor): Input node features.
            edge_index (Tensor): Graph connectivity in COO format.
            edge_weight (Tensor, optional): Edge weights. Defaults to None.
            gcn_kwargs (dict, optional): Additional arguments for the GCN layer call. Defaults to None.

        Returns:
            Tensor: Output node features.
        """
        out = x
        for layer in self.gcn_net:
            out = layer(out, edge_index, edge_weight=edge_weight, kwargs=gcn_kwargs)
        return out

The classifier block is the part of the model that produces the output classification. It passes the input graph representation of size 'input_dim' (typically the number of node features) through several affine ('linear') layers to create a new embedding of the graph representation, and then passes it through 3 parallel affine layers that create the final classification for dimension, boundary and manifold. This doesn´t add any regularization inbetween, as it should not be a very deep network. If it has to be, that is probably a sign that something with the backbone if off. 

In [None]:
class ClassifierBlock(torch.nn.Module):
    """Classifier Block for multi-class classification.

    Args:
        torch (Module): PyTorch module.
    """
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_dims: list[int],
        manifold_classes: int = 6,
        boundary_classes: int = 3,
        dimension_classes: int = 3,
        activation: Callable[[torch.Tensor], torch.tensor] = F.relu,
        linear_kwargs: list[dict] = None,
        dim_kwargs: dict = None,
        boundary_kwargs: dict = None,
        manifold_kwargs: dict = None,
    ):
        """Initialize the ClassifierBlock.

        Args:
            input_dim (int): Input feature dimension.
            output_dim (int): Output feature dimension.
            hidden_dims (list[int]): List of hidden layer dimensions.
            manifold_classes (int, optional): Number of manifold classes. Defaults to 6.
            boundary_classes (int, optional): Number of boundary classes. Defaults to 3.
            dimension_classes (int, optional): Number of dimension classes. Defaults to 3.
            activation (Callable[[Tensor], Tensor], optional): Activation function. Defaults to F.relu.
            linear_kwargs (list[dict], optional): Additional arguments for linear layers. Defaults to None.

        Args:
            input_dim (int): Input feature dimension.
            output_dim (int): Output feature dimension.
            hidden_dims (list[int]): List of hidden layer dimensions.
            manifold_classes (int, optional): Number of manifold classes. Defaults to 6.
            boundary_classes (int, optional): Number of boundary classes. Defaults to 3.
            dimension_classes (int, optional): Number of dimension classes. Defaults to 3.
            activation (Callable[[Tensor], Tensor], optional): Activation function. Defaults to F.relu.
            linear_kwargs (list[dict], optional): Additional arguments for linear layers. Defaults to None.
            dim_kwargs (dict, optional): Additional arguments for dimension layer. Defaults to None.
            boundary_kwargs (dict, optional): Additional arguments for boundary layer. Defaults to None.
            manifold_kwargs (dict, optional): Additional arguments for manifold layer. Defaults to None.
        """
        super(ClassifierBlock, self).__init__()
        self.activation = activation
        self.hidden_dims = hidden_dims
        self.total_classes = manifold_classes + boundary_classes + dimension_classes
        self.manifold_classes = manifold_classes
        self.boundary_classes = boundary_classes
        self.dimension_classes = dimension_classes

        if len(hidden_dims) == 0:
            self.backbone = Linear(input_dim, output_dim, **(linear_kwargs[0] if linear_kwargs else {}))
        else:
            layers = []
            in_dim = input_dim
            for (i, hidden_dim) in enumerate(hidden_dims):
                layers.append(
                    Linear(in_dim, hidden_dim, **(linear_kwargs[i] if linear_kwargs and linear_kwargs[i] else {}))
                )  # check again if we need the bias there, I don't think so actually...
                layers.append(activation)
                in_dim = hidden_dim

            self.backbone = torch.nn.Sequential(*layers)

            self.dim_layer = torch.nn.Linear(hidden_dim, self.dimension_classes, **(dim_kwargs if dim_kwargs else {}))

            self.boundary_layer = torch.nn.Linear(hidden_dim, self.boundary_classes, **(boundary_kwargs if boundary_kwargs else {}))

            self.manifold_layer = torch.nn.Linear(hidden_dim, self.manifold_classes, **(manifold_kwargs if manifold_kwargs else {}))

    def forward(self, x: torch.tensor, backbone_kwargs: dict=None, dim_layer_kwargs: dict=None, boundary_layer_kwargs: dict=None, manifold_layer_kwargs: dict=None)->tuple[torch.Tensor, torch.tensor, torch.tensor]:
        """Forward pass for the ClassifierBlock.

        Args:
            x (torch.Tensor): Input tensor.
            backbone_kwargs (dict, optional): Additional arguments for backbone layer. Defaults to None.
            dim_layer_kwargs (dict, optional): Additional arguments for dimension layer. Defaults to None.
            boundary_layer_kwargs (dict, optional): Additional arguments for boundary layer. Defaults to None.
            manifold_layer_kwargs (dict, optional): Additional arguments for manifold layer. Defaults to None.

        Returns:
            tuple[torch.Tensor, torch.tensor, torch.tensor]: Dimension, boundary, and manifold logits.
        """
        x = self.backbone(
            x, 
            **(backbone_kwargs if backbone_kwargs is not None else {})
        )  

        dim_logit = self.dim_layer(x, **(dim_layer_kwargs if dim_layer_kwargs is not None else {}))
        boundary_logit = self.boundary_layer(x, **(boundary_layer_kwargs if boundary_layer_kwargs is not None else {}))
        manifold_logit = self.manifold_layer(x, **(manifold_layer_kwargs if manifold_layer_kwargs is not None else {}))

        return dim_logit, boundary_logit, manifold_logit

We also have a block for including graph features. These would normally be added after the graph representation has been created from the node features and are used to inject additional global information into the network. This would be done via concatenation of the graph representation created by the `GCNBackbone` and the output of a `GraphFeaturesBlock` instance. Alternatively, they can be summed up, too. 

This is somewhat similar to the classifier in that it passes the input features through a number of affine/linear layers to provide the model with enough capacity to learn useful representations of them to add to the output of the `GCNBackbone`. Currently, this has no regularizers, because it has not been tested in the model yet. 

In [None]:
class GraphFeaturesBlock(torch.nn.Module):
    """Graph Features Block for processing global graph features.

    Args:
        torch (Module): PyTorch module.
    """
    def __init__(self, input_dim: int, output_dim: int, hidden_dims: int, activation: Callable[[torch.Tensor], torch.tensor] = F.relu, linear_kwargs: list[dict] = None, final_linear_kwargs: dict = None):
        """Initialize the GraphFeaturesBlock.

        Args:
            input_dim (int): Input feature dimension.
            output_dim (int): Output feature dimension.
            hidden_dims (list[int]): List of hidden layer dimensions.
            activation (Callable[[torch.Tensor], torch.tensor], optional): Activation function. Defaults to F.relu.
            linear_kwargs (list[dict], optional): Additional arguments for linear layers. Defaults to None.
            final_linear_kwargs (dict, optional): Additional arguments for final linear layer. Defaults to None.
        """
        super(GraphFeaturesBlock, self).__init__()
        self.activation = activation
        self.hidden_dims = hidden_dims

        if len(hidden_dims) == 0: 
            self.layers = [Linear(input_dim, output_dim, **(linear_kwargs[0] if linear_kwargs else {})),]
        else: 
            layers = []
            in_dim = input_dim
            for (i, hidden_dim) in enumerate(hidden_dims): 
                layers.append(Linear(in_dim, hidden_dim, **(linear_kwargs[i] if linear_kwargs and linear_kwargs[i] else {})))
                layers.append(activation)
                in_dim = hidden_dim
            layers.append(Linear(in_dim, output_dim, **(final_linear_kwargs if final_linear_kwargs else {})))
            self.layers = layers

    def forward(self, x: torch.tensor) -> torch.tensor:
        """Forward pass for the GraphFeaturesBlock.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.tensor: Output tensor.
        """
        for layer in self.layers: 
            x = layer(x)
        
        return x

### Main model class 
This class binds everything together - the `GCNBackbone`, the `Classifier` and the optional `GraphFeaturesBlock`. Additionally, it adds a pooling layer after the `GCNBackbone` to create the final graph representation. 

In [None]:
class GCNModel(torch.nn.Module):
    """Torch module for the full GCN model, which consists of a GCN backbone, a classifier, and a pooling layer, augmented with optional graph features network.
    Args:
        torch.nn.Module: base class
    """

    def __init__(
        self,
        gcn_net: GCNBackbone,
        classifier: ClassifierBlock,
        pooling_layer: torch.nn.Module,
        use_graph_features: bool = False,
        graph_features_net: torch.nn.Module = torch.nn.Identity,
    ):
        """Initialize the GCNModel.

        Args:
            gcn_net (GCNBackbone): GCN backbone network.
            classifier (ClassifierBlock): Classifier block.
            pooling_layer (torch.nn.Module): Pooling layer.
            use_graph_features (bool, optional): Whether to use graph features. Defaults to False.
            graph_features_net (torch.nn.Module, optional): Graph features network. Defaults to torch.nn.Identity.
        """
        super(GCNModel, self).__init__()
        self.gcn_net = gcn_net
        self.classifier = classifier
        self.graph_features_net = graph_features_net
        self.use_graph_features = use_graph_features
        self.pooling_layer = pooling_layer

    def get_embeddings(
        self,
        x: torch.tensor,
        edge_index: torch.tensor,
        batch: torch.tensor,
        edge_weight: torch.tensor = None,
        gcn_kwargs: dict = None,
    ):
        """Get embeddings from the GCN model.
        Args:
            x (torch.Tensor): Input node features.
            edge_index (torch.Tensor): Graph connectivity information.
            batch (torch.Tensor): Batch vector.
            edge_weight (torch.Tensor, optional): Edge weights. Defaults to None.
            gcn_kwargs (dict, optional): Additional arguments for GCN. Defaults to None.
        Returns:
            torch.Tensor: Node embeddings after GCN processing and pooling.
        """

        # apply the GCN backbone to the node features
        x = self.gcn_net(
            x, edge_index, edge_weight=edge_weight, **(gcn_kwargs if gcn_kwargs else {})
        )

        # for single graph processing, we can have batch=None, so we need to handle that case
        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)

        # pool everything together into a single graph representation
        x = self.pooling_layer(x, batch)
        return x

    def forward(
        self,
        x: torch.tensor,
        edge_index: torch.tensor,
        batch: torch.tensor,
        edge_weight: torch.tensor = None,
        graph_features: torch.tensor = None,
        gcn_kwargs: dict = None,
    ) -> tuple[torch.Tensor, torch.tensor, torch.tensor]:
        """Forward pass for the GCNModel.

        Args:
            x (torch.Tensor): Input node features.
            edge_index (torch.Tensor): Graph connectivity information.
            batch (torch.Tensor): Batch vector.
            edge_weight (torch.Tensor, optional): Edge weights. Defaults to None.
            graph_features (torch.Tensor, optional): Graph features. Defaults to None.
            gcn_kwargs (dict, optional): Additional arguments for GCN. Defaults to None.

        Returns:
            tuple[torch.Tensor, torch.tensor, torch.tensor]: Manifold, boundary, and dimension logits.
        """

        # apply the GCN backbone to the node features
        x = self.gcn_net(
            x, edge_index, edge_weight=edge_weight, **(gcn_kwargs if gcn_kwargs else {})
        )

        # for single graph processing, we can have batch=None, so we need to handle that case
        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)

        # pool everything together into a single graph representation
        x = self.pooling_layer(x, batch)

        # If we have graph features, we need to process them and concatenate them with the node features
        if self.use_graph_features:
            graph_features = self.graph_features_net(graph_features)
            x = torch.cat(
                (x, graph_features), dim=-1
            )  # -1 -> last dim. This concatenates, but we also could sum them

        # Classifier creates raw the logits for manifold, boundary, and dimension classification
        # no softmax or sigmoid is applied here, as we want to keep the logits for loss calculation
        dim, boundary, manifold = self.classifier(x)

        return dim, boundary, manifold

# Training

## Loss function

In [None]:
def criterion(x_pred: torch.tensor, y: torch.tensor, dim_kwargs: dict = None, boundary_kwargs: dict=None, manifold_kwargs: dict=None, dim_weight: float = 1.0, boundary_weight: float=1.0, manifold_weight: float=1.0):
    """Compute the loss for the GCN model.

    Args:
        x_pred (torch.tensor): Predicted logits.
        y (torch.tensor): Ground truth labels.
        dim_kwargs (dict, optional): Additional arguments for dimension loss. Defaults to None.
        boundary_kwargs (dict, optional): Additional arguments for boundary loss. Defaults to None.
        manifold_kwargs (dict, optional): Additional arguments for manifold loss. Defaults to None.
        dim_weight (float, optional): Weight for dimension loss. Defaults to 1.0.
        boundary_weight (float, optional): Weight for boundary loss. Defaults to 1.0.
        manifold_weight (float, optional): Weight for manifold loss. Defaults to 1.0.

    Returns:
        tuple: Total loss, dimension loss, boundary loss, manifold loss
    """
    # split up the logits into the three parts we need for the loss calculation
    # the logits are expected to be in the order: [dim_logits, boundary_logits,
    # manifold_logits]

    # since we have an unequal number of classes, we use inverse probability weighting
    # to balance the loss contributions from each class. This is done by weighting the loss
    dim_logits= x_pred[0]
    boundary_logits = x_pred[1]
    manifold_logits = x_pred[2]

    dim_truth = y[:, 0]
    boundary_truth = y[:, 1]
    manifold_truth = y[:, 2]

    dim_cel = torch.nn.CrossEntropyLoss(**(dim_kwargs if dim_kwargs else {}))
    boundary_cel = torch.nn.CrossEntropyLoss(**(boundary_kwargs if boundary_kwargs else {}))
    manifold_cel = torch.nn.CrossEntropyLoss(**(manifold_kwargs if manifold_kwargs else {}))

    loss_dim = dim_cel(dim_logits, dim_truth)
    loss_boundary = boundary_cel(boundary_logits, boundary_truth)
    loss_manifold = manifold_cel(manifold_logits, manifold_truth)

    total_loss = loss_dim * dim_weight + loss_boundary * boundary_weight + loss_manifold * manifold_weight

    return total_loss, loss_dim, loss_boundary, loss_manifold


## function for training the datamodel


We record mean and std-deviation (over batches) of the losses per epoch to investigate later

In [None]:
def train(
    model: GCNModel,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    epochs: int = 10,
    device: str = "cpu",
    criterion=criterion,
    early_stopping_patience: int = 10,
    early_stopping_window: int = 5,
    early_stopping_delta: float = 0.01,
) -> tuple[pd.DataFrame, GCNModel, int, float]:
    """Train the model and return the training results.

    Args:
        model (GCNModel): Model to be trained.
        train_loader (DataLoader): DataLoader for the training set.
        val_loader (DataLoader): DataLoader for the validation set.
        optimizer (torch.optim.Optimizer): Optimizer for the training process.
        epochs (int, optional): Number of training epochs. Defaults to 10.
        device (str, optional): Device to run the model on. Defaults to 'cpu'.
        criterion (_type_, optional): Loss function to be used. Defaults to criterion.
        early_stopping_patience (int, optional): Patience for early stopping. Defaults to 10.
        early_stopping_delta (float, optional): Minimum change to qualify as an improvement. Defaults to 0.01.

    Returns:
        tuple[pd.DataFrame, Dict, int, int]: A tuple containing the loss data, trained model, number of epochs, and best validation loss.
    """
    # make the losses into pandas dataframes, so we can easily plot them later and have them more orangized

    loss_data = pd.DataFrame(
        np.nan,
        index=range(epochs),
        columns=[
            "total_training_loss_mean",
            "total_training_loss_std",
            "dim_training_loss_mean",
            "dim_training_loss_std",
            "boundary_training_loss_mean",
            "boundary_training_loss_std",
            "manifold_training_loss_mean",
            "manifold_training_loss_std",
            "total_validation_loss_mean",
            "total_validation_loss_std",
            "dim_validation_loss_mean",
            "dim_validation_loss_std",
            "boundary_validation_loss_mean",
            "boundary_validation_loss_std",
            "manifold_validation_loss_mean",
            "manifold_validation_loss_std",
        ],
    )

    best_val_loss = float("inf")
    best_epoch = 0
    best_model = None
    current_patience = early_stopping_patience

    lossfunc = criterion  # Use the custom criterion defined above
    model = model.to(device)

    for epoch in range(epochs):
        total_training_loss_epoch = np.zeros(len(train_loader), dtype=np.float64)
        dim_training_loss_epoch = np.zeros(len(train_loader), dtype=np.float64)
        boundary_training_loss_epoch = np.zeros(len(train_loader), dtype=np.float64)
        manifold_training_loss_epoch = np.zeros(len(train_loader), dtype=np.float64)

        model.train()
        for batchnum, data in enumerate(
            tqdm(train_loader, desc=f"Epoch {epoch} Training")
        ):
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch)
            total_loss, loss_dim, loss_boundary, loss_manifold = lossfunc(out, data.y)
            total_loss.backward()
            optimizer.step()

            total_training_loss_epoch[batchnum] = total_loss.item()
            dim_training_loss_epoch[batchnum] = loss_dim.item()
            boundary_training_loss_epoch[batchnum] = loss_boundary.item()
            manifold_training_loss_epoch[batchnum] = loss_manifold.item()

        # Store the losses in the loss_data DataFrame
        loss_data.loc[epoch, "total_training_loss_mean"] = (
            total_training_loss_epoch.mean()
        )
        loss_data.loc[epoch, "total_training_loss_std"] = (
            total_training_loss_epoch.std()
        )
        loss_data.loc[epoch, "dim_training_loss_mean"] = dim_training_loss_epoch.mean()
        loss_data.loc[epoch, "dim_training_loss_std"] = dim_training_loss_epoch.std()
        loss_data.loc[epoch, "boundary_training_loss_mean"] = (
            boundary_training_loss_epoch.mean()
        )
        loss_data.loc[epoch, "boundary_training_loss_std"] = (
            boundary_training_loss_epoch.std()
        )
        loss_data.loc[epoch, "manifold_training_loss_mean"] = (
            manifold_training_loss_epoch.mean()
        )
        loss_data.loc[epoch, "manifold_training_loss_std"] = (
            manifold_training_loss_epoch.std()
        )

        # Validation step
        total_validation_loss_epoch = np.zeros(len(val_loader), dtype=np.float64)
        dim_validation_loss_epoch = np.zeros(len(val_loader), dtype=np.float64)
        boundary_validation_loss_epoch = np.zeros(len(val_loader), dtype=np.float64)
        manifold_validation_loss_epoch = np.zeros(len(val_loader), dtype=np.float64)

        model.eval()
        with torch.no_grad():
            for batchnum, data in enumerate(
                tqdm(val_loader, desc=f"Epoch {epoch} Validation")
            ):
                data = data.to(device)

                out = model(
                    data.x, data.edge_index, data.batch
                )  # this fails with incomplete batches
                total_loss, loss_dim, loss_boundary, loss_manifold = lossfunc(
                    out, data.y
                )

                total_validation_loss_epoch[batchnum] = total_loss.item()
                dim_validation_loss_epoch[batchnum] = loss_dim.item()
                boundary_validation_loss_epoch[batchnum] = loss_boundary.item()
                manifold_validation_loss_epoch[batchnum] = loss_manifold.item()

            # Store the validation losses in the loss_data DataFrame
            loss_data.loc[epoch, "total_validation_loss_mean"] = (
                total_validation_loss_epoch.mean()
            )
            loss_data.loc[epoch, "total_validation_loss_std"] = (
                total_validation_loss_epoch.std()
            )
            loss_data.loc[epoch, "dim_validation_loss_mean"] = (
                dim_validation_loss_epoch.mean()
            )
            loss_data.loc[epoch, "dim_validation_loss_std"] = (
                dim_validation_loss_epoch.std()
            )
            loss_data.loc[epoch, "boundary_validation_loss_mean"] = (
                boundary_validation_loss_epoch.mean()
            )
            loss_data.loc[epoch, "boundary_validation_loss_std"] = (
                boundary_validation_loss_epoch.std()
            )
            loss_data.loc[epoch, "manifold_validation_loss_mean"] = (
                manifold_validation_loss_epoch.mean()
            )
            loss_data.loc[epoch, "manifold_validation_loss_std"] = (
                manifold_validation_loss_epoch.std()
            )

        # Check for early stopping. We could also monitor F1 or precision/recall metrics to get a better estimate of the model performance, but for now we use the validation loss

        smoothed_val_loss = (
            loss_data.loc[:, "total_validation_loss_mean"]
            .rolling(window=early_stopping_window, min_periods=1)
            .mean()
        )

        if smoothed_val_loss[epoch] < best_val_loss - early_stopping_delta:
            best_val_loss = smoothed_val_loss[epoch]
            current_patience = early_stopping_patience  # Reset patience
            best_model = model.state_dict()
            best_epoch = epoch
        else:
            print(
                f"  No improvement in validation loss: {loss_data.loc[epoch, 'total_validation_loss_mean']:.4f} at epoch {epoch}, current patience: {current_patience}"
            )
            current_patience -= 1

        if current_patience <= 0:
            print(
                f"  Early stopping at epoch {epoch} with best validation loss: {best_val_loss:.4f}"
            )
            if best_model is not None:
                model.load_state_dict(best_model)
            break

        print(
            f"  training loss: {loss_data.loc[epoch, 'total_training_loss_mean']:.4f} ± {loss_data.loc[epoch, 'total_training_loss_std']:.4f}, Dim : {loss_data.loc[epoch, 'dim_training_loss_mean']:.4f} ± {loss_data.loc[epoch, 'dim_training_loss_std']:.4f}, Boundary : {loss_data.loc[epoch, 'boundary_training_loss_mean']:.4f} ± {loss_data.loc[epoch, 'boundary_training_loss_std']:.4f}, Manifold : {loss_data.loc[epoch, 'manifold_training_loss_mean']:.4f} ± {loss_data.loc[epoch, 'manifold_training_loss_std']:.4f}"
        )

        print(
            f"  validation loss: {loss_data.loc[epoch, 'total_validation_loss_mean']:.4f} ± {loss_data.loc[epoch, 'total_validation_loss_std']:.4f}, Dim : {loss_data.loc[epoch, 'dim_validation_loss_mean']:.4f} ± {loss_data.loc[epoch, 'dim_validation_loss_std']:.4f}, Boundary : {loss_data.loc[epoch, 'boundary_validation_loss_mean']:.4f} ± {loss_data.loc[epoch, 'boundary_validation_loss_std']:.4f}, Manifold : {loss_data.loc[epoch, 'manifold_validation_loss_mean']:.4f} ± {loss_data.loc[epoch, 'manifold_validation_loss_std']:.4f}"
        )
        torch.cuda.empty_cache()
    loss_data = loss_data.loc[0 : epoch + 1, :]  # Keep only the losses up to the current epoch
    return loss_data, best_model, best_epoch, epoch

### test run
The `test` function simply applies the model to the test data without gradient computation and in evaluation mode only --> no dropout or other regularizers that are used for training. 

In [None]:
def test(
    model: GCNModel, test_loader: torch_geometric.loader.DataLoader, device: str
) -> pd.DataFrame:
    """run the model on the test set and return the results as a pandas DataFrame.

    Args:
        model (GCNModel): Model to be evaluated.
        test_loader (torch_geometric.loader.DataLoader): DataLoader for the test set.
        device (str): Device to run the model on.

    Returns:
        pd.DataFrame: DataFrame containing the test results.
    """
    model.eval()
    
    loss_data = pd.DataFrame(
        np.nan, 
        index = range(len(test_loader.dataset)),
        columns=[
            "dim",
            "boundary",
            "manifold",
            "dim_correct",
            "boundary_correct",
            "manifold_correct",
            "dim_true",
            "boundary_true",
            "manifold_true",
        ]
    )
    start = 0
    stop = 0
    with torch.no_grad():
        for data in tqdm(test_loader):
            data = data.to(device)

            dim, boundary, manifold = model(data.x, data.edge_index, data.batch)

            dim = torch.argmax(
                dim, dim=1
            )  # this is enough here, because the encodings are concruential to the indices
            boundary = torch.argmax(
                boundary, dim=1
            )  # this is enough here, because the encodings are concruential to the indices
            manifold = torch.argmax(
                manifold, dim=1
            )  # this is enough here, because the encodings are concruential to the indices

            stop = start + data.num_graphs
            
            loss_data.iloc[start:stop, loss_data.columns.get_loc("dim")] = dim.cpu().numpy()
            loss_data.iloc[start:stop, loss_data.columns.get_loc("boundary")] = boundary.cpu().numpy()
            loss_data.iloc[start:stop, loss_data.columns.get_loc("manifold")] = manifold.cpu().numpy()
            loss_data.iloc[start:stop, loss_data.columns.get_loc("dim_true")] = data.y[:, 0].cpu().numpy()
            loss_data.iloc[start:stop, loss_data.columns.get_loc("boundary_true")] = data.y[:, 1].cpu().numpy()
            loss_data.iloc[start:stop, loss_data.columns.get_loc("manifold_true")] = data.y[:, 2].cpu().numpy()


            loss_data.iloc[start:stop, loss_data.columns.get_loc("dim_correct")] = np.float64(
                dim.cpu().numpy() == data.y[:, 0].cpu().numpy()
            )
            loss_data.iloc[start:stop, loss_data.columns.get_loc("boundary_correct")] = np.float64(
                boundary.cpu().numpy() == data.y[:, 1].cpu().numpy()
            )
            loss_data.iloc[start:stop, loss_data.columns.get_loc("manifold_correct")] = np.float64(
                manifold.cpu().numpy() == data.y[:, 2].cpu().numpy()
            )
            start = stop
            

    return loss_data

## Actual training loop

Set the seed for the internal rng. This is important for reproducibility

In [None]:
torch.manual_seed(532432)

### get cuda device
If we don´t have one, we need to run this on the cpu. That would not be suitable for any useful work though, and you might want to reduce the data size in such cases. 

In [None]:
torch.cuda.is_available()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### Define where the data is.  

We can have arbitrarily many hdf5 files to read data from

In [None]:
# datapath  = os.path.join(os.path.expanduser("~"), "data", "causal_sets")
datapath = os.path.join("/mnt", "dataLinux", "machinelearning_data", "QuantumGrav", "causal_sets")
files = [
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=2.h5"),
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=3.h5"),
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=4.h5"),
]

Define the dataset and the train, validation and test split. Then make `DataLoader` instances from them to use with the model. These take care of batching, shuffling and so on. notice that we use the same transformations as before. the dataset is smart enough to notice if it has to process data again. (the current one only does this by checking files though, so you have to delete the processessed ones again if you want to change the transformations.)

In [None]:
dset = CsDataset(
    input=files,
    output=os.path.join(datapath),
    pre_transform=target_shift,  # maybe add some augmentation stuff here later
    pre_filter=None,  # Filter data before loading, e.g., based on manifold or boundary or something like that
    validate_data=True,  # Validate data after loading
    loader=load_graph,  # Custom loader function
).shuffle() # shuffle the dataset to ensure a good mix of samples in each batch

# roughly 80-10-10 split for train-test-validation
train_size = int(math.ceil(0.8 * len(dset)))
test_size = int(math.ceil(0.1 * len(dset)))
val_size = len(dset) - train_size - test_size

train_loader = DataLoader(
    dset[0:train_size],
    batch_size=128,
    shuffle=True,
    pin_memory=True,
)

test_loader = DataLoader(
    dset[train_size : train_size + test_size],
    batch_size=128,
    shuffle=True,
    pin_memory=True,
)

val_loader = DataLoader(
    dset[train_size + test_size: train_size + test_size + val_size],
    batch_size=128,
    shuffle=True,
    pin_memory=True,
)

print(f"Train size: {len(train_loader)}, Test size: {len(test_loader)}, Validation size: {len(val_loader)}")
print(f"Total dataset size: {len(train_loader.dataset)}, {len(test_loader.dataset)}, {len(val_loader.dataset)}")

### Define a new model 

This is done by building it element by element --> first the graph convolutional blocks, then the classifier, and finally the whole model is assembled from these blocks together with a global pooling operator.

In [None]:
n_node_features = dset[0].x.shape[1]  # Number of node features
n_edge_features = dset[0].edge_attr.shape[1] if dset[0].edge_attr is not None else 0  # Number of edge features

# normalizer = torch.nn.Identity 
normalizer = torch.nn.BatchNorm1d 
conv_layer = GCNConv  # You can change this to GATConv, SAGEConv, etc. as needed

conv1 = GCNBlock(
    input_dim=n_node_features,
    output_dim=128,
    dropout=0.3,
    normalizer=normalizer(128),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=torch.nn.ReLU(),
    gcn_kwargs={"cached": False, "bias": True, "add_self_loops": True}  # Example of passing additional arguments to the GCN layer
)

conv2 = GCNBlock(
    input_dim=128,
    output_dim=256,
    dropout=0.3,
    normalizer=normalizer(256),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=torch.nn.ReLU(),
    gcn_kwargs={"cached": False, "bias": True, "add_self_loops": True}  # Example of passing additional arguments to the GCN layer
)

conv3 = GCNBlock(
    input_dim=256,
    output_dim=128,
    dropout=0.3,
    normalizer=normalizer(128),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=torch.nn.ReLU(),
    gcn_kwargs={"cached": False, "bias": True, "add_self_loops": True}  # Example of passing additional arguments to the GCN layer
)


gcn_backbone = GCNBackbone([conv1, conv2, conv3])   


In [None]:
classifier = ClassifierBlock(
    input_dim=128,  # Output dimension of the last GCN layer
    output_dim=3,  # Assuming you want to predict manifold_id, boundary_id, and dimension
    hidden_dims=[64, 32],  # Example hidden dimensions
    manifold_classes=6,
    boundary_classes=3,
    dimension_classes=3,
    activation=torch.nn.ReLU(),
)

we use mean pooling here, just because this is the simplest. Be aware however that this can enhance feature squishing

In [None]:
pooling_layer = global_mean_pool  # You can change this to global_max_pool, global_add_pool, etc.

In [None]:
model = GCNModel(
    gcn_net=gcn_backbone,
    classifier=classifier,
    pooling_layer=pooling_layer,
    use_graph_features=False,  # Set to True if you want to use graph features
)

### train the model

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001, weight_decay=5e-4)

Hyperparameters for training: 

In [None]:
epochs = 120
patience = 10
current_patience = patience
best_val_loss = float("inf")
tolerance = 0.01

In [None]:
loss_data, best_model, best_epoch, stopping_epoch = train(model, train_loader, val_loader, optimizer, epochs=epochs, device=device, criterion=criterion, early_stopping_patience=patience, early_stopping_delta=tolerance);

# Loss visualization

In [None]:
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
matplotlib.style.use('petroff10')  # Use a clean style for the plots


plt.figure(figsize=(8, 5))
ax = plt.gca()
mean_cols = [
        "total_training_loss_mean",
        "total_validation_loss_mean",
        "dim_training_loss_mean",
        "dim_validation_loss_mean",
        "boundary_training_loss_mean",
        "boundary_validation_loss_mean",
        "manifold_training_loss_mean", 
        "manifold_validation_loss_mean"]
std_cols =[
        "total_training_loss_std",
        "total_validation_loss_std",
        "dim_training_loss_std",
        "dim_validation_loss_std",
        "boundary_training_loss_std",
        "boundary_validation_loss_std",
        "manifold_training_loss_std",
        "manifold_validation_loss_std"
    ]

loss_data.plot(
    y=mean_cols, 
    yerr=loss_data[std_cols].set_axis(mean_cols, axis=1),
    ax=ax,
    title="Training and Validation Losses",
    xlabel="Epoch",
    ylabel="Loss",
    legend=True
)

# Evaluation

run model on test set 

In [None]:
model.load_state_dict(best_model)
model.eval() 
test_df = test(model, test_loader, device)
display(test_df)

## Per-label Accuracy
cummulative accuracy over all tasks

In [None]:
cum_acc = (test_df["dim_correct"].sum() + test_df["boundary_correct"].sum() + test_df["manifold_correct"].sum()) / (3*len(test_df))
print(f"Per-label accuracy: {cum_acc}")

# Per-task Accuracy 
which task is correct how often

In [None]:
dim_acc = test_df["dim_correct"].mean()
boundary_acc = test_df["boundary_correct"].mean()
manifold_acc = test_df["manifold_correct"].mean()
print(f"Dimension Accuracy: {dim_acc:.4f}")
print(f"Boundary Accuracy: {boundary_acc:.4f}")
print(f"Manifold Accuracy: {manifold_acc:.4f}")

the model is not super bad on individual tasks, but manifold accuracy is very low compared to the other 2 because it is much more difficult

# Per-sample accuracy

how often are they all correct for a given sample

In [None]:
strict_accuracy = (test_df["dim_correct"].astype(bool) & test_df["boundary_correct"].astype(bool) & test_df["manifold_correct"].astype(bool)).mean()
print(f"Strict (all-correct) Accuracy: {strict_accuracy:.4f}")


but it is pretty bad on the full task. this is the metric we need to push, because they are all relevant. Check the loss function again for this

# Confusion matrix
we build one per task: dimension, boundary, manifold

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay 

cm_dim = confusion_matrix(test_df["dim_true"], test_df["dim"])
cm_boundary = confusion_matrix(test_df["boundary_true"], test_df["boundary"])
cm_manifold = confusion_matrix(test_df["manifold_true"], test_df["manifold"])

dim_cd = ConfusionMatrixDisplay(cm_dim, display_labels=[0,1,2]).plot()
boundary_cd = ConfusionMatrixDisplay(cm_boundary, display_labels=[0,1,]).plot() 
manifold_cd = ConfusionMatrixDisplay(cm_manifold, display_labels=[0,1,2,3,4,5]).plot()

In [None]:
df = pd.DataFrame(np.nan, index = range(len(train_loader.dataset)), columns=["dim", "boundary", "manifold"])

i = 0
for data in tqdm(train_loader.dataset):

    df.iloc[i, 0] = data.y[0][0].item()
    df.iloc[i, 1] = data.y[0][1].item()
    df.iloc[i, 2] = data.y[0][2].item()
    i += 1
df 

In [None]:
df["manifold"].value_counts().sort_index()

# Precision/Recall per class

| Term |	Meaning | 
| -------|---------- | 
| TP | True Positives: correctly predicted samples of the class | 
| FP | False Positives: predicted as the class, but actually something else | 
| FN | False Negatives: actual class samples predicted as something else | 


**Precision** 
```math
Precision = \frac{TP}{TP + FP} \; \in [0, 1]
```
High precision -> few false positives

**Recall** 
```math 
Recall = \frac{TP}{TP + FN}  \; \in [0, 1]
```
High recall -> few false negatives 

We want both to be high, i.e., as close 1 as possible


In [None]:
from sklearn.metrics import recall_score, precision_score

precision_dim_per_class = precision_score(test_df['dim_true'], test_df["dim"], average = None) 
recall_dim_per_class = recall_score(test_df['dim_true'], test_df["dim"], average = None)


precision_boundary_per_class = precision_score(test_df['boundary_true'], test_df["boundary"], average=None)
recall_boundary_per_class = recall_score(test_df['boundary_true'], test_df["boundary"], average=None)

precision_manifold_per_class = precision_score(test_df['manifold_true'], test_df["manifold"], average = None)
recall_manifold_per_class = recall_score(test_df['manifold_true'], test_df["manifold"], average = None)


# average over all classes
precision_dim_avg = precision_score(test_df['dim_true'], test_df["dim"], average = 'macro') 
recall_dim_avg = recall_score(test_df['dim_true'], test_df["dim"], average = 'macro')

precision_boundary_avg = precision_score(test_df['boundary_true'], test_df["boundary"], average='macro')
recall_boundary_avg = recall_score(test_df['boundary_true'], test_df["boundary"], average='macro')

precision_manifold_avg = precision_score(test_df['manifold_true'], test_df["manifold"], average = 'macro')
recall_manifold_avg = recall_score(test_df['manifold_true'], test_df["manifold"], average = 'macro')


# weighted average over classes - weight with support = number of samples per class
precision_dim_weighted = precision_score(test_df['dim_true'], test_df["dim"], average = 'weighted') 
recall_dim_weighted = recall_score(test_df['dim_true'], test_df["dim"], average = 'weighted')

precision_boundary_weighted = precision_score(test_df['boundary_true'], test_df["boundary"], average='weighted')
recall_boundary_weighted = recall_score(test_df['boundary_true'], test_df["boundary"], average='weighted')

precision_manifold_weighted = precision_score(test_df['manifold_true'], test_df["manifold"], average = 'weighted')
recall_manifold_weighted = recall_score(test_df['manifold_true'], test_df["manifold"], average = 'weighted')

print("Precision and Recall per class for Dimension:")
for i in range(3):
    print(f"Class {i}: Precision = {precision_dim_per_class[i]:.4f}, Recall = {recall_dim_per_class[i]:.4f}")

print("Precision and Recall per class for Boundary: ")
for i in range(2): 
    print(f"Class {i}: Precision = {precision_boundary_per_class[i]:.4f}, Recall = {recall_boundary_per_class[i]:.4f}")

print("Precision and Recall per class for Manifold: ")
for i in range(6):
    print(f"Class {i}: Precision = {precision_manifold_per_class[i]:.4f}, Recall = {recall_manifold_per_class[i]:.4f}")

print(f"Average Precision (Macro): {precision_dim_avg:.4f}, Average Recall (Macro): {recall_dim_avg:.4f}")
print(f"Average Precision (Weighted): {precision_dim_weighted:.4f}, Average Recall (Weighted): {recall_dim_weighted:.4f}")

# F1 score

 The F1 score is the harmonic mean of precision and recall. It combines them into a single number that balances both, especially when you want a trade-off between the two:

 ```math
 F1 = 2 * \frac{
    Precision * Recall
 }{Precision + Recall} 
 ```

 The harmonic mean penalizes extreme imbalance more than the arithmetic mean.
E.g., if precision = 1.0 and recall = 0.0, the F1 score is 0.0 — you can't ignore one metric entirely.

macro F1: simple mean across all classes (treats each class equally)

weighted F1: mean weighted by number of samples per class (favours common classes)

per-class F1: no mean at all

The F1 score is especially valuable when you're optimizing for model robustness across all classes, particularly if one class is harder

In [None]:
from sklearn.metrics import f1_score
f1_dim = f1_score(test_df['dim_true'], test_df['dim'], average=None)  # Per-class F1
f1_dim_macro = f1_score(test_df['dim_true'], test_df['dim'], average='macro')  # Unweighted average
f1_dim_weighted = f1_score(test_df['dim_true'], test_df['dim'], average='weighted')  # Weighted by class support

f1_boundary = f1_score(test_df['boundary_true'], test_df['boundary'], average=None)  # Per-class F1
f1_boundary_macro = f1_score(test_df['boundary_true'], test_df['boundary'], average='macro')  # Unweighted average
f1_boundary_weighted = f1_score(test_df['boundary_true'], test_df['boundary'], average='weighted')  # Weighted by class support

f1_manifold = f1_score(test_df['manifold_true'], test_df['manifold'], average=None)
f1_manifold_macro = f1_score(test_df['manifold_true'], test_df['manifold'], average='macro')
f1_manifold_weighted = f1_score(test_df['manifold_true'], test_df['manifold'], average='weighted')

print("F1 score per class for dimension:")
for i in range(3):
    print(f"  class {i}: {f1_dim[i]}")
print("base average of F1 score for dimension: ", f1_dim_macro)
print("weighted average of F1 score for dimension: ", f1_dim_weighted)

print("F1 score per class for boundary:")
for i in range(2):
    print(f"  class {i}: {f1_boundary[i]}")
print("base average of F1 score for boundary: ", f1_boundary_macro)
print("weighted average of F1 score for boundary: ", f1_boundary_weighted)

print("F1 score per class for manifold:")
for i in range(6):
    print(f"  class {i}: {f1_manifold[i]}")
print("base average of F1 score for manifold: ", f1_manifold_macro)
print("weighted average of F1 score for manifold: ", f1_manifold_weighted)



# Save model

In [None]:
hyperparams = {
    "optimizer": {
        "type": "Adam",
        "learning_rate": 0.001,
        "weight_decay_rate": 5e-4,
    }, 
    "training": {
        "epochs": epochs,
        "patience": patience,
        "tolerance": tolerance, 
        "batchsize": 128
    }, 
    "model": {
        "Backbone": [
            {
                "input": 4,
                "output":128,
                "type": "GCNConv",
                "activation": "relu",
                "dropout": 0.3,
                "normalizer": "batchnorm",
                "kwargs": {"cached": False, "bias": True, "add_self_loops": True}
            },
            {
                "input": 128,
                "output":256,
                "type": "GCNConv",
                "activation": "relu",
                "dropout": 0.3,
                "normalizer": "batchnorm",
                "kwargs": {"cached": False, "bias": True, "add_self_loops": True}
            },
            {
                "input": 256,
                "output":128,
                "type": "GCNConv",
                "activation": "relu",
                "dropout": 0.3,
                "normalizer": "batchnorm",
                "kwargs": {"cached": False, "bias": True, "add_self_loops": True}
            },

        ],
        "Classifier": {
            "input_dim": 128,  # Output dimension of the last GCN layer
            "output_dim": 3,  # Assuming you want to predict manifold_id, boundary_id, and dimension
            "hidden_dims": [64, 32],  # Example hidden dimensions
            "manifold_classes": 6,
            "boundary_classes": 3,
            "dimension_classes": 3,
            "activation": "relu",
        },
        "GraphFeatures": None,
    }, 
    "data": {
        "node_features": ["in_degree", "out_degree", "max_path_future", "max_path_past"], 
        "graph_features": None
    }
}

In [None]:
model_path = os.path.join(datapath, "models", "gcn_model_simple_unoptimized.pt")
model_dir = os.path.abspath(os.path.dirname(model_path))
os.makedirs(model_dir, exist_ok = True)

with open(os.path.join(model_dir, "hyperparameters.json"), "w") as file: 
    json.dump(hyperparams, file)

torch.save(model.to('cpu'), os.path.join(model_dir, "model.pth"))



## TSNE feature space visualization

In [None]:
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

 get the embedding vectors of the model on the test set first

In [None]:
embedding_vectors = np.zeros((test_size, 128), dtype=np.float32)  # Preallocate space for embeddings. 
model = model.to(device)
model.eval()  # Set the model to evaluation mode
results = pd.DataFrame(np.nan, index=range(test_size), columns=["dim", "boundary", "manifold", "tsne_x", "tsne_y"])
start = 0
with torch.no_grad():
    for data in tqdm(test_loader):
        data = data.to(device)
        embeddings = model.get_embeddings(data.x, data.edge_index, data.batch)
        e = embeddings.cpu().numpy()
        chunk = data.num_graphs
        stop = start + chunk
        results.iloc[start:stop, results.columns.get_loc("dim")] = data.y[:, 0].cpu().numpy()
        results.iloc[start:stop, results.columns.get_loc("boundary")] = data.y[:, 1].cpu().numpy()
        results.iloc[start:stop, results.columns.get_loc("manifold")] = data.y[:, 2].cpu().numpy()
        embedding_vectors[start:stop] = e
        start = stop


In [None]:
tsne = TSNE(n_components=2, random_state=42, max_iter=1000, early_exaggeration= 4.0, perplexity=45, init='pca', n_jobs=-1)

In [None]:
tsne_result = tsne.fit_transform(embedding_vectors)

In [None]:
tsne_result.shape
results["tsne_x"] = tsne_result[:, 0]
results["tsne_y"] = tsne_result[:, 1]

In [None]:
results.head()

In [None]:

plt.figure(figsize=(8, 6))
sns.scatterplot(
    data=results,
    x="tsne_x", 
    y="tsne_y",
    hue="dim",
    palette='muted',
    alpha=0.7
)
plt.title("t-SNE Visualization of Dimension Classes")
plt.show()

In [None]:

plt.figure(figsize=(8, 6))
sns.scatterplot(
    data=results,
    x="tsne_x", 
    y="tsne_y",
    hue="boundary",
    palette='muted',
    alpha=0.7
)
plt.title("t-SNE Visualization of Boundary Classes")
plt.show()

In [None]:

plt.figure(figsize=(8, 6))
sns.scatterplot(
    data=results,
    x="tsne_x", 
    y="tsne_y",
    hue="manifold",
    palette='muted',
    alpha=0.7
)
plt.title("t-SNE Visualization of Manifold Classes")
plt.show()