# Data loading

In [None]:
import torch 
from torch_geometric.loader import DataLoader 
import torch_geometric.transforms as T
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Dataset, Data
from functools import lru_cache
from collections.abc import Callable
import concurrent.futures 
import torch.nn.functional as F 

import h5py
import os
import numpy as np
import math
from tqdm import tqdm
import json

optimizations: 
- load a batch of data for all datasets
- process them in parallel 
- then load the next


In [None]:
# TODO: this function is too long, break it up into smaller ones


def load_graph(
    f: h5py.File,
    idx: int,
    float_dtype: torch.dtype,
    int_dtype: torch.dtype,
    validate: bool = False,
) -> Data:
    
    # print(f"  Loading graph at index {idx}...")
    # Load adjacency matrix and convert to edge indices
    # print("  Data shapes: ")
    # print(f"   Adjacency matrix data shape: {f['adjacency_matrix'].shape}")
    # print(f"   link matrix data shape: {f['link_matrix'].shape}")
    # print(f"   past_relations data shape: {f['past_relations'].shape}")
    # print(f"   future_relations data shape: {f['future_relations'].shape}")
    # print(f"   in_degrees data shape: {f['in_degrees'].shape}")
    # print(f"   out_degrees data shape: {f['out_degrees'].shape}")
    # print(f"   max_path_lengths_future data shape: {f['max_path_lengths_future'].shape}")
    # print(f"   max_path_lengths_past data shape: {f['max_path_lengths_past'].shape}")
    # print(f"   Adjacency matrix data shape: {f['adjacency_matrix'].shape}")

    adj_raw = f["adjacency_matrix"][idx, :, :]
    # print(' raw adjacency matrix shape:', adj_raw.shape)
    adj_matrix = torch.tensor(adj_raw, dtype=float_dtype)
    # print('  Converting adjacency matrix to edge indices...')
    edge_index, edge_weight = dense_to_sparse(adj_matrix)
    # print('  Converting adjacency matrix to sparse format...')
    adj_matrix = adj_matrix.to_sparse()
    
    # print('  Loading node features...')
    # Load node features
    node_features = []

    # Sprinkling coordinates
    # sprinkling = torch.tensor(f["sprinkling"][idx, :, :], dtype=float_dtype)

    # node_features.append(sprinkling)

    # Degree information
    in_degrees = torch.tensor(f["in_degrees"][idx,:], dtype=float_dtype).unsqueeze(1)
    out_degrees = torch.tensor(f["out_degrees"][idx,:], dtype=float_dtype).unsqueeze(1)
    node_features.extend([in_degrees, out_degrees])

    # Path lengths
    max_path_future = torch.tensor(
        f["max_path_lengths_future"][idx,:], dtype=float_dtype
    ).unsqueeze(1)
    max_path_past = torch.tensor(
        f["max_path_lengths_past"][idx,:], dtype=float_dtype
    ).unsqueeze(1)
    node_features.extend([max_path_future, max_path_past])

    # I need more topological information here - angles, etc.
    # Link-based path lengths
    # max_path_future_links = torch.tensor(
    #     f["max_path_lengths_future_links"][idx,:], dtype=float_dtype
    # ).unsqueeze(1)
    # max_path_past_links = torch.tensor(
    #     f["max_path_lengths_past_links"][idx,:], dtype=float_dtype
    # ).unsqueeze(1)
    # node_features.extend([max_path_future_links, max_path_past_links])

    # # Topological ordering --> TODO: check again what this does
    # topo_future = torch.tensor(
    #     f["topological_order_future"][idx,:], dtype=float_dtype
    # ).unsqueeze(1)
    # topo_past = torch.tensor(
    #     f["topological_order_past"][idx,:], dtype=float_dtype
    # ).unsqueeze(1)
    # node_features.extend([topo_future, topo_past])

    # Concatenate all node features
    x = torch.cat(node_features, dim=1)

    # Load graph-level features (targets for regression)
    # print('  Loading graph-level features...')
    manifold_id = int(f["manifold_ids"][idx])
    boundary_id = int(f["boundary_ids"][idx])
    relation_dim = torch.tensor(f["relation_dim"][idx], dtype=float_dtype)
    atom_count = torch.tensor(f["atom_count"][idx], dtype=int_dtype)
    num_sources = torch.tensor(f["num_sources"][idx], dtype=int_dtype)
    num_sinks = torch.tensor(f["num_sinks"][idx], dtype=int_dtype)
    dimension = int(f["dimension"][()])

    # print('  Loading additional matrices...')
    # matrices
    link_matrix = torch.tensor(
        f["link_matrix"][idx, :, :], dtype=float_dtype
    ).to_sparse()
    past_relations = torch.tensor(
        f["past_relations"][idx, :, :], dtype=float_dtype
    ).to_sparse()
    future_relations = torch.tensor(
        f["future_relations"][idx, :, :], dtype=float_dtype
    ).to_sparse()

    # print('  Creating Data object...')
    # Create Data object
    data = Data(
        x=x,
        edge_index=edge_index.to('cpu'),
        edge_attr=edge_weight.unsqueeze(1).to('cpu')
        if edge_weight.numel() > 0
        else None,  # Not sure if this is a good idea need to add edge attributes if possible
        # node positions as positional attributes as well
        # pos=sprinkling,
        y=torch.tensor([manifold_id, boundary_id, dimension], dtype=torch.long).to('cpu'),
        # Graph-level attributes
        manifold_id=manifold_id,
        boundary_id=boundary_id,
        relation_dim=relation_dim,
        dimension=dimension,
        atom_count=atom_count,
        num_sources=num_sources.to('cpu'),
        num_sinks=num_sinks.to('cpu'),
        # sprinkling=sprinkling, # don't use this for now.
        # Additional matrices as graph attributes. make the shitty past and future relations and all that into node attributes!
        adjacency_matrix=adj_matrix.to('cpu'),
        link_matrix=link_matrix.to('cpu'),
        past_relations=past_relations.to('cpu'),
        future_relations=future_relations.to('cpu'),
    )

    if validate:
        data.validate()

    return data

