In [1]:
import swyft
import numpy as np
from swyft.utils.simulator import Simulator
from dask.distributed import LocalCluster
import torch

In [2]:
prior = swyft.Prior({"x1": ['uniform', 0., 1.], "y1": ['uniform', 0., 1.],"r": ['uniform', 0., 1.]})

In [3]:
def model(params, w = 0.03):
        "Ring simulator"
        x1, y1, r = params['x1'], params['y1'], params['r']*0.5
        
        grid = np.linspace(0, 1, 32, 32)
        X, Y = np.meshgrid(grid, grid)
        
        R1 = ((X-x1)**2 + (Y-y1)**2)**0.5
        mu = np.exp(-(R1-r)**2/w**2/2)
        
        # Add random distortions in terms of lines
        for _ in range(20):
            xr = np.random.rand(2)
            mu += 0.8*np.exp(-(X*xr[0]+Y*(1-xr[0])-xr[1])**2/0.01**2)
            
        return dict(x=np.array(mu))

def noise(obs, params = None, sigma=0.1):
    mu = obs['x']
    n = np.random.randn(*mu.shape)*sigma
    x = mu + n
    return dict(x=x)


In [4]:
z0 = np.array([0.4, 0.5, 0.6])
param0 = {'x1': 0.4, 'y1': 0.5, 'r': 0.6}
mu0 = model(param0)
obs0 = noise(mu0)

In [5]:
from swyft.nn import OnlineNormalizationLayer
from swyft.nn.module import Module

class Head(Module):
    def __init__(self, obs_shapes, online_norm = True, obs_transform = None):
        super().__init__(obs_shapes=obs_shapes, obs_transform=obs_transform,
                online_norm=online_norm)
        #print(obs_shapes)
        self.obs_transform = obs_transform

        self.n_features = 160 #should do automatically

        if online_norm:
            self.onl_f = OnlineNormalizationLayer(torch.Size([32,32]))
        else:
            self.onl_f = lambda f: f
        self.conv1 = torch.nn.Conv2d(1, 10, 3)
        self.conv2 = torch.nn.Conv2d(10, 20, 3)
        self.conv3 = torch.nn.Conv2d(20, 40, 3)
        self.pool = torch.nn.MaxPool2d(2)

    def forward(self, obs):
        if self.obs_transform is not None:
            obs = self.obs_transform(obs)
        '''x = []
        for key, value in sorted(obs.items()):
            x.append(value)
        x = torch.cat(x, dim = -1)
        x = self.onl_f(x)'''
        x=obs['x']
        x = self.onl_f(x)
        #print(x.shape)

        nbatch = len(x)

        x = x.unsqueeze(1)
        #print(x.shape)
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = x.view(nbatch, -1)
        #print('made forward pass!')
        #print(x.shape)

        return x

In [6]:
simulator = Simulator(model)

In [7]:
s = swyft.NestedRatios(simulator, prior, noise = noise, obs = obs0)
s.run(max_rounds = 2, train_args = dict(lr_schedule = [1e-3, 1e-4]), head=Head)

Creating new cache.
  adding 2945 new samples to simulator cache.


  return torch._C._cuda_getDeviceCount() > 0


  adding 3100 new samples to simulator cache.
