Skip to content

Commit

Permalink
Add fit scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
cweniger committed Feb 20, 2021
1 parent 5849e44 commit d10662c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 61 deletions.
32 changes: 0 additions & 32 deletions dev/heads.py

This file was deleted.

36 changes: 8 additions & 28 deletions dev/fit.py → scripts/fit.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,29 @@
#!/usr/bin/env python3

import logging

logging.basicConfig(level=logging.DEBUG, format="%(message)s")

import heads
import numpy as np
import pylab as plt
import simulators
import torch
from omegaconf import OmegaConf

import swyft
from swyft_model import noise, prior, model, par0, obs0, CustomHead

DEVICE = "cuda:0"
CACHE_PATH = "cache_FermiV1.zarr"


def noise(obs, params=None, sigma=1.0):
data = {k: v + np.random.randn(*v.shape) * sigma for k, v in obs.items()}
return data


def main():
root = conf_cli = OmegaConf.from_cli().root
conf = OmegaConf.load(root + ".yaml")

# Define model and prior
prior = simulators.prior_FermiV1
model = simulators.model_FermiV1

# Target observation
par0 = dict(ox=5.0, oy=5.0, a=1.5, p1=0.4, p2=1.1)
obs0 = noise(model(par0))
conf = OmegaConf.load("swyft_config.yaml")

params = par0.keys()
obs_shapes = {k: v.shape for k, v in obs0.items()}

cache = swyft.DirectoryCache(params, obs_shapes=obs_shapes, path=CACHE_PATH)
cache = swyft.DirectoryCache(params, obs_shapes=obs_shapes, path=conf.cache)

s = swyft.NestedRatios(
model,
prior,
noise=None,
noise=noise,
obs=obs0,
device=DEVICE,
Ninit=conf.Ninit,
Expand All @@ -53,12 +34,12 @@ def main():
s.run(
max_rounds=conf.max_rounds,
train_args=conf.train_args,
head=heads.Head_FermiV1,
head=CustomHead,
tail_args=conf.tail_args,
head_args=conf.head_args,
)

samples = s.marginals(obs0, 10000)
samples = s.marginals(obs0, 3000)
swyft.plot.plot1d(
samples,
list(prior.params()),
Expand All @@ -67,12 +48,11 @@ def main():
grid_interpolate=True,
truth=par0,
)
plt.savefig("%s.marginals.pdf" % root)
plt.savefig("marginals.pdf")

diagnostics = swyft.utils.sample_diagnostics(samples)
state_dict = {"NestedRatios": s.state_dict(), "diagnostics": diagnostics}
torch.save(state_dict, "%s.diags.pt" % root)

torch.save(state_dict, "sample_diagnostics.pt")

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion dev/base.yaml → scripts/swyft_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
run_args: {}
train_args:
lr_schedule: [1e-3]
tail_args: {}
head_args: {}
Ninit: 500
max_rounds: 1
cache: cache.zarr
69 changes: 69 additions & 0 deletions scripts/swyft_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import torch
import swyft

prior = swyft.Prior(
{
"ox": ["uniform", 0.0, 10.0],
"oy": ["uniform", 0.0, 10.0],
"a": ["uniform", 1.0, 2.0],
"p1": ["uniform", 0.0, 0.5],
"p2": ["uniform", 1.0, 2.0],
}
)

def simulator(a, ox, oy, p1, p2, sigma=0.1):
"""Some examplary image simulator."""
x = np.linspace(-5, 5, 50, 50)
X, Y = np.meshgrid(x, x)

diff = np.cos(X + ox) * np.cos(Y + oy) * a + 2

p = np.random.randn(*X.shape) * p1 - 0.3
psc = 10 ** p * p2
n = np.random.randn(*X.shape) * sigma
mu = diff * 5 + psc + n
return mu

def model(params):
"""Model wrapper around simulator code."""
mu = simulator(
params["a"], params["ox"], params["oy"], params["p1"], params["p2"]
)
return dict(mu=mu)

def noise(obs, params=None, sigma=1.0):
"""Associated noise model."""
data = {k: v + np.random.randn(*v.shape) * sigma for k, v in obs.items()}
return data

class CustomHead(swyft.Module):
def __init__(self, obs_shapes):
super().__init__(obs_shapes=obs_shapes)

self.n_features = 10

self.conv1 = torch.nn.Conv2d(1, 10, 5)
self.conv2 = torch.nn.Conv2d(10, 20, 5)
self.conv3 = torch.nn.Conv2d(20, 40, 5)
self.pool = torch.nn.MaxPool2d(2)
self.l = torch.nn.Linear(160, 10)

def forward(self, obs):
x = obs["mu"].unsqueeze(1)
nbatch = len(x)
# x = torch.log(0.1+x)

x = self.conv1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
x = self.conv3(x)
x = self.pool(x)
x = x.view(nbatch, -1)
x = self.l(x)

return x

par0 = dict(ox=5.0, oy=5.0, a=1.5, p1=0.4, p2=1.1)
obs0 = noise(model(par0))

0 comments on commit d10662c

Please sign in to comment.