# Data loading

In [None]:
import torch 
from torch_geometric.loader import DataLoader 
import torch_geometric.transforms as T
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Dataset, Data
from functools import lru_cache
from collections.abc import Callable

import h5py
import os
import numpy as np
import math
import json

In [None]:
# TODO: this function is too long, break it up into smaller ones


def load_graph(
    f: h5py.File,
    idx: int,
    float_dtype: torch.dtype,
    int_dtype: torch.dtype,
    validate: bool = False,
) -> Data:
    # Load adjacency matrix and convert to edge indices
    adj_matrix = torch.tensor(f["adjacency_matrix"][:, :, idx], dtype=float_dtype)
    edge_index, edge_weight = dense_to_sparse(adj_matrix)
    adj_matrix = adj_matrix.to_sparse()
    # Load node features
    node_features = []

    # Sprinkling coordinates
    # sprinkling = torch.tensor(f["sprinkling"][:, :, idx], dtype=float_dtype)

    # node_features.append(sprinkling)

    # Degree information
    in_degrees = torch.tensor(f["in_degrees"][:, idx], dtype=float_dtype).unsqueeze(1)
    out_degrees = torch.tensor(f["out_degrees"][:, idx], dtype=float_dtype).unsqueeze(1)
    node_features.extend([in_degrees, out_degrees])

    # Path lengths
    max_path_future = torch.tensor(
        f["max_path_lengths_future"][:, idx], dtype=float_dtype
    ).unsqueeze(1)
    max_path_past = torch.tensor(
        f["max_path_lengths_past"][:, idx], dtype=float_dtype
    ).unsqueeze(1)
    node_features.extend([max_path_future, max_path_past])

    # I need more topological information here - angles, etc.
    # Link-based path lengths
    # max_path_future_links = torch.tensor(
    #     f["max_path_lengths_future_links"][:, idx], dtype=float_dtype
    # ).unsqueeze(1)
    # max_path_past_links = torch.tensor(
    #     f["max_path_lengths_past_links"][:, idx], dtype=float_dtype
    # ).unsqueeze(1)
    # node_features.extend([max_path_future_links, max_path_past_links])

    # # Topological ordering --> TODO: check again what this does
    # topo_future = torch.tensor(
    #     f["topological_order_future"][:, idx], dtype=float_dtype
    # ).unsqueeze(1)
    # topo_past = torch.tensor(
    #     f["topological_order_past"][:, idx], dtype=float_dtype
    # ).unsqueeze(1)
    # node_features.extend([topo_future, topo_past])

    # Concatenate all node features
    x = torch.cat(node_features, dim=1)

    # Load graph-level features (targets for regression)
    manifold_id = torch.tensor(f["manifold_ids"][idx], dtype=int_dtype)
    boundary_id = torch.tensor(f["boundary_ids"][idx], dtype=int_dtype)
    relation_dim = torch.tensor(f["relation_dim"][idx], dtype=float_dtype)
    atom_count = torch.tensor(f["atom_count"][idx], dtype=int_dtype)
    num_sources = torch.tensor(f["num_sources"][idx], dtype=int_dtype)
    num_sinks = torch.tensor(f["num_sinks"][idx], dtype=int_dtype)
    dimension = torch.tensor(f["dimension"][idx], dtype=int_dtype)

    # matrices
    link_matrix = torch.tensor(
        f["link_matrix"][:, :, idx], dtype=float_dtype
    ).to_sparse()
    past_relations = torch.tensor(
        f["past_relations"][:, :, idx], dtype=float_dtype
    ).to_sparse()
    future_relations = torch.tensor(
        f["future_relations"][:, :, idx], dtype=float_dtype
    ).to_sparse()

    # Create Data object
    data = Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_weight.unsqueeze(1)
        if edge_weight.numel() > 0
        else None,  # Not sure if this is a good idea need to add edge attributes if possible
        # node positions as positional attributes as well
        # pos=sprinkling,
        y=torch.tensor([manifold_id[0], boundary_id[0], dimension[0]]),
        # Graph-level attributes
        manifold_id=manifold_id,
        boundary_id=boundary_id,
        relation_dim=relation_dim,
        dimension=dimension,
        atom_count=atom_count,
        num_sources=num_sources,
        num_sinks=num_sinks,
        # sprinkling=sprinkling, # don't use this for now.
        # Additional matrices as graph attributes. make the shitty past and future relations and all that into node attributes!
        adjacency_matrix=adj_matrix,
        link_matrix=link_matrix,
        past_relations=past_relations,
        future_relations=future_relations,
    )

    if validate:
        data.validate()

    return data