In [None]:

class OneHotEncodeTargets:
    def __init__(self, manifold_classes=6, boundary_classes=3, dimension_classes=3):
        self.manifold_classes = manifold_classes
        self.boundary_classes = boundary_classes  
        self.dimension_classes = dimension_classes
        
        # Dimension mapping: {2: 0, 3: 1, 4: 2}
        self.dim_to_idx = {2: 0, 3: 1, 4: 2}
        
    def __call__(self, data: Data) -> Data:
        # Extract the original targets
        manifold_id = data.y[0].long() - 1  # Convert to 0-based indexing
        boundary_id = data.y[1].long() - 1  # Convert to 0-based indexing  
        dimension = data.y[2].long()        # Keep as is for mapping
        
        # Map dimension to 0-based index
        dim_idx = self.dim_to_idx[dimension.item()]
        
        # Create one-hot encodings
        manifold_onehot = F.one_hot(manifold_id, num_classes=self.manifold_classes).float()
        boundary_onehot = F.one_hot(boundary_id, num_classes=self.boundary_classes).float()
        dimension_onehot = F.one_hot(torch.tensor(dim_idx), num_classes=self.dimension_classes).float()
        y_onehot = torch.cat([manifold_onehot, boundary_onehot, dimension_onehot], dim=0)

        data.y_original = data.y  # Keep original for reference
        data.y = y_onehot

        data.target_info={
            'manifold_classes': self.manifold_classes,
            'boundary_classes': self.boundary_classes,
            'dimension_classes': self.dimension_classes,
            'manifold_offset': 0, 
            'boundary_offset': self.manifold_classes,
            'dimension_offset': self.manifold_classes + self.boundary_classes,   
            'total_classes': self.manifold_classes + self.boundary_classes + self.dimension_classes,
        }
        
        return data

    def __repr__(self):
        return f'{self.__class__.__name__}(manifold_classes={self.manifold_classes}, boundary_classes={self.boundary_classes}, dimension_classes={self.dimension_classes})'

