In [1]:
import sys, os
sys.path.insert(0, "/home/storage/hans/jax_reco_new")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import jax.numpy as jnp
from jax.scipy import optimize
import jax
jax.config.update("jax_enable_x64", True)
import optimistix as optx

from jax.scipy.special import gammaincc, erf

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from lib.smallest_network_eqx import get_network_eval_fn
from lib.trafos import transform_network_outputs, transform_network_inputs
from lib.plotting import adjust_plot_1d

from dom_track_eval import get_eval_network_doms_and_track

In [2]:
def log1mexp(x):
    return jnp.where(
            x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
    )

In [3]:
w = 1.e-30
print(jnp.log(w), jnp.log(0.5))

-69.07755278982137 -0.6931471805599453


In [4]:
print(1-w)

1.0


In [5]:
print(log1mexp(jnp.log(w)), jnp.exp(log1mexp(jnp.log(w))))

-1.0000000000000024e-30 1.0


In [6]:
'''
import functools

@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 4))
def logsumexp(  # noqa: D103
    mat, axis=None, keepdims=False, b=None, return_sign=False
):
  return jax.scipy.special.logsumexp(
      mat, axis=axis, keepdims=keepdims, b=b, return_sign=return_sign
  )


@logsumexp.defjvp
def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents):
  """Custom derivative rule for lse that does not blow up with -inf.

  This logsumexp implementation uses the standard jax one in forward mode but
  implements a custom rule to differentiate. Given the preference of jax for
  jvp over vjp, and the fact that this is a simple linear rule, jvp is used.
  This custom differentiation address issues when the output of lse is
  -inf (which corresponds to the case where all inputs in a slice are -inf,
  which happens typically when ``a`` or ``b`` weight vectors have zeros.)

  Although both exp(lse) and its derivative should be 0, automatic
  differentiation returns a NaN derivative because of a -inf - (-inf) operation
  appearing in the definition of centered_exp below. This is corrected in the
  implementation below.

  Args:
    axis: argument from original logsumexp
    keepdims: argument from original logsumexp
    return_sign: argument from original logsumexp
    primals: mat and b, the two arguments against which we differentiate.
    tangents: of same size as mat and b.

  Returns:
    original primal outputs + their tangent.
  """  # noqa: D401
  mat, b = primals
  tan_mat, tan_b = tangents
  lse = logsumexp(mat, axis, keepdims, b, return_sign)
  if return_sign:
    lse, sign = lse
  lse = jnp.where(jnp.isfinite(lse), lse, 0.0)

  if axis is not None:
    centered_exp = jnp.exp(mat - jnp.expand_dims(lse, axis=axis))
  else:
    centered_exp = jnp.exp(mat - lse)

  if b is None:
    res = jnp.sum(centered_exp * tan_mat, axis=axis, keepdims=keepdims)
  else:
    res = jnp.sum(b * centered_exp * tan_mat, axis=axis, keepdims=keepdims)
    res += jnp.sum(tan_b * centered_exp, axis=axis, keepdims=keepdims)
  if return_sign:
    return (lse, sign), (sign * res, jnp.zeros_like(sign))
  return lse, res
'''

'\nimport functools\n\n@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 4))\ndef logsumexp(  # noqa: D103\n    mat, axis=None, keepdims=False, b=None, return_sign=False\n):\n  return jax.scipy.special.logsumexp(\n      mat, axis=axis, keepdims=keepdims, b=b, return_sign=return_sign\n  )\n\n\n@logsumexp.defjvp\ndef logsumexp_jvp(axis, keepdims, return_sign, primals, tangents):\n  """Custom derivative rule for lse that does not blow up with -inf.\n\n  This logsumexp implementation uses the standard jax one in forward mode but\n  implements a custom rule to differentiate. Given the preference of jax for\n  jvp over vjp, and the fact that this is a simple linear rule, jvp is used.\n  This custom differentiation address issues when the output of lse is\n  -inf (which corresponds to the case where all inputs in a slice are -inf,\n  which happens typically when ``a`` or ``b`` weight vectors have zeros.)\n\n  Although both exp(lse) and its derivative should be 0, automatic\n  differen

In [7]:
from jax.scipy.special import logsumexp
from lib.gamma_sf_approx import c_coeffs

In [8]:
def W(x, c):
    x_t = c[1] * (x-c[2])
    return 0.5 + 0.5 * jnp.tanh(x_t)

