# Simple model training 

In this notebook, we demonstrate how to train a simple classifier model using the QuantumGravPy framework. We will set up the data, define the model, and run the training process while monitoring performance metrics. 
This model does not perform well on the task, but serves as a basic example of how to use the framework for training. 

First, we import all the necessary packages and modules from QuantumGravPy. 

In [None]:
# main library
import QuantumGrav as QG

# logging, typing, configs, numpy
import yaml
from pathlib import Path
import logging
from typing import Any
import numpy as np


# data
import zarr
from zarr.storage import LocalStore
from zarr import Group

# data handling and plotting
import pandas as pd
import seaborn as sns

# evaluation
from sklearn.metrics import f1_score

# ML with torch
import torch
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data



Packages of note here: 
- sklearn is the scipy machine learning library and contains a vast collection tools, models, and metrics for machine learning tasks.
- seaborn is a statistical data visualization library built on top of matplotlib, providing a high-level interface for creating informative and attractive statistical graphics. 
- zarr is a powerful library for working with large, chunked, compressed, N-dimensional arrays, enabling efficient storage and retrieval of large datasets.
- torch_geometric is an extension library for PyTorch that provides tools and functionalities for deep learning on graph-structured data, enabling the development of graph neural networks and related models. QuantumGrav's Models are built on top of this library.


# Preparations 

First, we set the formatting that the logger should use to have a bit more structured output

In [None]:
# set loggin level to info for trainer output
logging.basicConfig(
    level=logging.INFO,  # or logging.INFO if you want less verbosity
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)


We also add a pretty printer function to display configuration sections in a more readable way.

In [None]:
def pretty_print(data: dict[Any, Any], offset = "") -> None:
    pp = pprint.PrettyPrinter(indent=4)
    if isinstance(data, dict):
        print()
        for key, value in data.items():
            print(f"{offset}{key}:", end=" ")
            pretty_print(value, offset + "  ")
    elif isinstance(data, list):
        print()
        for i, item in enumerate(data):
            print(f"{offset}- Item {i}:", end=" ")
            pretty_print(item, offset + "  " )
    else:
        pp.pprint(data)

# Define reader function for data

At first, we need functions that reads the data from disk and convert it into a format that can be used by the model.
This format is the torch_geometric Data. Here, we split this task into two functions. 
The first is `read_data`: This function reads the raw data from a Zarr store on disk and builds a dictionary from them. QuantumGrav provides a helper function to read an entire zarr group into memory. This function has a fixed signature that needs to be adhered to: 
```python
def read_data(store: zarr.storage.Store, idx: int, float_dtype: torch.dtype, int_dtype: torch.dtype, validate_data: bool) -> dict
```
These have the following meanings: 
- `store`: The zarr store from which to read the data.
- `idx`: The index of the data sample to read. All samples in a given dataset are indexed from 0 to N-1, where N is the total number of samples. This index runs across all zarr stores that are a part of the dataset. 
- `float_dtype`: The desired floating-point data type for the data.
- `int_dtype`: The desired integer data type for the data.
- `validate_data`: A boolean flag indicating whether to validate the data after reading it.
The function should return a dictionary containing the data read from the zarr store.



In [None]:
def read_data(
    store: LocalStore,
    idx: int,
    # these parameters are needed for the data reading function signature
    float_dtype: torch.dtype,
    int_dtype: torch.dtype,
    validate_data: bool,
) -> dict[str, Any]:
    """Read raw data from file"""

    # create an empty dict to hold the data
    data: dict[str, Any] = dict()

    # open the correct cset group. We always get the whole store, so we are not
    # dependent on a root group being present. Julia does not build one by default
    cset_group = zarr.open_group(store, path = f"cset_{idx+1}", mode="r")

    # read the data from the cset group into the data dict. This reads the entire group.
    QG.zarr_group_to_dict(
        cset_group, data
    )  # insert name of cset group to read

    return data


# transform into a Data object

The second function takes this dictionary and transforms it into a torch_geometric Data object. This function has a fixed signature that needs to be adhered to: 
```python
def collect_data(data: dict,) -> Data
```

