## K - Logging with WandB

Authors: Christoph Weniger

Last update: 18 September 2023

In [None]:
import numpy as np
import pylab as plt
import torch
import wandb
from pytorch_lightning.loggers import WandbLogger

import swyft
DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

In [None]:
torch.manual_seed(0)
np.random.seed(0)

In [None]:
N = 10_000  # Number of samples
z = np.random.rand(N, 1)*2-1  # Uniform prior over [-1, 1]
x = z + np.random.randn(N, 1)*0.2

In [None]:
samples = swyft.Samples(x = x, z = z)

In [None]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'z', num_blocks = 4)

    def forward(self, A, B):
        return self.logratios(A['x'], B['z'])

In [None]:
for i in range(5):
    wandb.init(reinit = True, project = '0K-WandB', group = 'experiment_1', name = 'round%i'%i)
    wandb_logger = WandbLogger(log_model='all')
    trainer = swyft.SwyftTrainer(accelerator = DEVICE, precision = 64, logger = wandb_logger, max_epochs = -1)

    dm = swyft.SwyftDataModule(samples, batch_size=128)
    network = Network()
    trainer.fit(network, dm)

    x0 = 0.0
    obs = swyft.Sample(x = np.array([x0]))
    prior_samples = swyft.Samples(z = np.random.rand(30_000, 1)*2-1)
    predictions = trainer.infer(network, obs, prior_samples)
    fig = swyft.plot_posterior(predictions, 'z[0]', smooth = 10, smooth_prior = True)
    for offset in [-0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6]:
        plt.axvline(x0+offset, color='g', ls = ':')
    plt.axvline(x0)
    wandb.log({"z[0]": wandb.Image(fig)})

    wandb.finish()