def R2(x, a, c):
    r2 = W(x, c) * (1.0-jnp.power(c[3], -x))
    return jax.scipy.special.gamma(a) * r2

def R1(x, a, c):
    w = W(x, c)
    return  (1-w) *jnp.exp(-x)*jnp.power(x, a)*(1.0/a + 
                                                 c[0]*x/(a*(a+1)) + 
                                                 (c[0]*x)**2/(a*(a+1)*(a+2)))

In [9]:
def logW(x, c):
    x = jnp.clip(c[1] * (x-c[2]), min=-15, max=15)
    _x = jnp.concatenate([jnp.expand_dims(x, axis=0), jnp.expand_dims(-x, axis=0)],
                         axis=0)
    
    return x - logsumexp(_x, 0)

def logR2(x, a, c):
    return jax.scipy.special.gammaln(a) + logW(x, c) + log1mexp(-x * jnp.log(c[3]))

def logR1(x, a, c):
    l1mw = log1mexp(logW(x, c))
    return l1mw - x + a*jnp.log(x) + jnp.log((1.0/a + 
                                                 c[0]*x/(a*(a+1)) + 
                                                 (c[0]*x)**2/(a*(a+1)*(a+2))))

In [10]:
a = 2
b = 1.e-2
c = c_coeffs(a)
x = 100

In [11]:
print(W(x, c))
print(1-W(x, c))

1.0
0.0


In [12]:
lw = logW(x, c) # log W
print(lw)
l1mw = log1mexp(lw) # log(1-W)
print(l1mw)

-9.414691248821327e-14
-29.993919933885245


In [13]:
print(jnp.exp(lw))
print(jnp.exp(l1mw))

0.9999999999999059
9.414691248820891e-14


In [14]:
print(R2(x*b, a, c))
print(logR2(x*b, a, c), jnp.exp(logR2(x*b, a, c)))


0.036919490393103556
-3.2990156724464432 0.03691949039310357


In [15]:
print(R1(x*b, a, c))

0.2363586184813647


In [16]:
print(logR1(x*b, a, c), jnp.exp(logR1(x*b, a, c)))

-1.442405057498202 0.2363586184813647


In [17]:
from lib.gamma_sf_approx import log_gamma_cdf_fast, gamma_cdf_fast, log_gamma_sf_fast

In [18]:
print(gamma_cdf_fast(x, a, b))

0.2732781088744679


In [19]:
print(log_gamma_cdf_fast(x, a, b), jnp.exp(log_gamma_cdf_fast(x, a, b)))

-1.2972652885080862 0.273278108874468


In [20]:
eval_network = get_network_eval_fn(bpath='/home/storage/hans/photondata/large_table_training/all_time_bins/w_penalty_b5/cache/test_penalties_tree_start_epoch_160', dtype=jnp.float64)

dist = 1
z = -500
rho = 0.0
zenith = np.pi/2
azimuth = 0.0

x = jnp.array([dist, rho, z, zenith, azimuth])
x_prime = transform_network_inputs(x)
y = eval_network(x_prime)
logits, gamma_as, gamma_bs = transform_network_outputs(y)
mix_probs = jax.nn.softmax(logits)
log_mix_probs = jnp.log(mix_probs)
# index 1 is main component
g_a = gamma_as
g_b = gamma_bs
sigma = 3

print(gamma_as, gamma_bs, mix_probs)

[1.00002936 1.00059187 1.67519907] [3.72481150e-03 1.31571253e-01 4.99725565e+00] [0.01525278 0.04859376 0.93615346]


In [21]:
from lib.c_mpe_gamma import c_multi_gamma_mpe_logprob_midpoint2
from jax.scipy.stats.norm import logpdf as norm_logpdf
from jax.scipy.stats.gamma import logpdf as gamma_logpdf

