In [292]:
import numpy as np
import pylab as plt
import torch

import swyft
DEVICE = 'gpu' if torch.cuda.is_available() else 'cpu'

In [293]:
torch.manual_seed(0)
np.random.seed(0)

In [309]:
N = 10_000  # Number of samples
z = np.random.rand(N, 1)*2-1  # Uniform prior over [-5, 1]
z = np.random.randn(N, 1)*0.1
x = z + np.random.randn(N, 1)*0.1

In [310]:
samples = swyft.Samples(x = x, z = z)

In [331]:
class Network(swyft.AdamWReduceLROnPlateau, swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.learning_rate = 1e-3
        self.early_stopping_patience = 100
        self.logratios = swyft.LogRatioEstimator_1dim(num_features = 1, num_params = 1, varnames = 'z', num_blocks = 4)
        self.Q = torch.nn.Parameter(torch.zeros(1))
        self.c = torch.nn.Linear(1, 1)

    def forward(self, A, B):
        x = A['x'][:64]
        z = B['z'][:64]
        s = self.c(x)
        q = torch.exp(self.Q)
        n = torch.randn(1)#/q.detach()**0.5
        loss = 0.5*((q+0.1**-2)*(z-s))**2*n**2 - n**2*(q+0.1**-2)
        loss = loss.reshape(-1)
        return swyft.AuxLoss(loss, 'x')

In [332]:
trainer = swyft.SwyftTrainer(accelerator = DEVICE, precision = 64, max_epochs = 30)
dm = swyft.SwyftDataModule(samples, batch_size=128)
network = Network()
print(torch.exp(network.Q))
trainer.fit(network, dm)
print(1/torch.exp(network.Q)**0.5)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint

  | Name      | Type                   | Params
-----------------------------------------------------
0 | logratios | LogRatioEstimator_1dim | 34.6 K
1 | c         | Linear                 | 2     
-----------------------------------------------------
34.6 K    Trainable params
0         Non-trainable params
34.6 K    Total params
0.277     Total estimated model params size (MB)


tensor([1.], grad_fn=<ExpBackward0>)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


Reloading best model: /Users/cweniger/Documents/swyft/notebooks/dev/lightning_logs/version_95/checkpoints/epoch=28-step=1827.ckpt
tensor([1.1861], dtype=torch.float64, grad_fn=<MulBackward0>)


In [333]:
print(1/torch.exp(network.Q)**0.5)

tensor([1.1861], dtype=torch.float64, grad_fn=<MulBackward0>)