In [None]:
def target_shift(data, manifold_classes=6, boundary_classes=3, dimension_classes=3) -> Data: 
    manifold_id = data.y[0].long() - 1  # conver to 0 based
    boundary_id = data.y[1].long() - 1  # conver to 0 based
    dimension = data.y[2].long() - 2  # conver to 0 based -> 2D is the lowest we can have
    
    data.y_original = data.y  # Keep original for reference
    data.y = torch.tensor([[dimension, boundary_id, manifold_id],], dtype=torch.long)

    data.target_info={
        'manifold_classes': manifold_classes,
        'boundary_classes': boundary_classes,
        'dimension_classes': dimension_classes,
        'dimension_offset': 0,   
        'boundary_offset': dimension_classes,
        'manifold_offset': dimension_classes + boundary_classes, 
        'total_classes': manifold_classes + boundary_classes + dimension_classes,
    }
    
    return data

In [None]:
class CsDataset(Dataset):
    def __init__(
        self,
        input: list[str],
        output: str,  # = root directory for processed data
        transform: Callable[[Data], Data] | None = None,
        pre_transform: Callable[[Data], Data] | None = None,
        pre_filter: Callable[[Data], Data] | None = None,
        validate_data: bool = False,
        loader: Callable[[h5py.File, torch.dtype, torch.dtype, bool], Data] = load_graph,
    ):
        self.input = input
        self._num_samples = None
        self.validate_data = validate_data
        self.root = output
        self.loader = loader
        self.root = output

        if len(self.processed_file_names) > 0:
            self.num_samples = len(self.processed_file_names)

            if len(self.processed_file_names) == 0:
                raise ValueError("No processed data found in the output directory.")

            if not os.path.exists(os.path.join(self.processed_dir, "metadata.json")):
                raise FileNotFoundError(
                    "Metadata file not found in the output directory. "
                    "Please ensure the dataset has been processed."
                )

            with open(os.path.join(self.processed_dir, "metadata.json"), "r") as f:
                metadata = json.load(f)
                self.manifold_codes = metadata["manifold_codes"]
                self.manifold_names = metadata["manifolds"]
                self.boundaries = metadata["boundaries"]
                self.boundary_codes = metadata["boundary_codes"]
                self._num_samples = len(self.processed_file_names)
        else:
            if input is None or len(input) == 0:
                raise ValueError("Input files must be provided for processing.")
            
            with h5py.File(input[0], "r") as f:
                self.manifold_codes = f["manifold_codes"][()]
                self.manifold_names = f["manifolds"][()]
                self.boundaries = f["boundaries"][()]
                self.boundary_codes = f["boundary_codes"][()]

            self._num_samples = 0
            for file in self.input:
                if not os.path.exists(file):
                    raise FileNotFoundError(f"Input file {file} does not exist.")
                with h5py.File(file, "r") as f:
                    self._num_samples += f["num_causal_sets"][()]
                    print(f"Processing file: {file}, current number of samples: {self._num_samples}")

        super().__init__(output, transform, pre_transform, pre_filter)

    @property
    def raw_paths(self):
        return self.input

    @property
    def output(self):
        return self.root

    @property
    def raw_file_names(self):
        return [os.path.basename(f) for f in self.input]

    @property
    def processed_file_names(self):
        if os.path.isdir(self.processed_dir) is False: 
            return []

        all_files = os.listdir(self.processed_dir)
        return [
            f
            for f in all_files
            if f.startswith("data_") and f.endswith(".pt")
        ]

    def process(self):
        print("processed dir: ", self.processed_dir, len(self.processed_file_names))
        # Convert NumPy arrays to Python lists for JSON serialization
        if not os.path.exists(self.processed_dir) or len(self.processed_file_names) == 0: 
            d = {
                "manifold_codes": [v.item() for v in self.manifold_codes],
                "manifolds": [str(m) for m in self.manifold_names],
                "boundaries": [str(m) for m in self.boundaries],
                "boundary_codes": [v.item() for v in self.boundary_codes],
            }

            with open(os.path.join(self.processed_dir, "metadata.json"), "w") as f:
                json.dump(d, f)

            file_index = 0
            for file in self.raw_paths:
                if not os.path.exists(file):
                    raise FileNotFoundError(f"Input file {file} does not exist.")
                with h5py.File(file, "r") as f:
                    # FIXME: this loop should be parallelized for large datasets
                    print(f"Processing file: {file}")
                    print(f"Number of causal sets: {f['num_causal_sets'][()]} of total {self._num_samples} with current index {file_index}")
                    for idx in tqdm(range(f["num_causal_sets"][()])):
                        data = self.loader(
                            f,
                            idx,
                            float_dtype=torch.float32,
                            int_dtype=torch.int64,
                            validate=self.validate_data,
                        )
                        if self.pre_filter is not None:
                            if not self.pre_filter(data):
                                continue
                        if self.pre_transform is not None:
                            data = self.pre_transform(data)
                        torch.save(data.to('cpu'), os.path.join(self.processed_dir, f"data_{file_index}.pt"))
                        file_index += 1

    def len(self):
        return self._num_samples

    def get(self, idx):
        # TODO: check again about the weights_only=False part
        data = torch.load(os.path.join(self.processed_dir, f"data_{idx}.pt"), weights_only=False)
        if self.transform is not None:
            data = self.transform(data)
        return data


