From cd2bd960980c0ecf2766f2fcc4bf0f7af0c0e2fe Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Fri, 11 Jun 2021 18:12:04 +0530 Subject: [PATCH] ENH: stats.unuran: add a higher level API This adds a higher-level class called `GenericGenerator` which can be used to create generators using UNU.RAN's String API. Currently, only strings are allowed but it can be extended to accept a object containing a UNU.RAN distribution. One can simply pass a string with the distribution to sample from and create a generator: >>> rng = GenericGenerator("normal") >>> rvs = rng.rvs(100_000) It is also possible to pass the method and its parameters as keyword arguments. >>> rng = GenericGenerator("normal", method="tdr", domain=(-5, 5), c=0., ... cpoints=(-4.5, -3.5, -2.5, 0.0, 2.5, 3.5, 4.5)) >>> rvs = rng.rvs(100_000) It is also possible to change the parameters of the distribution. The below example creates a samplers for the normal distribution with mean 6 and standard deviation 4.25: >>> rng = GenericGenerator("normal(6, 4.25)") Advanced use case includes passing a custom distribution with methods coded as strings. >>> rng = GenericGenerator("distr = cont; pdf = 'exp(-0.5 * x^2)'") >>> rvs = rng.rvs(100_000) --- scipy/stats/unuran/unuran_wrapper.pyx.templ | 122 +++++++++++++++++++- 1 file changed, 121 insertions(+), 1 deletion(-) diff --git a/scipy/stats/unuran/unuran_wrapper.pyx.templ b/scipy/stats/unuran/unuran_wrapper.pyx.templ index a69e322d6424..49d5889d5fb4 100644 --- a/scipy/stats/unuran/unuran_wrapper.pyx.templ +++ b/scipy/stats/unuran/unuran_wrapper.pyx.templ @@ -19,7 +19,8 @@ import numpy as np from scipy.stats._distn_infrastructure import argsreduce -__all__ = ["TransformedDensityRejection", "DiscreteAliasUrn"] +__all__ = ["GenericGenerator", "TransformedDensityRejection", + "DiscreteAliasUrn"] # Internal API for handling Python callbacks. @@ -473,6 +474,125 @@ cdef class Method: unur_urng_free(self.urng) +cdef class GenericGenerator(Method): + """ + Generic generator to sample from a wide family of distributions. + + This is a higher-level API for sampling from common distributions + easily. + + Parameters + ---------- + dist : str + Distribution string. See [1]_ for allowed distribution keys and [2]_ + for the grammar of a distribution string. + method : str, optional + Method to use. See [3]_ for available methods. Default is 'auto'. + domain : str, optional + Domain o the distribution. Can be used to truncate the distributions. + seed : int, BitGenerator, Generator, RandomState, SeedSequence, optional + Seed for the underlying uniform random number generation. + + Other Parameters + ---------------- + kwargs : + Other parameters that the Method takes. See [3]_ for all the available + methods and their parameters. + + References + ---------- + .. [1] UNU.RAN reference manual, Section 3.2.1, "Keys for Distribution String", + http://statmath.wu.ac.at/software/unuran/doc/unuran.html#KeysDistr + .. [2] UNU.RAN reference manual, Section 3.2, "Distribution String", + http://statmath.wu.ac.at/software/unuran/doc/unuran.html#StringDistr + .. [3] UNU.RAN reference manual, Section 3.4.1, "Keys for Method String", + http://statmath.wu.ac.at/software/unuran/doc/unuran.html#KeysMethod + + Examples + -------- + >>> from scipy.stats import GenericGenerator + + To create a sampler for the normal distribution, do: + + >>> rng = GenericGenerator("normal") + + Random variates can be sampled usin the `rvs` method: + + >>> rvs = rng.rvs(100_000) + + A method for sampling can also be specified by passing the `method` + parameter. In the example below, Transformed Density Rejection is used + as the method for sampling from the normal distirbution: + + >>> rng = GenericGenerator("normal", method="tdr") + + It is also possible to truncate the distribution by passing the `domain` + parameter and parameters of the `method` can be changed using keyword + arguments: + + >>> rng = GenericGenerator("normal", method="tdr", domain=(-5, 5), c=0., + ... cpoints=(-4.5, -3.5, -2.5, 0.0, 2.5, 3.5, 4.5)) + >>> rvs = rng.rvs(100_000) + + One can also change the parameters of the distributions. The below example + creates a samplers for the normal distribution with mean 6 and standard + deviation 4.25: + + >>> rng = GenericGenerator("normal(6, 4.25)") + """ + + def __cinit__(self, dist, method="auto", domain=None, seed=None, + **kwargs): + dist, method, domain, seed = self._validate_args(dist, method, + domain, seed) + self._init_generator(dist, method, domain, seed, kwargs) + + def _validate_args(self, dist, method, domain, seed): + dist = str(dist) + if dist == "": + raise ValueError("distribution string must not be empty") + method = str(method) + if domain is not None: + domain = tuple(domain) + return dist, method, domain, seed + + cdef int _init_generator(self, dist, method, domain, seed, + kwargs) except -1: + cdef: + ccallback_t callback + unuran_callback_t unur_callback + unur_urng *urng = NULL + + cdef str distr_str = f"{dist}" + if domain is not None: + distr_str += f"; domain = {domain}" + cdef str meth_str = f"method = {method}" + for key, value in kwargs.items(): + meth_str += f"; {key} = {value}" + + self.callbacks = {} + self.params = () + + init_unuran_callback(&callback, &unur_callback, self.callbacks, + self.params) + if setjmp(callback.error_buf) != 0: + release_unuran_callback(&callback, &unur_callback) + return -1 + + if seed is not None: + self.numpy_rng = get_numpy_rng(seed) + urng = get_urng(self.numpy_rng) + + self.urng = urng + self.rng = unur_makegen_ssu(distr_str.encode('UTF-8'), + meth_str.encode('UTF-8'), + urng) + + release_unuran_callback(&callback, &unur_callback) + + return 0 + + cdef class TransformedDensityRejection(Method): r""" Transformed Density Rejection (TDR) Method.