In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pylab as plt
import numpy as np
import torch
import swyft

In [3]:
class Simulator(swyft.Simulator):
    def __init__(self, bounds = None):
        super().__init__()
        self.on_after_forward = swyft.to_numpy32
        
    def forward(self, trace):
        z = trace.sample('z', np.random.rand, 3)
        x = trace.sample('x', lambda z: z + np.random.randn(3)*0.02, z)

In [4]:
class Network(swyft.SwyftModule):
    def __init__(self, dropout = 0.1, lr = 1e-4):
        super().__init__()
        self.classifier = swyft.RatioEstimatorMLP1d(3, 3, hidden_features = 256, dropout = self.hparams.dropout)
        
    def forward(self, x, z):
        x = x['x']
        z = z['z']
        ratios_z = self.classifier(x, z)
        return dict(z = ratios_z)

In [5]:
simulator = Simulator()
samples = simulator(1000)

100%|██████████| 1000/1000 [00:00<00:00, 46759.76it/s]


In [6]:
dl_train = samples[:800].get_dataloader(batch_size = 128, shuffle = True)
dl_valid = samples[800:900].get_dataloader(batch_size = 128)
dl_test = samples[900:].get_dataloader(batch_size = 128)

In [7]:
for lr in [1e-1, 1e-2, 1e-3, 1e-4]:
    network = Network(dropout = 0.2, lr = lr)
    trainer = swyft.SwyftTrainer(accelerator = 'gpu', gpus=1, max_epochs = 100, **swyft.tensorboard_config(save_dir = './lightning_logs2', name = '01-minimal-hparams', version=None))
    trainer.fit(network, dl_train, dl_valid)
    trainer.test(network, dl_test, ckpt_path = 'best')

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | RatioEstimatorMLP1d | 800 K 
---------------------------------------------------
800 K     Trainable params
0         Non-trainable params
800 K     Total params
3.201     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  88%|████████▊ | 7/8 [00:00<00:00, 63.72it/s, loss=322, v_num=4]    
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 8/8 [00:00<00:00, 63.24it/s, loss=322, v_num=4, val_loss=556.0]
Epoch 1:  88%|████████▊ | 7/8 [00:00<00:00, 73.62it/s, loss=204, v_num=4, val_loss=556.0]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 8/8 [00:00<00:00, 73.71it/s, loss=204, v_num=4, val_loss=47.10]
Epoch 2:  88%|████████▊ | 7/8 [00:00<00:00, 73.56it/s, loss=147, v_num=4, val_loss=47.10]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 70.56it/s, loss=147, v_num=4, val_loss=13.30]
Epoch 3:  88%|████████▊ | 7/8 [00:00<00:00, 65.71it/s, loss=24.4, v_num=4, val_loss=13.30]
Validating: 0it [00:00, ?it/s][A
Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 67.37it/s, loss=24.4, v_num=4, val_loss=4.350]
Epoch 4:  88%|████████▊ | 7/8 [00:00<00:00, 73.04it/s, loss=4.41, v_num=4, val_loss=4.350]
Validating: 0it [00:00, ?it/s][A
Epoch 4: 100%|██████████| 8/8

Restoring states from the checkpoint path at ./lightning_logs2/01-minimal-hparams/version_4/checkpoints/epoch=19-step=139.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs2/01-minimal-hparams/version_4/checkpoints/epoch=19-step=139.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': -2.5854907035827637, 'hp/KL-div': -4.858042240142822}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 1/1 [00:00<00:00, 68.19it/s]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | RatioEstimatorMLP1d | 800 K 
---------------------------------------------------
800 K     Trainable params
0         Non-trainable params
800 K     Total params
3.201     Total estimated model params size (MB)


Epoch 0:  88%|████████▊ | 7/8 [00:00<00:00, 81.69it/s, loss=2.41, v_num=5]   
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 8/8 [00:00<00:00, 81.17it/s, loss=2.41, v_num=5, val_loss=-.0119]
Epoch 1:  88%|████████▊ | 7/8 [00:00<00:00, 81.69it/s, loss=0.712, v_num=5, val_loss=-.0119]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 8/8 [00:00<00:00, 78.06it/s, loss=0.712, v_num=5, val_loss=-.609] 
Epoch 2:  88%|████████▊ | 7/8 [00:00<00:00, 82.63it/s, loss=-0.286, v_num=5, val_loss=-.609] 
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 82.84it/s, loss=-0.286, v_num=5, val_loss=-.741]
Epoch 3:  88%|████████▊ | 7/8 [00:00<00:00, 75.98it/s, loss=-2.07, v_num=5, val_loss=-.741] 
Validating: 0it [00:00, ?it/s][A
Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 75.70it/s, loss=-2.07, v_num=5, val_loss=-1.22]
Epoch 4:  88%|████████▊ | 7/8 [00:00<00:00, 68.76it/s, loss=-2.53, v_num=5, val_loss=-1.22]
Validating: 0it [00:00, ?it/s][A
Epoch 4: 1

Restoring states from the checkpoint path at ./lightning_logs2/01-minimal-hparams/version_5/checkpoints/epoch=5-step=41.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs2/01-minimal-hparams/version_5/checkpoints/epoch=5-step=41.ckpt


Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': -1.8838624954223633, 'hp/KL-div': -5.458173751831055}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 1/1 [00:00<00:00, 68.60it/s]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | RatioEstimatorMLP1d | 800 K 
---------------------------------------------------
800 K     Trainable params
0         Non-trainable params
800 K     Total params
3.201     Total estimated model params size (MB)


Epoch 0:  88%|████████▊ | 7/8 [00:00<00:00, 82.57it/s, loss=-0.983, v_num=6] 
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 8/8 [00:00<00:00, 81.99it/s, loss=-0.983, v_num=6, val_loss=-.0083]
Epoch 1:  88%|████████▊ | 7/8 [00:00<00:00, 80.39it/s, loss=-1.58, v_num=6, val_loss=-.0083] 
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 8/8 [00:00<00:00, 79.67it/s, loss=-1.58, v_num=6, val_loss=-.0196]
Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 78.74it/s, loss=-2.03, v_num=6, val_loss=-.0196]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 71.44it/s, loss=-2.03, v_num=6, val_loss=-.131] 
Epoch 3:  88%|████████▊ | 7/8 [00:00<00:00, 74.50it/s, loss=-2.52, v_num=6, val_loss=-.131]
Validating: 0it [00:00, ?it/s][A
Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 74.87it/s, loss=-2.52, v_num=6, val_loss=-1.30]
Epoch 4:  88%|████████▊ | 7/8 [00:00<00:00, 75.06it/s, loss=-2.69, v_num=6, val_loss=-1.30]
Validating: 0it [00:00, ?it/s][A
Epoch 4: 

Restoring states from the checkpoint path at ./lightning_logs2/01-minimal-hparams/version_6/checkpoints/epoch=10-step=76.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs2/01-minimal-hparams/version_6/checkpoints/epoch=10-step=76.ckpt


Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': -2.9667282104492188, 'hp/KL-div': -8.11322021484375}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 1/1 [00:00<00:00, 73.56it/s]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | RatioEstimatorMLP1d | 800 K 
---------------------------------------------------
800 K     Trainable params
0         Non-trainable params
800 K     Total params
3.201     Total estimated model params size (MB)


Epoch 0:  88%|████████▊ | 7/8 [00:00<00:00, 82.19it/s, loss=-0.146, v_num=7]  
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 8/8 [00:00<00:00, 81.77it/s, loss=-0.146, v_num=7, val_loss=0.000149]
Epoch 1: 100%|██████████| 8/8 [00:00<00:00, 77.30it/s, loss=-0.544, v_num=7, val_loss=0.000149]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 8/8 [00:00<00:00, 66.94it/s, loss=-0.544, v_num=7, val_loss=-.00424] 
Epoch 2:  88%|████████▊ | 7/8 [00:00<00:00, 77.41it/s, loss=-0.951, v_num=7, val_loss=-.00424]
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|██████████| 8/8 [00:00<00:00, 77.25it/s, loss=-0.951, v_num=7, val_loss=-.0221] 
Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 79.00it/s, loss=-1.65, v_num=7, val_loss=-.0221] 
Validating: 0it [00:00, ?it/s][A
Epoch 3: 100%|██████████| 8/8 [00:00<00:00, 71.77it/s, loss=-1.65, v_num=7, val_loss=-.0853]
Epoch 4: 100%|██████████| 8/8 [00:00<00:00, 79.22it/s, loss=-2.19, v_num=7, val_loss=-.0853]
Validating: 0it [00:00, ?it

Restoring states from the checkpoint path at ./lightning_logs2/01-minimal-hparams/version_7/checkpoints/epoch=11-step=83.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ./lightning_logs2/01-minimal-hparams/version_7/checkpoints/epoch=11-step=83.ckpt


Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': -3.0364508628845215, 'hp/KL-div': -6.616926670074463}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 1/1 [00:00<00:00, 53.05it/s]