# Model definition

In [None]:
import torch 
import torch_geometric
import os
from torch.nn import Linear 
from torch_geometric.nn.conv import GCNConv, GATConv, SAGEConv, GraphConv, GATv2Conv
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, SAGPooling, Set2Set
import torchviz

In [None]:
class GCNBlock(torch.nn.Module):
    def __init__(self, input_dim, output_dim, dropout=0.5, gcn_type=GCNConv, batchnorm=torch.nn.Identity, activation = F.relu, gcn_kwargs=None):
        super(GCNBlock, self).__init__()
        self.dropout = dropout
        self.gcn_type = gcn_type
        self.conv = gcn_type(input_dim, output_dim, **(gcn_kwargs if gcn_kwargs else {}))
        self.activation = activation
        self.batch_norm = batchnorm

        if input_dim != output_dim:
            # Use 1x1 convolution for projection
            self.projection = Linear(input_dim, output_dim, bias=False)
        else: 
            self.projection = torch.nn.Identity()

    def forward(self, x, edge_index, edge_weight=None, kwargs=None):
        x_res = x
        x = self.conv(x, edge_index, edge_weight=edge_weight, **(kwargs if kwargs else {}))  # Apply the GCN layer
        x = self.batch_norm(x, )  # this is a no-op if batch normalization is not used
        x = self.activation(x)
        x = x + self.projection(x_res)  # skip connection
        x = F.dropout(x, p=self.dropout, training=self.training)  # this is only applied during training
        return x


In [None]:
class GCNBackbone(torch.nn.Module): 
    def __init__(self, gcn_net: list[GCNBlock]): 
        super(GCNBackbone, self).__init__()
        self.gcn_net = torch.nn.ModuleList(gcn_net)

    def forward(self, x, edge_index, edge_weight=None, gcn_kwargs=None):
        out = x
        for layer in self.gcn_net:
            out = layer(out, edge_index, edge_weight=edge_weight, kwargs=gcn_kwargs)
            # Note: also changed gcn_kwargs to kwargs to match GCNBlock signature
        return out

In [None]:
class ClassifierBlock(torch.nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_dims,
        manifold_classes=6,
        boundary_classes=3,
        dimension_classes=3,
        activation=F.relu,
        linear_kwargs=None,
        dim_kwargs=None,
        boundary_kwargs=None,
        manifold_kwargs=None,
    ):
        super(ClassifierBlock, self).__init__()
        self.activation = activation
        self.hidden_dims = hidden_dims
        self.total_classes = manifold_classes + boundary_classes + dimension_classes
        self.manifold_classes = manifold_classes
        self.boundary_classes = boundary_classes
        self.dimension_classes = dimension_classes

        if len(hidden_dims) == 0:
            self.backbone = Linear(input_dim, output_dim, **(linear_kwargs[0] if linear_kwargs else {}))
        else:
            layers = []
            in_dim = input_dim
            for (i, hidden_dim) in enumerate(hidden_dims):
                layers.append(
                    Linear(in_dim, hidden_dim, **(linear_kwargs[i] if linear_kwargs and linear_kwargs[i] else {}))
                )  # check again if we need the bias there, I don't think so actually...
                layers.append(activation)
                in_dim = hidden_dim

            self.backbone = torch.nn.Sequential(*layers)

            self.dim_layer = torch.nn.Linear(hidden_dim, self.dimension_classes, **(dim_kwargs if dim_kwargs else {}))

            self.boundary_layer = torch.nn.Linear(hidden_dim, self.boundary_classes, **(boundary_kwargs if boundary_kwargs else {}))

            self.manifold_layer = torch.nn.Linear(hidden_dim, self.manifold_classes, **(manifold_kwargs if manifold_kwargs else {}))

    def forward(self, x, backbone_kwargs=None, dim_layer_kwargs=None, boundary_layer_kwargs=None, manifold_layer_kwargs=None):

        x = self.backbone(
            x, 
            **(backbone_kwargs if backbone_kwargs is not None else {})
        )  

        dim_logit = self.dim_layer(x, **(dim_layer_kwargs if dim_layer_kwargs is not None else {}))
        boundary_logit = self.boundary_layer(x, **(boundary_layer_kwargs if boundary_layer_kwargs is not None else {}))
        manifold_logit = self.manifold_layer(x, **(manifold_layer_kwargs if manifold_layer_kwargs is not None else {}))

        return dim_logit, boundary_logit, manifold_logit