In [None]:
class CsDataset(Dataset):
    def __init__(
        self,
        input: list[str],
        output: str,  # = root directory for processed data
        transform: Callable[[Data], Data] | None = None,
        pre_transform: Callable[[Data], Data] | None = None,
        pre_filter: Callable[[Data], Data] | None = None,
        validate_data: bool = False,
    ):
        super().__init__(output, transform, pre_transform, pre_filter)
        self.input = input
        self._num_samples = None
        self.validate_data = validate_data

        if os.path.exists(output):
            self.processed_data = self.processed_file_names
            self.num_samples = len(self.processed_data)
            with open(os.path.join(output, "metadata.json"), "r") as f:
                metadata = json.load(f)
                self.manifold_codes = metadata["manifold_codes"]
                self.manifold_names = metadata["manifold_names"]
                self.boundaries = metadata["boundaries"]
                self.boundary_codes = metadata["boundary_codes"]
        else:
            with h5py.File(input, "r") as f:
                self.manifold_codes = f["manifold_codes"]
                self.manifold_names = f["manifold_names"]
                self.boundaries = f["boundaries"]
                self.boundary_codes = f["boundary_codes"]

            self._num_samples = 0
            for file in self.input:
                if not os.path.exists(file):
                    raise FileNotFoundError(f"Input file {file} does not exist.")
                with h5py.File(file, "r") as f:
                    self._num_samples += f["num_causal_sets"]

    @property
    def raw_paths(self):
        return self.input

    @property
    def output(self):
        return self.root

    @property
    def raw_file_names(self):
        return [os.path.basename(self.input)]

    @property
    def processed_file_names(self):
        return [
            f
            for f in os.listdir(self.root)
            if f.startswith("data_") and f.endswith(".pt")
        ]

    def process(self):
        os.path.makedirs(self.root, exist_ok=True)

        d = {
            "manifold_codes": self.manifold_codes,
            "manifold_names": self.manifold_names,
            "boundaries": self.boundaries,
            "boundary_codes": self.boundary_codes,
        }

        with open(os.path.join(self.processed_dir, "metadata.json"), "w") as f:
            json.dump(d, f)

        for file in self.raw_paths:
            if not os.path.exists(file):
                raise FileNotFoundError(f"Input file {file} does not exist.")
            with h5py.File(file, "r") as f:
                # FIXME: this loop should be parallelized for large datasets
                for idx in range(f["num_causal_sets"]):
                    data = load_graph(
                        f,
                        idx,
                        float_dtype=torch.float32,
                        int_dtype=torch.int64,
                        validate=self.validate_data,
                    )
                    if self.pre_filter is not None:
                        if not self.pre_filter(data):
                            continue
                    if self.pre_transform is not None:
                        data = self.pre_transform(data)
                    torch.save(data, os.path.join(self.processed_dir, f"data_{idx}.pt"))

    def len(self):
        return self._num_samples

    @lru_cache(maxsize=100)  # cache some results for faster access
    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f"data_{idx}.pt"))
        if self.transform is not None:
            data = self.transform(data)
        return data


# Model definition

In [None]:
import torch 
import os
from torch.nn import Linear 
import torch.nn.functional as F 
from torch_geometric.nn.conv import GCNConv, GATConv, SAGEConv, GraphConv, GATv2Conv
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, SAGPooling, Set2Set

In [None]:
class GCNBlock(torch.nn.Module):
    def __init__(self, input_dim, output_dim, dropout=0.5, gcn_type=GCNConv, batchnorm=torch.nn.Identity, activation = F.relu, gcn_kwargs=None):
        super(GCNBlock, self).__init__()
        self.dropout = dropout
        self.gcn_type = gcn_type
        self.conv = gcn_type(input_dim, output_dim, **(gcn_kwargs if gcn_kwargs else {}))
        self.activation = activation
        self.batch_norm = batchnorm

    def forward(self, x, edge_index, edge_weight=None):
        x_res = x 
        x = self.conv(x, edge_index, edge_weight=edge_weight)  # Apply the GCN layer
        x = self.batch_norm(x, ) # this is a no-op if batch normalization is not used
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training) # this is only applied during training
        x = x + x_res # skip connection
        return x


