# How to track experiments

Experiment tracking helps compare model variants and keep a record of hyperparameters and training metrics. By default, `sbi` logs to TensorBoard. You can also bring your own tracker by implementing the lightweight `Tracker` protocol and passing it as `tracker=...`.

If using your own tracker (e.g., `wandb`, `mlflow` or `trackio`), note that the run lifecycle (e.g., `wandb.init`, `mlflow.start_run`) is handled on the user side.

## Define a minimal training setup

In [None]:
import torch

from sbi.inference import NPE
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.utils import BoxUniform

torch.manual_seed(0)

def simulator(theta):
    return theta + 0.1 * torch.randn_like(theta)

prior = BoxUniform(low=-2 * torch.ones(2), high=2 * torch.ones(2))

theta = prior.sample((5000,))
x = simulator(theta)

embedding_net = FCEmbedding(input_dim=x.shape[1], output_dim=32)
density_estimator = posterior_nn(
    model="nsf",
    embedding_net=embedding_net,
    num_transforms=4,
)

train_kwargs = dict(
    max_num_epochs=50,
    training_batch_size=128,
    validation_fraction=0.1,
    show_train_summary=False,
)

## Train with a tracker

By default, `sbi` uses a TensorBoard tracker to log training loss, validation loss,
number of epochs and more. 

When you want to track additional quantities, you instantiate the tracker yourself and
pass it to the inference class:

In [None]:
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.utils.tracking import TensorBoardTracker

tracker = TensorBoardTracker(SummaryWriter("sbi-logs"))
tracker.log_params({"embedding_dim": 32, "num_transforms": 4})

inference = NPE(prior=prior, tracker=tracker)
inference.append_simulations(theta, x)
estimator = inference.train(**train_kwargs)
posterior = inference.build_posterior(estimator)

## View TensorBoard results

You can then view your tracked run(s) on a TensorBoard shown on your localhost in the
browser. By default, `sbi` will create a log directory `sbi-logs` at the location the
training script was called.

```bash
tensorboard --logdir=sbi-logs
```

## Using other trackers

To enable usage of other trackers, we provide a lightweight `Protocol` that trackers
need to follow. You can implement a small adapter that satisfies the `Tracker` protocol
and pass it to `tracker=`. Below are minimal examples for common tools.

```python
# W&B adapter (requires `wandb.init()` before training)
class WandBAdapter:
    log_dir = None

    def __init__(self, run):
        self._run = run

    def log_metric(self, name, value, step=None):
        self._run.log({name: value}, step=step)

    def log_metrics(self, metrics, step=None):
        self._run.log(metrics, step=step)

    def log_params(self, params):
        self._run.config.update(params)

    def add_figure(self, name, figure, step=None):
        import wandb
        self._run.log({name: wandb.Image(figure)}, step=step)

    def flush(self):
        pass
```

```python
# MLflow adapter (configure tracking URI separately)
class MLflowAdapter:
    log_dir = None

    def __init__(self, mlflow):
        self._mlflow = mlflow

    def log_metric(self, name, value, step=None):
        self._mlflow.log_metric(name, value, step=step)

    def log_metrics(self, metrics, step=None):
        for name, value in metrics.items():
            self.log_metric(name, value, step=step)

    def log_params(self, params):
        self._mlflow.log_params(params)

    def add_figure(self, name, figure, step=None):
        self._mlflow.log_figure(figure, f"{name}.png")

    def flush(self):
        pass
```

```python
# Trackio adapter (requires `trackio.init()` before training)
class TrackioAdapter:
    log_dir = None

    def __init__(self, trackio):
        self._trackio = trackio

    def log_metric(self, name, value, step=None):
        self._trackio.log({name: value}, step=step)

    def log_metrics(self, metrics, step=None):
        self._trackio.log(metrics, step=step)

    def log_params(self, params):
        self._trackio.log(params)

    def add_figure(self, name, figure, step=None):
        self._trackio.log_image(figure, name=name, step=step)

    def flush(self):
        pass
```

When using external trackers, create an adapter instance and pass it to `tracker=`:

```python
# wandb.init(...)
tracker = WandBAdapter(wandb.run)
inference = NPE(prior=prior, density_estimator=density_estimator, tracker=tracker)
```

## Log figures

Trackers can also store matplotlib figures. For example, after training you can log a pairplot:

```python
from sbi.analysis import pairplot

x_o = x[:1]
samples = posterior.sample((1000,), x=x_o)
fig, _ = pairplot(samples)
tracker.add_figure("posterior_pairplot", fig, step=0)
```

Figure logging depends on the tracker implementation (e.g., `wandb.Image`, `mlflow.log_figure`).

## Custom training loop (optional)

If you want to log custom diagnostics per epoch, use the training interface tutorial: https://sbi.readthedocs.io/en/latest/advanced_tutorials/18_training_interface.html.

## Notes

- Each tool supports richer logging (artifacts, checkpoints, plots), but the patterns above are enough to track hyperparameters, epoch-wise losses, and validation metrics.
- If you already use Optuna or other sweep tools, you can call the logger inside the objective function to log each trial.