In [None]:
class GraphFeaturesBlock(torch.nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims, dropout=0.5, activation = F.relu, linear_kwargs=None, final_linear_kwargs=None):

        super(GraphFeaturesBlock, self).__init__()
        self.dropout = dropout
        self.activation = activation
        self.hidden_dims = hidden_dims

        if len(hidden_dims) == 0: 
            self.linear = Linear(input_dim, output_dim)
        else: 
            layers = []
            in_dim = input_dim
            for (i, hidden_dim) in enumerate(hidden_dims): 
                layers.append(Linear(in_dim, hidden_dim, **(linear_kwargs[i] if linear_kwargs and linear_kwargs[i] else {})))
                layers.append(activation)
                in_dim = hidden_dim
            layers.append(Linear(in_dim, output_dim, **(final_linear_kwargs if final_linear_kwargs else {})))
            self.linear = torch.nn.Sequential(*layers)

    def forward(self, x): 
        x = self.linear(x)
        return x

basic model class to organize the other things. does the following: 
- passes input through gcn network. This is a succession of GCN blocks
- applies graph feature network to graph level features and concatenates them with pooled node features **if** `use_graph_features = true`. 

- passes the result through the regression net to get out (dimension, boundary_id, manifold_id): 
```bash
x -> gcn -> pool -> concat(_, g) -> regression -> output
g ----------------> MLP_g(g) _|
```

In [None]:
class GCNModel(torch.nn.Module):
    def __init__(
        self,
        gcn_net,
        classifier,
        pooling_layer,
        use_graph_features=False,
        graph_features_net=torch.nn.Identity,
    ):
        super(GCNModel, self).__init__()
        self.gcn_net = gcn_net
        self.classifier = classifier
        self.graph_features_net = graph_features_net
        self.use_graph_features = use_graph_features
        self.pooling_layer = pooling_layer


    def forward(
        self,
        x,
        edge_index,
        batch,
        edge_weight=None,
        graph_features=None,
        gcn_kwargs=None,
    ):
        x = self.gcn_net(
            x, edge_index, edge_weight=edge_weight, **(gcn_kwargs if gcn_kwargs else {})
        )
        
        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
        
        x = self.pooling_layer(x, batch)
        if self.use_graph_features:
            graph_features = self.graph_features_net(graph_features)
            x = torch.cat((x, graph_features), dim=-1)  # last dim
        manifold, boundary, dim = self.classifier(x)

        return manifold, boundary, dim


# Training

## Loss function

In [None]:
def criterion(x_pred, y, dim_kwargs = None, boundary_kwargs=None, manifold_kwargs=None, dim_weight = 1.0, boundary_weight=1.0, manifold_weight=1.0):
    dim_logits= x_pred[0]
    boundary_logits = x_pred[1]
    manifold_logits = x_pred[2]

    dim_truth = y[:, 0]
    boundary_truth = y[:, 1]
    manifold_truth = y[:, 2]
    
    dim_cel = torch.nn.CrossEntropyLoss(**(dim_kwargs if dim_kwargs else {}))
    boundary_cel = torch.nn.CrossEntropyLoss(**(boundary_kwargs if boundary_kwargs else {}))
    manifold_cel = torch.nn.CrossEntropyLoss(**(manifold_kwargs if manifold_kwargs else {}))

    loss_dim = dim_cel(dim_logits, dim_truth)
    loss_boundary = boundary_cel(boundary_logits, boundary_truth)
    loss_manifold = manifold_cel(manifold_logits, manifold_truth)

    total_loss = loss_dim * dim_weight + loss_boundary * boundary_weight + loss_manifold * manifold_weight

    return total_loss, loss_dim, loss_boundary, loss_manifold


