# Plotting Tensorboard Summaries

`sbi` allows logging training progress via Tensorboard. You can inspect these logs using `tensorboard --logdir=...` in your terminal, or you can use `sbi.analysis.plot_summary` to plot the logged metrics directly in your notebook or python script.

This guide shows how to use `plot_summary` to visualize training and validation losses, and other metrics.


In [None]:
import torch
import shutil
import numpy as np
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from sbi.analysis import plot_summary
import matplotlib.pyplot as plt

# Create dummy tensorboard logs for demonstration
log_dir = "dummy_logs_guide"
if Path(log_dir).exists():
    shutil.rmtree(log_dir)

writer = SummaryWriter(log_dir)
for i in range(100):
    # Simulate training loss decreasing
    writer.add_scalar("training_loss", np.exp(-0.05 * i) + np.random.normal(0, 0.01), i)
    # Simulate validation loss decreasing but slightly higher
    writer.add_scalar("validation_loss", np.exp(-0.05 * i) + 0.1 + np.random.normal(0, 0.01), i)
    # Simulate some other metric increasing
    writer.add_scalar("accuracy", 1 - np.exp(-0.05 * i) + np.random.normal(0, 0.01), i)
writer.close()


## Basic Usage

By default, `plot_summary` plots a single metric (defaulting to "validation_loss" if not specified). You can specify a list of tags to plot multiple metrics in separate subplots.


In [None]:
# Plot training_loss and validation_loss in separate subplots
fig, axes = plot_summary(Path(log_dir), tags=["training_loss", "validation_loss"])
plt.show()


## Grouping Tags

You can plot multiple tags on the same axis by providing a list of tags as an element in the `tags` list. This is useful for comparing metrics, such as training and validation loss.


In [None]:
# Plot training_loss and validation_loss on the same subplot, and accuracy on another
fig, axes = plot_summary(
    Path(log_dir), 
    tags=[["training_loss", "validation_loss"], "accuracy"]
)
plt.show()


## Customization

You can customize the figure size, font size, and add titles.


In [None]:
fig, axes = plot_summary(
    Path(log_dir), 
    tags=[["training_loss", "validation_loss"], "accuracy"],
    figsize=(12, 5),
    fontsize=14,
    ylabel=["Loss", "Accuracy"],
    title="Training Summary",
    titles=["Losses", "Accuracy Metric"]
)
plt.show()


In [None]:
# Cleanup
if Path(log_dir).exists():
    shutil.rmtree(log_dir)
