In [1]:
import numpy as np
import pylab as plt
import swyft
import torch
from scipy import stats
%load_ext autoreload
%autoreload 2

np.random.seed(25)
torch.manual_seed(25)

<torch._C.Generator at 0x151ede780720>

In [2]:
DEVICE = 'cuda:0'
MAX_EPOCHS = 100
EXPECTED_N = 10000

## Torus model

In [3]:
def model(params, center = np.array([0.6, 0.8])):
    a, b, c = params['a'], params['b'], params['c']
    r = sum((c-center)**2)**0.5  # Return radial distance from center
    x = np.array([a, r, c])
    return dict(x=x)

def noise(obs, params, noise = np.array([0.03, 0.005, 0.2])):
    x = obs['x']
    n = np.random.randn(*x.shape)*noise
    return dict(x = x + n)

par0 = dict(a=0.57, b=0.8, c=1.0)
obs0 = model(par0)  # Using Asimov data

In [5]:
cache = swyft.MemoryCache(params = ['a', 'b', 'c'], obs_shapes = dict(x=(3,)))

Creating new cache.


In [6]:
prior = swyft.Prior({"a": ["uniform", 0., 1.], "b": ["uniform",  0., 1.], "c": ["uniform", 0., 1.]})

## Inference

In [None]:
s = swyft.SWYFT(model, noise, prior, cache, obs0, device = DEVICE)
s.infer1d(Ninit = 500)

Simulate:   7%|▋         | 36/529 [00:00<00:01, 350.89it/s]

N = 500
Round: 1
Adding 529 new samples. Run simulator!


Simulate: 100%|██████████| 529/529 [00:01<00:00, 347.76it/s]


n_features = 3
Start training
LR iteration 0
Validation loss: 4.0702338585486775
Validation loss: 3.164757563517644
Validation loss: 2.883050698500413
Validation loss: 2.8072382486783543
Validation loss: 2.8102125387925367
Total epochs: 5
LR iteration 1
Validation loss: 2.7685034458453837
Validation loss: 2.6877523202162523
Validation loss: 2.756403299478384
Total epochs: 3


Simulate:   0%|          | 0/281 [00:00<?, ?it/s]

Volume shrinkage: 0.8966771532265209
N = 724
Round: 2
Adding 281 new samples. Run simulator!


Simulate: 100%|██████████| 281/281 [00:00<00:00, 347.16it/s]


n_features = 3
Start training
LR iteration 0
Validation loss: 4.0214254591200085
Validation loss: 3.3320137129889593
Validation loss: 3.0513672961129084
Validation loss: 3.0233454969194202
Validation loss: 2.941438297430674
Validation loss: 3.0595359007517495
Total epochs: 6
LR iteration 1
Validation loss: 2.9844327767690024


In [None]:
s.infer2d(N = 10000)

In [None]:
post = s.posteriors()
swyft.corner(post, ["a", "b", "c"], color='r', figsize = (15,15), truth=par0)