In [None]:
# jupyter magic commands to avoid having to restart upon change of lib code
%load_ext autoreload
%autoreload 2

# Simple model training 

In this notebook, we demonstrate how to train a simple classifier model using the QuantumGravPy framework. This follows the `train_simple_model.ipynb` notebook, but we will show how to build and include a weighted sampler for imbalanced data. If you are not interested in reviewing the details of setting up the model training, you can skip to the "Prepare dataset and dataloaders with weighted sampler" section.

First, we import all the necessary packages and modules from QuantumGravPy. 

In [None]:
# main library
import QuantumGrav as QG

# logging, typing, configs, numpy
import yaml
from pathlib import Path
import logging
from typing import Any
import numpy as np


# data
import zarr
from zarr.storage import LocalStore
from zarr import Group

# data handling and plotting
import pandas as pd
import seaborn as sns

# evaluation
from sklearn.metrics import f1_score

# ML with torch
import torch
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data



Packages of note here: 
- sklearn is the scipy machine learning library and contains a vast collection tools, models, and metrics for machine learning tasks.
- seaborn is a statistical data visualization library built on top of matplotlib, providing a high-level interface for creating informative and attractive statistical graphics. 
- zarr is a powerful library for working with large, chunked, compressed, N-dimensional arrays, enabling efficient storage and retrieval of large datasets.
- torch_geometric is an extension library for PyTorch that provides tools and functionalities for deep learning on graph-structured data, enabling the development of graph neural networks and related models. QuantumGrav's Models are built on top of this library.


# Preparations 

First, we set the formatting that the logger should use to have a bit more structured output

In [None]:
# set loggin level to info for trainer output
logging.basicConfig(
    level=logging.INFO,  # or logging.INFO if you want less verbosity
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)


We also add a pretty printer function to display configuration sections in a more readable way.

In [None]:
def pretty_print(data: dict[Any, Any], offset = "") -> None:
    pp = pprint.PrettyPrinter(indent=4)
    if isinstance(data, dict):
        print()
        for key, value in data.items():
            print(f"{offset}{key}:", end=" ")
            pretty_print(value, offset + "  ")
    elif isinstance(data, list):
        print()
        for i, item in enumerate(data):
            print(f"{offset}- Item {i}:", end=" ")
            pretty_print(item, offset + "  " )
    else:
        pp.pprint(data)

# Define reader function for data

At first, we need functions that reads the data from disk and convert it into a format that can be used by the model.
This format is the torch_geometric Data. Here, we split this task into two functions. 
The first is `read_data`: This function reads the raw data from a Zarr store on disk and builds a dictionary from them. QuantumGrav provides a helper function to read an entire zarr group into memory. This function has a fixed signature that needs to be adhered to: 
```python
def read_data(store: zarr.storage.Store, idx: int, float_dtype: torch.dtype, int_dtype: torch.dtype, validate_data: bool) -> dict
```
These have the following meanings: 
- `store`: The zarr store from which to read the data.
- `idx`: The index of the data sample to read. All samples in a given dataset are indexed from 0 to N-1, where N is the total number of samples. This index runs across all zarr stores that are a part of the dataset. 
- `float_dtype`: The desired floating-point data type for the data.
- `int_dtype`: The desired integer data type for the data.
- `validate_data`: A boolean flag indicating whether to validate the data after reading it.
The function should return a dictionary containing the data read from the zarr store.



In [None]:
def read_data(
    store: LocalStore,
    idx: int,
    # these parameters are needed for the data reading function signature
    float_dtype: torch.dtype,
    int_dtype: torch.dtype,
    validate_data: bool,
) -> dict[str, Any]:
    """Read raw data from file"""

    # create an empty dict to hold the data
    data: dict[str, Any] = dict()

    # open the correct cset group. We always get the whole store, so we are not
    # dependent on a root group being present. Julia does not build one by default
    cset_group = zarr.open_group(store, path = f"cset_{idx+1}", mode="r")

    # read the data from the cset group into the data dict. This reads the entire group.
    QG.zarr_group_to_dict(
        cset_group, data
    )  # insert name of cset group to read

    return data


# transform into a Data object

The second function takes this dictionary and transforms it into a torch_geometric Data object. This function has a fixed signature that needs to be adhered to: 
```python
def collect_data(data: dict,) -> Data
```

i.e., it transforms a dictionary into a Data object. 
What it does internally with the data dictionary is up to the author of the function. In our case, we extract the adjacency matrix, node features, and labels from the dictionary and convert them into the appropriate format for a Data object.
However, we can also do more complex things, like computing eigenvalues, normalizing features, or adding additional attributes to the Data object.

In [None]:

