Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ChiSquare via Gamma #490

Merged
merged 4 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions doc/library/tensor/random/basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ PyTensor can produce :class:`RandomVariable`\s that draw samples from many diffe
.. autoclass:: pytensor.tensor.random.basic.CategoricalRV
:members: __call__

.. autoclass:: pytensor.tensor.random.basic.ChiSquareRV
:members: __call__

.. autoclass:: pytensor.tensor.random.basic.DirichletRV
:members: __call__

Expand Down
1 change: 0 additions & 1 deletion pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def {sized_fn_name}({random_fn_input_names}):
@numba_funcify.register(aer.NormalRV)
@numba_funcify.register(aer.LogNormalRV)
@numba_funcify.register(aer.GammaRV)
@numba_funcify.register(aer.ChiSquareRV)
@numba_funcify.register(aer.ParetoRV)
@numba_funcify.register(aer.GumbelRV)
@numba_funcify.register(aer.ExponentialRV)
Expand Down
52 changes: 2 additions & 50 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,56 +487,8 @@ def gamma(shape, rate=None, scale=None, **kwargs):
return _gamma(shape, scale, **kwargs)


class ChiSquareRV(RandomVariable):
r"""A chi square continuous random variable.

The probability density function for `chisquare` in terms of the number of degrees of
freedom :math:`k` is:

.. math::

f(x; k) = \frac{(1/2)^{k/2}}{\Gamma(k/2)} x^{k/2-1} e^{-x/2}

for :math:`k > 2`. :math:`\Gamma` is the gamma function:

.. math::

\Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t


This variable is obtained by summing the squares :math:`k` independent, standard normally
distributed random variables.

"""
name = "chisquare"
ndim_supp = 0
ndims_params = [0]
dtype = "floatX"
_print_name = ("ChiSquare", "\\operatorname{ChiSquare}")

def __call__(self, df, size=None, **kwargs):
r"""Draw samples from a chisquare distribution.

Signature
---------

`() -> ()`

Parameters
----------
df
The number :math:`k` of degrees of freedom. Must be positive.
size
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
independent, identically distributed random variables are
returned. Default is `None` in which case a single random variable
is returned.

"""
return super().__call__(df, size=size, **kwargs)


chisquare = ChiSquareRV()
def chisquare(df, **kwargs):
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
return gamma(shape=df / 2.0, scale=2.0, **kwargs)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
wd60622 marked this conversation as resolved.
Show resolved Hide resolved


class ParetoRV(ScipyRandomVariable):
Expand Down
13 changes: 0 additions & 13 deletions pytensor/tensor/random/rewriting/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
BetaBinomialRV,
ChiSquareRV,
GenGammaRV,
GeometricRV,
HalfNormalRV,
Expand Down Expand Up @@ -104,13 +103,6 @@ def inverse_gamma_from_gamma(fgraph, node):
return [next_rng, reciprocal(g)]


@node_rewriter([ChiSquareRV])
def chi_square_from_gamma(fgraph, node):
*other_inputs, df = node.inputs
next_rng, g = _gamma.make_node(*other_inputs, df / 2, 2).outputs
return [next_rng, g]


@node_rewriter([GenGammaRV])
def generalized_gamma_from_gamma(fgraph, node):
*other_inputs, alpha, p, lambd = node.inputs
Expand Down Expand Up @@ -171,11 +163,6 @@ def beta_binomial_from_beta_binomial(fgraph, node):
in2out(inverse_gamma_from_gamma),
"jax",
)
random_vars_opt.register(
"chi_square_from_gamma",
in2out(chi_square_from_gamma),
"jax",
)
random_vars_opt.register(
"generalized_gamma_from_gamma",
in2out(generalized_gamma_from_gamma),
Expand Down