# Storing training data on disk via Zarr

First we need some imports.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import torchist
import swyft

## Training data

Now we generate training data.  As simple example, we consider the model

$$
x = z + \epsilon
$$

where the parameter $z \sim \mathcal{N}(\mu = 0, \sigma = 1)$ is standard normal distributed, and $\epsilon \sim \mathcal{N}(\mu = 0, \sigma = 0.1)$ is a small noise contribution.  We are interested in the posterior of $z$ given a measurement of parameter $x$.

In [4]:
class Simulator(swyft.Simulator):
    def __init__(self):
        super().__init__()
        self.transform_samples = swyft.to_numpy32

    def forward(self, trace):
        z = trace.sample('z', lambda: np.random.rand(1))
        x = trace.sample('x', lambda z: z + np.random.randn(1)*0.1, z)
        
sim = Simulator()
shapes, dtypes = sim.get_shapes_and_dtypes()

In [8]:
store = swyft.ZarrStore("./zarr_store")
store.init(10000, 64, shapes, dtypes)

<swyft.lightning.stores.ZarrStore at 0x14ea66fb6c40>

In [9]:
store.simulate(sim, batch_size = 1000)  # This function can be run in parallel in many threads

100%|██████████| 1000/1000 [00:00<00:00, 49457.63it/s]
100%|██████████| 1000/1000 [00:00<00:00, 47723.83it/s]
100%|██████████| 1000/1000 [00:00<00:00, 43627.95it/s]
100%|██████████| 1000/1000 [00:00<00:00, 32552.85it/s]
100%|██████████| 1000/1000 [00:00<00:00, 44971.90it/s]
100%|██████████| 1000/1000 [00:00<00:00, 50843.75it/s]
100%|██████████| 1000/1000 [00:00<00:00, 33880.50it/s]
100%|██████████| 1000/1000 [00:00<00:00, 48008.97it/s]
100%|██████████| 1000/1000 [00:00<00:00, 46181.08it/s]
100%|██████████| 1000/1000 [00:00<00:00, 40057.15it/s]


In [10]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'z')

    def forward(self, A, B):
        logratios = self.logratios(A['x'], B['z'])
        return logratios

## Trainer

Training is now done using the `SwyftTrainer` class, which extends `pytorch_lightning.Trainer` by methods like `infer` (see below).

In [12]:
trainer = swyft.SwyftTrainer(accelerator = 'gpu', devices=1, max_epochs = 2, precision = 64)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


The `swyft.Samples` class provides convenience functions to generate data loaders for training and validation data.

In [16]:
dl_train = store.get_dataloader(batch_size = 64, idx_range = [0, 9500])
dl_valid = store.get_dataloader(batch_size = 64, idx_range = [9500, 10000])

In [17]:
network = Network()

In [18]:
trainer.fit(network, dl_train, dl_valid)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | logratios | LogRatioEstimator_1dim | 17.4 K
-----------------------------------------------------
17.4 K    Trainable params
0         Non-trainable params
17.4 K    Total params
0.139     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0: : 149it [00:02, 68.00it/s, loss=-0.54, v_num=1e+7] 
Validation: 0it [00:00, ?it/s][A
Validation: 0it [00:00, ?it/s][A
Validation DataLoader 0: : 0it [00:00, ?it/s][A
Epoch 0: : 150it [00:02, 67.53it/s, loss=-0.54, v_num=1e+7]
Epoch 0: : 151it [00:02, 67.65it/s, loss=-0.54, v_num=1e+7]
Epoch 0: : 152it [00:02, 67.82it/s, loss=-0.54, v_num=1e+7]
Epoch 0: : 153it [00:02, 67.98it/s, loss=-0.54, v_num=1e+7]
Epoch 0: : 154it [00:02, 68.06it/s, loss=-0.54, v_num=1e+7]
Epoch 0: : 155it [00:02, 68.21it/s, loss=-0.54, v_num=1e+7]
Epoch 0: : 156it [00:02, 68.39it/s, loss=-0.54, v_num=1e+7, val_loss=-.610]
Epoch 1: : 149it [00:02, 69.88it/s, loss=-0.549, v_num=1e+7, val_loss=-.610]
Validation: 0it [00:00, ?it/s][A
Validation: 0it [00:00, ?it/s][A
Validation DataLoader 0: : 0it [00:00, ?it/s][A
Epoch 1: : 150it [00:02, 69.59it/s, loss=-0.549, v_num=1e+7, val_loss=-.610]
Epoch 1: : 151it [00:02, 69.77it/s, loss=-0.549, v_num=1e+7, val_loss=-.610]
Epoch 1: : 152it [00:02, 69.95it/s, lo

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: : 156it [00:03, 45.35it/s, loss=-0.549, v_num=1e+7, val_loss=-.561]
