In [1]:
%load_ext autoreload
%autoreload 2

# Swyft in 15 Minutes

We discuss seven key steps of a typical Swyft workflow.

## 1. Installing Swyft

We can use `pip` to install the lightning branch (latest development branch) of Swyft.

In [2]:
#!pip install https://github.com/undark-lab/swyft.git@lightning

In [3]:
import numpy as np
from scipy import stats
import pylab as plt
import torch
import swyft

## 2. Define the Simulator

Next we define a simulator class, which specifies the computational graph of our simulator.

In [4]:
class Simulator(swyft.Simulator):
    def __init__(self):
        super().__init__()
        self.on_after_forward = swyft.to_numpy32
        self.x = np.linspace(-1, 1, 10)
         
    def forward(self, trace):
        z = trace.sample('z', lambda: np.random.rand(2)*2-1)
        f = trace.sample('f', lambda z: z[0] + z[1]*self.x, z)
        x = trace.sample('x', lambda f: f + np.random.randn(10)*0.1, f)
        
sim = Simulator()
samples = sim.sample(N = 10000)

100%|██████████| 10000/10000 [00:00<00:00, 23818.21it/s]


## 3. Define the SwyftModule

In [5]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Linear(10, 2)
        self.logratios1 = swyft.LogRatioEstimator_1dim(num_features = 2, num_params = 2, varnames = 'z')
        self.logratios2 = swyft.LogRatioEstimator_Ndim(num_features = 2, marginals = ((0, 1),), varnames = 'z')

    def forward(self, A, B):
        embedding = self.embedding(A['x'])
        logratios1 = self.logratios1(embedding, B['z'])
        logratios2 = self.logratios2(embedding, B['z'])
        return logratios1, logratios2

## 4. Train the model

In [6]:
trainer = swyft.SwyftTrainer(accelerator = 'gpu', gpus=1, max_epochs = 2)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [7]:
dl_train = samples[:-500].get_dataloader(batch_size = 64, shuffle = True)
dl_valid = samples[-500:].get_dataloader(batch_size = 64)

In [8]:
network = Network()
trainer.fit(network, dl_train, dl_valid)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                   | Params
------------------------------------------------------
0 | embedding  | Linear                 | 22    
1 | logratios1 | LogRatioEstimator_1dim | 34.9 K
2 | logratios2 | LogRatioEstimator_Ndim | 17.5 K
------------------------------------------------------
52.5 K    Trainable params
0         Non-trainable params
52.5 K    Total params
0.210     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

  rank_zero_warn(


Epoch 0:  95%|█████████▍| 149/157 [00:03<00:00, 47.84it/s, loss=-3, v_num=9927147]   
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████| 157/157 [00:03<00:00, 49.02it/s, loss=-3, v_num=9927147, val_loss=-3.15]
Epoch 1:  95%|█████████▍| 149/157 [00:02<00:00, 54.84it/s, loss=-3.13, v_num=9927147, val_loss=-3.15]
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|██████████| 157/157 [00:02<00:00, 56.37it/s, loss=-3.13, v_num=9927147, val_loss=-3.25]
Epoch 1: 100%|██████████| 157/157 [00:02<00:00, 55.28it/s, loss=-3.13, v_num=9927147, val_loss=-3.25]


## 5. Visualize training

N/A

## 6. Perform validation tests

In [9]:
trainer.test(network, dl_valid)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'hp/JS-div': -3.2495720386505127, 'hp/KL-div': -7.909391403198242}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 8/8 [00:00<00:00, 113.25it/s]


[{'hp/JS-div': -3.2495720386505127, 'hp/KL-div': -7.909391403198242}]

In [10]:
B = samples[:1000]
A = samples[:1000]
mass = trainer.estimate_mass(network, A, B)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 149it [00:00, ?it/s]



Predicting: 149it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Predicting: 100%|██████████| 1000/1000 [00:13<00:00, 62.11it/s]


AttributeError: 'list' object has no attribute 'items'

## 7. Generate posteriors

In [None]:
z0 = np.array([1.0, 1.0])
x0 = sim.sample(conditions = {"z": z0})['x']
plt.plot(x0)

In [None]:
prior_samples = sim.sample(targets = ['z'], N = 100000)

In [None]:
predictions = trainer.infer(network, swyft.Sample(x = x0), prior_samples)

In [None]:
predictions[0].parnames

In [None]:
import seaborn

for i in [0, 1]:
    v, w = predictions['z1'][i]
    clip = [v.min(), v.max()]
    seaborn.kdeplot(v, weights = w, clip = clip, shade = True)
    plt.axvline(z0[i], color='r', ls=':')
    plt.xlim([0.5, 1.2])

In [None]:
v, w = predictions['z2'][0]

In [None]:
predictions.sample(100)

In [None]:
seaborn.kdeplot(x = v[:1000,0], y = v[:1000,1], weights = w[:1000])