In [None]:
class RegressionBlock(torch.nn.Module): 
    def __init__(self, input_dim, output_dim, hidden_dims, activation = F.relu): 
        super(RegressionBlock, self).__init__()
        self.activation = activation
        self.hidden_dims = hidden_dims

        if len(hidden_dims) == 0: 
            self.layers = Linear(input_dim, output_dim)
        else: 
            layers = []
            in_dim = input_dim
            for hidden_dim in hidden_dims: 
                layers.append(Linear(in_dim, hidden_dim)) # check again if we need the bias there, I don't think so actually... 
                layers.append(activation)
                in_dim = hidden_dim
            layers.append(Linear(in_dim, output_dim))
            self.layers = torch.nn.Sequential(*layers)

    def forward(self, x): 
        x = self.layers(x) # <-- need softmax here somewhere? -> yes, for the categorical outputs, but not for the regression outputs
        return x

In [None]:
class GraphFeaturesBlock(torch.nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, dropout=0.5, activation = F.relu): 
        super(GraphFeaturesBlock, self).__init__()
        self.dropout = dropout
        self.activation = activation
        self.hidden_dims = hidden_dims

        if len(hidden_dims) == 0: 
            self.linear = Linear(input_dim, output_dim)
        else: 
            layers = []
            in_dim = input_dim
            for hidden_dim in hidden_dims: 
                layers.append(Linear(in_dim, hidden_dim)) # check bias
                layers.append(activation)
                in_dim = hidden_dim
            layers.append(Linear(in_dim, output_dim))
            self.linear = torch.nn.Sequential(*layers)

    def forward(self, x): 
        x = self.linear(x)
        return x

basic model class to organize the other things. does the following: 
- passes input through gcn network. This is a succession of GCN blocks
- applies graph feature network to graph level features and concatenates them with pooled node features **if** `use_graph_features = true`. 

- passes the result through the regression net to get out (dimension, boundary_id, manifold_id): 
```bash
x -> gcn -> pool -> concat(_, g) -> regression -> output
g ----------------> MLP_g(g) _|
```

In [None]:
class GCNModel(torch.nn.Module): 

    def __init__(self, gcn_net, regression_net, pooling_layer, use_graph_features=False, graph_features_net = torch.nn.Identity):
        super(GCNModel, self).__init__()
        self.gcn_net = gcn_net
        self.regression_net = regression_net
        self.graph_features_net = graph_features_net
        self.use_graph_features = use_graph_features
        self.pooling_layer = pooling_layer

    def forward(self, x, edge_index, batch, edge_weight = None, graph_features=None): 
        x = self.gcn_net(x, edge_index, edge_weight=edge_weight)
        x = self.pooling_layer(x, batch)
        if self.use_graph_features: 
            graph_features = self.graph_features_net(graph_features)
            x = torch.cat((x, graph_features), dim=-1) # last dim
        x = self.regression_net(x)
        x = torch.softmax(x, dim=-1)  # Apply softmax for categorical outputs
        return x

        


**TODO**
- [ ] how do GCNConv, GraphConv, SageConv work? 
- [ ] how does GlobalAttention, Set2Set work?
- [ ] train the damn thing

# Apply the model 

## make dataset from graphs and have them processed

In [None]:
datapath  = os.path.join(os.path.home(), "data", "causal_sets")
files = [
    os.path.join(datapath, "cset_data_min=300_max=600_N=10000_d=2.h5"),
    os.path.join(datapath, "cset_data_min=300_max=600_N=10000_d=3.h5"),
    os.path.join(datapath, "cset_data_min=300_max=600_N=10000_d=4.h5"),
]

In [None]:
# TODO: add one-hot encoding for manifold and boundary ids and for the dimension
dset = CsDataset(
    input=files,
    output=os.path.join(datapath, "processed"),
    transform=T.ToSparseTensor,  # Convert to sparse tensor format
    pre_transform=None, # maybe add some augmentation stuff here later
    pre_filter=None,  # Filter data before loading, e.g., based on manifold or boundary or something like that
    validate_data=True,  # Validate data after loading
).shuffle()

