In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pylab as plt
import numpy as np
import torch
import swyft.lightning as sl

In [3]:
class Simulator(sl.Simulator):
    def __init__(self, bounds = None):
        super().__init__()
        self.on_after_forward = sl.to_numpy32
        
    def forward(self, trace):
        z = trace.sample('z', np.random.rand, 3)
        x = trace.sample('x', lambda z: z + np.random.randn(3)*0.02, z)

In [4]:
class Network(sl.SwyftModule):
    def __init__(self, dropout = 0.1, lr = 1e-4):
        super().__init__()
        self.classifier = sl.RatioEstimatorMLP1d(3, 3, hidden_features = 256, dropout = self.hparams.dropout)
        
    def forward(self, x, z):
        x = x['x']
        z = z['z']
        ratios_z = self.classifier(x, z)
        return dict(z = ratios_z)

In [5]:
simulator = Simulator()
samples = simulator.sample(1000).to_numpy()

100%|██████████| 1000/1000 [00:00<00:00, 31347.56it/s]


In [6]:
datamodule = sl.SwyftDataModule(store = samples, batch_size = 128)



In [7]:
for lr in [1e-1, 1e-2, 1e-3, 1e-4]:
    network = Network(dropout = 0.2, lr = lr)
    trainer = sl.SwyftTrainer(accelerator = 'gpu', gpus=1, max_epochs = 100, **sl.tensorboard_config(save_dir = './lightning_logs2', name = '01-minimal-hparams', version=None))
    trainer.fit(network, datamodule)
    trainer.test(network, datamodule, ckpt_path = 'best')

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | RatioEstimatorMLP1d | 800 K 
---------------------------------------------------
800 K     Trainable params
0         Non-trainable params
800 K     Total params
3.201     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  78%|███████▊  | 7/9 [00:00<00:00, 64.15it/s, loss=247, v_num=8] 
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 9/9 [00:00<00:00, 69.70it/s, loss=247, v_num=8, val_loss=1.12e+3]
Epoch 1:  78%|███████▊  | 7/9 [00:00<00:00, 74.74it/s, loss=149, v_num=8, val_loss=1.12e+3]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 9/9 [00:00<00:00, 79.91it/s, loss=149, v_num=8, val_loss=62.60]  
Epoch 2:  78%|███████▊  | 7/9 [00:00<00:00, 74.24it/s, loss=108, v_num=8, val_loss=62.60]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 9/9 [00:00<00:00, 80.83it/s, loss=108, v_num=8, val_loss=10.10]
Epoch 3:  78%|███████▊  | 7/9 [00:00<00:00, 70.26it/s, loss=17.4, v_num=8, val_loss=10.10]
Validating: 0it [00:00, ?it/s][A
Epoch 3: 100%|██████████| 9/9 [00:00<00:00, 77.52it/s, loss=17.4, v_num=8, val_loss=4.590]
Epoch 4:  78%|███████▊  | 7/9 [00:00<00:00, 75.52it/s, loss=5.6, v_num=8, val_loss=4.590] 
Validating: 0it [00:00, ?it/s][A
Epoch 4: 100%|██████████| 

Restoring states from the checkpoint path at ./lightning_logs2/01-minimal-hparams/version_8/checkpoints/epoch=18-step=132.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs2/01-minimal-hparams/version_8/checkpoints/epoch=18-step=132.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': 1.2884330749511719, 'hp/KL-div': -6.22100305557251}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00, 129.09it/s]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | RatioEstimatorMLP1d | 800 K 
---------------------------------------------------
800 K     Trainable params
0         Non-trainable params
800 K     Total params
3.201     Total estimated model params size (MB)


Epoch 0:  78%|███████▊  | 7/9 [00:00<00:00, 80.74it/s, loss=6.54, v_num=9]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 9/9 [00:00<00:00, 85.78it/s, loss=6.54, v_num=9, val_loss=3.950]
Epoch 1:  78%|███████▊  | 7/9 [00:00<00:00, 68.80it/s, loss=4.71, v_num=9, val_loss=3.950]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 9/9 [00:00<00:00, 76.26it/s, loss=4.71, v_num=9, val_loss=3.630]
Epoch 2:  89%|████████▉ | 8/9 [00:00<00:00, 78.29it/s, loss=3.74, v_num=9, val_loss=3.630]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 9/9 [00:00<00:00, 72.23it/s, loss=3.74, v_num=9, val_loss=2.960]
Epoch 3:  78%|███████▊  | 7/9 [00:00<00:00, 73.77it/s, loss=2.03, v_num=9, val_loss=2.960]
Validating: 0it [00:00, ?it/s][A
Epoch 3: 100%|██████████| 9/9 [00:00<00:00, 81.28it/s, loss=2.03, v_num=9, val_loss=2.720]
Epoch 4:  78%|███████▊  | 7/9 [00:00<00:00, 74.69it/s, loss=1.54, v_num=9, val_loss=2.720]
Validating: 0it [00:00, ?it/s][A
Epoch 4: 100%|██████████| 9

  rank_zero_deprecation(
Restoring states from the checkpoint path at ./lightning_logs2/01-minimal-hparams/version_9/checkpoints/epoch=13-step=97.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs2/01-minimal-hparams/version_9/checkpoints/epoch=13-step=97.ckpt


Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': 1.2166918516159058, 'hp/KL-div': -8.114272117614746}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00, 135.49it/s]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | RatioEstimatorMLP1d | 800 K 
---------------------------------------------------
800 K     Trainable params
0         Non-trainable params
800 K     Total params
3.201     Total estimated model params size (MB)


Epoch 0:  78%|███████▊  | 7/9 [00:00<00:00, 82.66it/s, loss=3.28, v_num=10]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 9/9 [00:00<00:00, 89.19it/s, loss=3.28, v_num=10, val_loss=4.150]
Epoch 1:  78%|███████▊  | 7/9 [00:00<00:00, 62.95it/s, loss=2.63, v_num=10, val_loss=4.150]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 9/9 [00:00<00:00, 70.08it/s, loss=2.63, v_num=10, val_loss=4.170]
Epoch 2:  78%|███████▊  | 7/9 [00:00<00:00, 78.46it/s, loss=2.12, v_num=10, val_loss=4.170]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 9/9 [00:00<00:00, 85.01it/s, loss=2.12, v_num=10, val_loss=4.220]
Epoch 3:  78%|███████▊  | 7/9 [00:00<00:00, 82.89it/s, loss=1.55, v_num=10, val_loss=4.220]
Validating: 0it [00:00, ?it/s][A
Epoch 3: 100%|██████████| 9/9 [00:00<00:00, 89.37it/s, loss=1.55, v_num=10, val_loss=3.250]
Epoch 4:  78%|███████▊  | 7/9 [00:00<00:00, 75.97it/s, loss=1.36, v_num=10, val_loss=3.250]
Validating: 0it [00:00, ?it/s][A
Epoch 4: 100%|████

Restoring states from the checkpoint path at ./lightning_logs2/01-minimal-hparams/version_10/checkpoints/epoch=12-step=90.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs2/01-minimal-hparams/version_10/checkpoints/epoch=12-step=90.ckpt


Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': 1.1405411958694458, 'hp/KL-div': -7.927431583404541}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00, 135.54it/s]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | RatioEstimatorMLP1d | 800 K 
---------------------------------------------------
800 K     Trainable params
0         Non-trainable params
800 K     Total params
3.201     Total estimated model params size (MB)


Epoch 0:  78%|███████▊  | 7/9 [00:00<00:00, 82.09it/s, loss=4.02, v_num=11]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 9/9 [00:00<00:00, 88.12it/s, loss=4.02, v_num=11, val_loss=4.160]
Epoch 1:  89%|████████▉ | 8/9 [00:00<00:00, 79.87it/s, loss=3.63, v_num=11, val_loss=4.160]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 9/9 [00:00<00:00, 77.76it/s, loss=3.63, v_num=11, val_loss=4.160]
Epoch 2:  78%|███████▊  | 7/9 [00:00<00:00, 78.13it/s, loss=3.21, v_num=11, val_loss=4.160]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 9/9 [00:00<00:00, 84.08it/s, loss=3.21, v_num=11, val_loss=4.140]
Epoch 3:  78%|███████▊  | 7/9 [00:00<00:00, 69.31it/s, loss=2.52, v_num=11, val_loss=4.140]
Validating: 0it [00:00, ?it/s][A
Epoch 3: 100%|██████████| 9/9 [00:00<00:00, 76.00it/s, loss=2.52, v_num=11, val_loss=4.090]
Epoch 4:  78%|███████▊  | 7/9 [00:00<00:00, 68.61it/s, loss=1.98, v_num=11, val_loss=4.090]
Validating: 0it [00:00, ?it/s][A
Epoch 4: 100%|████

Restoring states from the checkpoint path at ./lightning_logs2/01-minimal-hparams/version_11/checkpoints/epoch=10-step=76.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs2/01-minimal-hparams/version_11/checkpoints/epoch=10-step=76.ckpt


Testing:  88%|████████▊ | 7/8 [00:00<00:00, 64.67it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': 1.1909857988357544, 'hp/KL-div': -7.024380683898926}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00, 61.59it/s]
