Skip to content

Commit

Permalink
ENH: stats.unuran: add a higher level API
Browse files Browse the repository at this point in the history
This adds a higher-level class called `GenericGenerator` which can be used to
create generators using UNU.RAN's String API. Currently, only strings are allowed
but it can be extended to accept a object containing a UNU.RAN distribution.

One can simply pass a string with the distribution to sample from and create a
generator:

    >>> rng = GenericGenerator("normal")
    >>> rvs = rng.rvs(100_000)

It is also possible to pass the method and its parameters as keyword arguments.

    >>> rng = GenericGenerator("normal", method="tdr", domain=(-5, 5), c=0.,
    ...                        cpoints=(-4.5, -3.5, -2.5, 0.0, 2.5, 3.5, 4.5))
    >>> rvs = rng.rvs(100_000)

It is also possible to change the parameters of the distribution. The below example
creates a samplers for the normal distribution with mean 6 and standard
deviation 4.25:

    >>> rng = GenericGenerator("normal(6, 4.25)")

Advanced use case includes passing a custom distribution with methods coded as strings.

    >>> rng = GenericGenerator("distr = cont; pdf = 'exp(-0.5 * x^2)'")
    >>> rvs = rng.rvs(100_000)
  • Loading branch information
tirthasheshpatel committed Jun 20, 2021
1 parent ca45d2b commit cd2bd96
Showing 1 changed file with 121 additions and 1 deletion.
122 changes: 121 additions & 1 deletion scipy/stats/unuran/unuran_wrapper.pyx.templ
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ import numpy as np
from scipy.stats._distn_infrastructure import argsreduce
__all__ = ["TransformedDensityRejection", "DiscreteAliasUrn"]
__all__ = ["GenericGenerator", "TransformedDensityRejection",
"DiscreteAliasUrn"]
# Internal API for handling Python callbacks.
Expand Down Expand Up @@ -473,6 +474,125 @@ cdef class Method:
unur_urng_free(self.urng)
cdef class GenericGenerator(Method):
"""
Generic generator to sample from a wide family of distributions.

This is a higher-level API for sampling from common distributions
easily.

Parameters
----------
dist : str
Distribution string. See [1]_ for allowed distribution keys and [2]_
for the grammar of a distribution string.
method : str, optional
Method to use. See [3]_ for available methods. Default is 'auto'.
domain : str, optional
Domain o the distribution. Can be used to truncate the distributions.
seed : int, BitGenerator, Generator, RandomState, SeedSequence, optional
Seed for the underlying uniform random number generation.

Other Parameters
----------------
kwargs :
Other parameters that the Method takes. See [3]_ for all the available
methods and their parameters.

References
----------
.. [1] UNU.RAN reference manual, Section 3.2.1, "Keys for Distribution String",
http://statmath.wu.ac.at/software/unuran/doc/unuran.html#KeysDistr
.. [2] UNU.RAN reference manual, Section 3.2, "Distribution String",
http://statmath.wu.ac.at/software/unuran/doc/unuran.html#StringDistr
.. [3] UNU.RAN reference manual, Section 3.4.1, "Keys for Method String",
http://statmath.wu.ac.at/software/unuran/doc/unuran.html#KeysMethod

Examples
--------
>>> from scipy.stats import GenericGenerator

To create a sampler for the normal distribution, do:

>>> rng = GenericGenerator("normal")

Random variates can be sampled usin the `rvs` method:

>>> rvs = rng.rvs(100_000)

A method for sampling can also be specified by passing the `method`
parameter. In the example below, Transformed Density Rejection is used
as the method for sampling from the normal distirbution:

>>> rng = GenericGenerator("normal", method="tdr")

It is also possible to truncate the distribution by passing the `domain`
parameter and parameters of the `method` can be changed using keyword
arguments:

>>> rng = GenericGenerator("normal", method="tdr", domain=(-5, 5), c=0.,
... cpoints=(-4.5, -3.5, -2.5, 0.0, 2.5, 3.5, 4.5))
>>> rvs = rng.rvs(100_000)

One can also change the parameters of the distributions. The below example
creates a samplers for the normal distribution with mean 6 and standard
deviation 4.25:

>>> rng = GenericGenerator("normal(6, 4.25)")
"""
def __cinit__(self, dist, method="auto", domain=None, seed=None,
**kwargs):
dist, method, domain, seed = self._validate_args(dist, method,
domain, seed)
self._init_generator(dist, method, domain, seed, kwargs)
def _validate_args(self, dist, method, domain, seed):
dist = str(dist)
if dist == "":
raise ValueError("distribution string must not be empty")
method = str(method)
if domain is not None:
domain = tuple(domain)
return dist, method, domain, seed
cdef int _init_generator(self, dist, method, domain, seed,
kwargs) except -1:
cdef:
ccallback_t callback
unuran_callback_t unur_callback
unur_urng *urng = NULL
cdef str distr_str = f"{dist}"
if domain is not None:
distr_str += f"; domain = {domain}"
cdef str meth_str = f"method = {method}"
for key, value in kwargs.items():
meth_str += f"; {key} = {value}"
self.callbacks = {}
self.params = ()
init_unuran_callback(&callback, &unur_callback, self.callbacks,
self.params)
if setjmp(callback.error_buf) != 0:
release_unuran_callback(&callback, &unur_callback)
return -1
if seed is not None:
self.numpy_rng = get_numpy_rng(seed)
urng = get_urng(self.numpy_rng)
self.urng = urng
self.rng = unur_makegen_ssu(distr_str.encode('UTF-8'),
meth_str.encode('UTF-8'),
urng)
release_unuran_callback(&callback, &unur_callback)
return 0
cdef class TransformedDensityRejection(Method):
r"""
Transformed Density Rejection (TDR) Method.
Expand Down

0 comments on commit cd2bd96

Please sign in to comment.