Skip to content

Commit

Permalink
REF: stats: use a dist object as input and remove string API
Browse files Browse the repository at this point in the history
Take a distribution object as input.
  • Loading branch information
tirthasheshpatel committed Jun 20, 2021
1 parent cd2bd96 commit 5823dbf
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 80 deletions.
3 changes: 2 additions & 1 deletion scipy/stats/unuran/__init__.py
@@ -1,4 +1,5 @@
from .unuran_wrapper import (
TransformedDensityRejection,
DiscreteAliasUrn
DiscreteAliasUrn,
GenericGenerator
)
111 changes: 32 additions & 79 deletions scipy/stats/unuran/unuran_wrapper.pyx.templ
Expand Up @@ -474,7 +474,7 @@ cdef class Method:
unur_urng_free(self.urng)


cdef class GenericGenerator(Method):
class GenericGenerator:
"""
Generic generator to sample from a wide family of distributions.

Expand All @@ -486,8 +486,8 @@ cdef class GenericGenerator(Method):
dist : str
Distribution string. See [1]_ for allowed distribution keys and [2]_
for the grammar of a distribution string.
method : str, optional
Method to use. See [3]_ for available methods. Default is 'auto'.
method : str
Method to use. Can be 'tdr' or 'dau'.
domain : str, optional
Domain o the distribution. Can be used to truncate the distributions.
seed : int, BitGenerator, Generator, RandomState, SeedSequence, optional
Expand All @@ -496,101 +496,54 @@ cdef class GenericGenerator(Method):
Other Parameters
----------------
kwargs :
Other parameters that the Method takes. See [3]_ for all the available
methods and their parameters.

References
----------
.. [1] UNU.RAN reference manual, Section 3.2.1, "Keys for Distribution String",
http://statmath.wu.ac.at/software/unuran/doc/unuran.html#KeysDistr
.. [2] UNU.RAN reference manual, Section 3.2, "Distribution String",
http://statmath.wu.ac.at/software/unuran/doc/unuran.html#StringDistr
.. [3] UNU.RAN reference manual, Section 3.4.1, "Keys for Method String",
http://statmath.wu.ac.at/software/unuran/doc/unuran.html#KeysMethod
Other parameters that the method takes.

Examples
--------
>>> from scipy.stats import GenericGenerator
>>>
>>> class Distribution:
... def pdf(self, x):
... return np.exp(-0.5 * x*x)
... def dpdf(self, x):
... return -x * np.exp(-0.5 * x*x)
...
>>> dist = Distribution()

To create a sampler for the normal distribution, do:
To create a sampler for this distribution, do:

>>> rng = GenericGenerator("normal")
>>> rng = GenericGenerator(dist, method="tdr")

Random variates can be sampled usin the `rvs` method:
Random variates can be sampled using the `rvs` method:

>>> rvs = rng.rvs(100_000)

A method for sampling can also be specified by passing the `method`
parameter. In the example below, Transformed Density Rejection is used
as the method for sampling from the normal distirbution:

>>> rng = GenericGenerator("normal", method="tdr")

It is also possible to truncate the distribution by passing the `domain`
parameter and parameters of the `method` can be changed using keyword
arguments:

>>> rng = GenericGenerator("normal", method="tdr", domain=(-5, 5), c=0.,
>>> rng = GenericGenerator(dist, method="tdr", domain=(-5, 5), c=0.,
... cpoints=(-4.5, -3.5, -2.5, 0.0, 2.5, 3.5, 4.5))
>>> rvs = rng.rvs(100_000)

One can also change the parameters of the distributions. The below example
creates a samplers for the normal distribution with mean 6 and standard
deviation 4.25:

>>> rng = GenericGenerator("normal(6, 4.25)")
"""
_allmethods = {
'tdr': TransformedDensityRejection,
'dau': DiscreteAliasUrn
}

def __cinit__(self, dist, method="auto", domain=None, seed=None,
**kwargs):
dist, method, domain, seed = self._validate_args(dist, method,
domain, seed)
self._init_generator(dist, method, domain, seed, kwargs)

def _validate_args(self, dist, method, domain, seed):
dist = str(dist)
if dist == "":
raise ValueError("distribution string must not be empty")
def __init__(self, dist, method, domain=None, seed=None, **kwargs):
method = str(method)
if domain is not None:
domain = tuple(domain)
return dist, method, domain, seed

cdef int _init_generator(self, dist, method, domain, seed,
kwargs) except -1:
cdef:
ccallback_t callback
unuran_callback_t unur_callback
unur_urng *urng = NULL

cdef str distr_str = f"{dist}"
if domain is not None:
distr_str += f"; domain = {domain}"
cdef str meth_str = f"method = {method}"
for key, value in kwargs.items():
meth_str += f"; {key} = {value}"

self.callbacks = {}
self.params = ()

init_unuran_callback(&callback, &unur_callback, self.callbacks,
self.params)
if setjmp(callback.error_buf) != 0:
release_unuran_callback(&callback, &unur_callback)
return -1

if seed is not None:
self.numpy_rng = get_numpy_rng(seed)
urng = get_urng(self.numpy_rng)

self.urng = urng
self.rng = unur_makegen_ssu(distr_str.encode('UTF-8'),
meth_str.encode('UTF-8'),
urng)

release_unuran_callback(&callback, &unur_callback)

return 0
if method.lower() not in self._allmethods:
raise ValueError(f"unknown method `{method}`")
method = method.lower()
self.method = self._allmethods[method](dist=dist, domain=domain,
seed=seed, **kwargs)

# add all the public attributes/methods to our generator.
for attrname in dir(self.method):
if not attrname.startswith("_"):
attr = getattr(self.method, attrname)
self.__dict__[attrname] = attr


cdef class TransformedDensityRejection(Method):
Expand Down

0 comments on commit 5823dbf

Please sign in to comment.