def collect_data(data: dict) -> Data:
    """transform the data dictionary into a PyG Data object and cache the result on disk"""

    adj = data["linkmat"]

    # normally, a transpose would be necessary here b/c Julia stores column major
    # and Python/numpy/torch row major by default but torch sparse does a transpose
    # again, so we can leave it as is...
    edge_index, _ = dense_to_sparse(torch.tensor(adj, dtype=torch.float32))

    # in and out degrees for the degree labels of the link matrix
    x = torch.tensor(
        np.array([
            data["in_degrees_link"],
            data["out_degrees_link"],
        ]),
        dtype=torch.float32,
    ).t()  # shape [N,num_labels]

    # manifold like as target
    y = torch.tensor(
        np.array([
            data["manifold_like"],
        ]),
        dtype=torch.float32,
    )

    # build the final PyG data object
    tgdata = Data(
        edge_index=edge_index,
        x=x,
        y=y,
        csettype = torch.tensor(data["cset_type"]),
        atom_count=torch.tensor(adj.shape[0]),
    )

    if not tgdata.validate():
        raise ValueError(f"Data validation failed for {tgdata}.")

    return tgdata


Why the split? Because you have a choice of when to apply the data transformations. You can do it once when reading the data from disk and cache the result (then it's called `pre_transform`) or you can do it on-the-fly during training (then it's called `transform`). What you choose depends on the complexity of the transformations and the size of your dataset, but also on the workflow you are using. If you are experimenting more on the data side, it can be more useful to do it on-the-fly. If the transformations are expensive, it can be better to do it once and cache the result. 

Next, we need a set of functions that define the loss function to use to train the model, and functions that compute the monitoring metrics from the model outputs that we want to track during validation and testing. 

# define loss function


First, we define the loss function to use during training. In this case, we use the standard binary cross-entropy loss for binary classification tasks. The loss function takes the model outputs and the true labels as inputs and computes the loss value. It must be noted that the model outputs a dictionary of tensors (one per task head), so we need to extract the relevant tensor for the loss computation. 

Note that the loss function must be appropriate for the task at hand. For multi-class classification, for example, you would use a different loss function, such as cross-entropy loss. 

Here, we also added checks for not-a-number (NaN) values in the model outputs and labels to ensure that the loss computation is valid. If NaN values are detected, an error is raised to prevent invalid computations during training. 


In [None]:
def compute_loss(
    predictions: dict[Any, torch.Tensor],
    data: Data,
    *args,
    **kwargs,
) -> torch.Tensor:
    """Compute loss for a single classifier task at index 0"""

    pred = predictions[0]
    tgt = data.y

    # check for nans -> comment out for speed, this is python overhead
    if torch.isnan(pred).any():
        raise ValueError("Nan in predictions")

    if torch.isnan(tgt).any():
        raise ValueError("Nan in targets")

    return torch.nn.BCEWithLogitsLoss(
        reduction="mean",
    )(pred, tgt)



# Monitor functions

Just like the loss function computes how 'well' a model is performing during training in order to compute the gradients for optimization, monitor functions compute metrics that help us understand how well the model is performing on validation and test datasets without affecting the training process itself. These therefore use metrics that measure model performance more directly, such as accuracy, precision, recall, F1-score, etc. 

In the same way as the loss function, the monitor functions take the model outputs and true labels as inputs and compute the desired metrics. Again, we need to extract the relevant tensor from the model outputs for the computation.