## make test, validation and train data loaders

In [None]:
train_size = int(math.ceil(0.8 * len(dset)))
test_size = int(math.ceil(0.1 * len(dset)))
val_size = len(dset) - train_size - test_size

In [None]:
test_loader = DataLoader(dset[:test_size], batch_size=32, shuffle=True)
train_loader = DataLoader(dset[test_size:test_size + train_size], batch_size=32, shuffle=True)
val_loader = DataLoader(dset[test_size + train_size:], batch_size=32, shuffle=True)

## make model

The model consists of 3 GCN blocks, a global mean pool (as a first step) and a final regression block 

In [None]:
n_node_features = dset[0].x.shape[1]  # Number of node features
n_edge_features = dset[0].edge_attr.shape[1] if dset[0].edge_attr is not None else 0  # Number of edge features


first build the graph convolutional model

In [None]:
normalizer = torch.nn.BatchNorm1d
conv_layer = GCNConv  # You can change this to GATConv, SAGEConv, etc. as needed

conv1 = GCNBlock(
    input_dim=n_node_features,
    output_dim=128,
    dropout=0.3,
    batchnorm=normalizer(128),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=F.relu,
    gcn_kwargs={"cached": True}  # Example of passing additional arguments to the GCN layer
)

conv2 = GCNBlock(
    input_dim=128,
    output_dim=256,
    dropout=0.3,
    batchnorm=normalizer(256),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=F.relu,
    gcn_kwargs={"cached": True}  # Example of passing additional arguments to the GCN layer
)

conv3 = GCNBlock(
    input_dim=256,
    output_dim=128,
    dropout=0.3,
    batchnorm=normalizer(128),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=F.relu,
    gcn_kwargs={"cached": True}  # Example of passing additional arguments to the GCN layer
)

gcn_chain = torch.nn.Sequential(
    conv1,
    conv2,
    conv3,
)


then build the regression layer, this is simple

In [None]:
regression_net = RegressionBlock(
    input_dim=128,  # Output dimension of the last GCN layer
    output_dim=3,  # Assuming you want to predict manifold_id, boundary_id, and dimension
    hidden_dims=[64, 32],  # Example hidden dimensions
    activation=F.relu
)

finally put things together and add the global pooling layer. We don't use the graph features for the moment, so we don't define a graph features processing network 

In [None]:
pooling_layer = global_mean_pool  # You can change this to global_max_pool, global_add_pool, etc.

model = GCNModel(
    gcn_net=gcn_chain,
    regression_net=regression_net,
    pooling_layer=pooling_layer,
    use_graph_features=False,  # Set to True if you want to use graph features
)


# Training

much of this will not work properly I'm sure. 

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001, weight_decay=5e-4)
criterion = torch.nn.BinaryCrossEntropyLoss() # this is unlikely to work because we have mixed regression and classification outputs, so we need to handle this differently or do we? 

Check that this is sensible first!

In [None]:
def train(model, train_loader, validation_loader, optimizer, criterion):
    training_loss = np.zeros(len(train_loader), dtype=np.float32)
    valid_loss = np.zeros(len(validation_loader), dtype=np.float32)
    for epoch, data in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch, graph_features=data.y)
        loss = criterion(out, data.y)  # Assuming data.y contains the target values
        loss.backward()
        optimizer.step()
        training_loss[epoch] = loss.item()

        mean_validation_loss = 0.0
        for j, data in enumerate(validation_loader):
            model.eval()
            with torch.no_grad():
                out = model(data.x, data.edge_index, data.batch, graph_features=data.y)
                loss = criterion(out, data.y)
                mean_validation_loss += loss.item()
        mean_validation_loss /= len(validation_loader)
        print(f"Epoch: {epoch}, Validation Loss: {mean_validation_loss:.4f}")
        valid_loss[epoch] = loss.item()

    return training_loss

def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            out = model(data.x, data.edge_index, data.batch, graph_features=data.y)
            predicted = out[torch.argmax(out, dim=-1)] # this doesn't work, I'm sure. 

            total += data.y.size(0)
            correct += (predicted == data.y).sum().item()


    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")
    return accuracy

In [None]:
train_loss, validation_loss = train(model, train_loader, val_loader, optimizer, criterion)

In [None]:
accuracy = test(model, test_loader) 