From 5823dbf285534b354a829cb075a538253c61e7ea Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Sun, 20 Jun 2021 23:27:43 +0530 Subject: [PATCH] REF: stats: use a dist object as input and remove string API Take a distribution object as input. --- scipy/stats/unuran/__init__.py | 3 +- scipy/stats/unuran/unuran_wrapper.pyx.templ | 111 ++++++-------------- 2 files changed, 34 insertions(+), 80 deletions(-) diff --git a/scipy/stats/unuran/__init__.py b/scipy/stats/unuran/__init__.py index 8d7e80ca08df..0f82803294a5 100644 --- a/scipy/stats/unuran/__init__.py +++ b/scipy/stats/unuran/__init__.py @@ -1,4 +1,5 @@ from .unuran_wrapper import ( TransformedDensityRejection, - DiscreteAliasUrn + DiscreteAliasUrn, + GenericGenerator ) diff --git a/scipy/stats/unuran/unuran_wrapper.pyx.templ b/scipy/stats/unuran/unuran_wrapper.pyx.templ index 49d5889d5fb4..a57388a1fd1b 100644 --- a/scipy/stats/unuran/unuran_wrapper.pyx.templ +++ b/scipy/stats/unuran/unuran_wrapper.pyx.templ @@ -474,7 +474,7 @@ cdef class Method: unur_urng_free(self.urng) -cdef class GenericGenerator(Method): +class GenericGenerator: """ Generic generator to sample from a wide family of distributions. @@ -486,8 +486,8 @@ cdef class GenericGenerator(Method): dist : str Distribution string. See [1]_ for allowed distribution keys and [2]_ for the grammar of a distribution string. - method : str, optional - Method to use. See [3]_ for available methods. Default is 'auto'. + method : str + Method to use. Can be 'tdr' or 'dau'. domain : str, optional Domain o the distribution. Can be used to truncate the distributions. seed : int, BitGenerator, Generator, RandomState, SeedSequence, optional @@ -496,101 +496,54 @@ cdef class GenericGenerator(Method): Other Parameters ---------------- kwargs : - Other parameters that the Method takes. See [3]_ for all the available - methods and their parameters. - - References - ---------- - .. [1] UNU.RAN reference manual, Section 3.2.1, "Keys for Distribution String", - http://statmath.wu.ac.at/software/unuran/doc/unuran.html#KeysDistr - .. [2] UNU.RAN reference manual, Section 3.2, "Distribution String", - http://statmath.wu.ac.at/software/unuran/doc/unuran.html#StringDistr - .. [3] UNU.RAN reference manual, Section 3.4.1, "Keys for Method String", - http://statmath.wu.ac.at/software/unuran/doc/unuran.html#KeysMethod + Other parameters that the method takes. Examples -------- >>> from scipy.stats import GenericGenerator + >>> + >>> class Distribution: + ... def pdf(self, x): + ... return np.exp(-0.5 * x*x) + ... def dpdf(self, x): + ... return -x * np.exp(-0.5 * x*x) + ... + >>> dist = Distribution() - To create a sampler for the normal distribution, do: + To create a sampler for this distribution, do: - >>> rng = GenericGenerator("normal") + >>> rng = GenericGenerator(dist, method="tdr") - Random variates can be sampled usin the `rvs` method: + Random variates can be sampled using the `rvs` method: >>> rvs = rng.rvs(100_000) - A method for sampling can also be specified by passing the `method` - parameter. In the example below, Transformed Density Rejection is used - as the method for sampling from the normal distirbution: - - >>> rng = GenericGenerator("normal", method="tdr") - It is also possible to truncate the distribution by passing the `domain` parameter and parameters of the `method` can be changed using keyword arguments: - >>> rng = GenericGenerator("normal", method="tdr", domain=(-5, 5), c=0., + >>> rng = GenericGenerator(dist, method="tdr", domain=(-5, 5), c=0., ... cpoints=(-4.5, -3.5, -2.5, 0.0, 2.5, 3.5, 4.5)) >>> rvs = rng.rvs(100_000) - - One can also change the parameters of the distributions. The below example - creates a samplers for the normal distribution with mean 6 and standard - deviation 4.25: - - >>> rng = GenericGenerator("normal(6, 4.25)") """ + _allmethods = { + 'tdr': TransformedDensityRejection, + 'dau': DiscreteAliasUrn + } - def __cinit__(self, dist, method="auto", domain=None, seed=None, - **kwargs): - dist, method, domain, seed = self._validate_args(dist, method, - domain, seed) - self._init_generator(dist, method, domain, seed, kwargs) - - def _validate_args(self, dist, method, domain, seed): - dist = str(dist) - if dist == "": - raise ValueError("distribution string must not be empty") + def __init__(self, dist, method, domain=None, seed=None, **kwargs): method = str(method) - if domain is not None: - domain = tuple(domain) - return dist, method, domain, seed - - cdef int _init_generator(self, dist, method, domain, seed, - kwargs) except -1: - cdef: - ccallback_t callback - unuran_callback_t unur_callback - unur_urng *urng = NULL - - cdef str distr_str = f"{dist}" - if domain is not None: - distr_str += f"; domain = {domain}" - cdef str meth_str = f"method = {method}" - for key, value in kwargs.items(): - meth_str += f"; {key} = {value}" - - self.callbacks = {} - self.params = () - - init_unuran_callback(&callback, &unur_callback, self.callbacks, - self.params) - if setjmp(callback.error_buf) != 0: - release_unuran_callback(&callback, &unur_callback) - return -1 - - if seed is not None: - self.numpy_rng = get_numpy_rng(seed) - urng = get_urng(self.numpy_rng) - - self.urng = urng - self.rng = unur_makegen_ssu(distr_str.encode('UTF-8'), - meth_str.encode('UTF-8'), - urng) - - release_unuran_callback(&callback, &unur_callback) - - return 0 + if method.lower() not in self._allmethods: + raise ValueError(f"unknown method `{method}`") + method = method.lower() + self.method = self._allmethods[method](dist=dist, domain=domain, + seed=seed, **kwargs) + + # add all the public attributes/methods to our generator. + for attrname in dir(self.method): + if not attrname.startswith("_"): + attr = getattr(self.method, attrname) + self.__dict__[attrname] = attr cdef class TransformedDensityRejection(Method):