Here we monitor two versions of the [f1 score](https://en.wikipedia.org/wiki/F-score), a standard metric for classification tasks. We compute the weighted f1 score, which takes into account class imbalances by weighting the contribution of each class according to its prevalence in the dataset. This is useful when dealing with imbalanced datasets where some classes are more frequent than others. This is done in the function `f1_monitor`. 


In [None]:
def f1_monitor(
    predictions: list[dict[Any, torch.Tensor]],
    targets: list[torch.Tensor],
) -> float:
    """Compute F1 score for a single classifier task (key 0)."""

    logits_list: list[torch.Tensor] = []
    for p in predictions:
        if not isinstance(p, dict):
            raise TypeError(f"Expected dict of tensors, got {type(p)}")
        if 0 in p:
            t = p[0]
        elif "0" in p:
            t = p["0"]
        else:
            t = next(iter(p.values()))
        logits_list.append(t.detach())

    pred_logits = torch.cat([t.reshape(-1) for t in logits_list])
    tgt_vec = torch.cat([t.detach().reshape(-1) for t in targets])

    pred_labels = (torch.sigmoid(pred_logits) >= 0.5).to(torch.int32).cpu().numpy()
    tgt_labels = tgt_vec.to(torch.int32).cpu().numpy()

    return f1_score(tgt_labels, pred_labels, average="weighted")



In addition, we monitor the f1 score per class, which provides insights into how well the model performs on each individual class. This is particularly useful for identifying classes that may be more challenging for the model to classify correctly. This is done in the function `f1_monitor_perclass` below. The only difference to the previous function is that we do not use the `average='weighted'` option in the sklearn f1_score function, which computes the f1 score for each class separately. 

The disadvantage of this approach is that we need to do some work twice, the assembling of the true and predicted labels into numpy arrays for instance. QuantumGrav allows you to build your own evaluator classes however so if this becomes a bottleneck you can implement a custom evaluator that computes both metrics in one pass for instance. This is covered elsewhere. 

In [None]:

def f1_monitor_perclass(
    predictions: list[dict[Any, torch.Tensor]],
    targets: list[torch.Tensor],
) -> float:
    """Compute F1 score for a single classifier task (key 0)."""

    logits_list: list[torch.Tensor] = []
    for p in predictions:
        if not isinstance(p, dict):
            raise TypeError(f"Expected dict of tensors, got {type(p)}")
        if 0 in p:
            t = p[0]
        elif "0" in p:
            t = p["0"]
        else:
            t = next(iter(p.values()))
        logits_list.append(t.detach())

    pred_logits = torch.cat([t.reshape(-1) for t in logits_list])
    tgt_vec = torch.cat([t.detach().reshape(-1) for t in targets])

    pred_labels = (torch.sigmoid(pred_logits) >= 0.5).to(torch.int32).cpu().numpy()
    tgt_labels = tgt_vec.to(torch.int32).cpu().numpy()

    return f1_score(tgt_labels, pred_labels, average=None)

# Setup trainer and datasets 

The trainer class is build from a config file that defines the model, optimizer, training epochs, dataset, evaluators and early stopping system as well as the loss function - everything we need to build a functioning training machinery. 

In [None]:
config_path = Path("../configs/train_classifier.yaml")

with open(config_path, "r") as configfile:
    config = yaml.load(configfile, QG.get_loader())


let's have a look at how that looks like.

In [None]:
pretty_print(config)


In [None]:
pretty_print(config["training"])

In [None]:
pretty_print(config["validation"]["validator"])

In [None]:
pretty_print(config["testing"]["tester"])

In [None]:
pretty_print(config["early_stopping"])

In [None]:
# set up trainer
trainer = QG.Trainer.from_config(config)

# Prepare dataset and dataloaders with weighted sampler

This time, we explicitly first prepare the datasets by calling `trainer.prepare_dataset()`. This returns the train, validation, and test datasets that we can use to build our dataloaders. We will build a weighted sampler for the training dataloader to deal with imbalanced data.

In [None]:
train_dataset, valid_dataset, test_dataset = trainer.prepare_dataset()

first, we need to build the weights to use for the sampler. We do this by computing the class weights based on the frequency of each class in the training dataset. The weights are then assigned to each sample in the dataset based on its class label. For the binary classificatoin task we have (manifold-like or not) we only have two classes, so we compute the weights for both classes and assign them accordingly.

In [None]:
# collect training labels for weighted sampler
train_labels = torch.tensor(
                [train_dataset[i].y[:, 0].long().item() for i in range(len(train_dataset))]
            )

next, we count how many we have per class and then compute the weights as the inverse of these counts. Finally, we assign the weights to each sample in the training dataset based on its label.

In [None]:
class_counts = torch.bincount(train_labels)
class_weights = 1.0 - class_counts.float() / len(train_dataset)
sample_weights = class_weights[train_labels]

finally, we set up the sampler. Note that we only use the training dataset for weighted sampling - we only need to oversample during training. 

In [None]:
weighted_sampler = torch.utils.data.WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True,
)

... and now pass it to the dataloader preparation function. We also pass in the already build train-, validation- and test datasets.  Note that the 'shuffle' parameter cannot be used when using the weighted sampler, because the sampler already defines the sampling strategy.

In [None]:
# get dataloaders
train_loader, valid_loader, test_loader = trainer.prepare_dataloaders(train_dataset = train_dataset, val_dataset=valid_dataset, test_dataset=test_dataset, training_sampler=weighted_sampler)

We have a dataset of 55000 datapoints in this case, and split this into 80% training-, 10% validation and 10% testing dataset.

In [None]:
len(train_loader.dataset)

In [None]:
len(valid_loader.dataset)

In [None]:
len(test_loader.dataset)

Have a look [at the pytorch data documentation](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.SequentialSampler) to see what other kind of samplers are available, or how to build your own if needed. 

# run training

Now, the training is run the usual way: 

In [None]:
# initialize model
trainer.initialize_model()


In [None]:
# initialize optimizer
trainer.initialize_optimizer()

In [None]:
train_results, validation_results = trainer.run_training(
    train_loader=train_loader,
    val_loader=valid_loader,
)

# test model

After training, the model is tested on an unseen dataset. The results informs us about how well the model generalizes to unseen data. 

In [None]:
test_results = trainer.run_test(test_loader)

In [None]:
test_results

testing results are somewhat mediocre so our model is not doing too well.. 