Check that this is sensible first!

In [None]:
def train(model, train_loader, validation_loader, optimizer, loss_func, device):
    # FIXME: this is bullshit, has no epochs and nothing!
    training_loss = np.zeros(len(train_loader), dtype=np.float32)
    training_loss_dim = np.zeros(len(train_loader), dtype=np.float32)
    training_loss_boundary = np.zeros(len(train_loader), dtype=np.float32)  
    training_loss_manifold = np.zeros(len(train_loader), dtype=np.float32)

    valid_loss = np.zeros(len(validation_loader), dtype=np.float32)


    for epoch, data in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()
        data = data.to(device)  # Move data to the same device as the model
        print(data.x.shape, data.edge_index.shape, data.batch.shape)
        out = model(data.x, data.edge_index, data.batch)
        total_loss, loss_dim, loss_boundary, loss_manifold = loss_func(out, data.y)  # Assuming data.y contains the target values
        total_loss.backward()
        optimizer.step()
        training_loss[epoch] = total_loss.item()
        training_loss_dim[epoch] = loss_dim.item()
        training_loss_boundary[epoch] = loss_boundary.item()
        training_loss_manifold[epoch] = loss_manifold.item()

        mean_validation_loss = 0.0
        for data in validation_loader:
            model.eval()
            with torch.no_grad():
                data = data.to(device) 
                out = model(data.x, data.edge_index, data.batch)
                total_loss, _, _, _ = loss_func(out, data.y)
                mean_validation_loss += total_loss.item()
        mean_validation_loss /= len(validation_loader)
        print(f"Epoch: {epoch}, Validation Loss: {mean_validation_loss:.4f}")
        valid_loss[epoch] = total_loss.item()

    return training_loss, valid_loss

def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device) 
            dim, boundary, manifold = model(data.x, data.edge_index, data.batch)
            
            dim = torch.argmax(dim, dim=1) # this is enough here, because the encodings are concruential to the indices
            boundary = torch.argmax(boundary, dim=1) # this is enough here, because the encodings are concruential to the indices
            manifold = torch.argmax(manifold, dim=1) # this is enough here, because the encodings are concruential to the indices

            total += data.y.size(0)
            correct += (dim == data.y).sum().item()
            correct += (boundary == data.y).sum().item()
            correct += (manifold == data.y).sum().item()


    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")
    return accuracy

## Actual training loop

### get cuda device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### define the data 

In [None]:
# datapath  = os.path.join(os.path.home(), "data", "causal_sets")
datapath = os.path.join("/mnt", "dataLinux", "machinelearning_data", "QuantumGrav", "causal_sets")
files = [
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=2.h5"),
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=3.h5"),
    os.path.join(datapath, "cset_data_min=300_max=650_N=25000_d=4.h5"),
]

In [None]:
dset = CsDataset(
    input=files,
    output=os.path.join(datapath),
    pre_transform=target_shift,  # maybe add some augmentation stuff here later
    pre_filter=None,  # Filter data before loading, e.g., based on manifold or boundary or something like that
    validate_data=True,  # Validate data after loading
    loader=load_graph,  # Custom loader function
).shuffle()

# train_size = int(math.ceil(0.8 * len(dset)))
# test_size = int(math.ceil(0.1 * len(dset)))
# val_size = len(dset) - train_size - test_size

train_size = 10*64
test_size = 5*64
val_size= 5*64

train_loader = DataLoader(
    dset[0:train_size],
    batch_size=64,
    shuffle=True,
    num_workers=4,
    # pin_memory=True,
    # persistent_workers=True,
    # prefetch_factor=5,
)


test_loader = DataLoader(
    dset[train_size : train_size + test_size],
    batch_size=64,
    shuffle=True,
    # pin_memory=True,
    # persistent_workers=True,
    # prefetch_factor=5,
)

val_loader = DataLoader(
    dset[train_size + test_size: train_size + test_size + val_size],
    batch_size=64,
    shuffle=True,
    # pin_memory=True,
    # persistent_workers=True,
    # prefetch_factor=5,
)