def c_multi_gamma_mpe_logprob_midpoint2_stable(x, log_mix_probs, a, b, n, sigma=3.0):
    """
    Q < 30
    """
    nmax = 10
    nint1 = 10
    nint2 = 15
    nint3 = 35
    #eps = 1.e-12
    eps = 1.e-6

    x0 = eps
    x_m0 = 0.01
    xvals0 = jnp.linspace(x0, x_m0, 10)[:-1]

    x_m1 = 0.05
    xvals1 = jnp.linspace(x_m0, x_m1, 10)[:-1]

    x_m2 = 0.25
    xvals2 = jnp.linspace(x_m1, x_m2, 10)[:-1]

    x_m25 = 0.75
    xvals25 = jnp.linspace(x_m2, x_m25, 10)[:-1]

    x_m3 = 2.5
    xvals3 = jnp.linspace(x_m25, x_m3, 10)[:-1]

    x_m4 = 8.0
    xvals4 = jnp.linspace(x_m3, x_m4, 20)

    xmin = jnp.max(jnp.array([1.5 * eps, x - 10 * sigma]))
    xmax = jnp.max(jnp.array([xmin+1.5*eps, x + 10 * sigma]))
    xvals_x = jnp.linspace(xmin, xmax, 101)
    xvals = jnp.sort(jnp.concatenate([xvals0, xvals1, xvals2, xvals25, xvals3, xvals4, xvals_x]))

    dx = xvals[1:]-xvals[:-1]

    xvals = 0.5*(xvals[:-1]+xvals[1:])
    log_n_pdf = norm_logpdf(xvals, loc=x, scale=sigma)

    a_e = jnp.expand_dims(a, axis=-1)
    b_e = jnp.expand_dims(b, axis=-1)
    log_mix_probs_e = jnp.expand_dims(log_mix_probs, axis=-1)

    xvals_e = jnp.expand_dims(xvals, axis=0)
    log_pdfs = logsumexp(gamma_logpdf(xvals_e, a_e, scale=1./b_e) + log_mix_probs_e, 0)
    log_sfs = logsumexp(log_gamma_sf_fast(xvals_e, a_e, b_e) + log_mix_probs_e, 0)
 
    print(dx)
    return logsumexp(log_n_pdf + log_pdfs + (n-1) * log_sfs + jnp.log(dx) + jnp.log(n), 0)

In [22]:
n = 10
x = 200
print(c_multi_gamma_mpe_logprob_midpoint2(x, log_mix_probs, g_a, g_b, n, 3.0))
print(c_multi_gamma_mpe_logprob_midpoint2_stable(x, log_mix_probs, g_a, g_b, n, 3.0))

