In [1]:
import os
os.environ['JAX_PLATFORM'] = 'cpu'

import jax.numpy as jnp
import jax
#jax.config.update("jax_enable_x64", True)


from jax.scipy.stats.norm import pdf as norm_pdf
from jax.scipy.stats.norm import logpdf as norm_logpdf
from jax.scipy.special import logsumexp

import matplotlib.pyplot as plt

def log1m_exp(x):
    """
    Numerically stable calculation
    of the quantity log(1 - exp(x)),
    following the algorithm of
    Machler [1]. This is
    the algorithm used in TensorFlow Probability,
    PyMC, and Stan, but it is not provided
    yet with Numpyro.

    Currently returns NaN for x > 0,
    but may be modified in the future
    to throw a ValueError

    [1] https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
    """
    # return 0. rather than -0. if
    # we get a negative exponent that exceeds
    # the floating point representation
    crit = -0.6931472
    arr_x = 1.0 * jnp.array(x)

    crit_oob = jnp.log(jnp.finfo(
        arr_x.dtype).smallest_normal)+5
    
    oob = arr_x < crit_oob
    mask = arr_x > crit

    more_val = jnp.log(-jnp.expm1(jnp.clip(arr_x, min=crit)))
    less_val = jnp.log1p(-jnp.exp(jnp.clip(arr_x, max=crit)))

    return jnp.where(
        oob,
        -jnp.exp(crit_oob),
        jnp.where(
            mask,
            more_val,
            less_val))

In [4]:
def log_cdf(x, a, b):
    return a * log1m_exp(-b*x)

def pdf(x, a, b):
    return a*b*jnp.power(1.-jnp.exp(-b*x), a-1) * jnp.exp(-b*x)

def log_cdf_diff(x1, x2, a, b):
    # x2 > x1: F[x2] - F[x1]
    
    crit_oob = jnp.log(jnp.finfo(x1.dtype).smallest_normal)+5
    
    x = a * (log1m_exp(-b*x1) - log1m_exp(-b*x2))
    

    #x_prime = jnp.where(x > -jnp.exp(crit_oob), -0.5, x)
    #y = log1m_exp(x_prime)
    #z = jnp.log(-x)
    #y = jnp.where(x == -0.5, z, y)
    
    #x_prime = jnp.where(x > -jnp.exp(crit_oob), -jnp.exp(crit_oob), x)
    #return log_cdf(x2, a, b) + log1m_exp(x_prime)
    return  log_cdf(x2, a, b) + log1m_exp(x)

    #return  log_cdf(x2, a, b) + y


#def log_cdf_diff2(x1, x2, a, b):
#    x = a * jnp.log((1-jnp.exp(-b*x1)) / (1-jnp.exp(-b*x2)))
#    print(x)
#    print(log1m_exp(x))
#    return log_cdf(x2, a, b) + log1m_exp(x)

In [8]:
print(jnp.log(jnp.finfo(jnp.float32).smallest_normal))

-87.33655


In [None]:
a = 3.
b = 0.1

xvals = jnp.linspace(0.0, 100, 1000)
yvals = pdf(xvals, a, b)

plt.plot(xvals, yvals)

In [45]:
x1 = jnp.array(1000.)
x2 = jnp.array(1001.)

print(cdf(x1,a,b)-cdf(x2,a,b))
print(jnp.log(cdf(x2,a,b)-cdf(x1,a,b)))

NameError: name 'cdf' is not defined

In [51]:
log_cdf_diff(x1, x2, a, b)

Array(-101.25355617, dtype=float64)

In [60]:
log_cdf_diff(edges[:-1], edges[1:], a, b)

