# Test Priors

In [1]:
import numpy as np
from delight.priors import *
from scipy.misc import derivative
from delight.utils import derivative_test

In [2]:
class SimpleChildModel(Model):
    def __init__(self):
        self.children = []
        self.params = OrderedDict({'a': 2.0})


class SimpleParentModel(Model):
    def __init__(self):
        self.children = [SimpleChildModel()]
        self.params = OrderedDict({'b': 3.0, 'e': 2.0})


class SimpleGrandParentModel(Model):
    def __init__(self):
        self.children = [SimpleParentModel()]
        self.params = OrderedDict({'c': 4.0})

In [3]:
def test_SimpleModel():
    """Test the model hierarchy, setters and getters"""
    mod = SimpleGrandParentModel()
    theta = [0] * mod.numparams()
    mod.set(theta)
    assert mod.get() == theta

In [4]:
test_SimpleModel()

In [5]:
def test_RayleighRedshiftDistr():
    mod = RayleighRedshiftDistr()
    alpha = 2.0
    mod.set([alpha])
    assert mod.get() == [alpha]
    z = 2.0
    res = z * np.exp(-0.5 * z**2 / alpha**2) / alpha**2
    assert mod(z) == res

In [6]:
test_RayleighRedshiftDistr()

In [7]:
def test_powerLawLuminosityFct():
    mod = powerLawLuminosityFct()
    theta = np.array([-1.2])
    mod.set(theta)
    assert mod.get() == theta
    z = 4.0
    res = 1 / np.exp(1.0)
    assert mod(z, mod.ellStar) == res
    ell = 1.1*mod.ellStar

    def prob(alpha):
        mod.set(alpha)
        return mod(z, ell)

    def prob_grad(alpha):
        mod.set(alpha)
        return np.array([mod.jac(z, ell)])

    relative_accuracy = 0.01
    derivative_test(theta, prob, prob_grad, relative_accuracy)

In [8]:
test_powerLawLuminosityFct()

In [9]:
def test_MultiTypePopulationPrior():
    numTypes, nz, nl = 3, 50, 50
    mod = MultiTypePopulationPrior(numTypes)
    ntot = numTypes * 1 - 1 + 1
    assert mod.numparams() == ntot
    theta = [0]*ntot
    mod.set(theta)
    assert mod.get() == theta

    mod = MultiTypePopulationPrior(numTypes, maglim=24)
    print(mod.get())
    theta = np.array(mod.get())
    redshifts = np.linspace(1e-2, 2, nz)
    luminosities = np.linspace(1e7, 1e9, nl)
    z_grid, l_grid = np.meshgrid(redshifts, luminosities)
    z_grid, l_grid = z_grid.ravel(), l_grid.ravel()
    grid = mod.grid(redshifts, luminosities)
    assert grid.shape[0] == numTypes
    assert grid.shape[1] == nz
    assert grid.shape[2] == nl

    grid2 = 0*grid
    for i in range(numTypes):
        zz, ll = np.meshgrid(redshifts, luminosities, indexing='ij')
        types = np.repeat(i, zz.ravel().size)
        grid2[i, :, :] = mod(types, zz.ravel(), ll.ravel()).reshape(zz.shape)
    assert np.allclose(grid, grid2)

    absMags = -2.5*np.log(luminosities)
    types2, redshifts2, luminosities2 = mod.draw(100, redshifts, luminosities)
    assert np.all(types2 >= 0)
    assert np.all(redshifts2 >= 0)
    assert np.all(luminosities2 >= 0)
    from copy import deepcopy

    for it in range(numTypes):
        for i in range(10):
            def prob(x):
                mod2 = deepcopy(mod)
                mod2.set(x)
                return mod2.gridflat(redshifts, luminosities)[it, i]

            def prob_grad(x):
                mod2 = deepcopy(mod)
                mod2.set(x)
                return mod2.gridflat_grad(redshifts, luminosities)[:, it, i]

            relative_accuracy = 0.01
            derivative_test(theta, prob, prob_grad, relative_accuracy,
                            dxfac=1e-2, order=15, lim=1e-4, superverbose=True)

    def prob(x):
        mod2 = deepcopy(mod)
        mod2.set(x)
        return np.sum(mod.gridflat(redshifts, luminosities))

    def prob_grad(x):
        mod2 = deepcopy(mod)
        mod2.set(x)
        return np.sum(mod.gridflat_grad(redshifts, luminosities), axis=(1, 2))

    relative_accuracy = 0.01
    print(prob_grad(theta))
    derivative_test(theta, prob, prob_grad, relative_accuracy,
                    dxfac=1e-1, order=15, lim=1e6, superverbose=True)
    # assert 0


In [10]:
test_MultiTypePopulationPrior()

[0.5, 0.5, -1.2]
0 analytical: -475503.0685275375 numerical: -475503.0685275386
1 analytical: 0.0 numerical: 2.582467573120084e-09
2 analytical: -547443.138632217 numerical: -547443.1386322211
0 analytical: -38956245.54569018 numerical: -38956245.545690045
1 analytical: 0.0 numerical: -2.5609097065171227e-07
2 analytical: -23319074.522364516 numerical: -23319074.522364367
0 analytical: -164501805.23014805 numerical: -164501805.23014754
1 analytical: 0.0 numerical: 1.4268152881413698e-06
2 analytical: -56343271.0374275 numerical: -56343271.03742717
0 analytical: -364019476.7346705 numerical: -364019476.7346708
1 analytical: 0.0 numerical: -2.753404260147363e-06
2 analytical: -63333312.74727201 numerical: -63333312.74727404
0 analytical: -616337170.2576281 numerical: -616337170.2576311
1 analytical: 0.0 numerical: 6.076152203604579e-06
2 analytical: -29686221.54544626 numerical: -29686221.545447167
0 analytical: -892297925.3708895 numerical: -892297925.3709092
1 analytical: 0.0 numerical