In [None]:
import numpy as np
import pylab as plt
import swyft
import torch
from scipy import stats
%load_ext autoreload
%autoreload 2

np.random.seed(25)
torch.manual_seed(25)

In [None]:
DEVICE = 'cuda:0'
NSAMPLES = 100000
MAX_EPOCHS = 100

In [None]:
def model(z):
    grid = np.linspace(-1, 1, 32, 32)
    X, Y = np.meshgrid(grid, grid)
    mu = np.zeros_like(X)
    for i in range(len(z)-2):
        mu += z[i] * np.cos(X*z[i+1]*i + z[i+2]*2) * np.cos(Y*z[i+2]*i + z[i+1]*2)        
    return mu

def noisemodel(x, z = None, sigma = 0.5):
    n = np.random.randn(*x.shape)*sigma
    return x + n

In [None]:
# Convolutional network as HEAD of inference network

class Head(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, 3)
        self.conv2 = torch.nn.Conv2d(10, 20, 3)
        self.conv3 = torch.nn.Conv2d(20, 40, 3)
        self.pool = torch.nn.MaxPool2d(2)
        
    def forward(self, x):
        nbatch = len(x)
        
        x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = x.view(nbatch, -1)

        return x

head = Head().to(DEVICE)

In [None]:
z0 = np.random.rand(20)
zdim = len(z0)
x0 = noisemodel(model(z0))
comb1d = [[i] for i in range(zdim)]

plt.imshow(x0)
plt.colorbar()

In [None]:
cache = swyft.MemoryCache(zdim = zdim, xshape = (32, 32))

In [None]:
intensity = swyft.get_unit_intensity(expected_n=NSAMPLES, dim=zdim)
cache.grow(intensity)
cache.simulate(model)

points = swyft.Points(cache, intensity, noisemodel)
re = swyft.RatioEstimator(points, combinations=comb1d, head=head, device=DEVICE)
re.train(max_epochs=MAX_EPOCHS, batch_size=32, lr_schedule=[2e-3, 1e-3, 5e-4, 2.5e-4, 1.25e-4])

In [None]:
swyft.plot1d(re, x0 = x0, z0 = z0, cmap = 'Greys', dims = (20, 10), ncol = 5)