### define a new model again

In [None]:
n_node_features = dset[0].x.shape[1]  # Number of node features
n_edge_features = dset[0].edge_attr.shape[1] if dset[0].edge_attr is not None else 0  # Number of edge features

normalizer = torch.nn.BatchNorm1d
conv_layer = GCNConv  # You can change this to GATConv, SAGEConv, etc. as needed

conv1 = GCNBlock(
    input_dim=n_node_features,
    output_dim=128,
    dropout=0.3,
    batchnorm=normalizer(128),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=torch.nn.ReLU(),
    gcn_kwargs={"cached": True}  # Example of passing additional arguments to the GCN layer
)

conv2 = GCNBlock(
    input_dim=128,
    output_dim=256,
    dropout=0.3,
    batchnorm=normalizer(256),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=torch.nn.ReLU(),
    gcn_kwargs={"cached": True}  # Example of passing additional arguments to the GCN layer
)

conv3 = GCNBlock(
    input_dim=256,
    output_dim=128,
    dropout=0.3,
    batchnorm=normalizer(128),  # Use BatchNorm1d for batch normalization
    gcn_type=conv_layer,  # You can change this to GATConv, SAGEConv, etc.
    activation=torch.nn.ReLU(),
    gcn_kwargs={"cached": True}  # Example of passing additional arguments to the GCN layer
)


gcn_backbone = GCNBackbone([conv1, conv2, conv3])   


In [None]:
classifier = ClassifierBlock(
    input_dim=128,  # Output dimension of the last GCN layer
    output_dim=3,  # Assuming you want to predict manifold_id, boundary_id, and dimension
    hidden_dims=[64, 32],  # Example hidden dimensions
    manifold_classes=6,
    boundary_classes=3,
    dimension_classes=3,
    activation=torch.nn.ReLU(),
)


In [None]:
pooling_layer = global_mean_pool  # You can change this to global_max_pool, global_add_pool, etc.

model = GCNModel(
    gcn_net=gcn_backbone,
    classifier=classifier,
    pooling_layer=pooling_layer,
    use_graph_features=False,  # Set to True if you want to use graph features
)


### Visualize the model

In [None]:
# x = torch.randn((first.x.shape[0], n_node_features), dtype=torch.float32)  
# out = model(x, first.edge_index, first.batch)

# dot = torchviz.make_dot(
#     out,
#     params=dict(model.named_parameters()),
#     show_attrs=False,  # Hide detailed attributes
#     show_saved=False,  # Hide saved tensors
# )

# dot.render("model_visualization", format="pdf", cleanup=True)

### train the model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001, weight_decay=5e-4)

In [None]:
epochs = 10
total_training_loss = np.zeros((epochs, 2), dtype=np.float32)
dim_training_loss = np.zeros((epochs, 2), dtype=np.float32)
boundary_training_loss = np.zeros((epochs, 2), dtype=np.float32)
manifold_training_loss = np.zeros((epochs, 2), dtype=np.float32)

total_validation_loss = np.zeros((epochs, 2), dtype=np.float32)
dim_validation_loss = np.zeros((epochs, 2), dtype=np.float32)
boundary_validation_loss = np.zeros((epochs, 2), dtype=np.float32)
manifold_validation_loss = np.zeros((epochs, 2), dtype=np.float32)


lossfunc = criterion  # Use the custom criterion defined above

model = model.to(device)
model.train()

