Skip to content

Commit

Permalink
moved network file into folder nn. blackened scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
bkmi committed Feb 22, 2021
1 parent 6876dd8 commit 666ea00
Show file tree
Hide file tree
Showing 18 changed files with 546 additions and 586 deletions.
4 changes: 2 additions & 2 deletions notebooks/Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@
"metadata": {},
"outputs": [],
"source": [
"from swyft.network import OnlineNormalizationLayer\n",
"from swyft.nn import OnlineNormalizationLayer\n",
"from swyft.utils import Module"
]
},
Expand Down Expand Up @@ -1276,4 +1276,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
10 changes: 7 additions & 3 deletions scripts/conf/definitions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch

import swyft

prior = swyft.Prior(
Expand All @@ -12,6 +13,7 @@
}
)


def simulator(a, ox, oy, p1, p2, sigma=0.1):
"""Some examplary image simulator."""
x = np.linspace(-5, 5, 50, 50)
Expand All @@ -25,18 +27,19 @@ def simulator(a, ox, oy, p1, p2, sigma=0.1):
mu = diff * 5 + psc + n
return mu


def model(params):
"""Model wrapper around simulator code."""
mu = simulator(
params["a"], params["ox"], params["oy"], params["p1"], params["p2"]
)
mu = simulator(params["a"], params["ox"], params["oy"], params["p1"], params["p2"])
return dict(mu=mu)


def noise(obs, params=None, sigma=1.0):
"""Associated noise model."""
data = {k: v + np.random.randn(*v.shape) * sigma for k, v in obs.items()}
return data


class CustomHead(swyft.Module):
def __init__(self, obs_shapes):
super().__init__(obs_shapes=obs_shapes)
Expand Down Expand Up @@ -65,5 +68,6 @@ def forward(self, obs):

return x


par0 = dict(ox=5.0, oy=5.0, a=1.5, p1=0.4, p2=1.1)
obs0 = noise(model(par0))
8 changes: 6 additions & 2 deletions scripts/run_swyft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Requires python 3.5+

import logging

logging.basicConfig(level=logging.DEBUG, format="%(message)s")

import os
import importlib.util
import os

import numpy as np
import pylab as plt
import torch
Expand All @@ -15,11 +17,12 @@

DEVICE = "cuda:0"


def main():
# Pretty hacky way to import local model
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
cwd = os.getcwd()
spec = importlib.util.spec_from_file_location("defs", cwd+"/definitions.py")
spec = importlib.util.spec_from_file_location("defs", cwd + "/definitions.py")
defs = importlib.util.module_from_spec(spec)
spec.loader.exec_module(defs)

Expand Down Expand Up @@ -71,5 +74,6 @@ def main():
state_dict = {"NestedRatios": s.state_dict(), "diagnostics": diagnostics}
torch.save(state_dict, "sample_diagnostics.pt")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion swyft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .estimation import Points, RatioEstimator
from .intensity import Prior
from .interface import Marginals, NestedRatios
from .network import DefaultHead, DefaultTail, OnlineNormalizationLayer
from .nn import DefaultHead, DefaultTail, OnlineNormalizationLayer
from .plot import corner, plot1d
from .utils import Module, format_param_list, set_verbosity

Expand Down
32 changes: 5 additions & 27 deletions swyft/estimation.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,17 @@
# pylint: disable=no-member, not-callable
from copy import deepcopy
from warnings import warn

import numpy as np
import torch
import torch.nn as nn
from scipy.integrate import trapz
from scipy.special import xlogy

from .cache import Dataset, Normalize
from .network import DefaultHead, DefaultTail
from .cache import Dataset
from .nn import DefaultHead, DefaultTail
from .train import trainloop
from .types import (
Array,
Callable,
Combinations,
Device,
Dict,
Optional,
PathType,
Sequence,
Tuple,
Union,
)
from .types import Array, Device, Optional, Sequence, Tuple
from .utils import (
Module,
array_to_tensor,
dict_to_device,
dict_to_tensor,
dict_to_tensor_unsqueeze,
format_param_list,
get_obs_shapes,
process_combinations,
tobytes,
verbosity,
)


Expand Down Expand Up @@ -211,13 +189,13 @@ def posterior(self, obs0, prior, n_samples=100000):
pars = prior.sample(n_samples) # prior samples

# Unmasked original wrongly normalized log_prob densities
log_probs = prior.log_prob(pars, unmasked = True)
log_probs = prior.log_prob(pars, unmasked=True)

lnL = self.lnL(obs0, pars) # evaluate lnL for reference observation
weights = {}
for k, v in lnL.items():
weights[k] = np.exp(v)
return dict(params=pars, weights=weights, log_priors = log_probs)
return dict(params=pars, weights=weights, log_priors=log_probs)


class Points:
Expand Down
4 changes: 2 additions & 2 deletions swyft/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

logging.basicConfig(level=logging.DEBUG, format="%(message)s")

from .cache import DirectoryCache, MemoryCache
from .cache import MemoryCache
from .estimation import Points, RatioEstimator
from .intensity import Prior
from .network import DefaultHead, DefaultTail
from .nn import DefaultHead, DefaultTail
from .types import Dict
from .utils import all_finite, format_param_list

Expand Down
Loading

0 comments on commit 666ea00

Please sign in to comment.