In [51]:
import numpy as np
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import relaxed
import functools
from fast_soft_sort.jax_ops import soft_sort

np.random.seed(0)
nBg = 8000
nSig = 300
background = np.random.normal(40, 10, nBg)
signal = np.random.normal(50, 5, nSig)


def significance(S, B):
    """calculate the significance according to the formula above"""
    return jnp.sqrt(2 * ((S + B) * jnp.log(1 + S / B) - S))


def pipeline(pars, data):
    s, b = data
    bins = pars
    sig_hist = relaxed.hist(s, bins=bins, bandwidth=1e-1)
    bg_hist = relaxed.hist(b, bins=bins, bandwidth=1e-1)
    sig = significance(sig_hist, bg_hist)
    return 1 / jnp.nanmean(sig), sig


def new_sig(s, b):
    n = s+b
    print(n)
    mu_hat = jnp.mean(n)**.5
    print(n*(jnp.log((mu_hat*s + b)/b))- mu_hat*s)
    q0 = 2*jnp.sum(n*(jnp.log((mu_hat*s + b)/b)) - mu_hat*s)
    return q0#**0.5

def pipeline2(pars, data):
    s, b = data
    bins = pars
    sig_hist = relaxed.hist(s, bins=bins, bandwidth=1e-1)
    bg_hist = relaxed.hist(b, bins=bins, bandwidth=1e-1)
    sig = new_sig(sig_hist, bg_hist)
    return 1 / sig, sig


pipe = functools.partial(pipeline2, data=(signal, background))
init = jnp.linspace(0, 70, 8)
pipe(init)


[   7.839876  163.8499   1095.9736   2814.2742   2863.311    1176.6823
  169.69638 ]
[    0.            0.            0.           -7.1432953 -2019.7014
 -2883.5713      -40.315796 ]


(DeviceArray(-0.000101, dtype=float32), DeviceArray(-9901.463, dtype=float32))

In [54]:
import optax
import celluloid
import matplotlib.pyplot as plt

plt.rc("figure", figsize=(10, 10), dpi=100, facecolor="white")

fig, ax = plt.subplots(1, 1)
ax_copy = ax
camera = celluloid.Camera(fig)

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
    opt_state = optimizer.init(params)

    @jax.jit
    def step(params, opt_state):
        grads, loss_value = jax.grad(pipe, has_aux=True)(params)
        updates, opt_state = optimizer.update(jnp.nan_to_num(grads), opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_value

    for i in range(200):
        params, opt_state, loss_value = step(params, opt_state)
        ax.hist([background, signal], bins=params, stacked=True, label=["background B", "signal S"], color=["C0", "C1"])
        ax.set_title(f"iteration {i}, significances: {loss_value}")
        camera.snap()
        print(f"step {i}, loss: {loss_value}")

    return params


# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(init, optimizer)

# animate
animation = camera.animate()
animation.save("animation2.gif")

Traced<ShapedArray(float32[7])>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float32[7])>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float32[7])>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float32[7]), *)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fc911d4fd40>, invars=(Traced<ShapedArray(float32[7]):JaxprTrace(level=1/1)>, Traced<ShapedArray(float32[7]):JaxprTrace(level=1/1)>), outvars=[<weakref at 0x7fc910a38e00; to 'JaxprTracer' at 0x7fc910a38ef0>], primitive=xla_call, params={'device': None, 'backend': None, 'name': 'jvp(fn)', 'donated_invars': (False, False), 'inline': True, 'call_jaxpr': { lambda ; a:f32[7] b:f32[7]. let c:f32[7] = add a b in (c,) }}, source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7fc910a19770>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Traced<ShapedArray(float32[7])>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float32[7])>with<DynamicJaxprTrac

In [37]:
jnp.log(1.4)

DeviceArray(0.3364722, dtype=float32, weak_type=True)