In [None]:
# jupyter magic commands to avoid having to restart upon change of lib code
%load_ext autoreload
%autoreload 2

# Training a custom model with QG 

You are not limited to the prebuild models that are available in QuantumGrav, but can use its facilities in full or in part to train your own models or replace parts of it with your own implementation. This notebook will give an example of how to build and train your own model with QuantumGrav. 

## Building the model 

We start with building a simple model by hand. This runs a torch_geometric graph convolutional layer in 'source-to-target' and 'target-to-source' message passing mode, which effectively acts as 'looking into the future neighborhood' and 'looking into the past neighborhood' if we think about causal sets (compare [torch_geometric's `DirGNNConv`](https://pytorch-geometric.readthedocs.io/en/2.5.2/generated/torch_geometric.nn.conv.DirGNNConv.html)). A trainable parameter will make the model learn to weight the two results in the output.

In [None]:
import torch
import torch_geometric
from typing import Any
from copy import deepcopy
import yaml
from pathlib import Path
from zarr.storage import LocalStore
import numpy as np
import QuantumGrav as QG

a more general way of building this would be the following: 

In [None]:
class BidirectionalModel(torch.nn.Module):

    def __init__(self, conv: type[torch.nn.Module], args: list[Any], kwargs: dict[str, Any]):
        super().__init__()
        self.futureconv  = deepcopy(conv(*args, **kwargs)) # deepcopy to avoid reusing conv in both parts
        self.pastconv = deepcopy(conv(*args, **kwargs))
        self.activation = torch.nn.ReLU()
        self.weightf = torch.nn.Parameter(torch.tensor(1.0))
        self.weightb = torch.nn.Parameter(torch.tensor(1.0))
        self.mlp = torch_geometric.nn.dense.Linear(args[1], 1)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor | None = None) -> torch.Tensor:
        x_f = self.futureconv(x, edge_index) # future direction
        x_p = self.pastconv(x, edge_index.flip([0])) # past direction
        x_ = self.weightf*x_f + self.weightb*x_p
        x_ = self.activation(x_)
        x_ = torch_geometric.nn.pool.global_mean_pool(x_, batch)
        x_ = self.mlp(x_)
        return x_


## instantiating the custom model

In [None]:
test_model = BidirectionalModel(torch_geometric.nn.conv.SAGEConv, [2,32], {"normalize": True, "root_weight": False,})

# using our custom model in a config

in order to use this with the trainer, directly, we can use this in a config node. To this end, first have a look at the config we already know and how the model is defined there: 

```yaml
model:
  encoder_type: !pyobject QuantumGrav.models.GNNBlock
  encoder_args: [2, 32]
  encoder_kwargs:
    dropout: 0.3
    with_skip: True
    gnn_layer_type: !pyobject torch_geometric.nn.conv.GCNConv
    gnn_layer_args: []
    gnn_layer_kwargs: {cached: False, bias: True, add_self_loops: True}
    normalizer_type: !pyobject torch.nn.BatchNorm1d
    norm_args: [32]
    norm_kwargs: {eps: 0.00001, momentum: 0.2}
    activation_type: !pyobject torch.nn.ReLU
    skip_args: [2, 32]
    skip_kwargs: {weight_initializer: "glorot"}
  downstream_tasks:
    - [
        !pyobject QuantumGrav.models.LinearSequential,
        [
            [[32, 1],],
            [!pyobject torch.nn.Identity,],
        ],
        {
            linear_kwargs: [
                {bias: True,},
            ],
            activation_kwargs: [{},],
        },
    ]
  pooling_layers:
    - [!pyobject torch_geometric.nn.global_mean_pool, [], {}]

  active_tasks:
    0: True

```

this is geared towards building a `QG.GNNModel` instance, which for many usecases is general enough and provides a cognitive scheme to think about architecture. For our custom model, this scheme is not suitable anymore, and therefore we use a different way to define the model in the config, namely the usual (type, args, kwargs) scheme. 

**Note** This approach can become complicated for complex models and can lead to confusion in the config definition. Therefore, it's usually useful to think about the level of generality and complexity first, and parameterize the Model classes accordingly.

For our simple example, this is well within the boundaries of the manageable, so here is the config definition for the custom model: 

```yaml 
model:
  type: !pyobject __main__.BidirectionalModelGenearl
  args:
    - !pyobject torch_geometric.nn.conv.SAGEConv # type
    - [2, 32,] # args 
    - {
      normalize: True,
      root_weight: False
    } # kwargs 
  # no further kwargs for `BidirectionalModel`
```

let's define the loss function, data reader and -transform so the trainer has something to work with: 

In [None]:
def compute_loss(predictions: torch.Tensor,data:torch_geometric.data.Data, *args) -> torch.Tensor:
    return torch.nn.functional.mse_loss(predictions, data.y)


In [None]:
def read_data(
    store: LocalStore,
    idx: int,
    # these parameters are needed for the data reading function signature
    float_dtype: torch.dtype,
    int_dtype: torch.dtype,
    validate_data: bool,
) -> dict[str, Any]:
    """Read raw data from file"""

    # create an empty dict to hold the data
    data: dict[str, Any] = dict()

    # open the correct cset group. We always get the whole store, so we are not
    # dependent on a root group being present. Julia does not build one by default
    cset_group = zarr.open_group(store, path = f"cset_{idx+1}", mode="r")

    # read the data from the cset group into the data dict. This reads the entire group.
    QG.zarr_group_to_dict(
        cset_group, data
    )  # insert name of cset group to read

    return data


In [None]:

def collect_data(data: dict) -> torch_geometric.data.Data:
    """transform the data dictionary into a PyG Data object and cache the result on disk"""

    adj = data["linkmat"]

    # normally, a transpose would be necessary here b/c Julia stores column major
    # and Python/numpy/torch row major by default but torch sparse does a transpose
    # again, so we can leave it as is...
    edge_index, _ = torch_geometric.utils.dense_to_sparse(torch.tensor(adj, dtype=torch.float32))

    # in and out degrees for the degree labels of the link matrix
    x = torch.tensor(
        np.array([
            data["in_degrees_link"],
            data["out_degrees_link"],
        ]),
        dtype=torch.float32,
    ).t()  # shape [N,num_labels]

    # manifold like as target
    y = torch.tensor(
        np.array([
            data["manifold_like"],
        ]),
        dtype=torch.float32,
    )

    # build the final PyG data object
    tgdata = torch_geometric.data.Data(
        edge_index=edge_index,
        x=x,
        y=y,
        csettype = torch.tensor(data["cset_type"]),
        atom_count=torch.tensor(adj.shape[0]),
    )

    if not tgdata.validate():
        raise ValueError(f"Data validation failed for {tgdata}.")

    return tgdata


we skipped the definition of additioanl monitor functions here, so only the average losses will be monitored. we also didn't make use of early stopping - the goal is to show how to use a custom model, not how to use the infrastructure. 

# load config and build trainer

In [None]:
config_path = Path("../../configs/train_custom_model.yaml")

with open(config_path, "r") as configfile:
    config = yaml.load(configfile, QG.get_loader())

In [None]:
trainer = QG.Trainer(config)

In [None]:
train_loader, valid_loader, test_loader = trainer.prepare_dataloaders()

In [None]:
len(train_loader)

In [None]:
len(valid_loader)

In [None]:
len(test_loader)

In [None]:
trainer.run_training(train_loader=train_loader, val_loader=valid_loader)