-52.38723396862468
[1.11100000e-03 1.11100000e-03 1.11100000e-03 1.11100000e-03
 1.11100000e-03 1.11100000e-03 1.11100000e-03 1.11100000e-03
 1.11100000e-03 4.44444444e-03 4.44444444e-03 4.44444444e-03
 4.44444444e-03 4.44444444e-03 4.44444444e-03 4.44444444e-03
 4.44444444e-03 4.44444444e-03 2.22222222e-02 2.22222222e-02
 2.22222222e-02 2.22222222e-02 2.22222222e-02 2.22222222e-02
 2.22222222e-02 2.22222222e-02 2.22222222e-02 5.55555556e-02
 5.55555556e-02 5.55555556e-02 5.55555556e-02 5.55555556e-02
 5.55555556e-02 5.55555556e-02 5.55555556e-02 5.55555556e-02
 1.94444444e-01 1.94444444e-01 1.94444444e-01 1.94444444e-01
 1.94444444e-01 1.94444444e-01 1.94444444e-01 1.94444444e-01
 1.94444444e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 2.89473684e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 2.89473684e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 2.89473684e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 2.89473684e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 1.62

In [23]:
df0 = jax.grad(c_multi_gamma_mpe_logprob_midpoint2, argnums=2)
df1 = jax.grad(c_multi_gamma_mpe_logprob_midpoint2_stable, argnums=3)

In [24]:
df0(x, log_mix_probs, g_a, g_b, n, 3.0)

Array([9.01024387e+00, 7.97495881e-09, 0.00000000e+00], dtype=float64)

In [25]:
df1(x, log_mix_probs, g_a, g_b, n, 3.0)

[1.11100000e-03 1.11100000e-03 1.11100000e-03 1.11100000e-03
 1.11100000e-03 1.11100000e-03 1.11100000e-03 1.11100000e-03
 1.11100000e-03 4.44444444e-03 4.44444444e-03 4.44444444e-03
 4.44444444e-03 4.44444444e-03 4.44444444e-03 4.44444444e-03
 4.44444444e-03 4.44444444e-03 2.22222222e-02 2.22222222e-02
 2.22222222e-02 2.22222222e-02 2.22222222e-02 2.22222222e-02
 2.22222222e-02 2.22222222e-02 2.22222222e-02 5.55555556e-02
 5.55555556e-02 5.55555556e-02 5.55555556e-02 5.55555556e-02
 5.55555556e-02 5.55555556e-02 5.55555556e-02 5.55555556e-02
 1.94444444e-01 1.94444444e-01 1.94444444e-01 1.94444444e-01
 1.94444444e-01 1.94444444e-01 1.94444444e-01 1.94444444e-01
 1.94444444e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 2.89473684e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 2.89473684e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 2.89473684e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 2.89473684e-01 2.89473684e-01 2.89473684e-01 2.89473684e-01
 1.62000000e+02 6.000000

Array([-1.60463170e+003, -2.13430376e-007, -1.10861642e-276], dtype=float64)

In [26]:
g_a

Array([1.00002936, 1.00059187, 1.67519907], dtype=float64)

In [27]:
g_b

Array([3.72481150e-03, 1.31571253e-01, 4.99725565e+00], dtype=float64)

In [28]:
def logR1(x, a, c):
    l1mw = log1mexp(logW(x, c))
    return l1mw - x + a*jnp.log(x) + jnp.log((1.0/a + 
                                                 c[0]*x/(a*(a+1)) + 
                                                 (c[0]*x)**2/(a*(a+1)*(a+2))))

In [29]:
c = c_coeffs(a)
logR1(x*g_b, g_a, c)

Array([-7.25585506e-01, -4.82675146e+01, -1.00695792e+03], dtype=float64)

In [30]:
f = lambda x, y, z: jnp.sum(logR1(x, y, z))
jax.grad(f, argnums=1)(x*g_b, g_a, c)

Array([-1.46604227,  1.48003319,  5.66541929], dtype=float64)

In [31]:
from lib.gamma_sf_approx import log_gamma_sf_fast

In [32]:
def t(x, a, b):
    return logsumexp(log_gamma_sf_fast(x, a, b) + log_mix_probs, 0)

jax.grad(t, argnums=2)(x, g_a, g_b)

Array([-1.86200749e+002, -2.36180585e-009, -1.56801122e-301], dtype=float64)

In [33]:
t(x, g_a, g_b)

Array(-4.90829838, dtype=float64)

In [34]:
print(log_mix_probs)

[-4.18299335 -3.02426018 -0.06597586]


In [35]:
log_gamma_sf_fast(x, a, b) 

Array(-0.91747711, dtype=float64)

In [36]:
g_a

Array([1.00002936, 1.00059187, 1.67519907], dtype=float64)

In [37]:
g_b

Array([3.72481150e-03, 1.31571253e-01, 4.99725565e+00], dtype=float64)

In [38]:
x * g_b

Array([7.44962299e-01, 2.63142505e+01, 9.99451131e+02], dtype=float64)

In [39]:
log1mexp(logW(x*g_b, c))

Array([ -0.0493198 , -29.99391993, -29.99391993], dtype=float64)

In [40]:
W(x*g_b, c)

Array([0.04812333, 1.        , 1.        ], dtype=float64)

In [41]:
print(c[1] * (x*g_b-c[2]))

[ -1.49233421  23.69729168 982.38467141]


In [42]:
    x = c[1] * (x*g_b-c[2])
    _x = jnp.concatenate([jnp.expand_dims(x, axis=0), jnp.expand_dims(-x, axis=0)],
                         axis=0)

In [43]:
print(x)

[ -1.49233421  23.69729168 982.38467141]


In [44]:
print(_x)

[[  -1.49233421   23.69729168  982.38467141]
 [   1.49233421  -23.69729168 -982.38467141]]


In [45]:
x - logsumexp(_x)

Array([-983.87700562, -958.68737974,    0.        ], dtype=float64)

In [46]:
print(x)

[ -1.49233421  23.69729168 982.38467141]


In [47]:
jnp.tanh(956.80492373)

Array(1., dtype=float64, weak_type=True)

In [48]:
jnp.tanh(100)

Array(1., dtype=float64, weak_type=True)

In [49]:
1+jnp.exp(-2*50)

Array(1., dtype=float64, weak_type=True)

In [50]:
jnp.log1p

<PjitFunction of <function log1p at 0x73a9d1fee520>>

In [51]:
jax.nn.softplus(-900)

Array(0., dtype=float64, weak_type=True)

In [52]:
jnp.tanh(10)

Array(1., dtype=float64, weak_type=True)