i.e., it transforms a dictionary into a Data object. 
What it does internally with the data dictionary is up to the author of the function. In our case, we extract the adjacency matrix, node features, and labels from the dictionary and convert them into the appropriate format for a Data object.
However, we can also do more complex things, like computing eigenvalues, normalizing features, or adding additional attributes to the Data object.

In [None]:

def collect_data(data: dict) -> Data:
    """transform the data dictionary into a PyG Data object and cache the result on disk"""

    adj = data["linkmat"]

    # normally, a transpose would be necessary here b/c Julia stores column major
    # and Python/numpy/torch row major by default but torch sparse does a transpose
    # again, so we can leave it as is...
    edge_index, _ = dense_to_sparse(torch.tensor(adj, dtype=torch.float32))

    # in and out degrees for the degree labels of the link matrix
    x = torch.tensor(
        np.array([
            data["in_degrees_link"],
            data["out_degrees_link"],
        ]),
        dtype=torch.float32,
    ).t()  # shape [N,num_labels]

    # manifold like as target
    y = torch.tensor(
        np.array([
            data["manifold_like"],
        ]),
        dtype=torch.float32,
    )

    # build the final PyG data object
    tgdata = Data(
        edge_index=edge_index,
        x=x,
        y=y,
        csettype = torch.tensor(data["cset_type"]),
        atom_count=torch.tensor(adj.shape[0]),
    )

    if not tgdata.validate():
        raise ValueError(f"Data validation failed for {tgdata}.")

    return tgdata


Why the split? Because you have a choice of when to apply the data transformations. You can do it once when reading the data from disk and cache the result (then it's called `pre_transform`) or you can do it on-the-fly during training (then it's called `transform`). What you choose depends on the complexity of the transformations and the size of your dataset, but also on the workflow you are using. If you are experimenting more on the data side, it can be more useful to do it on-the-fly. If the transformations are expensive, it can be better to do it once and cache the result. 

Next, we need a set of functions that define the loss function to use to train the model, and functions that compute the monitoring metrics from the model outputs that we want to track during validation and testing. 

# define loss function


First, we define the loss function to use during training. In this case, we use the standard binary cross-entropy loss for binary classification tasks. The loss function takes the model outputs and the true labels as inputs and computes the loss value. It must be noted that the model outputs a dictionary of tensors (one per task head), so we need to extract the relevant tensor for the loss computation. 

Note that the loss function must be appropriate for the task at hand. For multi-class classification, for example, you would use a different loss function, such as cross-entropy loss. 

Here, we also added checks for not-a-number (NaN) values in the model outputs and labels to ensure that the loss computation is valid. If NaN values are detected, an error is raised to prevent invalid computations during training. 


In [None]:
def compute_loss(
    predictions: dict[Any, torch.Tensor],
    data: Data,
    *args,
    **kwargs,
) -> torch.Tensor:
    """Compute loss for a single classifier task at index 0"""

    pred = predictions[0]
    tgt = data.y

    # check for nans -> comment out for speed, this is python overhead
    if torch.isnan(pred).any():
        raise ValueError("Nan in predictions")

    if torch.isnan(tgt).any():
        raise ValueError("Nan in targets")

    return torch.nn.BCEWithLogitsLoss(
        reduction="mean",
    )(pred, tgt)



# Monitor functions

Just like the loss function computes how 'well' a model is performing during training in order to compute the gradients for optimization, monitor functions compute metrics that help us understand how well the model is performing on validation and test datasets without affecting the training process itself. These therefore use metrics that measure model performance more directly, such as accuracy, precision, recall, F1-score, etc. 

In the same way as the loss function, the monitor functions take the model outputs and true labels as inputs and compute the desired metrics. Again, we need to extract the relevant tensor from the model outputs for the computation.

