Skip to content

Commit

Permalink
Implement ChiSquare via Gamma (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Nov 15, 2023
1 parent 9653ade commit 902eeb6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 52 deletions.
3 changes: 0 additions & 3 deletions doc/library/tensor/random/basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ PyTensor can produce :class:`RandomVariable`\s that draw samples from many diffe
.. autoclass:: pytensor.tensor.random.basic.CategoricalRV
:members: __call__

.. autoclass:: pytensor.tensor.random.basic.ChiSquareRV
:members: __call__

.. autoclass:: pytensor.tensor.random.basic.DirichletRV
:members: __call__

Expand Down
1 change: 0 additions & 1 deletion pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def {sized_fn_name}({random_fn_input_names}):
@numba_funcify.register(aer.NormalRV)
@numba_funcify.register(aer.LogNormalRV)
@numba_funcify.register(aer.GammaRV)
@numba_funcify.register(aer.ChiSquareRV)
@numba_funcify.register(aer.ParetoRV)
@numba_funcify.register(aer.GumbelRV)
@numba_funcify.register(aer.ExponentialRV)
Expand Down
51 changes: 16 additions & 35 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,56 +487,37 @@ def gamma(shape, rate=None, scale=None, **kwargs):
return _gamma(shape, scale, **kwargs)


class ChiSquareRV(RandomVariable):
r"""A chi square continuous random variable.
def chisquare(df, size=None, **kwargs):
r"""Draw samples from a chisquare distribution.
The probability density function for `chisquare` in terms of the number of degrees of
freedom :math:`k` is:
.. math::
f(x; k) = \frac{(1/2)^{k/2}}{\Gamma(k/2)} x^{k/2-1} e^{-x/2}
for :math:`k > 2`. :math:`\Gamma` is the gamma function:
.. math::
\Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t
This variable is obtained by summing the squares :math:`k` independent, standard normally
distributed random variables.
"""
name = "chisquare"
ndim_supp = 0
ndims_params = [0]
dtype = "floatX"
_print_name = ("ChiSquare", "\\operatorname{ChiSquare}")

def __call__(self, df, size=None, **kwargs):
r"""Draw samples from a chisquare distribution.
Signature
---------
`() -> ()`
Parameters
----------
df
The number :math:`k` of degrees of freedom. Must be positive.
size
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
independent, identically distributed random variables are
returned. Default is `None` in which case a single random variable
is returned.
"""
return super().__call__(df, size=size, **kwargs)

Signature
---------
`() -> ()`
chisquare = ChiSquareRV()
Parameters
----------
df
The number :math:`k` of degrees of freedom. Must be positive.
size
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
independent, identically distributed random variables are
returned. Default is `None` in which case a single random variable
is returned.
"""
return gamma(shape=df / 2.0, scale=2.0, size=size, **kwargs)


class ParetoRV(ScipyRandomVariable):
Expand Down
13 changes: 0 additions & 13 deletions pytensor/tensor/random/rewriting/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
BetaBinomialRV,
ChiSquareRV,
GenGammaRV,
GeometricRV,
HalfNormalRV,
Expand Down Expand Up @@ -104,13 +103,6 @@ def inverse_gamma_from_gamma(fgraph, node):
return [next_rng, reciprocal(g)]


@node_rewriter([ChiSquareRV])
def chi_square_from_gamma(fgraph, node):
*other_inputs, df = node.inputs
next_rng, g = _gamma.make_node(*other_inputs, df / 2, 2).outputs
return [next_rng, g]


@node_rewriter([GenGammaRV])
def generalized_gamma_from_gamma(fgraph, node):
*other_inputs, alpha, p, lambd = node.inputs
Expand Down Expand Up @@ -171,11 +163,6 @@ def beta_binomial_from_beta_binomial(fgraph, node):
in2out(inverse_gamma_from_gamma),
"jax",
)
random_vars_opt.register(
"chi_square_from_gamma",
in2out(chi_square_from_gamma),
"jax",
)
random_vars_opt.register(
"generalized_gamma_from_gamma",
in2out(generalized_gamma_from_gamma),
Expand Down

0 comments on commit 902eeb6

Please sign in to comment.