Array([ -1.37602544,  -0.93170427,  -1.55358075,  -2.42926362,
        -3.38526323,  -4.36929933,  -5.36345616,  -6.36131056,
        -7.36052178,  -8.36023167,  -9.36012496, -10.3600857 ,
       -11.36007126, -12.36006595, -13.36006399, -14.36006328,
       -15.36006301, -16.36006291, -17.36006288, -18.36006286,
       -19.36006286, -20.36006286, -21.36006286, -22.36006286,
       -23.36006286, -24.36006286, -25.36006286, -26.36006286,
       -27.36006286, -28.36006286, -29.36006286, -30.36006286,
       -31.36006286, -32.36006286, -33.36006286, -34.36006286,
       -35.36006286, -36.36006286, -37.36006286, -38.36006286,
       -39.36006286, -40.36006286, -41.36006286, -42.36006286,
       -43.36006286, -44.36006286, -45.36006286, -46.36006286,
       -47.36006286, -48.36006286, -49.36006286, -50.36006286,
       -51.36006286, -52.36006286, -53.36006286, -54.36006286,
       -55.36006286, -56.36006286, -57.36006286, -58.36006286,
       -59.36006286, -60.36006286, -61.36006286, -62.36

In [56]:
g = jax.grad(log_cdf_diff, argnums=3)
print(g(x1, x2, a, b))

-990.4916680552251


In [57]:
gv = jax.vmap(g, (0, 0, None, None), 0)

In [64]:
edges = jnp.linspace(1.e-15, 1000, 101)

In [65]:
print(gv(edges[:-1], edges[1:], a, b))

[  17.45930121    4.21722183   -9.57588662  -21.87999803  -33.09793787
  -43.69249773  -53.96733739  -64.08951116  -74.14228089  -84.16458482
  -94.17385557 -104.17765841 -114.17920178 -124.17982267 -134.18007063
 -144.18016904 -154.18020789 -164.18022315 -174.18022913 -184.18023145
 -194.18023236 -204.18023271 -214.18023285 -224.1802329  -234.18023292
 -244.18023293 -254.18023293 -264.18023293 -274.18023293 -284.18023293
 -294.18023293 -304.18023293 -314.18023293 -324.18023293 -334.18023293
 -344.18023293 -354.18023293 -364.18023293 -374.18023293 -384.18023293
 -394.18023293 -404.18023293 -414.18023293 -424.18023293 -434.18023293
 -444.18023293 -454.18023293 -464.18023293 -474.18023293 -484.18023293
 -494.18023293 -504.18023293 -514.18023293 -524.18023293 -534.18023293
 -544.18023293 -554.18023293 -564.18023293 -574.18023293 -584.18023293
 -594.18023293 -604.18023293 -614.18023293 -624.18023293 -634.18023293
 -644.18023293 -654.18023293 -664.18023293 -674.18023293 -684.18023293
 -694.

In [61]:
print(edges)

[   0.   10.   20.   30.   40.   50.   60.   70.   80.   90.  100.  110.
  120.  130.  140.  150.  160.  170.  180.  190.  200.  210.  220.  230.
  240.  250.  260.  270.  280.  290.  300.  310.  320.  330.  340.  350.
  360.  370.  380.  390.  400.  410.  420.  430.  440.  450.  460.  470.
  480.  490.  500.  510.  520.  530.  540.  550.  560.  570.  580.  590.
  600.  610.  620.  630.  640.  650.  660.  670.  680.  690.  700.  710.
  720.  730.  740.  750.  760.  770.  780.  790.  800.  810.  820.  830.
  840.  850.  860.  870.  880.  890.  900.  910.  920.  930.  940.  950.
  960.  970.  980.  990. 1000.]


In [None]:
crit_oob = jnp.log(jnp.finfo(jnp.float32).smallest_normal)+5
print(crit_oob)

In [None]:
log_cdf(x2, a, b)

In [None]:
log_cdf(x1, a, b)

In [None]:
x = a * (log1m_exp(-b*x1) - log1m_exp(-b*x2))

In [None]:
print(x)

In [None]:
log1m_exp(x)

In [None]:
log1m_exp(crit_oob)

In [None]:
jnp.log(1-jnp.exp(-100))

In [None]:
jnp.log(0.9999)

In [None]:
log1m_exp(-300)

In [None]:
crit_oob

In [None]:
jnp.log(jnp.exp(crit_oob))

In [None]:
jnp.exp(crit_oob)

In [None]:
crit_oob

In [None]:
-jnp.exp(crit_oob)

In [5]:
jnp.log(1-jnp.exp(-17))

Array(-4.1399378e-08, dtype=float64, weak_type=True)

In [6]:
-jnp.exp(-17)

Array(-4.13993772e-08, dtype=float64, weak_type=True)

In [7]:
-jnp.exp(crit_oob)

NameError: name 'crit_oob' is not defined

In [8]:
jnp.exp(-jnp.exp(crit_oob))

NameError: name 'crit_oob' is not defined

In [9]:
-jnp.exp(crit_oob)

NameError: name 'crit_oob' is not defined

In [10]:
from jax._src.typing import Array, ArrayLike
from jax._src.scipy.special import gammaln, xlogy
def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
  r"""Multinomial log probability mass function.

  JAX implementation of :obj:`scipy.stats.multinomial` ``logpdf``.

  The multinomial probability distribution is given by

  .. math::

     f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}

  with :math:`n = \sum_i x_i`.

  Args:
    x: arraylike, value at which to evaluate the PMF
    n: arraylike, distribution shape parameter
    p: arraylike, distribution shape parameter

  Returns:
    array of logpmf values.

  See Also:
    :func:`jax.scipy.stats.multinomial.pmf`
  """
  logprobs = gammaln(n + 1) + jnp.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)
  return logprobs

In [11]:
def logpmf2(x: ArrayLike, n: ArrayLike, logp: ArrayLike) -> Array:
  r"""Multinomial log probability mass function.

  JAX implementation of :obj:`scipy.stats.multinomial` ``logpdf``.

  The multinomial probability distribution is given by

  .. math::

     f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}

  with :math:`n = \sum_i x_i`.

  Args:
    x: arraylike, value at which to evaluate the PMF
    n: arraylike, distribution shape parameter
    p: arraylike, distribution shape parameter

  Returns:
    array of logpmf values.

  See Also:
    :func:`jax.scipy.stats.multinomial.pmf`
  """
  logprobs = gammaln(n + 1) + jnp.sum(x*logp - gammaln(x + 1), axis=-1)
  return logprobs

In [12]:
edges = jnp.linspace(0, 1000, 101)

In [13]:
log_probs = log_cdf_diff(edges[:-1], edges[1:], a, b)

In [14]:
probs = jnp.exp(log_probs)

In [15]:
print(probs)

[2.52580458e-01 3.93881857e-01 2.11489327e-01 8.81016854e-02
 3.38687259e-02 1.26601080e-02 4.68468705e-03 1.72710175e-03
 6.35866592e-04 2.33990118e-04 8.60893404e-05 3.16717417e-05
 1.16515509e-05 4.28638881e-06 1.57687740e-06 5.80101195e-07
 2.13407360e-07 7.85081879e-08 2.88815493e-08 1.06249284e-08
 3.90869273e-09 1.43792770e-09 5.28984039e-10 1.94602353e-10
 7.15902047e-11 2.63365645e-11 9.68868063e-12 3.56426642e-12
 1.31122034e-12 4.82371005e-13 1.77454376e-13 6.52818166e-14
 2.40158382e-14 8.83493314e-15 3.25019027e-15 1.19567818e-15
 4.39865420e-16 1.61817445e-16 5.95293112e-17 2.18996098e-17
 8.05641620e-18 2.96378989e-18 1.09031737e-18 4.01105344e-19
 1.47558410e-19 5.42837053e-20 1.99698592e-20 7.34650064e-21
 2.70262655e-21 9.94240745e-22 3.65760730e-22 1.34555853e-22
 4.95003319e-23 1.82101544e-23 6.69914144e-24 2.46447641e-24
 9.06630204e-25 3.33530613e-25 1.22699056e-25 4.51384600e-26
 1.66055114e-26 6.10882627e-27 2.24731159e-27 8.26739733e-28
 3.04140551e-28 1.118870

In [16]:
log_probs

Array([ -1.37602544,  -0.93170427,  -1.55358075,  -2.42926362,
        -3.38526323,  -4.36929933,  -5.36345616,  -6.36131056,
        -7.36052178,  -8.36023167,  -9.36012496, -10.3600857 ,
       -11.36007126, -12.36006595, -13.36006399, -14.36006328,
       -15.36006301, -16.36006291, -17.36006288, -18.36006286,
       -19.36006286, -20.36006286, -21.36006286, -22.36006286,
       -23.36006286, -24.36006286, -25.36006286, -26.36006286,
       -27.36006286, -28.36006286, -29.36006286, -30.36006286,
       -31.36006286, -32.36006286, -33.36006286, -34.36006286,
       -35.36006286, -36.36006286, -37.36006286, -38.36006286,
       -39.36006286, -40.36006286, -41.36006286, -42.36006286,
       -43.36006286, -44.36006286, -45.36006286, -46.36006286,
       -47.36006286, -48.36006286, -49.36006286, -50.36006286,
       -51.36006286, -52.36006286, -53.36006286, -54.36006286,
       -55.36006286, -56.36006286, -57.36006286, -58.36006286,
       -59.36006286, -60.36006286, -61.36006286, -62.36

In [17]:
import numpy as np
vals = np.random.randint(0, 32, len(edges)-1)
print(logpmf(jnp.array(vals), jnp.sum(vals), probs))

-59532.712870354604


In [18]:
print(logpmf2(jnp.array(vals), jnp.sum(vals), log_probs))

-59532.712870354604


In [19]:
delta = cdf(edges[1:], a, b) - cdf(edges[:-1], a, b)
print(delta)

NameError: name 'cdf' is not defined

In [None]:
g = jax.grad(logpmf2, argnums=2)

In [None]:
vals = jnp.array(vals, dtype=jnp.float32)

In [20]:
g(vals, jnp.sum(vals), log_probs)

NameError: name 'g' is not defined

In [21]:
vals

array([ 9,  5,  7, 31,  0, 23,  8,  0, 30,  9,  9,  3, 28,  7, 19, 25,  4,
       15, 30, 19, 24, 20, 25, 10,  9,  4, 29, 19, 22, 21,  8, 22, 12, 21,
       11, 17, 10, 21, 25, 23,  9,  2, 15, 20, 10,  8,  6, 17,  5, 26, 26,
       25, 26, 18, 29,  2, 19,  3,  0, 11, 30, 30, 13, 25,  0, 16, 24, 16,
        1,  9,  5, 10, 31,  3,  4,  5, 25, 23,  3,  4, 17, 13, 10,  5, 28,
        5, 10, 26, 12, 26,  6,  6,  3,  4,  0, 14, 10,  2, 15, 18])

In [22]:
def logpmf3(x: ArrayLike, n: ArrayLike, a, b) -> Array:
  r"""Multinomial log probability mass function.

  JAX implementation of :obj:`scipy.stats.multinomial` ``logpdf``.

  The multinomial probability distribution is given by

  .. math::

     f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}

  with :math:`n = \sum_i x_i`.

  Args:
    x: arraylike, value at which to evaluate the PMF
    n: arraylike, distribution shape parameter
    p: arraylike, distribution shape parameter

  Returns:
    array of logpmf values.

  See Also:
    :func:`jax.scipy.stats.multinomial.pmf`
  """
  logp = log_cdf_diff(edges[:-1], edges[1:], a, b)
  logprobs = gammaln(n + 1) + jnp.sum(x*logp - gammaln(x + 1), axis=-1)
  return logprobs

In [23]:
print(logpmf3(vals, jnp.sum(vals), a, b))

-59532.712870354604


In [24]:
vals = jnp.array(vals, dtype=jnp.float32)
print(vals)

[ 9.  5.  7. 31.  0. 23.  8.  0. 30.  9.  9.  3. 28.  7. 19. 25.  4. 15.
 30. 19. 24. 20. 25. 10.  9.  4. 29. 19. 22. 21.  8. 22. 12. 21. 11. 17.
 10. 21. 25. 23.  9.  2. 15. 20. 10.  8.  6. 17.  5. 26. 26. 25. 26. 18.
 29.  2. 19.  3.  0. 11. 30. 30. 13. 25.  0. 16. 24. 16.  1.  9.  5. 10.
 31.  3.  4.  5. 25. 23.  3.  4. 17. 13. 10.  5. 28.  5. 10. 26. 12. 26.
  6.  6.  3.  4.  0. 14. 10.  2. 15. 18.]


In [25]:
print(logpmf3(vals, jnp.sum(vals), a, b))

-59532.71282160684


In [29]:
g = jax.grad(logpmf3, argnums=2)

In [30]:
g(vals, jnp.sum(vals), a, b)

Array(nan, dtype=float64, weak_type=True)

In [28]:
a

3.0