diff --git a/notebooks/fast_gen_pois.ipynb b/notebooks/fast_gen_pois.ipynb new file mode 100644 index 0000000..7a3b6d8 --- /dev/null +++ b/notebooks/fast_gen_pois.ipynb @@ -0,0 +1,1857 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: warning: libm.so.6, needed by /usr/lib/x86_64-linux-gnu/libblas.so, not found (try using -rpath or -rpath-link)\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: /usr/lib/x86_64-linux-gnu/libblas.so: undefined reference to `cabs@GLIBC_2.2.5'\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: /usr/lib/x86_64-linux-gnu/libblas.so: undefined reference to `cabsf@GLIBC_2.2.5'\n", + "collect2: error: ld returned 1 exit status\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: cannot find -lcblas\n", + "collect2: error: ld returned 1 exit status\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: warning: libm.so.6, needed by /usr/lib/x86_64-linux-gnu/libblas.so, not found (try using -rpath or -rpath-link)\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: /usr/lib/x86_64-linux-gnu/libblas.so: undefined reference to `cabs@GLIBC_2.2.5'\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: /usr/lib/x86_64-linux-gnu/libblas.so: undefined reference to `cabsf@GLIBC_2.2.5'\n", + "collect2: error: ld returned 1 exit status\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: cannot find -lcblas\n", + "collect2: error: ld returned 1 exit status\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: warning: libm.so.6, needed by /usr/lib/x86_64-linux-gnu/libblas.so, not found (try using -rpath or -rpath-link)\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: /usr/lib/x86_64-linux-gnu/libblas.so: undefined reference to `cabs@GLIBC_2.2.5'\n", + "/home/ricardo/miniconda3/envs/aeppl/compiler_compat/ld: /usr/lib/x86_64-linux-gnu/libblas.so: undefined reference to `cabsf@GLIBC_2.2.5'\n", + "collect2: error: ld returned 1 exit status\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pymc as pm\n", + "from matplotlib import pyplot as plt\n", + "from scipy.special import gammaln\n", + "\n", + "\n", + "def _logpow(x, m):\n", + " \"\"\"\n", + " Calculates log(x**m) since m*log(x) will fail when m, x = 0.\n", + " \"\"\"\n", + " # return m * log(x)\n", + " return np.where(x == 0, np.where(m == 0, 0.0, -np.inf), m * np.log(x))\n", + "\n", + "\n", + "def _logprob(x, p, lam):\n", + " p_lam_x = p + lam * x\n", + " return np.where(\n", + " x >= 0,\n", + " np.log(p) + _logpow(p_lam_x, x - 1) - p_lam_x - gammaln(x + 1),\n", + " -np.inf,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def _rejection_region_monotonicity(rng, p, lam, dist_size, idxs_mask=None):\n", + " if idxs_mask is None:\n", + " idxs_mask = np.ones(dist_size, dtype=\"bool\")\n", + " p = np.broadcast_to(p, dist_size)[idxs_mask]\n", + " lam = np.broadcast_to(lam, dist_size)[idxs_mask]\n", + " dist_size = np.sum(idxs_mask)\n", + " p0 = np.exp(-p)\n", + " b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi)\n", + " x = np.full(dist_size, np.nan)\n", + " inds_to_sample = np.ones(dist_size, dtype=bool) # dummy boolean mask#u > p0 / (p0 + b)\n", + " counter = -1\n", + " while np.any(inds_to_sample):\n", + " counter += 1\n", + " u = rng.uniform(size=dist_size)\n", + " zero_xs = u <= p0 / (p0 + b)\n", + " x[inds_to_sample & zero_xs] = 0\n", + " inds_to_sample = inds_to_sample & ~zero_xs\n", + "\n", + " v = rng.uniform(size=dist_size)\n", + " w = rng.uniform(size=dist_size)\n", + " _x = np.floor(1 / w ** 2)\n", + " accepted = v * b * (1 / np.sqrt(_x) - 1 / np.sqrt(_x + 1)) <= np.exp(_logprob(_x, p, lam))\n", + " x[inds_to_sample & accepted] = _x[inds_to_sample & accepted]\n", + " inds_to_sample = inds_to_sample & ~accepted\n", + " # print(counter)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def _rejection_region_poisson(rng, p, lam, dist_size, idxs_mask=None):\n", + " if idxs_mask is None:\n", + " idxs_mask = np.ones(dist_size, dtype=\"bool\")\n", + " p = np.broadcast_to(p, dist_size)[idxs_mask]\n", + " lam = np.broadcast_to(lam, dist_size)[idxs_mask]\n", + " dist_size = np.sum(idxs_mask)\n", + "\n", + " eps = (1 - lam) / (2 + (p - lam) * (1 - lam)) ** (1 / 3)\n", + " delta = (1 - lam) ** (2 / 5) / (2 + (p - lam) * (1 - lam)) ** (1 / 3)\n", + " mu = (p - lam) / (1 - lam)\n", + " sigma = np.sqrt((1 + delta) * (p - lam) / (1 - lam - eps) / (1 - lam) ** 2)\n", + " psi = (\n", + " p * delta * (2 + delta - 2 * lam)\n", + " + (1 + delta) * (1 - lam) ** 2\n", + " - lam * (1 - lam + delta) ** 2\n", + " ) / (2 * (p - 1 - delta))\n", + " G = (\n", + " (p * (1 - lam - eps) * np.sqrt(1 + delta))\n", + " / ((p - lam) * (1 - lam) * (1 - eps) ** 2)\n", + " * np.exp(psi / (1 + delta))\n", + " )\n", + "\n", + " def g(x, G, mu, sigma):\n", + " return G / (np.sqrt(2 * np.pi) * sigma) * np.exp(-((x - mu) ** 2) / (2 * sigma ** 2))\n", + "\n", + " def h_r(x, p, lam, eps, mu):\n", + " return (\n", + " (p * (1 - lam - eps) ** 1.5)\n", + " / (np.sqrt(2 * np.pi) * (p - lam) ** 1.5)\n", + " * np.exp(\n", + " -(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (x - mu)\n", + " + 2 * (1 - lam)\n", + " )\n", + " )\n", + "\n", + " t_r = np.ceil((p - lam) / (1 - lam - eps) - 1)\n", + " H_r = (\n", + " (2 * p * (1 - lam - eps) ** 1.5 * np.exp(2 * (1 - lam)))\n", + " / (\n", + " np.sqrt(2 * np.pi)\n", + " * (p - lam) ** 1.5\n", + " * (1 - 2 * (1 - lam - eps) / (p - lam))\n", + " * eps\n", + " * (1 - lam)\n", + " )\n", + " * np.exp(-(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (t_r - mu))\n", + " )\n", + "\n", + " def h_l(x, p, lam, delta, mu):\n", + " return p / np.sqrt(2 * np.pi) * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu))\n", + "\n", + " t_l = np.ceil((p - lam) / (1 - lam + delta) - 1)\n", + " H_l = (\n", + " (2 * p * (1 + delta))\n", + " / (np.sqrt(2 * np.pi) * delta * (1 - lam))\n", + " * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (t_l + 1 - mu))\n", + " )\n", + "\n", + " x = np.zeros(dist_size)\n", + " inds_to_sample = np.arange(dist_size)\n", + " n_to_accept = np.zeros(dist_size)\n", + " counter = -1\n", + " while np.any(inds_to_sample):\n", + " counter += 1\n", + " _dist_size = len(inds_to_sample)\n", + " U = rng.uniform(size=_dist_size)\n", + " N = rng.normal(size=_dist_size)\n", + " V = rng.uniform(size=_dist_size)\n", + " E = rng.exponential(size=_dist_size)\n", + " _G = G[inds_to_sample]\n", + " _H_l = H_l[inds_to_sample]\n", + " _H_r = H_r[inds_to_sample]\n", + " _p = p[inds_to_sample]\n", + " _lam = lam[inds_to_sample]\n", + " _mu = mu[inds_to_sample]\n", + " _sigma = sigma[inds_to_sample]\n", + " _delta = delta[inds_to_sample]\n", + " _eps = eps[inds_to_sample]\n", + " _t_l = t_l[inds_to_sample]\n", + " _t_r = t_r[inds_to_sample]\n", + "\n", + " center = U < _G / (_G + _H_l + _H_r)\n", + " left = (U < (_G + _H_l) / (_G + _H_l + _H_r)) & ~center\n", + " raw_center_y = _mu + _sigma * N\n", + " raw_left_y = _t_l - 2 * E * (1 + _delta) / _delta / (1 - _lam)\n", + " raw_right_y = _t_r + 2 * E / ((1 - 2 * (1 - _lam - _eps) / (_p - _lam)) * _eps * (1 - _lam))\n", + " Y = np.where(\n", + " center,\n", + " np.where(\n", + " (raw_center_y >= _t_l) & (raw_center_y < _t_r),\n", + " raw_center_y,\n", + " np.nan,\n", + " ),\n", + " np.where(\n", + " left,\n", + " np.where(\n", + " raw_left_y >= 0,\n", + " raw_left_y,\n", + " np.nan,\n", + " ),\n", + " np.where(\n", + " raw_right_y >= 0,\n", + " raw_right_y,\n", + " np.nan,\n", + " ),\n", + " ),\n", + " )\n", + " X = np.floor(Y)\n", + " accepted = (\n", + " V\n", + " * np.where(\n", + " center,\n", + " g(Y, G=_G, mu=_mu, sigma=_sigma),\n", + " np.where(\n", + " left,\n", + " h_l(Y, p=_p, lam=_lam, delta=_delta, mu=_mu),\n", + " h_r(Y, p=_p, lam=_lam, eps=_eps, mu=_mu),\n", + " ),\n", + " )\n", + " <= np.exp(_logprob(X, _p, _lam))\n", + " )\n", + "\n", + " x[inds_to_sample[accepted]] = X[accepted]\n", + " n_to_accept[inds_to_sample[accepted]] = counter\n", + " inds_to_sample = inds_to_sample[~accepted]\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def _rejection_region_abel(rng, p, lam, dist_size, idxs_mask=None):\n", + " if idxs_mask is None:\n", + " idxs_mask = np.ones(dist_size, dtype=\"bool\")\n", + " p = np.broadcast_to(p, dist_size)[idxs_mask]\n", + " lam = np.broadcast_to(lam, dist_size)[idxs_mask]\n", + " dist_size = np.sum(idxs_mask)\n", + "\n", + " nu = 2 * (p ** 2 - lam * p - 3 * lam ** 2) / (3 * lam ** 2)\n", + " alpha = 0.2746244084 # Taken from page 259\n", + " t = np.floor(alpha * np.maximum(nu, 0))\n", + " problematic = (p < 1 + lam) | ((p * (1 - lam)) > (2 * lam))\n", + " t[problematic] = 0\n", + " # b = p * np.exp(np.maximum(1 - p, 0)) * np.sqrt(2 / np.pi)\n", + " b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi)\n", + " q_r = b / np.sqrt(t + 1)\n", + "\n", + " rho_t = ( # Taken from page 250\n", + " 1\n", + " - p\n", + " + np.log(p)\n", + " - 0.5 * np.log(2 * np.pi)\n", + " + (t - 1) * (np.log(lam * t + p) - np.log(t + 1))\n", + " - 1.5 * np.log(t + 1)\n", + " + (1 - lam) * t\n", + " )\n", + "# rho_t_prime = ( # Taken form page 271\n", + "# np.log(lam * t + p) \n", + "# - np.log(t + 1)\n", + "# + 1\n", + "# - lam\n", + "# + 0.5 / (t + 1)\n", + "# - (lam + p) / (lam * t + p)\n", + "# )\n", + " rho_t_prime = (\n", + " np.log(lam * t + p)\n", + " - np.log(t + 1)\n", + " + 1\n", + " - lam\n", + " - (t + 0.5) / (t + 1) ** 2\n", + " - (t - 1) * lam / (lam * t + p)\n", + " )\n", + " q = np.exp(-rho_t_prime)\n", + " q_l = np.where(t == 0, np.exp(-p), np.exp(rho_t) / (1 - q))\n", + "\n", + " x = np.zeros(dist_size)\n", + " n_to_accept = np.zeros(dist_size)\n", + " inds_to_sample = np.arange(dist_size)\n", + " counter = -1\n", + " while np.any(inds_to_sample):\n", + " counter += 1\n", + " _dist_size = len(inds_to_sample)\n", + " U = rng.uniform(size=_dist_size)\n", + " V = rng.uniform(size=_dist_size)\n", + " W = rng.uniform(size=_dist_size)\n", + " # E = rng.uniform(size=_dist_size)\n", + " E = rng.exponential(size=_dist_size)\n", + " _p = p[inds_to_sample]\n", + " _lam = lam[inds_to_sample]\n", + " _t = t[inds_to_sample]\n", + " _q = q[inds_to_sample]\n", + " _q_l = q_l[inds_to_sample]\n", + " _q_r = q_r[inds_to_sample]\n", + " _b = b[inds_to_sample]\n", + "# raw_left = np.where(_t == 0, 0, _t - np.floor(-E / np.log(1 - _q)))\n", + "# raw_left = np.where(_t == 0, 0, _t - np.floor(-E / np.log(_q)))\n", + " # raw_left = np.where(_t == 0, 0, _t + np.ceil(np.log(1 - E) / _q))\n", + " raw_left = np.where(_t == 0, 0, _t - np.floor(E / _q))\n", + " raw_right = np.floor((_t + 1) / W ** 2)\n", + " \n", + " left = U <= _q_l / (_q_l + _q_r)\n", + " accepted = np.where(\n", + " left,\n", + " np.where(\n", + " _t == 0,\n", + " True,\n", + " np.where(\n", + " raw_left < 0,\n", + " False,\n", + "# V * _q_l * _q ** (_t - raw_left) * (1 - _q)\n", + " V * _q_l * _q ** (_t - raw_left) * (1 - _q ** (_t + 1))\n", + " <= np.exp(_logprob(raw_left, _p, _lam)),\n", + " ),\n", + " ),\n", + " V * _b * (1 / np.sqrt(raw_right) - 1 / np.sqrt(raw_right + 1))\n", + " <= np.exp(_logprob(raw_right, _p, _lam)),\n", + " )\n", + " X = np.where(left, raw_left, raw_right)\n", + "\n", + " x[inds_to_sample[accepted]] = X[accepted]\n", + " n_to_accept[inds_to_sample[accepted]] = counter\n", + " inds_to_sample = inds_to_sample[~accepted]\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def _inverse_rng_fn(rng, theta, lam, dist_size):\n", + " log_u = np.log(rng.uniform(size=dist_size))\n", + " pos_lam = lam > 0\n", + " abs_log_lam = np.log(np.abs(lam))\n", + " theta_m_lam = theta - lam\n", + " log_s = -theta\n", + " log_p = log_s.copy()\n", + " x_ = 0\n", + " x = np.zeros(dist_size)\n", + " below_cutpoint = log_s < log_u\n", + " with np.errstate(divide=\"ignore\", invalid=\"ignore\"):\n", + " counter = 0\n", + " while np.any(below_cutpoint):\n", + " counter += 1\n", + " x_ += 1\n", + " x[below_cutpoint] += 1\n", + " log_c = np.log(theta_m_lam + lam * x_)\n", + " # Compute log(1 + lam / C)\n", + " log1p_lam_m_C = np.where(\n", + " pos_lam,\n", + " np.log1p(np.exp(abs_log_lam - log_c)),\n", + " pm.math.log1mexp_numpy(abs_log_lam - log_c, negative_input=True),\n", + " )\n", + " log_p = log_c + log1p_lam_m_C * (x_ - 1) + log_p - np.log(x_) - lam\n", + " log_s = np.logaddexp(log_s, log_p)\n", + " below_cutpoint = log_s < log_u\n", + "# print(counter)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def _branching_rng_fn(rng, theta, lam, dist_size, idxs_mask=None):\n", + " if idxs_mask is None:\n", + " idxs_mask = np.ones(dist_size, dtype=bool)\n", + " lam_ = np.abs(lam) # This algorithm is only valid for positive lam\n", + " y = rng.poisson(theta, size=dist_size)\n", + " x = y.copy()\n", + " higher_than_zero = y > 0\n", + " while np.any(higher_than_zero[idxs_mask]):\n", + " y = rng.poisson(lam_ * y)\n", + " x[higher_than_zero] = x[higher_than_zero] + y[higher_than_zero]\n", + " higher_than_zero = y > 0\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/1238690607.py:5: RuntimeWarning: divide by zero encountered in true_divide\n", + " poisson_idxs = np.broadcast_to(p >= np.maximum(3, 2 * lam / (1 - lam)), dist_size)\n", + "/tmp/ipykernel_21224/1238690607.py:7: RuntimeWarning: divide by zero encountered in true_divide\n", + " (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam))),\n" + ] + } + ], + "source": [ + "rng = np.random.default_rng(42)\n", + "p, lam = np.meshgrid(np.logspace(-2, 4, 50), np.linspace(0, 1, 50))\n", + "dist_size = (100, *p.shape)\n", + "monotonicity_idxs = np.broadcast_to(p <= np.exp(lam), dist_size)\n", + "poisson_idxs = np.broadcast_to(p >= np.maximum(3, 2 * lam / (1 - lam)), dist_size)\n", + "abel_idxs = np.broadcast_to(\n", + " (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam))),\n", + " dist_size,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 3.56 s, sys: 108 ms, total: 3.67 s\n", + "Wall time: 3.69 s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/2250734458.py:80: RuntimeWarning: overflow encountered in power\n", + " V * _q_l * _q ** (_t - raw_left) * (1 - _q ** (_t + 1))\n", + "/tmp/ipykernel_21224/2555885245.py:12: RuntimeWarning: invalid value encountered in log\n", + " return np.where(x == 0, np.where(m == 0, 0.0, -np.inf), m * np.log(x))\n" + ] + } + ], + "source": [ + "%%time\n", + "samples = np.full(dist_size, np.nan)\n", + "samples[monotonicity_idxs] = _rejection_region_monotonicity(\n", + " rng=rng, p=p, lam=lam, dist_size=dist_size, idxs_mask=monotonicity_idxs\n", + ")\n", + "samples[poisson_idxs] = _rejection_region_poisson(\n", + " rng=rng,\n", + " p=p,\n", + " lam=lam,\n", + " dist_size=dist_size,\n", + " idxs_mask=poisson_idxs,\n", + ")\n", + "samples[abel_idxs] = _rejection_region_abel(\n", + " rng=rng,\n", + " p=p,\n", + " lam=lam,\n", + " dist_size=dist_size,\n", + " idxs_mask=abel_idxs,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "c = np.zeros_like(p.flatten())\n", + "c[monotonicity_idxs[0].flatten()] = 0\n", + "c[poisson_idxs[0].flatten()] = 1\n", + "c[abel_idxs[0].flatten()] = 2" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/1687800886.py:2: RuntimeWarning: divide by zero encountered in true_divide\n", + " (p / (1 - lam)).flatten(),\n" + ] + }, + { + "data": { + "text/plain": [ + "Text(0.5, 0, 'Sample mean')" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.scatter(\n", + " (p / (1 - lam)).flatten(),\n", + " np.mean(samples, axis=0).flatten(),\n", + " c=c,\n", + " alpha=0.2,\n", + " cmap=\"jet\",\n", + ")\n", + "ax = plt.gca()\n", + "ax.set_xscale(\"log\")\n", + "ax.set_yscale(\"log\")\n", + "plt.plot(ax.get_xlim(), ax.get_ylim(), \"-k\", alpha=0.2)\n", + "plt.xlabel(\"Expected mean\")\n", + "plt.xlabel(\"Sample mean\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/816571136.py:2: RuntimeWarning: divide by zero encountered in true_divide\n", + " (p / (1 - lam) ** 3).flatten(),\n" + ] + }, + { + "data": { + "text/plain": [ + "Text(0.5, 0, 'Sample variance')" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.scatter(\n", + " (p / (1 - lam) ** 3).flatten(),\n", + " np.var(samples, axis=0).flatten(),\n", + " c=c,\n", + " alpha=0.2,\n", + " cmap=\"jet\",\n", + ")\n", + "ax = plt.gca()\n", + "ax.set_xscale(\"log\")\n", + "ax.set_yscale(\"log\")\n", + "\n", + "ax = plt.gca()\n", + "plt.plot(ax.get_xlim(), ax.get_ylim(), \"-k\", alpha=0.2)\n", + "plt.xlabel(\"Expected variance\")\n", + "plt.xlabel(\"Sample variance\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def normal_approx(x, p, lam):\n", + " mu = p / (1 - lam)\n", + " sigma = np.sqrt(p / (1 - lam) ** 3)\n", + " return 1 / np.sqrt(2 * np.pi) / sigma * np.exp(-0.5 * (x - mu) ** 2 / sigma ** 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def monotonicity_region_envelope(x, p, lam):\n", + " return p * np.exp(\n", + " 2\n", + " - lam\n", + " - np.minimum(lam, p) * np.sqrt(2 / np.pi) * ((1 / np.sqrt(x)) - (1 / np.sqrt(x + 1)))\n", + " ) + (x == 0) * np.exp(\n", + " -p\n", + " ) # Extra probability for x==0" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def poisson_region_envelope(x, p, lam):\n", + " eps = (1 - lam) / (2 + (p - lam) * (1 - lam)) ** (1 / 3)\n", + " delta = (1 - lam) ** (2 / 5) / (2 + (p - lam) * (1 - lam)) ** (1 / 3)\n", + " mu = (p - lam) / (1 - lam)\n", + " sigma = np.sqrt((1 + delta) * (p - lam) / (1 - lam - eps) / (1 - lam) ** 2)\n", + " psi = (\n", + " p * delta * (2 + delta - 2 * lam)\n", + " + (1 + delta) * (1 - lam) ** 2\n", + " - lam * (1 - lam + delta) ** 2\n", + " ) / (2 * (p - 1 - delta))\n", + " G = (\n", + " (p * (1 - lam - eps) * np.sqrt(1 + delta))\n", + " / ((p - lam) * (1 - lam) * (1 - eps) ** 2)\n", + " * np.exp(psi / (1 + delta))\n", + " )\n", + "\n", + " def g(x):\n", + " return G / (np.sqrt(2 * np.pi) * sigma) * np.exp(-((x - mu) ** 2) / (2 * sigma ** 2))\n", + "\n", + " h_r_norm = (p * (1 - lam - eps) ** 1.5) / (np.sqrt(2 * np.pi) * (p - lam) ** 1.5)\n", + " h_r_exp_A = -(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam)\n", + " h_r_exp_B = 2 * (1 - lam)\n", + "\n", + " def h_r(x):\n", + " return h_r_norm * np.exp(h_r_exp_A * (x - mu) + h_r_exp_B)\n", + "\n", + " t_r = np.ceil((p - lam) / (1 - lam - eps) - 1)\n", + " H_r = (\n", + " (2 * p * (1 - lam - eps) ** 1.5 * np.exp(2 * (1 - lam)))\n", + " / (\n", + " np.sqrt(2 * np.pi)\n", + " * (p - lam) ** 1.5\n", + " * (1 - 2 * (1 - lam - eps) / (p - lam))\n", + " * eps\n", + " * (1 - lam)\n", + " )\n", + " * np.exp(-(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (t_r - mu))\n", + " )\n", + "\n", + " def h_l(x):\n", + " return p / np.sqrt(2 * np.pi) * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu))\n", + "\n", + " t_l = np.ceil((p - lam) / (1 - lam + delta) - 1)\n", + " H_l = (\n", + " (2 * p * (1 + delta))\n", + " / (np.sqrt(2 * np.pi) * delta * (1 - lam))\n", + " * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (t_l + 1 - mu))\n", + " )\n", + " return np.where(x < t_l, h_l(x), np.where(x < t_r, g(x), h_r(x)))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAD8CAYAAACRkhiPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAABG/0lEQVR4nO3deVxU1fvA8c9hX1RQQEEQFRfccF9Kw9S0zCy/aWWWlmlamtpmtthiZr8s00qzzHJNs1XNSnPfV3AFwQWVFAUREJWdYc7vj0FEBUVlGGCe9+vFi7nnnnvv4zDeZ+45956jtNYIIYSwTjaWDkAIIYTlSBIQQggrJklACCGsmCQBIYSwYpIEhBDCikkSEEIIKyZJQAghrJgkASGEsGIWTQJKqQCl1Cyl1O+WjEMIIayVKu4nhpVSs4GeQLzWukm+8u7AV4At8IPWemK+db9rrR+72b49PT11rVq1ijVeIYQo73bv3p2gtfYqaJ2dGY43F/gamH+5QCllC0wHugExQIhSapnWOuJWdlyrVi1CQ0OLMVQhhCj/lFL/Fbau2JuDtNabgKRritsCUVrr41rrLOBnoFdR9qeUGqqUClVKhZ47d66YoxVCCOtWUn0CvsCpfMsxgK9SykMpNQNooZR6u6ANtdYztdattdatvbwKvJoRQghxm8zRHFRkWutE4EVLxiCEENaspK4ETgM18i375ZYJIYSwoJJKAiFAPaVUbaWUA/AksKyEji2EEKIQxZ4ElFKLgO1AoFIqRik1WGttAEYAK4FI4Fet9cHiPrYQQohbU+x9AlrrfoWULweWF/fxhBBC3D4ZNkKUSlprwk9fYN62aNKzciwdjhDllkXvDhIiP601B2IusDwsluXhsZxKSgegursz3RpVs3B0QpRPkgSERWmt2Xcq2XTiD4vjdHI6djaKDnU96dXMl6/XR2HIMVo6TCHKLUkCosQZjZq9p86zPCyOFWGxnLmQgb2tIrieF690rcf9jbxxc7HnyNlLfL0+CmPxDm8lhMhHkoAoEUajZvfJ8/xzIJZ/w+OIu5iBg60NHet7MvqBQO5rWA03Z/urtlG5vzWSBYQwF0kCwmxyjJqQ6CRWhMWyIjyO+EuZONjZ0Km+F28FNeC+hlWp6GRf6PZKmdKAXAkIYT6SBESxMuQY2RWdxPKwWP4NP0tCSiaOdjZ0DqxKj6Y+dGlQlQqORfvY2eReChT3cOdCiCskCYg7ZsgxsuN4EsvDY1kZHkdiahbO9rZ0aVCVB4O86RxYFdcinvjzs8m7EpAkIIS5SBIQtyU7x8j2Y4msCI9l5cGzJKVm4eJgOvE/FOTDvYFeuDjc2cdL5V0JFEPAQogCSRIQRZZlMLLtWALLw2JZFXGW5LRsXB1sua9hNXoE+dAp0Asne9tiO56N9AkIYXaSBMQNZRmMbIk6x/KwOFYdjONihoGKjnZ0bVSNB5t407F+8Z7487t8JSDNQUKYjyQBcZ2M7By2HDV9418deZZLGQYqOtnRrVE1Hgry4Z56njjamefEn9/lu4OkY1gI85EkIADTiX/jkXOsCItlTWQ8KZkG3Jzt6d7Ymx5BPnSo64mDXckONWUjfQJCmJ0kASuWkZ3DhsPx/BMWx7rIs6Rm5eDuYs9DQT70aOrD3QEeJX7iz0/6BIQwP0kCViYty8CGw+f4JyyW9YfiScvKoYqrA48096VHkDd3BXhgb1s6BpeVPgEhzE+SgBVIzTSw7lA8K8JjWX/oHOnZOXhWcODRFr70CPKhXe0q2JWSE39+CukTEMLcJAmUUymZBtZGnmV5WCwbDp8j02DEq6Ijj7Xyo0eQD21rV8H2cqN7KZXXJ2DZMIQo1yQJlCMXM7JzT/xxbDxyjiyDkWqVHOnX1p8Hm3jTulbpP/Hnl9cnIJ0CQpiNJIEy7kJ6NmsizrIiPJZNRxLIyjHiXcmJp9v581CQDy39K2NThk78+UnHsBDmJ0mgDEpOy2J1hKmpZ0tUAtk5Gl93Z565uyYPBvnQooZ7mT3xX0U6hoUwO0kCZcT51CxWRcSxPCyOrVEJGIwav8rOPNehNj2CfGjm55b3cFV5UR7ymBClnSSBUiwxJZNVud/4tx1LJMeo8a/iwuDg2jwU5EOQb/k78ecno4gKYX4WTQJKqf8BDwGVgFla61WWjKc0SEjJ5N/wOFaEx7LjeBI5Rk0tDxde6BhAjyAfGlevVK5P/PlJn4AQ5nfbSUApNRvoCcRrrZvkK+8OfAXYAj9orScWtg+t9VJgqVKqMvA5YJVJIP5SBivD4/gnLJZdJ5IwagjwdGXYvXXoEeRDQ5+KVnPiz08eFhPC/O7kSmAu8DUw/3KBUsoWmA50A2KAEKXUMkwJ4ZNrth+ktY7Pff1u7nZW4+zFDP7NPfGHRCehNdStWoERXerRI8ibwGrWeeLPT+YTEML8bjsJaK03KaVqXVPcFojSWh8HUEr9DPTSWn+C6arhKsp0lpsIrNBa7ynoOEqpocBQAH9//9sNt1SIvZDOijBTU0/of+fRGgKrVeTl++rxUJAP9apVtHSIpYqNjCKK1pq0rBwuZmSTkmHgUkY2KZcukXL+DCkXYslKSYLMSxgz0jAaDWTbKLLtbbBxcEI5u2Hr4oa9kzv2jhWp6OKJR4XKVHK2x83ZnopOdlRwtCuVT4uLklPcfQK+wKl8yzFAuxvUHwl0BdyUUnW11jOuraC1ngnMBGjdunWZOxucTk5nRVgsy8Ni2XMyGYAG3hV5rWt9HgzyoW7VCpYNsBQr730CWmsupGfzX2IaJ5PSOB2fQMbJXaiTITjHR1MhKZlKyem4XTRgZ9DYGsDeAB4G8LqF9yTbFrLsTD/ZdnDSGVJcFZdcbbnkYk+6iwsZFTzIcvODqoFU9GmBn0cNanq4UsvDheruzqVmPClR/CzaMay1ngpMtWQM5nAqKY0V4bEsD4tj36lkABr5VOKNBwJ5sIk3AV5y4i+Ky41hZb1PwGjUxJxP5/DZSxyOu8ixMwnknNlMtZhteF48Q6WLabhfyKFlElRMv7Jdlh3EVYYzngrtaIeNvT02DvbYODpi5+SMvbMz9k4u2Do5oZwcsbWxxc5gxCbbiDEzi5yMdIwZGaafrCxyMrPIyczEPiUTnwvZ1D1twDnDAKQDicARYB2ZdpBcARIrK6LcHDlX0Z1Ejzpc9OuMd41mNPCpRKB3RQKrVbytuaNF6VLcf8HTQI18y365ZeXeycQ0loebvvEfiLkAQJCvG292b8CDTbyp5elq4QjLnrLYJ2A0ao6dS2HfqWT2nUom4nQy2ed2Uz9tC7UuROOZkEajOI1P0pVtLrlAsocd55pU5IK/DxXrNqBqs2D8G9xNMxc388abnk72yaNcOL6PpBMRXDwTTWr8WbLOJeNxLpO6JzOwz4kD4oCtJLtCUmUbwiu58o+rHyer3YNt0D008qtMC393mtVwp5KTvVljFsWruJNACFBPKVUb08n/SeCpYj5GqRGdkJp34g8/fRGAZn5uvP1gA3oE+VCjiouFIyzblFIoVbr7BFIyDYREJxFyIon9MckcOJVMtcz9tE/bQO3zZ2h9LosaseCaaaqf5gzJ/q7E3lMTj3bBBAQ/TsWqvhaL38bZGcfAplQNbErVAtbr7GySwjZxKmQlCYfDSI+JxeFcJo2jLtEmKxKIJHXN95yo4cga71q8Vbkbzt5NaFHDnRb+lWlbuzJ1vCpY/U0Opdmd3CK6COgEeCqlYoAPtNazlFIjgJWY7giarbU+WCyRlhLHz6WwPCyWf8LiiIw1nfhb+LsztkdDHgzyxq+ynPiLk41SpapPIC3LQGj0ebYfT2T7sUTCTl/AUV+go/6XjokHeeJkKgGnwM4IRgVJXrZcauWNfdt7qNO1H6516pepE6Kyt8ej5X14tLzvqvKUpBMcCvmF/7ZvgPDT+B/LpMnRwzyuDvOfr+KUXwXWVmnJh/b3UalSZe4O8KB9HQ/uruOBfxWXMvUelHeqNH/Lulbr1q11aGhoiR83Kv4Sy8PiWB4Wy6G4SwC0qlmZHkE+dG/ija+7c4nHZC3qvrOcoR0DGNO9gUWOr7XmeEIq6w/Fs+5QPCHRSWTnaKraxtI75x+qnzqG34nsvOadRA8bsoKq49vtf9S5/xnsKlrHHV9ZqeeJXPEdZ9euxOFgPNXijQCcc4PTteyJql6Lvx17chZffN2d6RToRZcGVWlfxxNnB/PPV23tlFK7tdatC1wnSeB6WmuOxqfwz4FYVoTHcuRsCkpBm5pVeDDImweb+ODt5mT2OATUH7uCwcG1ebMEk0B2jpEdxxNZGxnP+sPx/JeYBkAb93juP/cHHlHR1IzOwSUTDLaQVNuZCne1IvCp16kQYJlkVdqcP7KHiN++Im3nAaoez8DBAClOcDLAhnN1arHU8SEOGWriaGdDh7qedGlQlS4NqlJdvlCZxY2SgHTt59JacyjuEivCYvknLJZj51JRCtrWqsL4Xo15oLE31SrJib+kKVUydwdl5xjZGpXA8rBYVkWcJTktG0c7G+7xt+f5S4tw37ufWscM2Bnhoiucb+KOy30P0Ojx17G1km/7t6Jy/ZZ0GDsPAENqChG/Teb8yhXUirhAo4jjtLOfxqnaNlyqX4dlJ/vw7iHTc6Mt/N1Nc1wH+UhCKCFWfSWgtSYi9iLLw2JZERbH8YRUbBTcFeDBg0E+PNC4GlUryonfkhq+9y8D7q7JOz0aFvu+c4yaLVEJ/LX/DKsjznIhPZuKjnZ0q1+F9onLcNz6L76RKThnwQVXSA5yo26vx6nT61WUjdw3fztysjIJXzKNM8uXUCU8iUqppucYEmrakd64KQurPsG2RNN308tNrg8FyZX3nZLmoHy01oSfvsjy8FhWhMUSnZiGrY3i7gAPHgzy5oHG3nhWcCymiMWdavT+vzzV1p93ezYqtn0ejrvE4j0xLNl7mvhLmaYTf8Oq3K+icVz3PRVCTlAhDVId4WwDB/y6dSZowCfYOMo30+KUnZ1J6JLJxKz4k2oHL+J1ETLt4UJQRZLueZKZWa2JjLuEUnBPXU8ea+XHA429cbKXPoRbZfVJQGvNgZgLLA+LZXl4LKeS0rG1UbSv48FDQT7c39ibKq4OZohY3KkmH6ykb5savHeHSeB8ahZL953mjz0xhJ++iJ2NolNgVfo0q0a10J/InvMTlRKzyLKDqACo3K4O7QdOxMm3yc13Lu5YcloSm359n5Tlm2gckY2DAZJ8FJU6tmJb8+eZfziH08npVHS0o2czHx5r5UdL/8pyl1ERWX0S2HPyPL2/2YadjeKeep70aOJDt0bVqCwn/lIvaNxKHmvlxwcPN77lbbXW7DmZzMId//F3WCxZBiONq1eiT0s/ejT24Oivk1DzluCRlEO0NyQ0t6XdIwOoHzwK7OVbvyVordl/eC37vp+A346z+CZChqPG2NILu0eGMS+7ESvC4kjPzqGOlysD7qpJ71Z+8oDaTVh9EjAaNcv2n6FzYFXcXOTDUpY0+3AVj7bwZdwjRU8CqZkGlu47zYIdJ4mMvUgFRzsebeHLU+388a9iw4ZZY3H5aQ3eCUbOeGoy763CvQPG4hrY48pjysLizqXGs2bRBxiWb6b5oRzsjJDuZ4dvnz5sajmABaFx7D+VjIuDLb2a+zLgrpo0ql7J0mGXSlafBETZ1Xz8Kh5pVp3xvW7eLHMyMY3ZW0/w++4YUjINNPSpRP+7/OnV3BeDTmHl7Ldx/3kj/vGahCoam07etHvhU+xq3miMQ2Fp2cZsVu2Yw+GF39FidxreyZDhDr6PduJs73f4cX8Sf+47Q6bBSOualXmmfS16NPGW0VHzkSQgyqyWH62mR5A3E/4XVGidvSfP88PmE6wIj8XWRtGzaXX631WTlv7uXMi8wPIf38f9pzXUjtVccNO4dqlBk1emoarJPf1liVEb2XRiNVvmT6DF+gQCzkJmBY3Pg61wfGE8v59QLNjxH9GJafi6OzP4ntr0bVNDBrlDkoAow1pPWM0Djb35+NGrk0COUbMm8iw/bD5OSPR5KjrZ0f+umgxsX4tqlZxIzkhm6Z+f4P7DXwSe0lyqqKnUyYfAV6ehqktnb1m3JzaUVfPeIXD1KRqchmwXTbX7m+Lx8kTWn3dl5qZjhESfp9Llz0WHWlZ9u7ckAVFmtfl4DV0bVuOT3qYkYMgx8teBM0xbF8Xxc6l53/ieaFODCo52pBvS+W39NIzfzqddRA7pzhqXjh4EvjoFVUuafcqbPXG7+XPRuzRa+R9NozUGR03VzoF4vvUl+7PcmbnxOCsj4rC3saFPK1+Gd6prlQM7ShIQZVbbj9fQpUFVPvpfE5buPc309VFEJ6bRwLsiwzvXzWv7NRgNLNu3iJivp9BpZwYojUMLe+qN/j9smz1s6X+GMLOdsTv547cPCFp1ktZRGqO9xqtTIB5vTuKkky8/bD7Ob6ExGLWmT0s/XupcF3+PK8kg05DDiJ/24lnBgc6BVelQ17NcNSNJEhBl1l3/txZ3F3vSsnI4mZRGI59KjLqvHvc3qoaNjUJrzfoTawj5Zhyd1iZRMV2TUx8ajHgJ+67DwUYeLLIWWms2n97MoiXjaLM2jrsOGcEePLs1x+PtyZxz8GDGxmP8tOskOUZN7xa+jOhSl5oerhw9e4luX2zCRplmsnOwtaFdQJW8MY1qepTt+UAkCYgyK/izdZxKSqepnxujutTjvoZV8x4QOpp0lD9+eIM2Sw5T/Txk+mQT+GRnnPtPBFcPC0cuLMVgNLAkagmLV3xOt/WXuPuQBnuNZ/fWVHlnKgm2Lny74RiLdp3EYNQ82sKXtrWqMOaPA8wf1BZbG2UaNfZwPMfPpQIQ4OlK59yE0KZWFRzsytadR5IERJm1NSoBg1HTsZ5n3sn/YtZF5v/7KdWnLaHhKU2Gew4BXX2o9PL3KK/6Fo5YlBap2anMCpvFuvWzeXRTFm0Pa2wcoEqvzlQZ/QmJypEZG4+zYOd/ZBlMQ1+vee3eq+b9/i8xlXW5w4jvPJ5EVo6RCo523FPXk84NvOgcWJWqZWBgSUkColwwaiNLo5ay6tdPGfzbRey1xquVgerD3kW1GiBNP6JAcalxTA6dTOTOFQzYkkPTI2DjoKjSrzdVRrxJvNGOqWujiIy9yM9D7yp0bKK0LANboxJZdyie9YfiibuYAUAT30p0CaxK5wZVaebnjo1N6XvgUJKAKPMOJx3mw20fUH1lGM+tMWLjZqDOkPY4PPUVOLtbOjxRBuyI3cHHOz5GHz3B8I0Gah9T2DjZUnX067g/PfCWxiHSWhMZe4n1h00JYc/J8xg1eLg6cG99Lzo3qErH+l64OZeOEQokCYgyK8OQwYz9M/jxwByGrsqh494cKtRUVP/sM2yb9bR0eKKMyc7JZl7EPGbun4H/6UxGr8zA7YwdLg198fniexxq1b6t/Z5PzWLT0XOsOxTPxiPnSE7LxtZG0apmZbo0qErnwKrUr2a5uZYlCYgyaUfsDsZvH0/y2ZP83+85VDsNHl1q4zXlV5RThZvvQIhCxKbEMnHXRNafXEv/vYqH12WhtA1Vh/Sj8oixKNvbb1rMMWr2nTqf25dwLm8ucl93Zzo3ME2reXdAyU6rKUlAlCkXMi8wKWQSfx77kzaJjrz+00Vs023xeeUZ3Ia8benwRDmyKnoVH+/8GLv484z7JxvPaHCuURGfKV/jGNS2WI4ReyGd9YfOsf5wPFujEkjLysHRzoa763jkXSWY+wE2SQKizNgUs4kPtn1AcsZ5Rkfk0OrvLGycnajxzQyc29xt6fBEOZSckcyk0Eksi/qTPhH2PLEyDZtshefjnfAYOxXlUHxDzmcacth5PMnUuZxv/up6VSuYEkKDqrSqWRn7Yh78TpKAKPVSslKYFDqJxUcXU8++Mh//fQL2uOJU2we/2T9h7+Nj6RBFObc5ZjMfbv+QzHNnGf+vxvtIDk7V7PGZMAGn4EeK/Xhaa04kpOYlhF0nksjO0VR0sqNjPVPncqdAr2KZ6bDUJgGlVCfgI+Ag8LPWesON6ksSKJ9C4kJ4d8u7xKXG8WJKJXouO0nKfy5U7NqJ6pOmYOMsE7yIkpGSlcLnoZ/zx9E/6B3lwpN/JUOmwuuRVnhMmIWyN98zAZcystkalZCbFM5x7lImSkFTP3c6B3rRtWE1mvi63da+zZIElFKzgZ5AvNa6Sb7y7sBXgC3wg9Z64g32cS/wFnAWmKC1jrrRMSUJlC8Zhgy+2vMVCyIX0Dy9Cm+ujMPusEJjg+ewF/EcMVKmDxQWse7kOj7c/iEkX+T/VkKViAyqdnDC4/1voKb5myWNRk1E7MW8B9X2xyTTObAqswe2ua39mSsJdARSgPmXk4BSyhY4AnQDYoAQoB+mhPDJNbsYBCRorY1KqWrAFK310zc6piSB8uNY8jHe2PQGaUeP8OpuN/z3JqIUuHXvhMdr7+Lg52fpEIWVS0hP4INtH7D51EY++tuJ+uGp+LZPotJTw+Det8CMVwXXSkzJ5EJ6NgFet3dX3I2SwG0Pk6e13qSUqnVNcVsgSmt9PPfAPwO9tNafYLpqKMx5oMCGL6XUUGAogL+//+2GK0oJrTWLjy7mp2Uf03tbDi0jcrCxTaByC1eqfDgb+3rNLR2iEAB4OnvydZev+f3o73ymPuPNC7bonVWo5TQdl0P/wCNfg3/JDE/uUcERj2LoGyhIcY+V6gucyrccAxT6LimlegMPAO7A1wXV0VrPBGaC6UqguAIVJe9S1iW+XfAq1f/YxoRjGuUAVRpdokq/x7F7dCLYmedDLsTtUkrxeP3HaevdlvcdX+WZqZFkb69GXbd4XGffDx3HwL1vgm3ZHXbaopFrrRcDiy0ZgygZYefCCBs+kIfD0sh2dcCz+SWq1EnG9um50Kj477wQojjVrFSTmY8v4huXD7h73J/s3+hGjWeCqLHpMzjyL/T5AbwCLR3mbSnu8VBPAzXyLfvllgkrpbVmYeRCnlnxDMdqOmJ4NIgm3f/Dq60ztsPXSAIQZYajrSOvPjKRjE9ew+ViNgd+jWZd68GQcASmt4Vd31s6xNtS3EkgBKinlKqtlHIAngSWFfMxRBmRlp3Gm5vfZOKuiQT73sOoDk0IclyJTeC98OJm8G1p6RCFuGWdHhiC28QPqRWbQ8x3//Jp8GCya7SD5aNhwWNwKc7SId6S204CSqlFwHYgUCkVo5QarLU2ACOAlUAk8KvW+mDxhCrKkugL0Ty9/GlWRq/k5caD+DJqP277f4EmfaDfL+BSxdIhCnHbavd8gqrvvE2boxqH75Yy1MebxHYvQNRq01XB0TWWDrHIbjsJaK37aa19tNb2Wms/rfWs3PLlWuv6Wus6WuuPiy9UUVas/W8t/f7pR0J6At+2+5Dnt87DJnY/dPsI+swq0VvrhDAXrwHPUOW55+i+R+O3fD99U/YQ/sT34FwZFvaBVe+CIdPSYd5U2e3SFqVOjjGHaXunMSt8Fo09GvNFvf74LB4OWSnQfzHUvc/SIQpRrKq+MZrsM2d4auVKMj0zeTbjU97tOppH9y2DbdPg2Hp4fC541rN0qIUqWxNlilIrJSuFl9e/zKzwWfSp14d5Nf6Hz6KnQQHPLZcEIMolZWND9c8+xbllS55bfJGeqXV5P/QzJgQ0JbvHZ6ZO4x/ugyMrLR1qoSQJiDt26tIpBqwYwJbTWxjbbizjnAJwXPICeDWEYdvB/y5LhyiE2dg4OuI3/WscfHwYeiaQ55o8xy9HfuX5xG2cf+5vcKgAPz0Bf70CWWmWDvc6MoqouCMhcSG8tuE1cnQOk++dzN0Rq2DLFFMCeG65dAALq2FITMS2cmWUjQ3Ljy/nva3vUc21Gl/f8ykBG6fAwcVQvSU8+RNUKtlRcW80bIRcCYjb9seRPxi6aijuju4s6vETd4csMCWAul1hyDpJAMKq2Hl4oGxMp9QeAT2Y3X02qdmp9F/zAtvbD4WHv4Ize+GbdnBsnYWjvUKSgLhlOcYcPgv5jHHbx9HOpx0LH/yRmqs/gr0LoOWz0O9ncDDvTElClHbNvJqx6KFFeFfwZtjaYfxawcV0dWzvAj8+CpsmWTpEQJKAuEUZhgxGbxzNjxE/8nTDp/m681dU+nUghP8OzftDzy/B1t7SYQpRKlSvUJ0fH/yRDr4d+GjHR3x6djM5L26Geg/AugmwqB9kXLRojJIERJGdzzjPkFVDWHtyLWPajOGtVq9j9/PTcHw9BI+GXl+DjXykhMjP1d6VqZ2nMqDRABZELuDVnR+R/tgsaDsUDi+Hb9tD7AGLxSf/Y0WRXL4DKCIxgsmdJjOgbm/TrW9Rq6H9SOjyLsgEMEIUyNbGljFtxvB227fZcGoDQ9YOI7nLWHh8HqQmwNyHLPaUsSQBcVPhCeH0X96f5MxkfnjgB7pVvwd+6Aax+6HrOLh/giQAIYrgqYZPMbnTZCITI3nm32c4U7MtDN8Gtg6mp4zXfwIlfMemJAFxQxtPbWTQykE42znz44M/0sI90PStJf4g9PwC7nnV0iEKUaZ0q9mNmffPJCE9gf7L+3OILBi1B2p3hI0T4Zf+YMwpsXgkCYhCLTu2jJfXv0yAWwALeiygtqMHzLwXTu+GbuOh9SBLhyhEmdSqWivmd5+PjbJh4L8D2XE+EvovgQ6vwKG/YUYwXCiZUfglCYgCLYhYwNgtY2nj3YbZD8zG087V1AeQcAR6fA4dXrZ0iEKUaXUr12VBjwX4uPowbM0wVp5aC90+hO4T4VykqcP4v21mj0OSgLiK1prp+6bzacindPXvyvT7puOCgvn/MyWAR6ZB2yGWDlOIcsHb1Zt5D84jyDOIMZvGsOToErhrGDz7N6Dhx94Q+bdZY5AkIPIYtZGJuyYyY/8M/lf3f0y6dxIOWpv6AE7tgPs+gJbPWDpMIcqVSg6VmNF1Bnf53MX7295nQcQCqNUBhu8EZ3f45WnY9LnZji9JQACQbcxm7Jax/HToJ55p9Azj24/HzmiEWfeb+gC6fgjBr1k6TCHKJRd7F6Z1mUZX/658GvIp3+7/Fl3RG17aCTU7wLqPYM+PZjm2JAFBZk4mr61/jb+P/83IFiMZ3Xo0ypgDfwyG2H2m20DvecXCUQpRvjnYOjDp3kk8UucRvtn3DZNDJ6MdK8EzyyCgM0RvNstxZVIZK5dhyOCV9a+w9cxW3mn3Dv0a9DPdnvbHYIhcBm2GyG2gQpQQOxs7PurwEa72rsyLmEdKdgrv3fUetk8uNI05ZI5jmmWvokxIN6Qzat0odsbu5MP2H9K7Xm/TgypLXoCIpab2/4fM1xYphLiejbLh7bZvU8G+At+HfU9GTgYTOkzAzkwPZEoSsFJp2WmMWDeC0LhQPurwEb3q9jKt2DwZwn6Dpn3h4amWDVIIK6WUYlTLUTjbOTN171RyjDn8X/D/YW9T/IMzShKwQqnZqQxfM5x95/bxSfAnPBTwkGnF9m9MHVD+d8P/ZshQEEJY2JCmQ7C3sWfy7sloNJM6TkIV8/9LSQJWJiUrhWFrhhGWEManHT+le63uphV75sPKt8GjHgxYKqOBClFKDGwyEHtbezycPIo9AYCFk4BSyh+YCiQBR7TWEy0ZT3l3KesSL65+kYjECD6/93O61uxqWnF6NywbBW7+MHQD2DtZNE4hxNWebvi02fZ921/3lFKzlVLxSqnwa8q7K6UOK6WilFJv3WQ3QcDvWutBQIvbjUXcXFp2GsPXDM8bCjovASQchTkPgasXvLARHCtYNlAhRIm6kyuBucDXwPzLBUopW2A60A2IAUKUUssAW+CTa7YfBOwAfldKDQLM8ySEIN2Qzoh1IwhLCGPSvZPo4t/FtOLiGdN4QMZs6LdI5gQWwgrddhLQWm9SStW6prgtEKW1Pg6glPoZ6KW1/gToee0+lFKjgQ9y9/U7MKeAOkOBoQD+/v63G67VyszJ5JX1rxAaF8rE4Il0q9nNtCI9Gb7vAhkXYMAS8Gtt0TiFEJZR3L1/vsCpfMsxuWWF+RcYpZSaAUQXVEFrPVNr3Vpr3drLy6vYArUG2TnZvL7hdbad2cb4DuPpEdDDtCInG5YOg0ux0GcW1Oli2UCFEBZj0Y5hrXU48JglYyivDEYDYzaNYWPMRt676z3+V/d/phVGI/z9imlu0w4vQ5C8/UJYs+K+EjgN1Mi37JdbJkpQjjGHd7a8w5qTaxjTZgxPBD5xZeWGT2DvAmj0P9PEMEIIq1bcSSAEqKeUqq2UcgCeBJYV8zHEDWit+Xjnx6w4sYKXW77MgEYDrqyMWgObPgO/NvD4XIvFKIQoPe7kFtFFwHYgUCkVo5QarLU2ACOAlUAk8KvW+mDxhCqKYtreafx25DcGNxnM80HPX1kRHwk/Pw3u/qYJK+RpYCEEd3Z3UL9CypcDy287InHb5h+cz/dh39OnXh9ebplv+sdLcTC/l6lD+Ikf5WEwIUQeGTainFh2bBmTQifRrWY33rvrvSuPl2emmKaoSzkLT/0G1ZtbNE4hROkiA8SUA+tPruf9re/TzqcdE4MnYmtje2XlugkQf9A0OXz9+y0XpBCiVJIkUMaFxIUweuNoGlZpyFedv8LB1uHKyt3zYOe30KSPTA4vhCiQJIEyLDIxklHrRuFb0Zdvun6Dq73rlZXRW+GvUeBRF3pNt1yQQohSTZJAGXUm5QzD1w6ngkMFZnabSWWnyldWJh6D+Y+Ao5vpTiB7Z8sFKoQo1aRjuAy6kHmBYWuGkWnI5PsHv8fb1fvKyhwDLB0ORgM8swQq+VguUCFEqSdJoIzJysnilfWvcOrSKb7r9h11K9e9usI/r8GpHdD1Q/BtZZkghRBlhiSBMsSojby75V1Cz5pGBG3j3ebqCiGzYM88aNAT7nnFIjEKIcoW6RMoQ77c8yUrok3DQeTNC3zZyZ2mq4AqdWRICCFEkUkSKCMWHVrEnPA59A3sy+Amg69emZUKS14Ahwow8B+wtbdMkEKIMkeag8qA9SfXM3HXRO71u5e32r519WTTxhz4pT+cPwG9v5eOYCHELZErgVIuIjGCNze/SaMqjfis42fY2VyTtzd+CsfWQdsXoOkTBe9ECCEKIUmgFItPi2fk2pG4Obox7b5puNi7XF3h2DpTEqjeEh781DJBCiHKNGkOKqXSDemMXDeSS9mX+PHBH/F09ry6QloSLBlmeiCs7wIZGloIcVskCZRCRm1k7JaxRCZGMrXLVAKrBF5TwQh/PA8pcaY5gt1uNI2zEEIUTpqDSqHp+6az+r/VvNbqNTrV6HR9he1fw7G10HqQzBEshLgjkgRKmX+O/8PMAzN5tO6jPNv42esrnNgEq9+Dak3ggU9KPkAhRLkiSaAU2Re/j/e3vk/raq2vnhjmsux0WDYK7JzhyYUyQ5gQ4o5Jn0ApEZsSy8vrX6aaazW+6PQF9gU98LX8DdPzAI9Mg8q1SjxGIUT5I0mgFEg3pPPy+pfJyslizgNzcHdyv77S3oWw90eo/yC0GFDiMQohyidJAhamtebD7R9yKOkQ07pMI8A94PpKcWHw53Co6AOPfiu3gwohik2J9QkopQKUUrOUUr/fqMza/BjxI/8c/4eXmr/EvTXuLbjSmg9Nv59cCM6VC64jhBC3oUhJQCk1WykVr5QKv6a8u1LqsFIqSin11o32obU+rrUefLMya7IjdgdTdk/hPv/7GNK0kDmAN30OUavhvvdlfgAhRLEranPQXOBrYP7lAqWULTAd6AbEACFKqWWALXDtvYuDtNbxdxxtOXI65TRvbHyDWpVq8fE9H2OjCsjH/22DdR+BZ31oN6zkgxRClHtFSgJa601KqVrXFLcForTWxwGUUj8DvbTWnwA9izXKcibdkM7L614mx5jDV12+unqC+MtSE+GX3A7gvgvBweX6OkIIcYfupE/AFziVbzkmt6xASikPpdQMoIVS6u3CygrYbqhSKlQpFXru3Lk7CLd00FrzwdYPOHL+CJ92/JSalWoWXHH9x5CWYBoe2qt+yQYphLAaJXZ3kNY6EXjxZmUFbDcTmAnQunVrbbYAS8i8g/PyZgcL9gsuuNKx9RA6C1o+K8NDCyHM6k6uBE4DNfIt++WWiULsit3FF3u+oFvNbtfPDnZZWhL8NtD0Ovj1EotNCGGd7iQJhAD1lFK1lVIOwJPAsuIJq/yJT4vnjU1vULNSTT7q8NH1Q0Jc9u9bkJEMT8yHyoU0FQkhRDEp6i2ii4DtQKBSKkYpNVhrbQBGACuBSOBXrfVB84VadmUbs3lj4xukG9KZcu+UgjuCwdQMdOAXaNYPGvUq2SCFEFapqHcH9SukfDmwvFgjKoem7pnKnvg9TAyeSN3KdQuuZMiExUNNr7u8V3LBCSGsmowiamZr/1vL3INz6RvYl4cCHiq84u+DIDUeHpstk8QIIUqMJAEzOnnxJO9ufZcmHk0Y02ZM4RX3/wyH/obAHtC4d8kFKISwepIEzCTDkMFrG17DRtkwudNkHGwdCq6Ykw3rJoBjJXjkaxkcTghRomQUUTP5v53/x+Hzh5l+33SqV6heeMU/R8CFU9DvF3D1KLkAhRACuRIwiyVHl7AkaglDmw6lo1/HwiseWQkHfgbf1lDv/pILUAghckkSKGZR56P4v53/RzvvdgxvNrzwillpsDR3ULi+C8BG/hRCiJInZ55ilG5I541Nb+Bi78LEjhOxtbEtvPKfL0FaomlwuEo+JRekEELkI30CxeizkM+ISo7iu67f4ensWXjFc0fg4GKo9wA0uMFto0IIYWZyJVBM/o3+l9+P/M6gJoNo79u+8IpGIyzoDcoGHvpc7gYSQliUJIFicOrSKT7c9iFNvZoyosWIG1de/Z7pbqDg0eDuXzIBCiFEISQJ3KHsnGzGbByDUorPOn6GvY194ZXP7IXtX0OVALj3Bg+PCSFECZE+gTs0de9UwhPDmdJpCr4VbjDcg9bw92um1/1+AdsbJAshhCghciVwBzbHbM4bF6hbzW43rnzgVzizB+6fIDOFCSFKDUkCtyk+LZ6xW8ZSv3J93mjzxo0rZ6XC8jfA3gVaDyqZAIUQogikOeg2GLWRsVvGkpGTwaSOk3C0dbzxBr8MgMwL8OQicChkLgEhSoHs7GxiYmLIyMiwdCjiNjg5OeHn54e9fdGbmyUJ3IYFEQvYEbuD9+9+nwD3gBtXjlgGx9ZCnfsg8MGSCVCI2xQTE0PFihWpVatW4bPfiVJJa01iYiIxMTHUrl27yNtJc9AtOpx0mC/3fEnnGp15rN5jN65szIGVY02vH/1OngkQpV5GRgYeHh6SAMogpRQeHh63fBUnSeAWZOZk8tbmt3BzdGNc+3E3/4+y5Qu4cBKe+BEqeJVMkELcIUkAZdft/O2kOegWfLn7S6KSo/i267dUcapy48qpibDxM3DxME0WI4QQpZBcCRTRttPbWBC5gKcaPMU9vvfcfIOFfSAnE/r/AbaSa4UoKltbW5o3b573M3HixGLd/9y5cxkx4iZP9lsROTsVwfmM84zdOpa67nV5tdWrN99g70LT08FNHoPqLcwfoBDliLOzM/v27bN0GFZDksBNaK0Zt20cFzIvMKPrDJzsnG68gSEL1n5oet1zivkDFMJMPvzrIBFnLhbrPhtVr8QHDze+rW1r1arFs88+y19//UV2dja//fYb9evXJyAggH379uHu7g5AvXr12LJlCzY2Nrz44oucPHkSgC+//JIOHTpctc/o6GgGDRpEQkICXl5ezJkzB39/fwYOHIiTkxOhoaFcvHiRKVOm0LNnT3JycnjrrbfYsGEDmZmZvPTSS7zwwgt39J5YWok1BymlApRSs5RSv19T7qqUClVK9SypWG7F4qOLWXdqHaNajCKwSuDNN9j6FaSchad+Ayc38wcoRDmTnp5+VXPQL7/8krfO09OTPXv2MGzYMD7//HNsbGzo1asXS5YsAWDnzp3UrFmTatWq8fLLL/Pqq68SEhLCH3/8wfPPP3/dsUaOHMmzzz7LgQMHePrppxk1alTeuujoaHbt2sU///zDiy++SEZGBrNmzcLNzY2QkBBCQkL4/vvvOXHihPnfFDMq0pWAUmo20BOI11o3yVfeHfgKsAV+0FoX2nintT4ODL42CQBvAr/eauAl4eTFk3wa8intvNvxTONnbr5BSjxsnAgVq0PdruYPUAgzut1v7HfqRs1BvXv3BqBVq1YsXrwYgL59+zJ+/Hiee+45fv75Z/r27QvAmjVriIiIyNv24sWLpKSkXLW/7du35+1nwIABjBlzZWDHJ554AhsbG+rVq0dAQACHDh1i1apVHDhwgN9/N53GLly4wNGjR2/pvvzSpqjNQXOBr4H5lwuUUrbAdKAbEAOEKKWWYUoIn1yz/SCtdfy1O1VKdQMigJu0sZS8HGMO7259Fztlx4R7JmCjinDRtOhJMBpkukghzMTR0fR0vq2tLQaDAYC7776bqKgozp07x9KlS3n33XcBMBqN7NixAyen2zu9XHu7pVIKrTXTpk3jgQceuIN/RelSpDOV1noTkHRNcVsgSmt9XGudBfwM9NJah2mte17zc10CyNUJuAt4ChiiVFHOtCXjx4gf2Ru/l7fbvY23q/fNNwj7HU7vhoaPgF8r8wcohABMJ+dHH32U1157jYYNG+Lh4QHA/fffz7Rp0/LqFXR10b59e37++WcAFi5cSHBwcN663377DaPRyLFjxzh+/DiBgYE88MADfPvtt2RnZwNw5MgRUlNTzfivM7876Rj2BU7lW44B2hVWWSnlAXwMtFBKva21/kRrPTZ33UAgQWttLGC7ocBQAH//kpmE5VjyMabtnUaXGl3oGVCEroqc7CtPBvf62rzBCVHOXe4TuKx79+43vU20b9++tGnThrlz5+aVTZ06lZdeeommTZtiMBjo2LEjM2bMuGq7adOm8dxzzzFp0qS8juHL/P39adu2LRcvXmTGjBk4OTnx/PPPEx0dTcuWLdFa4+XlxdKlS4vjn20xSmtdtIpK1QL+vtwnoJR6DOiutX4+d3kA0E5rbbYbcFu3bq1DQ0PNtXsAso3ZDFg+gDMpZ1jca/GN5wq+bM98WDYS+syCoJsMJSFEKRYZGUnDhg0tHYbFDRw4kJ49e/LYY2Xv/3NBf0Ol1G6tdeuC6t/JlcBpoEa+Zb/csjJtVtgsDiYeZPK9k4uWANKSTAmgki80ftT8AQohRDG6kyQQAtRTStXGdPJ/ElPbfpkVmRjJd/u/o0ftHtxf6/6ibbRspOn3g5+Bja35ghNClJj8zUrlXZE6YpVSi4DtQKBSKkYpNVhrbQBGACuBSOBXrfVB84VqXlk5Wbyz5R0qO1XmnXbvFG2jM/vg0N/g1wYalsrHHIQQ4oaKdCWgte5XSPlyYHmxRmQh3+z7hqjkKKbfNx03xyI85KU1LOoHyhb6LjR/gEIIYQal5pZMS9oXv485B+fQp14fOvp1LNpGJzbBpTNwzytQsZpZ4xNCCHOx+iSQbkjn3a3v4u3izejWo4u2UXa6acpIJ3foeJP5hYUQohSz+iQwfe90/rv4H+M7jKeCQ4WibbT1K9Ocwe1HgL2zeQMUwspcHkq6SZMmPP7446SlpV1V3rhxY5o1a8bkyZMxGk2PFm3YsAE3N7e88Ya6dpVhW4rKqpNA2Lkwfoz8kcfrP047n0Kfc7taaoJpsphKvnDP6+YNUAgrdHnsoPDwcBwcHPIe8LpcfvDgQVavXs2KFSv48MMP87YLDg5m37597Nu3jzVr1lgq/DLHaoeSzsrJ4v1t7+Pl7MVrrV4r+oZ/PA86x/RgmIwPJMqzFW9BXFjx7tM7CB4s+iQxwcHBHDhw4LryqlWrMnPmTNq0acO4ceOKMUDrY7VnsZkHZhKVHMX7d79f9GaghCg4vh5qBUPNu80boBBWzmAwsGLFCoKCggpcHxAQQE5ODvHxpqHJNm/enNcc9PHHH5dkqGWaVV4JHE46zKywWfQM6Fn0u4G0hhVjTLeE9vnBvAEKURrcwjf24pR/7KDg4GAGDx5cpO2Cg4P5+++/zRhZ+WR1ScBgNPD+tvep5FiJN9u8WfQNT+2CY2uhTheoWIRRRYUQt6Wo00seP34cW1tbqlatSmRkpPkDK6esLgnMOziPiMQIJt87GXcn96JtlJMNPz8Ftg7w+DyzxieEuLlz587x4osvMmLEiOvG/Re3xqqSwIkLJ/hm3zd09e9a9LGBAELnQFoCdHgZnCqZL0AhRKEuNxNlZ2djZ2fHgAEDeO21W7ipQxTIapKAURv5YNsHONk5MfausUXfMDMFVr0LzpWhy/vmC1AIAXDdFJCX5eTkFLpNp06d6NSpk5kiKt+s5u6gRYcWsTd+L2PajCnaENGXRSyFnEzo+SXYWk3OFEJYCatIAjGXYvhqz1d08O3AI3UeKfqGqYnw50vgURca9TJfgEIIYSFW8dXW3saeYN9gRrcefWudSJsmmX53HAPS+SSEKIesIglUc63G5E6Tb22j89Gw81vwbgrN+polLiGEsDSraA66LUtfMv2WieOFEOWYJIGCxIXDf1sgoBP4NLN0NEIIYTaSBAqybSrY2EOf2ZaORAiro5Ti9devjND7+eefl/ggcZ06dSI0NLREj2kpkgSuFbMbDvwC9bqBq4eloxHC6jg6OrJ48WISEhJua3uDwVDMEZmXpeO1io7hIjMa4a+XTa+7W2bwLCFKi093fcqhpEPFus8GVRrwZtsbj9llZ2fH0KFD+eKLL64bDTQ6OppBgwaRkJCAl5cXc+bMwd/fn4EDB+Lk5MTevXvp0KEDSUlJODs7s3fvXuLj45k9ezbz589n+/bttGvXjrlz5wIwbNgwQkJCSE9P57HHHrtqfoKCjB8/nr/++ov09HTat2/Pd999h1KKTp060axZMzZu3IjBYGD27Nm0bduWcePGcezYMaKiokhISGDMmDEMGTKEDRs28N5771G5cmUOHTrEgQMHGDZsGKGhodjZ2TFlyhQ6d+5Mr1696NOnD8888wzfffcdmzZtYuHC4p3TXK4E8juxAc6GQdO+ULmmpaMRwmq99NJLLFy4kAsXLlxVPnLkSJ599lkOHDjA008/zahRo/LWxcTEsG3bNqZMmQLA+fPn2b59O1988QWPPPIIr776KgcPHiQsLCxvgLqPP/6Y0NBQDhw4wMaNGwucuyC/ESNGEBISQnh4OOnp6VeNWpqWlsa+ffv45ptvGDRoUF75gQMHWLduHdu3b2f8+PGcOXMGgD179vDVV19x5MgRpk+fjlKKsLAwFi1axLPPPktGRgYzZ85k/PjxbN68mcmTJzNt2rQ7el8LIlcClxmy4NdnwckNHp5q6WiEsLibfWM3p0qVKvHMM88wdepUnJ2vTOG6fft2Fi9eDMCAAQMYM2ZM3rrHH38cW1vbvOWHH34YpRRBQUFUq1Ytb16Cxo0bEx0dTfPmzfn111+ZOXMmBoOB2NhYIiIiaNq0aaFxrV+/ns8++4y0tDSSkpJo3LgxDz/8MAD9+vUDoGPHjly8eJHk5GQAevXqhbOzM87OznTu3Jldu3bh7u5O27ZtqV27NgBbtmxh5MiRADRo0ICaNWty5MgRmjZtyvjx4+ncuTNLliyhSpUqd/rWXqfEkoBSKgAYC7hprR/LLQsGns6No5HWun1JxXOdqNWQedE0cby9k8XCEEKYvPLKK7Rs2ZLnnnuuSPVdXV2vWnZ0dATAxsYm7/XlZYPBwIkTJ/j8888JCQmhcuXKDBw4kIyMjEL3n5GRwfDhwwkNDaVGjRqMGzfuqvrXPoh6ebmw8mvjLUxYWBgeHh55VxDFrUjNQUqp2UqpeKVU+DXl3ZVSh5VSUUqpt260D631ca314GvKNmutXwT+Biw3RrMxB9aOB1cvuPeG/wwhRAmpUqUKTzzxBLNmzcora9++PT///DMACxcuJDg4+Lb3f/HiRVxdXXFzc+Ps2bOsWLHihvUvn/A9PT1JSUnh999/v2r9L7/8Api+1bu5ueHm5gbAn3/+SUZGBomJiWzYsIE2bdpct+/g4OC8tv4jR45w8uRJAgMD2bVrFytWrGDv3r18/vnnnDhx4rb/vYUp6pXAXOBrYP7lAqWULTAd6AbEACFKqWWALfDJNdsP0lrH32D/TwFFmz7IHCKWwrlD0PYFGSROiFLk9ddf5+uvrzywOW3aNJ577jkmTZqU1zF8u5o1a0aLFi1o0KABNWrUoEOHDjes7+7uzpAhQ2jSpAne3t7XncydnJxo0aIF2dnZzJ595fbypk2b0rlzZxISEnjvvfeoXr06R44cuWrb4cOHM2zYMIKCgrCzs8vruB4yZAhz5syhevXqTJ48mUGDBrFu3bpinUNBaa2LVlGpWsDfWusmuct3A+O01g/kLr8NoLW+NgFcu5/fLzcH5S77A+9prYcUUn8oMBTA39+/1X///VekeIssKw2mNAAUvHFMkoCwapGRkTRs2NDSYZQ5nTp14vPPP6d169ZXlY8bN44KFSowevToEouloL+hUmq31rp1QfXv5O4gX+BUvuWY3LICKaU8lFIzgBaXE0auwUCh6VxrPVNr3Vpr3drLy+sOwi1ExJ+QcQHaj5QEIISwOiV21tNaJwIvFlD+QUnFcJ30ZFj6IrjXhODXb1pdCCEKsmHDhgLLS/pJ59txJ1cCp4Ea+Zb9csvKjog/Tb87viFDRQshrNKdJIEQoJ5SqrZSygF4ElhWPGGVgOwMWPeRacKYFv0tHY0QQlhEUW8RXQRsBwKVUjFKqcFaawMwAlgJRAK/aq0Pmi/UYrZvIaSeg2b95CpACGG1itQnoLXuV0j5cmB5sUZUEtKTYcWbUCVA+gKEEFbNOscOCvsNjNmmO4LkKkCIUuXjjz+mcePGNG3alObNm7Nz506zHu9Wh43esGEDPXv2LJZjt29/80ESvvzyS9LS0orleAWxvnsiUxNg+Wio2hhaD7p5fSFEidm+fTt///03e/bswdHRkYSEBLKysiwdltls27btpnW+/PJL+vfvj4uLi1lisL4kEG4afIqOJffwhhBlUdz//R+ZkcU7lLRjwwZ4v/NOoetjY2Px9PTMG+vH09Mzb92NhnFu0aIFmzdvJjU1lfnz5/PJJ58QFhZG3759mTBhAtHR0XTv3p1WrVqxZ88eGjduzPz58687sa5atYoPPviAzMxM6tSpw5w5c6hQoQL//vsvr7zyCi4uLtxzzz0Fxj537lyWLFnChQsXOH36NP379+eDD0x3wE+ZMiXvKeLnn3+eV155BYAKFSqQkpLChg0bGDduHJ6enoSHh9OqVSsWLFjAtGnTOHPmDJ07d8bT05M1a9YwePBgQkNDUUoxaNAgXn311dv+e4C1NQcZMmHjp1CtCTTpbelohBDXuP/++zl16hT169dn+PDhbNy4MW/djYZxdnBwIDQ0lBdffJFevXoxffp0wsPDmTt3LomJiQAcPnyY4cOHExkZSaVKlfjmm2+uOnZCQgITJkxgzZo17Nmzh9atWzNlyhQyMjIYMmQIf/31F7t37yYuLq7Q+Hft2sUff/zBgQMH+O233wgNDWX37t3MmTOHnTt3smPHDr7//nv27t173bZ79+7lyy+/JCIiguPHj7N161ZGjRpF9erVWb9+PevXr2ffvn2cPn2a8PBwwsLCijy43o1Y15XA7nmQlgAdRt28rhBW7kbf2M2lQoUK7N69m82bN7N+/Xr69u3LxIkTGThw4A2HcX7kkUcACAoKonHjxvj4+AAQEBDAqVOncHd3v2p8oP79+zN16tSrhnPYsWMHEREReXWysrK4++67OXToELVr16ZevXp5286cObPA+Lt164aHh2lGwt69e7NlyxaUUjz66KN5o4b27t2bzZs306JFi6u2bdu2LX5+fgA0b96c6Ojo6646AgICOH78OCNHjuShhx7i/vvvv813+grrSQIZF2H1e+DuD+0lCQhRWtna2tKpUyc6depEUFAQ8+bN48knn7zhMM43GzYaCh/S+TKtNd26dWPRokVXlV+egKYobnaMG8kft62tbYHTTlauXJn9+/ezcuVKZsyYwa+//nrVYHW3w3qagw78AoYMUwKQO4KEKJUOHz7M0aNH85b37dtHzZo1bzqMc1GcPHmS7du3A/DTTz9d9y37rrvuYuvWrURFRQGQmprKkSNHaNCgAdHR0Rw7dgzguiSR3+rVq0lKSiI9PZ2lS5fSoUMHgoODWbp0KWlpaaSmprJkyZJbGgK7YsWKXLp0CTA1WRmNRvr06cOECRPYs2fPLb0HBbGOK4HEY7l3BDWCtgUOViqEKAVSUlIYOXIkycnJ2NnZUbduXWbOnHnTYZyLIjAwkOnTpzNo0CAaNWrEsGHDrlrv5eXF3Llz6devH5mZmQBMmDCB+vXrM3PmTB566CFcXFwIDg7OOylfq23btvTp04eYmBj69++fN6rowIEDadu2LWDqGL62KehGhg4dSvfu3alevTpffvklzz33HEajEYBPPrnhoM1FUuShpEuD1q1b61u5nzfPhRhY+Q407w/177wNTYjyqrwOJR0dHU3Pnj0JDw+/eeXbNHfuXEJDQ6+a/8ASbnUoaeu4EnDzgyfm37yeEEJYGetIAkIIq1arVi2zXgWAqcln4MCBZj2GOVhPx7AQokjKUhOxuNrt/O0kCQgh8jg5OZGYmCiJoAzSWpOYmIiTk9MtbSfNQUKIPH5+fsTExHDu3DlLhyJug5OTU94DZ0UlSUAIkcfe3p7atWtbOgxRgqQ5SAghrJgkASGEsGKSBIQQwoqVqSeGlVLngP/uYBeeQEIxhWNuZSlWKFvxlqVYoWzFK7Gaz53EW1Nr7VXQijKVBO6UUiq0sEenS5uyFCuUrXjLUqxQtuKVWM3HXPFKc5AQQlgxSQJCCGHFrC0JFDwdUOlUlmKFshVvWYoVyla8Eqv5mCVeq+oTEEIIcTVruxIQQgiRjyQBIYSwYmU6CSilZiul4pVS4fnKqiilViuljub+rpxbrpRSU5VSUUqpA0qplvm2eTa3/lGl1LMlHO/jSqmDSimjUqr1NfXfzo33sFLqgXzl3XPLopRSb5VgrJOUUody378lSin30hDrDeL9KDfWfUqpVUqp6rnlFv0sFBRrvnWvK6W0UsqztMaqlBqnlDqd+77uU0r1yLeu1H0OcstH5n52DyqlPisN8Rby3v6S732NVkrtM3usWusy+wN0BFoC4fnKPgPeyn39FvBp7usewApAAXcBO3PLqwDHc39Xzn1duQTjbQgEAhuA1vnKGwH7AUegNnAMsM39OQYEAA65dRqVUKz3A3a5rz/N995aNNYbxFsp3+tRwIzS8FkoKNbc8hrASkwPRHqW1liBccDoAuqW1s9BZ2AN4Ji7XLU0xFvY5yDf+snA++aOtUxfCWitNwFJ1xT3Aublvp4H/C9f+XxtsgNwV0r5AA8Aq7XWSVrr88BqoHtJxau1jtRaHy6gei/gZ611ptb6BBAFtM39idJaH9daZwE/59YtiVhXaa0NuYs7gMtj1lo01hvEezHfoitw+S4Ii34WCvncAnwBjMkXZ2mOtSCl8nMADAMmaq0zc+vEl4Z4b/TeKqUU8ASwyNyxlukkUIhqWuvY3NdxQLXc177AqXz1YnLLCiu3tNIe7yBM31ChFMeqlPpYKXUKeBp4P7e41MWrlOoFnNZa779mVamLNdeI3Oap2Sq3yfUGMVk61vpAsFJqp1Jqo1KqTW55aY0XIBg4q7U+mrtstljLYxLIo03XUXIPbDFTSo0FDMBCS8dyM1rrsVrrGphiHWHpeAqilHIB3uFKkirtvgXqAM2BWEzNFqWZHaZms7uAN4Bfc79pl2b9uHIVYFblMQmczb1cJvf35Uu/05jaXC/zyy0rrNzSSmW8SqmBQE/g6dwkyw1iKk3v7UKgT+7r0hZvHUztvPuVUtG5x92jlPIuhbGitT6rtc7RWhuB7zE1SVAaY80VAyzObVLbBRgxDcZWKuNVStkBvYFf8hWbL9bi7uwo6R+gFld3Ak3i6o7hz3JfP8TVHWy7csurACcwda5Vzn1dpaTizVe+gas7hhtzdUfQcUydQHa5r2tzpSOocQm9t92BCMDrmnoWj7WQeOvlez0S+L20fBYK+xzkrovmSsdwqYsV8Mn3+lVMbdWl+XPwIjA+93V9TM0nqjTEW9DnIPf/2cZryswWq1n+CCX1g+lyKRbIxpTtBwMewFrgKKY7Aqrk1lXAdEw96WFcfcIdhKmjJQp4roTjfTT3dSZwFliZr/7Y3HgPAw/mK+8BHMldN7YEY43K/Q+0L/dnRmmI9Qbx/gGEAweAvwDf0vBZKCjWa9ZHcyUJlLpYgR9zYzkALOPqpFAaPwcOwILcz8IeoEtpiLewzwEwF3ixgPpmiVWGjRBCCCtWHvsEhBBCFJEkASGEsGKSBIQQwopJEhBCCCsmSUAIIayYJAEhhLBikgSEEMKK/T+bSfvbWcItGAAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "x = np.linspace(1000, 1700, 1000)\n", + "lam = 0.3\n", + "p = 1000\n", + "y0 = poisson_region_envelope(x, p, lam)\n", + "y1 = np.exp(_logprob(np.floor(x), p, lam))\n", + "y2 = normal_approx(x, p, lam)\n", + "plt.semilogy(x, y0, label=\"Envelope\")\n", + "plt.semilogy(x, y1, label=\"PDF\")\n", + "plt.plot(x, y2, label=\"Normal approx\")\n", + "samples = _rejection_region_poisson(np.random.default_rng(42), p, lam, 100000)\n", + "y, edges = np.histogram(samples, bins=20, density=True)\n", + "plt.plot(0.5 * (edges[1:] + edges[:-1]), y, label=\"Sampled points\")\n", + "plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def abel_rejection_envelope(x, p, lam):\n", + " p = np.asarray(p)\n", + " lam = np.asarray(lam)\n", + " nu = 2 / 3 * (p ** 2 - lam * p - 3 * lam ** 2) / lam ** 2\n", + " alpha = 0.2746244084 # Taken from page 259\n", + " # alpha = 3/7\n", + " t = np.floor(alpha * np.maximum(nu, 0))\n", + " problematic = (p < 1 + lam) | ((p * (1 - lam)) > (2 * lam))\n", + " if t.size == 1:\n", + " if problematic:\n", + " t = 0\n", + " else:\n", + " t[problematic] = 0\n", + " b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi)\n", + " q_r = b / np.sqrt(t + 1)\n", + "\n", + " rho_t = ( # Taken from page 250\n", + " 1\n", + " - p\n", + " + np.log(p)\n", + " - 0.5 * np.log(2 * np.pi)\n", + " + (t - 1) * (np.log(lam * t + p) - np.log(t + 1))\n", + " - 1.5 * np.log(t + 1)\n", + " + (1 - lam) * t\n", + " )\n", + " rho_t_prime = (\n", + " np.log(lam * t + p)\n", + " - np.log(t + 1)\n", + " + 1\n", + " - lam\n", + " - (t + 0.5) / (t + 1) ** 2\n", + " - (t - 1) * lam / (lam * t + p)\n", + " )\n", + " q = np.where(t == 0, 0, np.exp(-rho_t_prime))\n", + " q_l = np.where(t == 0, np.exp(-p), np.exp(rho_t) / (1 - q))\n", + " return np.where(\n", + " x <= t,\n", + " q_l * q ** (t - x) * (1 - q ** (t + 1)),\n", + " b * (1 / np.sqrt(x) - 1 / np.sqrt(x + 1)),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def abel_rejection_proposal_density(x, p, lam):\n", + " p = np.asarray(p)\n", + " lam = np.asarray(lam)\n", + " nu = 2 / 3 * (p ** 2 - lam * p - 3 * lam ** 2) / lam ** 2\n", + " alpha = 0.2746244084 # Taken from page 259\n", + " # alpha = 3/7\n", + " t = np.floor(alpha * np.maximum(nu, 0))\n", + " problematic = (p < 1 + lam) | ((p * (1 - lam)) > (2 * lam))\n", + " if t.size == 1:\n", + " if problematic:\n", + " t = 0\n", + " else:\n", + " t[problematic] = 0\n", + " b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi)\n", + " q_r = b / np.sqrt(t + 1)\n", + "\n", + " rho_t = ( # Taken from page 250\n", + " 1\n", + " - p\n", + " + np.log(p)\n", + " - 0.5 * np.log(2 * np.pi)\n", + " + (t - 1) * (np.log(lam * t + p) - np.log(t + 1))\n", + " - 1.5 * np.log(t + 1)\n", + " + (1 - lam) * t\n", + " )\n", + " rho_t_prime = (\n", + " np.log(lam * t + p)\n", + " - np.log(t + 1)\n", + " + 1\n", + " - lam\n", + " - (t + 0.5) / (t + 1) ** 2\n", + " - (t - 1) * lam / (lam * t + p)\n", + " )\n", + " q = np.where(t == 0, 0, np.exp(-rho_t_prime))\n", + " q_l = np.where(t == 0, np.exp(-p), np.exp(rho_t) / (1 - q))\n", + " return np.where(\n", + " x <= t,\n", + " q ** (t - x) * (1 - q ** (t + 1)) / (1 - q),\n", + " np.sqrt(t + 1) * (1 / np.sqrt(x) - 1 / np.sqrt(x + 1)),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/4290048940.py:39: RuntimeWarning: divide by zero encountered in true_divide\n", + " b * (1 / np.sqrt(x) - 1 / np.sqrt(x + 1)),\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "p, lam = 2.1209508879201904, 0.5510204081632653\n", + "mean = p / (1 - lam)\n", + "std = np.sqrt(p / (1 - lam) ** 3)\n", + "x = np.linspace(np.maximum(0, mean - 1.2 * std), mean + 15 * std, 1000)\n", + "plt.semilogy(x, np.exp(_logprob(x, p, lam)), label=\"PDF\")\n", + "plt.plot(x, abel_rejection_envelope(x, p, lam), label=\"Envelope\")\n", + "plt.plot(x, abel_rejection_proposal_density(x, p, lam), label=\"Proposal density\")\n", + "# samples = _rejection_region_abel(np.random.default_rng(42), p, lam, 100000)\n", + "# print(np.mean(samples <= 1), np.sum(np.exp(_logprob(np.array([0, 1]), p, lam))))\n", + "y, edges = np.histogram(samples, bins=20, density=True)\n", + "# plt.plot(0.5 * (edges[1:] + edges[:-1]), y, label=\"Sampled points\")\n", + "plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/2555885245.py:12: RuntimeWarning: invalid value encountered in log\n", + " return np.where(x == 0, np.where(m == 0, 0.0, -np.inf), m * np.log(x))\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "p, lam = 2.1209508879201904, 0.5510204081632653\n", + "\n", + "mean = p / (1 - lam)\n", + "std = np.sqrt(p / (1 - lam) ** 3)\n", + "x = np.linspace(0, 10, 1000)\n", + "\n", + "plt.figure(figsize=(16, 6))\n", + "plt.semilogy(x, np.exp(_logprob(np.floor(x), p, lam)), label=\"PMF\")\n", + "\n", + "# plt.plot(x, abel_rejection_envelope(np.floor(x), p, lam), label=\"Envelope\")\n", + "# plt.plot(x, abel_rejection_proposal_density(np.floor(x), p, lam), label=\"Proposal density\")\n", + "samples_monot = _rejection_region_monotonicity(np.random.default_rng(), p, lam, 100000)\n", + "samples_branch = _branching_rng_fn(np.random.default_rng(), p, lam, 100000)\n", + "samples_abel = _rejection_region_abel(np.random.default_rng(), p, lam, 100000)\n", + "for samples, algo in zip((samples_monot, samples_branch, samples_abel), (\"monoticity\", \"branch\", \"abel\")):\n", + " u, c = np.unique(samples, return_counts=True)\n", + " edges = np.arange(11)\n", + " y = np.array([np.sum(c[u == e]) for e in edges])\n", + " plt.step(edges, y / samples.size, label=f\"Sampled points ({algo})\", where=\"post\")\n", + " plt.legend(loc=(1, 0))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.11991754613423761" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.exp(_logprob(0, p, lam))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/4253790774.py:8: RuntimeWarning: divide by zero encountered in true_divide\n", + " poisson_idxs = p >= np.maximum(3, 2 * lam / (1 - lam))\n", + "/tmp/ipykernel_21224/4253790774.py:10: RuntimeWarning: divide by zero encountered in true_divide\n", + " abel_idxs = (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam)))\n" + ] + } + ], + "source": [ + "p_range = np.logspace(-2, 4, 50)\n", + "lam_range = np.linspace(0, 1, 60)\n", + "p, lam = np.meshgrid(p_range, lam_range)\n", + "dist_size = p.shape\n", + "# monotonicity_idxs = p <= (1 + lam)\n", + "# monotonicity_idxs = p <= np.exp(lam)\n", + "# poisson_idxs = p > np.maximum(1 + lam, 2 * lam / (1 - lam))\n", + "poisson_idxs = p >= np.maximum(3, 2 * lam / (1 - lam))\n", + "# poisson_idxs = p > np.maximum(np.exp(lam), 2 * lam / (1 - lam))\n", + "abel_idxs = (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam)))\n", + "# abel_idxs = (lam == 1) | ((p > np.exp(lam)) & (p <= 2 * lam / (1 - lam)))\n", + "# abel_idxs = (lam == 1) | ((p >= 3) & (p <= 2 * lam / (1 - lam)))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.54,)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idxs = np.full((lam_range.size, 50), 0, dtype=int)\n", + "# idxs[monotonicity_idxs] = 0\n", + "idxs[poisson_idxs] = 1\n", + "idxs[abel_idxs] = 2\n", + "np.mean(idxs == 1)," + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/2902846004.py:5: RuntimeWarning: divide by zero encountered in true_divide\n", + " poisson_idxs = p >= np.maximum(3, 2 * lam / (1 - lam))\n", + "/tmp/ipykernel_21224/2902846004.py:6: RuntimeWarning: divide by zero encountered in true_divide\n", + " abel_idxs = (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam)))\n", + "/tmp/ipykernel_21224/1720600079.py:5: RuntimeWarning: invalid value encountered in multiply\n", + " - np.minimum(lam, p) * np.sqrt(2 / np.pi) * ((1 / np.sqrt(x)) - (1 / np.sqrt(x + 1)))\n", + "/tmp/ipykernel_21224/3405323858.py:41: RuntimeWarning: overflow encountered in exp\n", + " return p / np.sqrt(2 * np.pi) * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu))\n", + "/tmp/ipykernel_21224/3405323858.py:41: RuntimeWarning: overflow encountered in multiply\n", + " return p / np.sqrt(2 * np.pi) * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu))\n", + "/tmp/ipykernel_21224/4290048940.py:38: RuntimeWarning: overflow encountered in power\n", + " q_l * q ** (t - x) * (1 - q ** (t + 1)),\n", + "/tmp/ipykernel_21224/4290048940.py:39: RuntimeWarning: divide by zero encountered in true_divide\n", + " b * (1 / np.sqrt(x) - 1 / np.sqrt(x + 1)),\n", + "/tmp/ipykernel_21224/4290048940.py:38: RuntimeWarning: divide by zero encountered in power\n", + " q_l * q ** (t - x) * (1 - q ** (t + 1)),\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Poisson envelope is not higher than pmf for p=3.7, lam=0.475, xs=[4]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.492, xs=[3 4 5]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.508, xs=[3 4 5]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.525, xs=[3 4 5 6]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.525, xs=[6]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.542, xs=[3 4 5 6]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.542, xs=[6 7]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.559, xs=[3 4 5 6]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.559, xs=[5 6 7 8]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.576, xs=[3 4 5 6 7]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.576, xs=[5 6 7 8]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.593, xs=[3 4 5 6 7]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.593, xs=[5 6 7 8]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.610, xs=[3 4 5 6 7]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.610, xs=[5 6 7 8 9]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.627, xs=[3 4 5 6 7 8]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.627, xs=[5 6 7 8 9]\n", + "Poisson envelope is not higher than pmf for p=3.7, lam=0.644, xs=[3 4 5 6 7 8]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.644, xs=[5 6 7 8 9]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.661, xs=[ 5 6 7 8 9 10]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.678, xs=[ 5 6 7 8 9 10]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.695, xs=[ 5 6 7 8 9 10]\n", + "Poisson envelope is not higher than pmf for p=4.9, lam=0.712, xs=[ 6 7 8 9 10]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/4290048940.py:38: RuntimeWarning: overflow encountered in multiply\n", + " q_l * q ** (t - x) * (1 - q ** (t + 1)),\n", + "/tmp/ipykernel_21224/4290048940.py:38: RuntimeWarning: invalid value encountered in multiply\n", + " q_l * q ** (t - x) * (1 - q ** (t + 1)),\n" + ] + } + ], + "source": [ + "p_range = np.logspace(-2, 4, 50)\n", + "lam_range = np.linspace(0, 1, 60)\n", + "p, lam = np.meshgrid(p_range, lam_range)\n", + "\n", + "poisson_idxs = p >= np.maximum(3, 2 * lam / (1 - lam))\n", + "abel_idxs = (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam)))\n", + "\n", + "idxs = np.full((lam_range.size, 50), 0, dtype=int)\n", + "# idxs[monotonicity_idxs] = 0\n", + "idxs[poisson_idxs] = 1\n", + "idxs[abel_idxs] = 2\n", + "\n", + "xs = np.concatenate([np.arange(10), np.logspace(1, 4.5)]).astype(int)\n", + "for i in range(p.shape[0]):\n", + " for j in range(p.shape[1]):\n", + " if idxs[i, j] == 0:\n", + " proposal = monotonicity_region_envelope(xs, p[i, j], lam[i, j])\n", + " pmf = np.exp(_logprob(xs, p[i, j], lam[i, j]))\n", + " bad_xs = xs[proposal < pmf + 0.00]\n", + " if len(bad_xs):\n", + " print(\n", + " f\"Monotonicity envelope is not higher than pmf for p={p[i, j]:.1f}, \"\n", + " f\"lam={lam[i, j]:.3f}, xs={bad_xs}\"\n", + " )\n", + "\n", + " elif idxs[i, j] == 1:\n", + " proposal = poisson_region_envelope(xs, p[i, j], lam[i, j])\n", + " pmf = np.exp(_logprob(xs, p[i, j], lam[i, j]))\n", + " bad_xs = xs[proposal < pmf - 0.03]\n", + " if len(bad_xs):\n", + " print(\n", + " f\"Poisson envelope is not higher than pmf for p={p[i, j]:.1f}, \"\n", + " f\"lam={lam[i, j]:.3f}, xs={bad_xs}\"\n", + " )\n", + "\n", + " elif idxs[i, j] == 2:\n", + " proposal = abel_rejection_envelope(xs, p[i, j], lam[i, j])\n", + " pmf = np.exp(_logprob(xs, p[i, j], lam[i, j]))\n", + " bad_xs = xs[proposal < pmf - 0.0 + 0.00]\n", + " if len(bad_xs):\n", + " print(\n", + " f\"Abel envelope is not higher than pmf for p={p[i, j]:.1f}, \"\n", + " f\"lam={lam[i, j]:.3f}, xs={bad_xs}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "data = idxs.T\n", + "\n", + "fig, ax = plt.subplots(figsize=(7, 7))\n", + "\n", + "# get discrete colormap\n", + "cmap = plt.get_cmap(\"coolwarm\", np.max(data) - np.min(data) + 1)\n", + "# set limits .5 outside true range\n", + "mat = ax.imshow(\n", + " data,\n", + " cmap=cmap,\n", + " vmin=0,\n", + " vmax=2,\n", + " origin=\"lower\",\n", + ")\n", + "# tell the colorbar to tick at integers\n", + "cbar = fig.colorbar(mat, ticks=[0, 1, 2], fraction=0.038)\n", + "cbar.ax.set_yticklabels([\"monotonicity\", \"poisson\", \"abel\"])\n", + "\n", + "ax.set_xlabel(\"lam\")\n", + "every = 8\n", + "plt.xticks(range(0, lam_range.size)[::every], np.round(lam_range[::every], 2))\n", + "\n", + "ax.set_ylabel(\"p\")\n", + "every = 8\n", + "plt.yticks(range(0, p_range.size)[::every], np.round(p_range[::every], 2));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Benchmark algorithms" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "import signal\n", + "import time\n", + "from functools import reduce" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class TimeOutError(RuntimeError):\n", + " pass\n", + "\n", + "\n", + "def handler(signum, frame):\n", + " raise TimeOutError\n", + "\n", + "\n", + "signal.signal(signal.SIGALRM, handler)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "rng = np.random.default_rng(42)\n", + "dist_size = (100, *p.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def benchmark_algorithm(algorithm_fn, draws=100, duration_cutoff=0.05):\n", + " duration = np.full_like(p, np.nan)\n", + " for i in range(p.shape[0]):\n", + " for j in range(p.shape[1]):\n", + " signal.setitimer(signal.ITIMER_REAL, duration_cutoff)\n", + " start = time.time()\n", + " try:\n", + " algorithm_fn(rng, p=p[i, j], lam=lam[i, j], dist_size=draws)\n", + " signal.alarm(0)\n", + " except TimeOutError:\n", + " continue\n", + " end = time.time()\n", + " duration[i, j] = end - start\n", + " signal.alarm(0)\n", + " return duration" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def plot_benchmark(duration, title=\"\"):\n", + " fig, ax = plt.subplots(1, 2, figsize=(14, 7), sharey=True)\n", + "\n", + " # Region\n", + " data = idxs.T\n", + " cmap = plt.get_cmap(\"coolwarm\", np.max(data) - np.min(data) + 1)\n", + " mat = ax[0].imshow(\n", + " data,\n", + " cmap=cmap,\n", + " vmin=0,\n", + " vmax=2,\n", + " origin=\"lower\",\n", + " )\n", + " cbar = fig.colorbar(mat, ticks=[0, 1, 2], fraction=0.038, ax=ax[0])\n", + " cbar.ax.set_yticks([0.3, 1, 1.7])\n", + " cbar.ax.set_yticklabels([\"monoticity.\", \"poisson\", \"abel\"], rotation=90, va=\"center\")\n", + "\n", + " # Timings\n", + " data = duration.T\n", + " mat = ax[1].imshow(\n", + " data,\n", + " cmap=\"viridis\",\n", + " origin=\"lower\",\n", + " )\n", + " cbar = fig.colorbar(mat, fraction=0.038, ax=ax[1])\n", + "\n", + "\n", + " for axi in ax:\n", + " axi.set_xlabel(\"lam\")\n", + " every = 8\n", + " axi.set_xticks(range(0, lam_range.size)[::every])\n", + " axi.set_xticklabels(np.round(lam_range[::every], 2))\n", + "\n", + " axi.set_ylabel(\"p\")\n", + " every = 8\n", + " axi.set_yticks(range(0, p_range.size)[::every])\n", + " axi.set_yticklabels(np.round(p_range[::every], 2))\n", + "\n", + " axi.axhline(20.5, color=\"white\")\n", + "\n", + " fig.suptitle(title, y=0.85, fontsize=18)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "duration_monot = benchmark_algorithm(_rejection_region_monotonicity)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/2133065621.py:11: RuntimeWarning: invalid value encountered in sqrt\n", + " sigma = np.sqrt((1 + delta) * (p - lam) / (1 - lam - eps) / (1 - lam) ** 2)\n", + "/tmp/ipykernel_21224/2133065621.py:41: RuntimeWarning: invalid value encountered in power\n", + " * (p - lam) ** 1.5\n", + "/tmp/ipykernel_21224/2133065621.py:29: RuntimeWarning: invalid value encountered in power\n", + " / (np.sqrt(2 * np.pi) * (p - lam) ** 1.5)\n", + "/tmp/ipykernel_21224/2133065621.py:50: RuntimeWarning: overflow encountered in exp\n", + " return p / np.sqrt(2 * np.pi) * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu))\n", + "/tmp/ipykernel_21224/2133065621.py:10: RuntimeWarning: divide by zero encountered in true_divide\n", + " mu = (p - lam) / (1 - lam)\n", + "/tmp/ipykernel_21224/2133065621.py:11: RuntimeWarning: divide by zero encountered in true_divide\n", + " sigma = np.sqrt((1 + delta) * (p - lam) / (1 - lam - eps) / (1 - lam) ** 2)\n", + "/tmp/ipykernel_21224/2133065621.py:18: RuntimeWarning: invalid value encountered in true_divide\n", + " (p * (1 - lam - eps) * np.sqrt(1 + delta))\n", + "/tmp/ipykernel_21224/2133065621.py:36: RuntimeWarning: divide by zero encountered in true_divide\n", + " t_r = np.ceil((p - lam) / (1 - lam - eps) - 1)\n", + "/tmp/ipykernel_21224/2133065621.py:46: RuntimeWarning: invalid value encountered in subtract\n", + " * np.exp(-(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (t_r - mu))\n", + "/tmp/ipykernel_21224/2133065621.py:52: RuntimeWarning: divide by zero encountered in true_divide\n", + " t_l = np.ceil((p - lam) / (1 - lam + delta) - 1)\n", + "/tmp/ipykernel_21224/2133065621.py:54: RuntimeWarning: divide by zero encountered in true_divide\n", + " (2 * p * (1 + delta))\n", + "/tmp/ipykernel_21224/2133065621.py:56: RuntimeWarning: invalid value encountered in subtract\n", + " * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (t_l + 1 - mu))\n", + "/tmp/ipykernel_21224/2133065621.py:85: RuntimeWarning: divide by zero encountered in true_divide\n", + " raw_left_y = _t_l - 2 * E * (1 + _delta) / _delta / (1 - _lam)\n", + "/tmp/ipykernel_21224/2133065621.py:86: RuntimeWarning: divide by zero encountered in true_divide\n", + " raw_right_y = _t_r + 2 * E / ((1 - 2 * (1 - _lam - _eps) / (_p - _lam)) * _eps * (1 - _lam))\n", + "/tmp/ipykernel_21224/2133065621.py:86: RuntimeWarning: invalid value encountered in add\n", + " raw_right_y = _t_r + 2 * E / ((1 - 2 * (1 - _lam - _eps) / (_p - _lam)) * _eps * (1 - _lam))\n", + "/tmp/ipykernel_21224/2133065621.py:38: RuntimeWarning: invalid value encountered in true_divide\n", + " (2 * p * (1 - lam - eps) ** 1.5 * np.exp(2 * (1 - lam)))\n", + "/tmp/ipykernel_21224/2133065621.py:84: RuntimeWarning: invalid value encountered in add\n", + " raw_center_y = _mu + _sigma * N\n", + "/tmp/ipykernel_21224/2133065621.py:85: RuntimeWarning: invalid value encountered in subtract\n", + " raw_left_y = _t_l - 2 * E * (1 + _delta) / _delta / (1 - _lam)\n", + "/tmp/ipykernel_21224/2133065621.py:24: RuntimeWarning: invalid value encountered in subtract\n", + " return G / (np.sqrt(2 * np.pi) * sigma) * np.exp(-((x - mu) ** 2) / (2 * sigma ** 2))\n", + "/tmp/ipykernel_21224/2133065621.py:50: RuntimeWarning: invalid value encountered in subtract\n", + " return p / np.sqrt(2 * np.pi) * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu))\n", + "/tmp/ipykernel_21224/2133065621.py:31: RuntimeWarning: invalid value encountered in subtract\n", + " -(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (x - mu)\n", + "/tmp/ipykernel_21224/2555885245.py:19: RuntimeWarning: invalid value encountered in subtract\n", + " np.log(p) + _logpow(p_lam_x, x - 1) - p_lam_x - gammaln(x + 1),\n" + ] + } + ], + "source": [ + "duration_poisson = benchmark_algorithm(_rejection_region_poisson)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/2250734458.py:8: RuntimeWarning: divide by zero encountered in true_divide\n", + " nu = 2 * (p ** 2 - lam * p - 3 * lam ** 2) / (3 * lam ** 2)\n", + "/tmp/ipykernel_21224/2555885245.py:12: RuntimeWarning: invalid value encountered in log\n", + " return np.where(x == 0, np.where(m == 0, 0.0, -np.inf), m * np.log(x))\n", + "/tmp/ipykernel_21224/2250734458.py:80: RuntimeWarning: overflow encountered in power\n", + " V * _q_l * _q ** (_t - raw_left) * (1 - _q ** (_t + 1))\n" + ] + } + ], + "source": [ + "duration_abel = benchmark_algorithm(_rejection_region_abel)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_21224/1066448282.py:4: RuntimeWarning: divide by zero encountered in log\n", + " abs_log_lam = np.log(np.abs(lam))\n" + ] + } + ], + "source": [ + "duration_inverse = benchmark_algorithm(\n", + " lambda rng, p, lam, dist_size: _inverse_rng_fn(rng, theta=p, lam=lam, dist_size=dist_size)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "duration_branch = benchmark_algorithm(\n", + " lambda rng, p, lam, dist_size: _branching_rng_fn(rng, theta=p, lam=lam, dist_size=dist_size)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_benchmark(duration=duration_monot, title=\"Monotonicity algorithm speed\")\n", + "plot_benchmark(duration=duration_poisson, title=\"Poisson algorithm speed\")\n", + "plot_benchmark(duration=duration_abel, title=\"Abel algorithm speed\")\n", + "plot_benchmark(duration=duration_inverse, title=\"Inverse algorithm speed\")\n", + "plot_benchmark(duration=duration_branch, title=\"Branching algorithm speed\")" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "def plot_benchmark_comparison(benchmarks, benchmark_names, threshold=0):\n", + " \n", + " nan_to_inf = lambda x: np.nan_to_num(x, nan=np.inf)\n", + " best_duration = np.full_like(duration_monot, np.nan)\n", + "\n", + " for i, reference in enumerate(benchmarks):\n", + " comps = (reference + threshold < nan_to_inf(other) for other in benchmarks if other is not reference)\n", + " best_duration[reduce(lambda x, y: x & y, comps)] = i\n", + " \n", + " \n", + " fig, ax = plt.subplots(1, 2, figsize=(14, 7), sharey=True)\n", + "\n", + " # Region\n", + " data = idxs.T\n", + " cmap = plt.get_cmap(\"coolwarm\", np.max(data) - np.min(data) + 1)\n", + " mat = ax[0].imshow(\n", + " data,\n", + " cmap=cmap,\n", + " vmin=0,\n", + " vmax=2,\n", + " origin=\"lower\",\n", + " )\n", + " cbar = fig.colorbar(mat, ticks=[0, 1, 2], fraction=0.038, ax=ax[0])\n", + " cbar.ax.set_yticks([0.3, 1, 1.7])\n", + " cbar.ax.set_yticklabels([\"monoticity.\", \"poisson\", \"abel\"], rotation=90, va=\"center\")\n", + "\n", + " # Timings\n", + " data = best_duration.T\n", + " cmap = plt.get_cmap(\"coolwarm\", np.nanmax(data) - np.nanmin(data) + 1)\n", + " mat = ax[1].imshow(\n", + " data,\n", + " cmap=cmap,\n", + " origin=\"lower\",\n", + " )\n", + " cbar = fig.colorbar(mat, fraction=0.038, ax=ax[1])\n", + " cbar.ax.set_yticks([0.3, 1, 1.7, 2.8, 3.5][:len(benchmarks)])\n", + " cbar.ax.set_yticklabels(benchmark_names, rotation=90, va=\"center\")\n", + "\n", + "\n", + " for axi in ax:\n", + " axi.set_xlabel(\"lam\")\n", + " every = 8\n", + " axi.set_xticks(range(0, lam_range.size)[::every])\n", + " axi.set_xticklabels(np.round(lam_range[::every], 2))\n", + "\n", + " axi.set_ylabel(\"p\")\n", + " every = 8\n", + " axi.set_yticks(range(0, p_range.size)[::every])\n", + " axi.set_yticklabels(np.round(p_range[::every], 2))\n", + "\n", + " axi.axhline(20.5, color=\"k\")\n", + "\n", + " fig.suptitle(\"Best performance per region\", y=0.85, fontsize=18)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_benchmark_comparison(\n", + " [duration_monot, duration_poisson, duration_abel],\n", + " [\"monot\", \"poisson\", \"abel\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_benchmark_comparison(\n", + " [duration_monot, duration_poisson, duration_abel, duration_inverse, duration_branch],\n", + " [\"monot\", \"poisson\", \"abel\", \"inverse\", \"branch\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_benchmark_comparison(\n", + " [duration_monot, duration_poisson, duration_abel, duration_inverse, duration_branch],\n", + " [\"monot\", \"poisson\", \"abel\", \"inverse\", \"branch\"],\n", + " threshold=0.001\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_benchmark_comparison(\n", + " [duration_monot, duration_branch],\n", + " [\"monot\", \"branch\"],\n", + " threshold=0.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_benchmark_comparison(\n", + " [duration_poisson, duration_branch],\n", + " [\"poisson\", \"branch\"],\n", + " threshold=0.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_benchmark_comparison(\n", + " [duration_abel, duration_branch],\n", + " [\"abel\", \"branch\"],\n", + " threshold=0.000\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "hide_input": false, + "jupytext": { + "formats": "ipynb,py:percent" + }, + "kernelspec": { + "display_name": "roche_preclinical_hd_gpu", + "language": "python", + "name": "roche_preclinical_hd_gpu" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "toc-showtags": false + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/fast_gen_pois.py b/notebooks/fast_gen_pois.py new file mode 100644 index 0000000..61db454 --- /dev/null +++ b/notebooks/fast_gen_pois.py @@ -0,0 +1,918 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.4.2 +# kernelspec: +# display_name: roche_preclinical_hd_gpu +# language: python +# name: roche_preclinical_hd_gpu +# --- + +# %% pycharm={"name": "#%%\n"} +import numpy as np +import pymc as pm + +from matplotlib import pyplot as plt +from scipy.special import gammaln + + +def _logpow(x, m): + """ + Calculates log(x**m) since m*log(x) will fail when m, x = 0. + """ + # return m * log(x) + return np.where(x == 0, np.where(m == 0, 0.0, -np.inf), m * np.log(x)) + + +def _logprob(x, p, lam): + p_lam_x = p + lam * x + return np.where( + x >= 0, + np.log(p) + _logpow(p_lam_x, x - 1) - p_lam_x - gammaln(x + 1), + -np.inf, + ) + + +# %% pycharm={"name": "#%%\n"} +def _rejection_region_monotonicity(rng, p, lam, dist_size, idxs_mask=None): + if idxs_mask is None: + idxs_mask = np.ones(dist_size, dtype="bool") + p = np.broadcast_to(p, dist_size)[idxs_mask] + lam = np.broadcast_to(lam, dist_size)[idxs_mask] + dist_size = np.sum(idxs_mask) + p0 = np.exp(-p) + b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi) + x = np.full(dist_size, np.nan) + inds_to_sample = np.ones(dist_size, dtype=bool) # dummy boolean mask#u > p0 / (p0 + b) + counter = -1 + while np.any(inds_to_sample): + counter += 1 + u = rng.uniform(size=dist_size) + zero_xs = u <= p0 / (p0 + b) + x[inds_to_sample & zero_xs] = 0 + inds_to_sample = inds_to_sample & ~zero_xs + + v = rng.uniform(size=dist_size) + w = rng.uniform(size=dist_size) + _x = np.floor(1 / w ** 2) + accepted = v * b * (1 / np.sqrt(_x) - 1 / np.sqrt(_x + 1)) <= np.exp(_logprob(_x, p, lam)) + x[inds_to_sample & accepted] = _x[inds_to_sample & accepted] + inds_to_sample = inds_to_sample & ~accepted + # print(counter) + return x + + +# %% pycharm={"name": "#%%\n"} +def _rejection_region_poisson(rng, p, lam, dist_size, idxs_mask=None): + if idxs_mask is None: + idxs_mask = np.ones(dist_size, dtype="bool") + p = np.broadcast_to(p, dist_size)[idxs_mask] + lam = np.broadcast_to(lam, dist_size)[idxs_mask] + dist_size = np.sum(idxs_mask) + + eps = (1 - lam) / (2 + (p - lam) * (1 - lam)) ** (1 / 3) + delta = (1 - lam) ** (2 / 5) / (2 + (p - lam) * (1 - lam)) ** (1 / 3) + mu = (p - lam) / (1 - lam) + sigma = np.sqrt((1 + delta) * (p - lam) / (1 - lam - eps) / (1 - lam) ** 2) + psi = ( + p * delta * (2 + delta - 2 * lam) + + (1 + delta) * (1 - lam) ** 2 + - lam * (1 - lam + delta) ** 2 + ) / (2 * (p - 1 - delta)) + G = ( + (p * (1 - lam - eps) * np.sqrt(1 + delta)) + / ((p - lam) * (1 - lam) * (1 - eps) ** 2) + * np.exp(psi / (1 + delta)) + ) + + def g(x, G, mu, sigma): + return G / (np.sqrt(2 * np.pi) * sigma) * np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) + + def h_r(x, p, lam, eps, mu): + return ( + (p * (1 - lam - eps) ** 1.5) + / (np.sqrt(2 * np.pi) * (p - lam) ** 1.5) + * np.exp( + -(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (x - mu) + + 2 * (1 - lam) + ) + ) + + t_r = np.ceil((p - lam) / (1 - lam - eps) - 1) + H_r = ( + (2 * p * (1 - lam - eps) ** 1.5 * np.exp(2 * (1 - lam))) + / ( + np.sqrt(2 * np.pi) + * (p - lam) ** 1.5 + * (1 - 2 * (1 - lam - eps) / (p - lam)) + * eps + * (1 - lam) + ) + * np.exp(-(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (t_r - mu)) + ) + + def h_l(x, p, lam, delta, mu): + return p / np.sqrt(2 * np.pi) * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu)) + + t_l = np.ceil((p - lam) / (1 - lam + delta) - 1) + H_l = ( + (2 * p * (1 + delta)) + / (np.sqrt(2 * np.pi) * delta * (1 - lam)) + * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (t_l + 1 - mu)) + ) + + x = np.zeros(dist_size) + inds_to_sample = np.arange(dist_size) + n_to_accept = np.zeros(dist_size) + counter = -1 + while np.any(inds_to_sample): + counter += 1 + _dist_size = len(inds_to_sample) + U = rng.uniform(size=_dist_size) + N = rng.normal(size=_dist_size) + V = rng.uniform(size=_dist_size) + E = rng.exponential(size=_dist_size) + _G = G[inds_to_sample] + _H_l = H_l[inds_to_sample] + _H_r = H_r[inds_to_sample] + _p = p[inds_to_sample] + _lam = lam[inds_to_sample] + _mu = mu[inds_to_sample] + _sigma = sigma[inds_to_sample] + _delta = delta[inds_to_sample] + _eps = eps[inds_to_sample] + _t_l = t_l[inds_to_sample] + _t_r = t_r[inds_to_sample] + + center = U < _G / (_G + _H_l + _H_r) + left = (U < (_G + _H_l) / (_G + _H_l + _H_r)) & ~center + raw_center_y = _mu + _sigma * N + raw_left_y = _t_l - 2 * E * (1 + _delta) / _delta / (1 - _lam) + raw_right_y = _t_r + 2 * E / ((1 - 2 * (1 - _lam - _eps) / (_p - _lam)) * _eps * (1 - _lam)) + Y = np.where( + center, + np.where( + (raw_center_y >= _t_l) & (raw_center_y < _t_r), + raw_center_y, + np.nan, + ), + np.where( + left, + np.where( + raw_left_y >= 0, + raw_left_y, + np.nan, + ), + np.where( + raw_right_y >= 0, + raw_right_y, + np.nan, + ), + ), + ) + X = np.floor(Y) + accepted = ( + V + * np.where( + center, + g(Y, G=_G, mu=_mu, sigma=_sigma), + np.where( + left, + h_l(Y, p=_p, lam=_lam, delta=_delta, mu=_mu), + h_r(Y, p=_p, lam=_lam, eps=_eps, mu=_mu), + ), + ) + <= np.exp(_logprob(X, _p, _lam)) + ) + + x[inds_to_sample[accepted]] = X[accepted] + n_to_accept[inds_to_sample[accepted]] = counter + inds_to_sample = inds_to_sample[~accepted] + return x + + +# %% pycharm={"name": "#%%\n"} +def _rejection_region_abel(rng, p, lam, dist_size, idxs_mask=None): + if idxs_mask is None: + idxs_mask = np.ones(dist_size, dtype="bool") + p = np.broadcast_to(p, dist_size)[idxs_mask] + lam = np.broadcast_to(lam, dist_size)[idxs_mask] + dist_size = np.sum(idxs_mask) + + nu = 2 * (p ** 2 - lam * p - 3 * lam ** 2) / (3 * lam ** 2) + alpha = 0.2746244084 # Taken from page 259 + t = np.floor(alpha * np.maximum(nu, 0)) + problematic = (p < 1 + lam) | ((p * (1 - lam)) > (2 * lam)) + t[problematic] = 0 + # b = p * np.exp(np.maximum(1 - p, 0)) * np.sqrt(2 / np.pi) + b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi) + q_r = b / np.sqrt(t + 1) + + rho_t = ( # Taken from page 250 + 1 + - p + + np.log(p) + - 0.5 * np.log(2 * np.pi) + + (t - 1) * (np.log(lam * t + p) - np.log(t + 1)) + - 1.5 * np.log(t + 1) + + (1 - lam) * t + ) + # rho_t_prime = ( # Taken form page 271 + # np.log(lam * t + p) + # - np.log(t + 1) + # + 1 + # - lam + # + 0.5 / (t + 1) + # - (lam + p) / (lam * t + p) + # ) + rho_t_prime = ( + np.log(lam * t + p) + - np.log(t + 1) + + 1 + - lam + - (t + 0.5) / (t + 1) ** 2 + - (t - 1) * lam / (lam * t + p) + ) + q = np.exp(-rho_t_prime) + q_l = np.where(t == 0, np.exp(-p), np.exp(rho_t) / (1 - q)) + + x = np.zeros(dist_size) + n_to_accept = np.zeros(dist_size) + inds_to_sample = np.arange(dist_size) + counter = -1 + while np.any(inds_to_sample): + counter += 1 + _dist_size = len(inds_to_sample) + U = rng.uniform(size=_dist_size) + V = rng.uniform(size=_dist_size) + W = rng.uniform(size=_dist_size) + # E = rng.uniform(size=_dist_size) + E = rng.exponential(size=_dist_size) + _p = p[inds_to_sample] + _lam = lam[inds_to_sample] + _t = t[inds_to_sample] + _q = q[inds_to_sample] + _q_l = q_l[inds_to_sample] + _q_r = q_r[inds_to_sample] + _b = b[inds_to_sample] + # raw_left = np.where(_t == 0, 0, _t - np.floor(-E / np.log(1 - _q))) + # raw_left = np.where(_t == 0, 0, _t - np.floor(-E / np.log(_q))) + # raw_left = np.where(_t == 0, 0, _t + np.ceil(np.log(1 - E) / _q)) + raw_left = np.where(_t == 0, 0, _t - np.floor(E / _q)) + raw_right = np.floor((_t + 1) / W ** 2) + + left = U <= _q_l / (_q_l + _q_r) + accepted = np.where( + left, + np.where( + _t == 0, + True, + np.where( + raw_left < 0, + False, + # V * _q_l * _q ** (_t - raw_left) * (1 - _q) + V * _q_l * _q ** (_t - raw_left) * (1 - _q ** (_t + 1)) + <= np.exp(_logprob(raw_left, _p, _lam)), + ), + ), + V * _b * (1 / np.sqrt(raw_right) - 1 / np.sqrt(raw_right + 1)) + <= np.exp(_logprob(raw_right, _p, _lam)), + ) + X = np.where(left, raw_left, raw_right) + + x[inds_to_sample[accepted]] = X[accepted] + n_to_accept[inds_to_sample[accepted]] = counter + inds_to_sample = inds_to_sample[~accepted] + return x + + +# %% pycharm={"name": "#%%\n"} +def _inverse_rng_fn(rng, theta, lam, dist_size): + log_u = np.log(rng.uniform(size=dist_size)) + pos_lam = lam > 0 + abs_log_lam = np.log(np.abs(lam)) + theta_m_lam = theta - lam + log_s = -theta + log_p = log_s.copy() + x_ = 0 + x = np.zeros(dist_size) + below_cutpoint = log_s < log_u + with np.errstate(divide="ignore", invalid="ignore"): + counter = 0 + while np.any(below_cutpoint): + counter += 1 + x_ += 1 + x[below_cutpoint] += 1 + log_c = np.log(theta_m_lam + lam * x_) + # Compute log(1 + lam / C) + log1p_lam_m_C = np.where( + pos_lam, + np.log1p(np.exp(abs_log_lam - log_c)), + pm.math.log1mexp_numpy(abs_log_lam - log_c, negative_input=True), + ) + log_p = log_c + log1p_lam_m_C * (x_ - 1) + log_p - np.log(x_) - lam + log_s = np.logaddexp(log_s, log_p) + below_cutpoint = log_s < log_u + # print(counter) + return x + + +# %% +def _branching_rng_fn(rng, theta, lam, dist_size, idxs_mask=None): + if idxs_mask is None: + idxs_mask = np.ones(dist_size, dtype=bool) + lam_ = np.abs(lam) # This algorithm is only valid for positive lam + y = rng.poisson(theta, size=dist_size) + x = y.copy() + higher_than_zero = y > 0 + while np.any(higher_than_zero[idxs_mask]): + y = rng.poisson(lam_ * y) + x[higher_than_zero] = x[higher_than_zero] + y[higher_than_zero] + higher_than_zero = y > 0 + return x + + +# %% pycharm={"name": "#%%\n"} +rng = np.random.default_rng(42) +p, lam = np.meshgrid(np.logspace(-2, 4, 50), np.linspace(0, 1, 50)) +dist_size = (100, *p.shape) +monotonicity_idxs = np.broadcast_to(p <= np.exp(lam), dist_size) +poisson_idxs = np.broadcast_to(p >= np.maximum(3, 2 * lam / (1 - lam)), dist_size) +abel_idxs = np.broadcast_to( + (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam))), + dist_size, +) + +# %% pycharm={"name": "#%%\n"} +# %%time +samples = np.full(dist_size, np.nan) +samples[monotonicity_idxs] = _rejection_region_monotonicity( + rng=rng, p=p, lam=lam, dist_size=dist_size, idxs_mask=monotonicity_idxs +) +samples[poisson_idxs] = _rejection_region_poisson( + rng=rng, + p=p, + lam=lam, + dist_size=dist_size, + idxs_mask=poisson_idxs, +) +samples[abel_idxs] = _rejection_region_abel( + rng=rng, + p=p, + lam=lam, + dist_size=dist_size, + idxs_mask=abel_idxs, +) + +# %% pycharm={"name": "#%%\n"} +c = np.zeros_like(p.flatten()) +c[monotonicity_idxs[0].flatten()] = 0 +c[poisson_idxs[0].flatten()] = 1 +c[abel_idxs[0].flatten()] = 2 + +# %% pycharm={"name": "#%%\n"} +plt.scatter( + (p / (1 - lam)).flatten(), + np.mean(samples, axis=0).flatten(), + c=c, + alpha=0.2, + cmap="jet", +) +ax = plt.gca() +ax.set_xscale("log") +ax.set_yscale("log") +plt.plot(ax.get_xlim(), ax.get_ylim(), "-k", alpha=0.2) +plt.xlabel("Expected mean") +plt.xlabel("Sample mean") + +# %% pycharm={"name": "#%%\n"} +plt.scatter( + (p / (1 - lam) ** 3).flatten(), + np.var(samples, axis=0).flatten(), + c=c, + alpha=0.2, + cmap="jet", +) +ax = plt.gca() +ax.set_xscale("log") +ax.set_yscale("log") + +ax = plt.gca() +plt.plot(ax.get_xlim(), ax.get_ylim(), "-k", alpha=0.2) +plt.xlabel("Expected variance") +plt.xlabel("Sample variance") + + +# %% pycharm={"name": "#%%\n"} +def normal_approx(x, p, lam): + mu = p / (1 - lam) + sigma = np.sqrt(p / (1 - lam) ** 3) + return 1 / np.sqrt(2 * np.pi) / sigma * np.exp(-0.5 * (x - mu) ** 2 / sigma ** 2) + + +# %% pycharm={"name": "#%%\n"} +def monotonicity_region_envelope(x, p, lam): + return p * np.exp( + 2 + - lam + - np.minimum(lam, p) * np.sqrt(2 / np.pi) * ((1 / np.sqrt(x)) - (1 / np.sqrt(x + 1))) + ) + (x == 0) * np.exp( + -p + ) # Extra probability for x==0 + + +# %% pycharm={"name": "#%%\n"} +def poisson_region_envelope(x, p, lam): + eps = (1 - lam) / (2 + (p - lam) * (1 - lam)) ** (1 / 3) + delta = (1 - lam) ** (2 / 5) / (2 + (p - lam) * (1 - lam)) ** (1 / 3) + mu = (p - lam) / (1 - lam) + sigma = np.sqrt((1 + delta) * (p - lam) / (1 - lam - eps) / (1 - lam) ** 2) + psi = ( + p * delta * (2 + delta - 2 * lam) + + (1 + delta) * (1 - lam) ** 2 + - lam * (1 - lam + delta) ** 2 + ) / (2 * (p - 1 - delta)) + G = ( + (p * (1 - lam - eps) * np.sqrt(1 + delta)) + / ((p - lam) * (1 - lam) * (1 - eps) ** 2) + * np.exp(psi / (1 + delta)) + ) + + def g(x): + return G / (np.sqrt(2 * np.pi) * sigma) * np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) + + h_r_norm = (p * (1 - lam - eps) ** 1.5) / (np.sqrt(2 * np.pi) * (p - lam) ** 1.5) + h_r_exp_A = -(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) + h_r_exp_B = 2 * (1 - lam) + + def h_r(x): + return h_r_norm * np.exp(h_r_exp_A * (x - mu) + h_r_exp_B) + + t_r = np.ceil((p - lam) / (1 - lam - eps) - 1) + H_r = ( + (2 * p * (1 - lam - eps) ** 1.5 * np.exp(2 * (1 - lam))) + / ( + np.sqrt(2 * np.pi) + * (p - lam) ** 1.5 + * (1 - 2 * (1 - lam - eps) / (p - lam)) + * eps + * (1 - lam) + ) + * np.exp(-(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (t_r - mu)) + ) + + def h_l(x): + return p / np.sqrt(2 * np.pi) * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu)) + + t_l = np.ceil((p - lam) / (1 - lam + delta) - 1) + H_l = ( + (2 * p * (1 + delta)) + / (np.sqrt(2 * np.pi) * delta * (1 - lam)) + * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (t_l + 1 - mu)) + ) + return np.where(x < t_l, h_l(x), np.where(x < t_r, g(x), h_r(x))) + + +# %% pycharm={"name": "#%%\n"} +x = np.linspace(1000, 1700, 1000) +lam = 0.3 +p = 1000 +y0 = poisson_region_envelope(x, p, lam) +y1 = np.exp(_logprob(np.floor(x), p, lam)) +y2 = normal_approx(x, p, lam) +plt.semilogy(x, y0, label="Envelope") +plt.semilogy(x, y1, label="PDF") +plt.plot(x, y2, label="Normal approx") +samples = _rejection_region_poisson(np.random.default_rng(42), p, lam, 100000) +y, edges = np.histogram(samples, bins=20, density=True) +plt.plot(0.5 * (edges[1:] + edges[:-1]), y, label="Sampled points") +plt.legend() + + +# %% pycharm={"name": "#%%\n"} +def abel_rejection_envelope(x, p, lam): + p = np.asarray(p) + lam = np.asarray(lam) + nu = 2 / 3 * (p ** 2 - lam * p - 3 * lam ** 2) / lam ** 2 + alpha = 0.2746244084 # Taken from page 259 + # alpha = 3/7 + t = np.floor(alpha * np.maximum(nu, 0)) + problematic = (p < 1 + lam) | ((p * (1 - lam)) > (2 * lam)) + if t.size == 1: + if problematic: + t = 0 + else: + t[problematic] = 0 + b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi) + q_r = b / np.sqrt(t + 1) + + rho_t = ( # Taken from page 250 + 1 + - p + + np.log(p) + - 0.5 * np.log(2 * np.pi) + + (t - 1) * (np.log(lam * t + p) - np.log(t + 1)) + - 1.5 * np.log(t + 1) + + (1 - lam) * t + ) + rho_t_prime = ( + np.log(lam * t + p) + - np.log(t + 1) + + 1 + - lam + - (t + 0.5) / (t + 1) ** 2 + - (t - 1) * lam / (lam * t + p) + ) + q = np.where(t == 0, 0, np.exp(-rho_t_prime)) + q_l = np.where(t == 0, np.exp(-p), np.exp(rho_t) / (1 - q)) + return np.where( + x <= t, + q_l * q ** (t - x) * (1 - q ** (t + 1)), + b * (1 / np.sqrt(x) - 1 / np.sqrt(x + 1)), + ) + + +# %% pycharm={"name": "#%%\n"} +def abel_rejection_proposal_density(x, p, lam): + p = np.asarray(p) + lam = np.asarray(lam) + nu = 2 / 3 * (p ** 2 - lam * p - 3 * lam ** 2) / lam ** 2 + alpha = 0.2746244084 # Taken from page 259 + # alpha = 3/7 + t = np.floor(alpha * np.maximum(nu, 0)) + problematic = (p < 1 + lam) | ((p * (1 - lam)) > (2 * lam)) + if t.size == 1: + if problematic: + t = 0 + else: + t[problematic] = 0 + b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi) + q_r = b / np.sqrt(t + 1) + + rho_t = ( # Taken from page 250 + 1 + - p + + np.log(p) + - 0.5 * np.log(2 * np.pi) + + (t - 1) * (np.log(lam * t + p) - np.log(t + 1)) + - 1.5 * np.log(t + 1) + + (1 - lam) * t + ) + rho_t_prime = ( + np.log(lam * t + p) + - np.log(t + 1) + + 1 + - lam + - (t + 0.5) / (t + 1) ** 2 + - (t - 1) * lam / (lam * t + p) + ) + q = np.where(t == 0, 0, np.exp(-rho_t_prime)) + q_l = np.where(t == 0, np.exp(-p), np.exp(rho_t) / (1 - q)) + return np.where( + x <= t, + q ** (t - x) * (1 - q ** (t + 1)) / (1 - q), + np.sqrt(t + 1) * (1 / np.sqrt(x) - 1 / np.sqrt(x + 1)), + ) + + +# %% pycharm={"name": "#%%\n"} +p, lam = 2.1209508879201904, 0.5510204081632653 +mean = p / (1 - lam) +std = np.sqrt(p / (1 - lam) ** 3) +x = np.linspace(np.maximum(0, mean - 1.2 * std), mean + 15 * std, 1000) +plt.semilogy(x, np.exp(_logprob(x, p, lam)), label="PDF") +plt.plot(x, abel_rejection_envelope(x, p, lam), label="Envelope") +plt.plot(x, abel_rejection_proposal_density(x, p, lam), label="Proposal density") +# samples = _rejection_region_abel(np.random.default_rng(42), p, lam, 100000) +# print(np.mean(samples <= 1), np.sum(np.exp(_logprob(np.array([0, 1]), p, lam)))) +y, edges = np.histogram(samples, bins=20, density=True) +# plt.plot(0.5 * (edges[1:] + edges[:-1]), y, label="Sampled points") +plt.legend() + +# %% pycharm={"name": "#%%\n"} +p, lam = 2.1209508879201904, 0.5510204081632653 + +mean = p / (1 - lam) +std = np.sqrt(p / (1 - lam) ** 3) +x = np.linspace(0, 10, 1000) + +plt.figure(figsize=(16, 6)) +plt.semilogy(x, np.exp(_logprob(np.floor(x), p, lam)), label="PMF") + +# plt.plot(x, abel_rejection_envelope(np.floor(x), p, lam), label="Envelope") +# plt.plot(x, abel_rejection_proposal_density(np.floor(x), p, lam), label="Proposal density") +samples_monot = _rejection_region_monotonicity(np.random.default_rng(), p, lam, 100000) +samples_branch = _branching_rng_fn(np.random.default_rng(), p, lam, 100000) +samples_abel = _rejection_region_abel(np.random.default_rng(), p, lam, 100000) +for samples, algo in zip( + (samples_monot, samples_branch, samples_abel), ("monoticity", "branch", "abel") +): + u, c = np.unique(samples, return_counts=True) + edges = np.arange(11) + y = np.array([np.sum(c[u == e]) for e in edges]) + plt.step(edges, y / samples.size, label=f"Sampled points ({algo})", where="post") + plt.legend(loc=(1, 0)) + +# %% pycharm={"name": "#%%\n"} +np.exp(_logprob(0, p, lam)) + +# %% pycharm={"name": "#%%\n"} +p_range = np.logspace(-2, 4, 50) +lam_range = np.linspace(0, 1, 60) +p, lam = np.meshgrid(p_range, lam_range) +dist_size = p.shape +# monotonicity_idxs = p <= (1 + lam) +# monotonicity_idxs = p <= np.exp(lam) +# poisson_idxs = p > np.maximum(1 + lam, 2 * lam / (1 - lam)) +poisson_idxs = p >= np.maximum(3, 2 * lam / (1 - lam)) +# poisson_idxs = p > np.maximum(np.exp(lam), 2 * lam / (1 - lam)) +abel_idxs = (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam))) +# abel_idxs = (lam == 1) | ((p > np.exp(lam)) & (p <= 2 * lam / (1 - lam))) +# abel_idxs = (lam == 1) | ((p >= 3) & (p <= 2 * lam / (1 - lam))) + +# %% pycharm={"name": "#%%\n"} +idxs = np.full((lam_range.size, 50), 0, dtype=int) +# idxs[monotonicity_idxs] = 0 +idxs[poisson_idxs] = 1 +idxs[abel_idxs] = 2 +np.mean(idxs == 1), + +# %% pycharm={"name": "#%%\n"} +p_range = np.logspace(-2, 4, 50) +lam_range = np.linspace(0, 1, 60) +p, lam = np.meshgrid(p_range, lam_range) + +poisson_idxs = p >= np.maximum(3, 2 * lam / (1 - lam)) +abel_idxs = (lam == 1) | ((p >= 1 + lam) & (p <= 2 * lam / (1 - lam))) + +idxs = np.full((lam_range.size, 50), 0, dtype=int) +# idxs[monotonicity_idxs] = 0 +idxs[poisson_idxs] = 1 +idxs[abel_idxs] = 2 + +xs = np.concatenate([np.arange(10), np.logspace(1, 4.5)]).astype(int) +for i in range(p.shape[0]): + for j in range(p.shape[1]): + if idxs[i, j] == 0: + proposal = monotonicity_region_envelope(xs, p[i, j], lam[i, j]) + pmf = np.exp(_logprob(xs, p[i, j], lam[i, j])) + bad_xs = xs[proposal < pmf + 0.00] + if len(bad_xs): + print( + f"Monotonicity envelope is not higher than pmf for p={p[i, j]:.1f}, " + f"lam={lam[i, j]:.3f}, xs={bad_xs}" + ) + + elif idxs[i, j] == 1: + proposal = poisson_region_envelope(xs, p[i, j], lam[i, j]) + pmf = np.exp(_logprob(xs, p[i, j], lam[i, j])) + bad_xs = xs[proposal < pmf - 0.03] + if len(bad_xs): + print( + f"Poisson envelope is not higher than pmf for p={p[i, j]:.1f}, " + f"lam={lam[i, j]:.3f}, xs={bad_xs}" + ) + + elif idxs[i, j] == 2: + proposal = abel_rejection_envelope(xs, p[i, j], lam[i, j]) + pmf = np.exp(_logprob(xs, p[i, j], lam[i, j])) + bad_xs = xs[proposal < pmf - 0.0 + 0.00] + if len(bad_xs): + print( + f"Abel envelope is not higher than pmf for p={p[i, j]:.1f}, " + f"lam={lam[i, j]:.3f}, xs={bad_xs}" + ) + +# %% pycharm={"name": "#%%\n"} +data = idxs.T + +fig, ax = plt.subplots(figsize=(7, 7)) + +# get discrete colormap +cmap = plt.get_cmap("coolwarm", np.max(data) - np.min(data) + 1) +# set limits .5 outside true range +mat = ax.imshow( + data, + cmap=cmap, + vmin=0, + vmax=2, + origin="lower", +) +# tell the colorbar to tick at integers +cbar = fig.colorbar(mat, ticks=[0, 1, 2], fraction=0.038) +cbar.ax.set_yticklabels(["monotonicity", "poisson", "abel"]) + +ax.set_xlabel("lam") +every = 8 +plt.xticks(range(0, lam_range.size)[::every], np.round(lam_range[::every], 2)) + +ax.set_ylabel("p") +every = 8 +plt.yticks(range(0, p_range.size)[::every], np.round(p_range[::every], 2)) + +# %% [markdown] +# ## Benchmark algorithms + +# %% pycharm={"name": "#%%\n"} +import signal +import time + +from functools import reduce + + +# %% pycharm={"name": "#%%\n"} +class TimeOutError(RuntimeError): + pass + + +def handler(signum, frame): + raise TimeOutError + + +signal.signal(signal.SIGALRM, handler) + +# %% pycharm={"name": "#%%\n"} +rng = np.random.default_rng(42) +dist_size = (100, *p.shape) + + +# %% +def benchmark_algorithm(algorithm_fn, draws=100, duration_cutoff=0.05): + duration = np.full_like(p, np.nan) + for i in range(p.shape[0]): + for j in range(p.shape[1]): + signal.setitimer(signal.ITIMER_REAL, duration_cutoff) + start = time.time() + try: + algorithm_fn(rng, p=p[i, j], lam=lam[i, j], dist_size=draws) + signal.alarm(0) + except TimeOutError: + continue + end = time.time() + duration[i, j] = end - start + signal.alarm(0) + return duration + + +# %% +def plot_benchmark(duration, title=""): + fig, ax = plt.subplots(1, 2, figsize=(14, 7), sharey=True) + + # Region + data = idxs.T + cmap = plt.get_cmap("coolwarm", np.max(data) - np.min(data) + 1) + mat = ax[0].imshow( + data, + cmap=cmap, + vmin=0, + vmax=2, + origin="lower", + ) + cbar = fig.colorbar(mat, ticks=[0, 1, 2], fraction=0.038, ax=ax[0]) + cbar.ax.set_yticks([0.3, 1, 1.7]) + cbar.ax.set_yticklabels(["monoticity.", "poisson", "abel"], rotation=90, va="center") + + # Timings + data = duration.T + mat = ax[1].imshow( + data, + cmap="viridis", + origin="lower", + ) + cbar = fig.colorbar(mat, fraction=0.038, ax=ax[1]) + + for axi in ax: + axi.set_xlabel("lam") + every = 8 + axi.set_xticks(range(0, lam_range.size)[::every]) + axi.set_xticklabels(np.round(lam_range[::every], 2)) + + axi.set_ylabel("p") + every = 8 + axi.set_yticks(range(0, p_range.size)[::every]) + axi.set_yticklabels(np.round(p_range[::every], 2)) + + axi.axhline(20.5, color="white") + + fig.suptitle(title, y=0.85, fontsize=18) + + +# %% +duration_monot = benchmark_algorithm(_rejection_region_monotonicity) + +# %% +duration_poisson = benchmark_algorithm(_rejection_region_poisson) + +# %% +duration_abel = benchmark_algorithm(_rejection_region_abel) + +# %% +duration_inverse = benchmark_algorithm( + lambda rng, p, lam, dist_size: _inverse_rng_fn(rng, theta=p, lam=lam, dist_size=dist_size) +) + +# %% +duration_branch = benchmark_algorithm( + lambda rng, p, lam, dist_size: _branching_rng_fn(rng, theta=p, lam=lam, dist_size=dist_size) +) + +# %% +plot_benchmark(duration=duration_monot, title="Monotonicity algorithm speed") +plot_benchmark(duration=duration_poisson, title="Poisson algorithm speed") +plot_benchmark(duration=duration_abel, title="Abel algorithm speed") +plot_benchmark(duration=duration_inverse, title="Inverse algorithm speed") +plot_benchmark(duration=duration_branch, title="Branching algorithm speed") + + +# %% pycharm={"name": "#%%\n"} +def plot_benchmark_comparison(benchmarks, benchmark_names, threshold=0): + + nan_to_inf = lambda x: np.nan_to_num(x, nan=np.inf) + best_duration = np.full_like(duration_monot, np.nan) + + for i, reference in enumerate(benchmarks): + comps = ( + reference + threshold < nan_to_inf(other) + for other in benchmarks + if other is not reference + ) + best_duration[reduce(lambda x, y: x & y, comps)] = i + + fig, ax = plt.subplots(1, 2, figsize=(14, 7), sharey=True) + + # Region + data = idxs.T + cmap = plt.get_cmap("coolwarm", np.max(data) - np.min(data) + 1) + mat = ax[0].imshow( + data, + cmap=cmap, + vmin=0, + vmax=2, + origin="lower", + ) + cbar = fig.colorbar(mat, ticks=[0, 1, 2], fraction=0.038, ax=ax[0]) + cbar.ax.set_yticks([0.3, 1, 1.7]) + cbar.ax.set_yticklabels(["monoticity.", "poisson", "abel"], rotation=90, va="center") + + # Timings + data = best_duration.T + cmap = plt.get_cmap("coolwarm", np.nanmax(data) - np.nanmin(data) + 1) + mat = ax[1].imshow( + data, + cmap=cmap, + origin="lower", + ) + cbar = fig.colorbar(mat, fraction=0.038, ax=ax[1]) + cbar.ax.set_yticks([0.3, 1, 1.7, 2.8, 3.5][: len(benchmarks)]) + cbar.ax.set_yticklabels(benchmark_names, rotation=90, va="center") + + for axi in ax: + axi.set_xlabel("lam") + every = 8 + axi.set_xticks(range(0, lam_range.size)[::every]) + axi.set_xticklabels(np.round(lam_range[::every], 2)) + + axi.set_ylabel("p") + every = 8 + axi.set_yticks(range(0, p_range.size)[::every]) + axi.set_yticklabels(np.round(p_range[::every], 2)) + + axi.axhline(20.5, color="k") + + fig.suptitle("Best performance per region", y=0.85, fontsize=18) + + +# %% +plot_benchmark_comparison( + [duration_monot, duration_poisson, duration_abel], + ["monot", "poisson", "abel"], +) + +# %% +plot_benchmark_comparison( + [duration_monot, duration_poisson, duration_abel, duration_inverse, duration_branch], + ["monot", "poisson", "abel", "inverse", "branch"], +) + +# %% +plot_benchmark_comparison( + [duration_monot, duration_poisson, duration_abel, duration_inverse, duration_branch], + ["monot", "poisson", "abel", "inverse", "branch"], + threshold=0.001, +) + +# %% +plot_benchmark_comparison([duration_monot, duration_branch], ["monot", "branch"], threshold=0.0) + +# %% +plot_benchmark_comparison([duration_poisson, duration_branch], ["poisson", "branch"], threshold=0.0) + +# %% +plot_benchmark_comparison([duration_abel, duration_branch], ["abel", "branch"], threshold=0.000) + +# %% diff --git a/preclinpack/blocks/distributions.py b/preclinpack/blocks/distributions.py index 16671a4..06685fc 100644 --- a/preclinpack/blocks/distributions.py +++ b/preclinpack/blocks/distributions.py @@ -32,6 +32,7 @@ from pymc.distributions.distribution import _moment from pymc.distributions.shape_utils import rv_size_is_none from pymc.distributions.transforms import _default_transform + from scipy.special import gammaln try: from aesara.link.jax.dispatch import jax_funcify @@ -195,6 +196,13 @@ def logp(op, value_var_list, rng, size, dtype, dim, alpha, **kwargs): dim > 1, ) + def _logpow(x, m): + """ + Calculates log(x**m) since m*log(x) will fail when m, x = 0. + """ + # return m * log(x) + return np.where(x == 0, np.where(m == 0, 0.0, -np.inf), m * np.log(x)) + class GeneralizedPoissonRV(RandomVariable): name = "generalized_poisson" ndim_supp = 0 @@ -213,21 +221,285 @@ def rng_fn(cls, rng, theta, lam, size): else: dist_size = np.broadcast_shapes(theta.shape, lam.shape) - # A mix of 2 algorithms described by Famoye (1997) is used depending on - # parameter values + # A mix of 4 algorithms described by Devroye (1989) and Famoye (1997) is used + # depending on parameter values # 0: Inverse method, computed on the log scale. Used when lam <= 0. - # 1: Branching method. Used when lambda > 0. + # 1: Poisson rejection region. Used when theta > max(3, 2 * lam / (1 - lam)) + # 2: Abel rejection region. Used when lam == 1 or (p >= 1 + lam and p <= 2*lam / (1-lam)) + # 3: Universal bound (AKA monotonicity region). Used when lam > 0 and not in regions 1 or 2 + + poisson_idxs = np.broadcast_to(theta >= np.maximum(3, 2 * lam / (1 - lam)), dist_size) + abel_idxs = np.broadcast_to( + (lam == 1) | ((theta >= 1 + lam) & (theta <= 2 * lam / (1 - lam))), + dist_size, + ) + monotonicity_idxs = (lam > 0) & (~poisson_idxs) & (~abel_idxs) + inverse_method_idxs = np.broadcast_to(lam < 0, dist_size) + x = np.empty(dist_size) - idxs_mask = np.broadcast_to(lam < 0, dist_size) - if np.any(idxs_mask): - x[idxs_mask] = cls._inverse_rng_fn(rng, theta, lam, dist_size, idxs_mask=idxs_mask)[ - idxs_mask - ] - idxs_mask = ~idxs_mask - if np.any(idxs_mask): - x[idxs_mask] = cls._branching_rng_fn( - rng, theta, lam, dist_size, idxs_mask=idxs_mask - )[idxs_mask] + if np.any(inverse_method_idxs): + x[inverse_method_idxs] = cls._inverse_rng_fn( + rng, theta, lam, dist_size, idxs_mask=inverse_method_idxs + )[inverse_method_idxs] + if np.any(monotonicity_idxs): + x[monotonicity_idxs] = cls._rejection_region_monotonicity( + rng, theta, lam, dist_size, idxs_mask=monotonicity_idxs + ) + if np.any(poisson_idxs): + x[poisson_idxs] = cls._rejection_region_poisson( + rng, theta, lam, dist_size, idxs_mask=poisson_idxs + ) + if np.any(abel_idxs): + x[abel_idxs] = cls._rejection_region_abel( + rng, theta, lam, dist_size, idxs_mask=abel_idxs + ) + return x + + @staticmethod + def _logprob(x, theta, lam): + theta_lam_x = theta + lam * x + return np.log(theta) + _logpow(theta_lam_x, x - 1) - theta_lam_x - gammaln(x + 1) + + @classmethod + def _rejection_region_monotonicity(cls, rng, p, lam, dist_size, idxs_mask=None): + if idxs_mask is None: + idxs_mask = np.ones(dist_size, dtype="bool") + p = np.broadcast_to(p, dist_size)[idxs_mask] + lam = np.broadcast_to(lam, dist_size)[idxs_mask] + dist_size = np.sum(idxs_mask) + p0 = np.exp(-p) + b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi) + u = rng.uniform(size=dist_size) + x = np.zeros(dist_size) + inds_to_sample = u > p0 / (p0 + b) + counter = 0 + while np.any(inds_to_sample): + counter += 1 + v = rng.uniform(size=dist_size) + w = rng.uniform(size=dist_size) + _x = np.floor(1 / w ** 2) + accepted = v * b * (1 / np.sqrt(_x) - 1 / np.sqrt(_x + 1)) <= np.exp( + cls._logprob(_x, p, lam) + ) + x[inds_to_sample & accepted] = _x[inds_to_sample & accepted] + inds_to_sample = inds_to_sample & ~accepted + return x + + @classmethod + def _rejection_region_poisson(cls, rng, p, lam, dist_size, idxs_mask=None): + if idxs_mask is None: + idxs_mask = np.ones(dist_size, dtype="bool") + p = np.broadcast_to(p, dist_size)[idxs_mask] + lam = np.broadcast_to(lam, dist_size)[idxs_mask] + dist_size = np.sum(idxs_mask) + + eps = (1 - lam) / (2 + (p - lam) * (1 - lam)) ** (1 / 3) + delta = (1 - lam) ** (2 / 5) / (2 + (p - lam) * (1 - lam)) ** (1 / 3) + mu = (p - lam) / (1 - lam) + sigma = np.sqrt((1 + delta) * (p - lam) / (1 - lam - eps) / (1 - lam) ** 2) + psi = ( + p * delta * (2 + delta - 2 * lam) + + (1 + delta) * (1 - lam) ** 2 + - lam * (1 - lam + delta) ** 2 + ) / (2 * (p - 1 - delta)) + G = ( + (p * (1 - lam - eps) * np.sqrt(1 + delta)) + / ((p - lam) * (1 - lam) * (1 - eps) ** 2) + * np.exp(psi / (1 + delta)) + ) + + def g(x, G, mu, sigma): + return ( + G / (np.sqrt(2 * np.pi) * sigma) * np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) + ) + + def h_r(x, p, lam, eps, mu): + return ( + (p * (1 - lam - eps) ** 1.5) + / (np.sqrt(2 * np.pi) * (p - lam) ** 1.5) + * np.exp( + -(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (x - mu) + + 2 * (1 - lam) + ) + ) + + t_r = np.ceil((p - lam) / (1 - lam - eps) - 1) + H_r = ( + (2 * p * (1 - lam - eps) ** 1.5 * np.exp(2 * (1 - lam))) + / ( + np.sqrt(2 * np.pi) + * (p - lam) ** 1.5 + * (1 - 2 * (1 - lam - eps) / (p - lam)) + * eps + * (1 - lam) + ) + * np.exp( + -(1 - 2 * (1 - lam - eps) / (p - lam)) * (eps / 2) * (1 - lam) * (t_r - mu) + ) + ) + + def h_l(x, p, lam, delta, mu): + return ( + p + / np.sqrt(2 * np.pi) + * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (x + 1 - mu)) + ) + + t_l = np.ceil((p - lam) / (1 - lam + delta) - 1) + H_l = ( + (2 * p * (1 + delta)) + / (np.sqrt(2 * np.pi) * delta * (1 - lam)) + * np.exp(delta * (1 - lam) / (2 * (1 + delta)) * (t_l + 1 - mu)) + ) + + x = np.zeros(dist_size) + inds_to_sample = np.arange(dist_size) + n_to_accept = np.zeros(dist_size) + counter = -1 + while np.any(inds_to_sample): + counter += 1 + _dist_size = len(inds_to_sample) + U = rng.uniform(size=_dist_size) + N = rng.normal(size=_dist_size) + V = rng.uniform(size=_dist_size) + E = rng.exponential(size=_dist_size) + _G = G[inds_to_sample] + _H_l = H_l[inds_to_sample] + _H_r = H_r[inds_to_sample] + _p = p[inds_to_sample] + _lam = lam[inds_to_sample] + _mu = mu[inds_to_sample] + _sigma = sigma[inds_to_sample] + _delta = delta[inds_to_sample] + _eps = eps[inds_to_sample] + _t_l = t_l[inds_to_sample] + _t_r = t_r[inds_to_sample] + + center = U < _G / (_G + _H_l + _H_r) + left = (U < (_G + _H_l) / (_G + _H_l + _H_r)) & ~center + raw_center_y = _mu + _sigma * N + raw_left_y = _t_l - 2 * E * (1 + _delta) / _delta / (1 - _lam) + raw_right_y = _t_r + 2 * E / ( + (1 - 2 * (1 - _lam - _eps) / (_p - _lam)) * _eps * (1 - _lam) + ) + Y = np.where( + center, + np.where( + (raw_center_y >= _t_l) & (raw_center_y < _t_r), + raw_center_y, + np.nan, + ), + np.where( + left, + np.where( + raw_left_y >= 0, + raw_left_y, + np.nan, + ), + np.where( + raw_right_y >= 0, + raw_right_y, + np.nan, + ), + ), + ) + X = np.floor(Y) + accepted = ( + V + * np.where( + center, + g(Y, G=_G, mu=_mu, sigma=_sigma), + np.where( + left, + h_l(Y, p=_p, lam=_lam, delta=_delta, mu=_mu), + h_r(Y, p=_p, lam=_lam, eps=_eps, mu=_mu), + ), + ) + <= np.exp(cls._logprob(X, _p, _lam)) + ) + + x[inds_to_sample[accepted]] = X[accepted] + n_to_accept[inds_to_sample[accepted]] = counter + inds_to_sample = inds_to_sample[~accepted] + return x + + @classmethod + def _rejection_region_abel(cls, rng, p, lam, dist_size, idxs_mask=None): + if idxs_mask is None: + idxs_mask = np.ones(dist_size, dtype="bool") + p = np.broadcast_to(p, dist_size)[idxs_mask] + lam = np.broadcast_to(lam, dist_size)[idxs_mask] + dist_size = np.sum(idxs_mask) + + nu = 2 * (p ** 2 - lam * p - 3 * lam ** 2) / (3 * lam ** 2) + alpha = 0.2746244084 # Taken from page 259 + t = np.floor(alpha * np.maximum(nu, 0)) + problematic = (p < 1 + lam) | ((p * (1 - lam)) > (2 * lam)) + t[problematic] = 0 + b = p * np.exp(2 - lam - np.minimum(lam, p)) * np.sqrt(2 / np.pi) + q_r = b / np.sqrt(t + 1) + + rho_t = ( # Taken from page 250 + 1 + - p + + np.log(p) + - 0.5 * np.log(2 * np.pi) + + (t - 1) * (np.log(lam * t + p) - np.log(t + 1)) + - 1.5 * np.log(t + 1) + + (1 - lam) * t + ) + rho_t_prime = ( + np.log(lam * t + p) + - np.log(t + 1) + + 1 + - lam + + 1.5 / (t + 1) + - (lam + p) / (lam * t + p) + ) + q = np.exp(-rho_t_prime) + q_l = np.where(t == 0, np.exp(-p), np.exp(rho_t) / (1 - np.exp(-rho_t_prime))) + + x = np.zeros(dist_size) + n_to_accept = np.zeros(dist_size) + inds_to_sample = np.arange(dist_size) + counter = -1 + while np.any(inds_to_sample): + counter += 1 + _dist_size = len(inds_to_sample) + U = rng.uniform(size=_dist_size) + V = rng.uniform(size=_dist_size) + W = rng.uniform(size=_dist_size) + E = rng.exponential(size=_dist_size) + _p = p[inds_to_sample] + _lam = lam[inds_to_sample] + _t = t[inds_to_sample] + _q = q[inds_to_sample] + _q_l = q_l[inds_to_sample] + _q_r = q_r[inds_to_sample] + _b = b[inds_to_sample] + raw_left = np.where(_t == 0, 0, _t - np.floor(-E / np.log(1 - _q))) + raw_right = np.floor((_t + 1) / W ** 2) + left = U <= _q_l / (_q_l + _q_r) + accepted = np.where( + left, + np.where( + _t == 0, + True, + np.where( + raw_left > 0, + False, + V * _q_l * _q ** (_t - raw_left) * (1 - _q) + <= np.exp(cls._logprob(raw_left, _p, _lam)), + ), + ), + V * _b * (1 / np.sqrt(raw_right) - 1 / np.sqrt(raw_right + 1)) + <= np.exp(cls._logprob(raw_right, _p, _lam)), + ) + X = np.where(left, raw_left, raw_right) + + x[inds_to_sample[accepted]] = X[accepted] + n_to_accept[inds_to_sample[accepted]] = counter + inds_to_sample = inds_to_sample[~accepted] return x @classmethod @@ -257,18 +529,6 @@ def _inverse_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask): below_cutpoint = log_s < log_u return x - @classmethod - def _branching_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask): - lam_ = np.abs(lam) # This algorithm is only valid for positive lam - y = rng.poisson(theta, size=dist_size) - x = y.copy() - higher_than_zero = y > 0 - while np.any(higher_than_zero[idxs_mask]): - y = rng.poisson(lam_ * y) - x[higher_than_zero] = x[higher_than_zero] + y[higher_than_zero] - higher_than_zero = y > 0 - return x - generalized_poisson = GeneralizedPoissonRV() class GeneralizedPoisson(pm.distributions.Discrete):