Here we monitor two versions of the [f1 score](https://en.wikipedia.org/wiki/F-score), a standard metric for classification tasks. We compute the weighted f1 score, which takes into account class imbalances by weighting the contribution of each class according to its prevalence in the dataset. This is useful when dealing with imbalanced datasets where some classes are more frequent than others. This is done in the function `f1_monitor`. 


In [None]:
def f1_monitor(
    predictions: list[dict[Any, torch.Tensor]],
    targets: list[torch.Tensor],
) -> float:
    """Compute F1 score for a single classifier task (key 0)."""

    logits_list: list[torch.Tensor] = []
    for p in predictions:
        if not isinstance(p, dict):
            raise TypeError(f"Expected dict of tensors, got {type(p)}")
        if 0 in p:
            t = p[0]
        elif "0" in p:
            t = p["0"]
        else:
            t = next(iter(p.values()))
        logits_list.append(t.detach())

    pred_logits = torch.cat([t.reshape(-1) for t in logits_list])
    tgt_vec = torch.cat([t.detach().reshape(-1) for t in targets])

    pred_labels = (torch.sigmoid(pred_logits) >= 0.5).to(torch.int32).cpu().numpy()
    tgt_labels = tgt_vec.to(torch.int32).cpu().numpy()

    return f1_score(tgt_labels, pred_labels, average="weighted")



In addition, we monitor the f1 score per class, which provides insights into how well the model performs on each individual class. This is particularly useful for identifying classes that may be more challenging for the model to classify correctly. This is done in the function `f1_monitor_perclass` below. The only difference to the previous function is that we do not use the `average='weighted'` option in the sklearn f1_score function, which computes the f1 score for each class separately. 

The disadvantage of this approach is that we need to do some work twice, the assembling of the true and predicted labels into numpy arrays for instance. QuantumGrav allows you to build your own evaluator classes however so if this becomes a bottleneck you can implement a custom evaluator that computes both metrics in one pass for instance. This is covered elsewhere. 

In [None]:

def f1_monitor_perclass(
    predictions: list[dict[Any, torch.Tensor]],
    targets: list[torch.Tensor],
) -> float:
    """Compute F1 score for a single classifier task (key 0)."""

    logits_list: list[torch.Tensor] = []
    for p in predictions:
        if not isinstance(p, dict):
            raise TypeError(f"Expected dict of tensors, got {type(p)}")
        if 0 in p:
            t = p[0]
        elif "0" in p:
            t = p["0"]
        else:
            t = next(iter(p.values()))
        logits_list.append(t.detach())

    pred_logits = torch.cat([t.reshape(-1) for t in logits_list])
    tgt_vec = torch.cat([t.detach().reshape(-1) for t in targets])

    pred_labels = (torch.sigmoid(pred_logits) >= 0.5).to(torch.int32).cpu().numpy()
    tgt_labels = tgt_vec.to(torch.int32).cpu().numpy()

    return f1_score(tgt_labels, pred_labels, average=None)

# Setup trainer and datasets 

The trainer class is build from a config file that defines the model, optimizer, training epochs, dataset, evaluators and early stopping system as well as the loss function - everything we need to build a functioning training machinery. 

In [None]:
config_path = Path("../configs/train_classifier.yaml")

with open(config_path, "r") as configfile:
    config = yaml.load(configfile, QG.get_loader())


let's have a look at how that looks like.

In [None]:
pretty_print(config)


In [None]:
pretty_print(config["training"])

In [None]:
pretty_print(config["validation"]["validator"])

In [None]:
pretty_print(config["testing"]["tester"])

In [None]:
pretty_print(config["early_stopping"])

In [None]:
# set up trainer
trainer = QG.Trainer.from_config(config)

# Prepare dataset and dataloaders

let's once more have a look at the config file for the data 

In [None]:
pretty_print(config["data"])

Here, we see that the functions for data reading and transformation are actually used: The `collect_data` function is used as a `pre_transform`, i.e., the results will be cached on disk, and `read_data` will read the raw data. Have a look at the documentation of `QGDastaset` to learn more about them. We also want the dataset to be shuffled here. Doing this, either here in the dataset or in the data loader, is important because otherwise we will stack the data by type otherwise (polynomial first, then complex, then random), which is bad for training. 

The `prepare_dataloaders` function normally takes care of preparing the dataset too, but we can do this by hand or also use our own, independently build dataset that we have put together in code without the config. Have a look at the arguments of `prepare_dataloader` for instance. 
The dataloaders have their own arguments: 

In [None]:
pretty_print(config["training"])

Pay attention to the parameters from batch_size downwards. We have a batchsize of 64, and drop the last batch when it can't be made into one that is 64 datapoints long. We also use 12 worker processes to read data and build batches and try to pre-load as much as 8 batches into memory to minimize the time the gpu has to wait on data to become available. Finally, we use `pin_memory` which is a optimization to improve transfer speed from cpu to gpu (see [here](https://discuss.pytorch.org/t/when-to-set-pin-memory-to-true/19723) and [here](https://developer.nvidia.com/blog/how-optimize-data-transfers-cuda-cc/) for more). 

In [None]:
pretty_print(config["validation"])

In [None]:
# get dataloaders
train_loader, valid_loader, test_loader = trainer.prepare_dataloaders()

We have a dataset of 55000 datapoints in this case, and split this into 80% training-, 10% validation and 10% testing dataset.

In [None]:
len(train_loader.dataset)

In [None]:
len(valid_loader.dataset)

In [None]:
len(test_loader.dataset)

# check out data and plot some statistics

Let us examine the dataset to verify we have no unwanted structures in it. First we put the data we want into pandas dataframes which makes plotting with seaborn easy. That's substantial work, hence we avoid doing it when the dataframe has already been saved on disk before: 

In [None]:
if not (Path(config["training"]["path"]) / "cset_data_summary.csv").exists():

    import tqdm


    df_trainer = pd.DataFrame(
        {
            "cset_type": [],
            "cset_size": [],
            "manifold_like": [],
        }
    )


    df_validator = pd.DataFrame(
        {
            "cset_type": [],
            "cset_size": [],
            "manifold_like": [],
        }
    )



    df_tester = pd.DataFrame(
        {
            "cset_type": [],
            "cset_size": [],
            "manifold_like": [],
        }
    )


    for i in tqdm.tqdm(range(len(trainer.train_dataset))):
        data = trainer.train_dataset[i]
        df_trainer.loc[len(df_trainer)] = {
            "cset_type": data.csettype.to(torch.int32).item(),
            "cset_size": data.atom_count.item(),
            "manifold_like": data.y.item(),
        }

    for i in tqdm.tqdm(range(len(trainer.val_dataset))):
        data = trainer.val_dataset[i]
        df_validator.loc[len(df_validator)] = {
            "cset_type": data.csettype.to(torch.int32).item(),
            "cset_size": data.atom_count.item(),
            "manifold_like": data.y.item(),
        }

    for i in tqdm.tqdm(range(len(trainer.test_dataset))):
        data = trainer.test_dataset[i]
        df_tester.loc[len(df_tester)] = {
            "cset_type": data.csettype.to(torch.int32).item(),
            "cset_size": data.atom_count.item(),
            "manifold_like": data.y.item(),
        }

    df_trainer["dataset"] = "trainer"
    df_validator["dataset"] = "validator"
    df_tester["dataset"] = "tester"

    df_all = pd.concat([df_trainer, df_validator, df_tester], ignore_index=True)
    df_all.head()
    df_all["manifold_like"] = df_all["manifold_like"].astype(int)
    df_all.to_csv(Path(config["training"]["path"]) / "cset_data_summary.csv", index=False)
else:
    df_all = pd.read_csv(Path(config["training"]["path"]) / "cset_data_summary.csv")


Here, we only record cset_type, cset_size and manifold_likeness, but we can do whatever we have in the output of `collect_data`. 

In [None]:
df_all.head()

Let's have a look at the histogram of the `cset_type`: We see that they are all relatively evenly distributed between training, validation and test set, but since the dataset contains two stores of complex- and polynomial data (types 1, 2), their bars are roughly twice the height of the others too. 

In [None]:

type_plot = sns.histplot(data=df_all, x="cset_type", hue="dataset", multiple="dodge", shrink=0.8, stat = "count")
fig = type_plot.get_figure()

# save figure to output path
fig.savefig(Path(config["training"]["path"]) / "cset_type_distribution.png")

fig.show()


... consequently if we make a histogram of manifold_likeness, we see as expected that complex- and polynomial datasets are marked as manifold like and the others are not. They are still, overall, underreprepresented though compared to the other ones. 

In [None]:

manifoldlike_plot = sns.histplot(data=df_all, x="cset_type", hue="manifold_like", multiple="dodge", palette = "tab10")
fig = manifoldlike_plot.get_figure()

# save figure to output path
fig.savefig(Path(config["training"]["path"]) / "manifold_like_distribution.png")
fig.show()


Let's make a more complex plot that shows marginal distributions over types and sizes to see if there are other biases. We see that there are fewer examples than for the others, This is because for complex csets, some nodes can be deleted when they are too close to a cut through the domain. The rest is consistent with what we have seen before. 

In [None]:
size_plot = sns.jointplot(data=df_all, x="cset_size", y="cset_type", kind="hist", cmap="mako")

# save figure to output path
size_plot.savefig(Path(config["training"]["path"]) / "cset_size_distribution.png")


# run training

We can run training now. We first need to initialize the model and the optimizer that adjusts its parameter. These are defined in the config file:

In [None]:
pretty_print(config["model"])

The trainer uses this config to set up the model when we call the `initialize_model` task.

In [None]:
# initialize model
trainer.initialize_model()


For the optimizer, we are  using the [AdamW optimizer](https://docs.pytorch.org/docs/main/generated/torch.optim.AdamW.html) here. In the same way, the trainer uses the config to build the optimizer when we call the `initalize_optimizer()` function

In [None]:
pretty_print(config["training"])

In [None]:
# initialize optimizer
trainer.initialize_optimizer()

Finally, we start the training process by calling the trainer's `run_training` function and pass it the train_dataloader and the validation dataloader as arguments to use as training and validation data. The rest - looping over epochs, calling the optimizer and recording data, is handled by the trainer in the background. 

In [None]:
train_results, validation_results = trainer.run_training(
    train_loader=train_loader,
    val_loader=valid_loader,
)

# test model

After training, the model is tested on an unseen dataset. The results informs us about how well the model generalizes to unseen data. 

In [None]:
test_results = trainer.run_test(test_loader)

In [None]:
test_results

testing results are somewhat mediocre so our model is not doing too well.. 

# save training, validation, test results to disk

because the validation dand testing results are dataframes, we can make use of pandas powerful abilities and interactions with seaborn for visualization, data handling and computation. Here, we save the results to csv files to the same directory we save the other training data to, and then proceed to visualize the results 

In [None]:
# save training, validation, test results to disk
train_results = pd.DataFrame(train_results, columns=["loss_mean", "loss_std"])

train_results.to_csv(
    Path(config["training"]["path"]) / "training_results.csv"
)

# save validatoin results to disk
validation_results.to_csv(Path(config["training"]["path"]) / "validation_result.csv")

# run test
test_results.to_csv(Path(config["training"]["path"]) / "test_result.csv")


In order for plotting with seaborn to work, we need to change the validataion dataframe to long form. We also add an 'epoch' column. 

In [None]:
validation_results["epoch"] = range(1, len(validation_results) + 1)

validation_results.head()

In [None]:
valid_melted =validation_results.loc[:, ["epoch", "loss_avg", "loss_min", "loss_max", "f1_weighted"]].melt(
    id_vars=["epoch"],
    value_vars=["loss_avg", "loss_min", "loss_max", "f1_weighted"],
    var_name="metric",
    value_name="value",
)

In [None]:
ax = sns.lineplot(data=valid_melted, x="epoch", y="value", hue="metric", linewidth=2.5, markers="o", legend="brief")
ax.figure.savefig(Path(config["training"]["path"]) / "validation_metrics.png")


As is apparent above, the model seems to be not powerful enough to learn well from the data, adn f1 scores and losses are not very stable. 