for epoch in range(epochs):
    total_training_loss_bt = np.zeros((len(train_loader), 2), dtype=np.float64)
    dim_training_loss_bt = np.zeros((len(train_loader), 2), dtype=np.float64)
    boundary_training_loss_bt = np.zeros((len(train_loader), 2), dtype=np.float64)
    manifold_training_loss_bt = np.zeros((len(train_loader), 2), dtype=np.float64)

    model.train()
    for batch, data in enumerate(tqdm(train_loader, desc=f"Epoch {epoch} Training")):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        total_loss, loss_dim, loss_boundary, loss_manifold = lossfunc(out, data.y)
        total_loss.backward()
        optimizer.step()

        total_training_loss_bt[batch] = total_loss.item()
        dim_training_loss_bt[batch] = loss_dim.item()
        boundary_training_loss_bt[batch] = loss_boundary.item()
        manifold_training_loss_bt[batch] = loss_manifold.item()

    mean_total_loss = total_training_loss_bt.mean()
    mean_dim_loss = dim_training_loss_bt.mean()
    mean_boundary_loss = boundary_training_loss_bt.mean()
    mean_manifold_loss = manifold_training_loss_bt.mean()

    std_total_loss = total_training_loss_bt.std()
    std_dim_loss = dim_training_loss_bt.std()
    std_boundary_loss = boundary_training_loss_bt.std()
    std_manifold_loss = manifold_training_loss_bt.std()

    total_training_loss[batch, 0] = mean_total_loss
    dim_training_loss[batch, 0] = mean_dim_loss
    boundary_training_loss[batch, 0] = mean_boundary_loss
    manifold_training_loss[batch, 0] = mean_manifold_loss   

    total_training_loss[batch, 1] = std_total_loss
    dim_training_loss[batch, 1] = std_dim_loss
    boundary_training_loss[batch, 1] = std_boundary_loss
    manifold_training_loss[batch, 1] = std_manifold_loss

    # Validation step

    total_validation_loss_bt = np.zeros((len(val_loader), 2), dtype=np.float64)
    dim_validation_loss_bt = np.zeros((len(val_loader), 2), dtype=np.float64)
    boundary_validation_loss_bt = np.zeros((len(val_loader), 2), dtype=np.float64)
    manifold_validation_loss_bt = np.zeros((len(val_loader), 2), dtype=np.float64)

    model.eval()
    with torch.no_grad():
        for batch, data in enumerate(tqdm(val_loader, desc=f"Epoch {epoch} Validation")):
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            total_loss, loss_dim, loss_boundary, loss_manifold = lossfunc(out, data.y)

            total_validation_loss_bt[batch] = total_loss.item()
            dim_validation_loss_bt[batch] = loss_dim.item()
            boundary_validation_loss_bt[batch] = loss_boundary.item()
            manifold_validation_loss_bt[batch] = loss_manifold.item()

        mean_val_total_loss = total_validation_loss_bt.mean()
        mean_val_dim_loss = dim_validation_loss_bt.mean()
        mean_val_boundary_loss = boundary_validation_loss_bt.mean()
        mean_val_manifold_loss = manifold_validation_loss_bt.mean()

        std_val_total_loss = total_validation_loss_bt.std()
        std_val_dim_loss = dim_validation_loss_bt.std()
        std_val_boundary_loss = boundary_validation_loss_bt.std()
        std_val_manifold_loss = manifold_validation_loss_bt.std()

        total_validation_loss[batch, 0] = mean_val_total_loss
        dim_validation_loss[batch, 0] = mean_val_dim_loss
        boundary_validation_loss[batch, 0] = mean_val_boundary_loss
        manifold_validation_loss[batch, 0] = mean_val_manifold_loss

        total_validation_loss[batch, 1] = std_val_total_loss
        dim_validation_loss[batch, 1] = std_val_dim_loss
        boundary_validation_loss[batch, 1] = std_val_boundary_loss
        manifold_validation_loss[batch, 1] = std_val_manifold_loss

    print(f"Epoch: {epoch}")
    print(
        f"  training loss: {mean_total_loss:.4f} ± {std_total_loss:.4f}, Dim : {mean_dim_loss:.4f} ± {std_dim_loss:.4f}, Boundary : {mean_boundary_loss:.4f} ± {std_boundary_loss:.4f}, Manifold : {mean_manifold_loss:.4f} ± {std_manifold_loss:.4f}"
    )

    print(
        f"  validation loss: {mean_val_total_loss:.4f} ± {std_val_total_loss:.4f}, Dim : {mean_val_dim_loss:.4f} ± {std_val_dim_loss:.4f}, Boundary : {mean_val_boundary_loss:.4f} ± {std_val_boundary_loss:.4f}, Manifold : {mean_val_manifold_loss:.4f} ± {std_val_manifold_loss:.4f}"
    )

In [None]:
accuracy = test(model, test_loader)

# TODO
- [ ] gpu utilization low 
- [ ] implement early stopping 
- [ ] performance improvements 
- [ ] experiments with different models 