Skip to content

Commit

Permalink
[ENH] Refactor of BaseDistribution and descendants - generalised di…
Browse files Browse the repository at this point in the history
…stribution param broadcasting in base class (#5176)

Mirror of
[sktime/skpro#21

Moves the `_get_bc_params` method from child distributions to the parent
distribtion class, `BaseDistribution`.

This implementation is quite simple and doesn't use `_tags` and still
means the child distributions have to call the `_get_bc_params` method.
  • Loading branch information
Alex-JG3 committed Sep 9, 2023
1 parent 7ad08e0 commit 7affb72
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 34 deletions.
37 changes: 37 additions & 0 deletions sktime/proba/base.py
Expand Up @@ -158,6 +158,43 @@ def _method_error_msg(self, method="this method", severity="warn", fill_in=None)
else:
return msg

def _get_bc_params(self, *args, dtype=None):
"""Fully broadcast tuple of parameters given param shapes and index, columns.
Parameters
----------
args : float, int, array of floats, or array of ints (1D or 2D)
Distribution parameters that are to be made broadcastable. If no positional
arguments are provided, all parameters of `self` are used except for `index`
and `columns`.
dtype : str, optional
broadcasted arrays are cast to all have datatype `dtype`. If None, then no
datatype casting is done.
Returns
-------
Tuple of float or integer arrays
Each element of the tuple represents a different broadcastable distribution
parameter.
"""
number_of_params = len(args)
if number_of_params == 0:
# Handle case where no positional arguments are provided
params = self.get_params()
params.pop("index")
params.pop("columns")
args = tuple(params.values())
number_of_params = len(args)

if hasattr(self, "index") and self.index is not None:
args += (self.index.to_numpy().reshape(-1, 1),)
if hasattr(self, "columns") and self.columns is not None:
args += (self.columns.to_numpy(),)
bc = np.broadcast_arrays(*args)
if dtype is not None:
bc = [array.astype(dtype) for array in bc]
return bc[:number_of_params]

def pdf(self, x):
r"""Probability density function.
Expand Down
12 changes: 1 addition & 11 deletions sktime/proba/normal.py
Expand Up @@ -45,7 +45,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
# and broadcast of parameters.
# move this functionality to the base class
# 0.19.0?
self._mu, self._sigma = self._get_bc_params()
self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma)
shape = self._mu.shape

if index is None:
Expand All @@ -56,16 +56,6 @@ def __init__(self, mu, sigma, index=None, columns=None):

super().__init__(index=index, columns=columns)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.sigma]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return bc[0], bc[1]

def energy(self, x=None):
r"""Energy of self, w.r.t. self or a constant frame x.
Expand Down
14 changes: 3 additions & 11 deletions sktime/proba/t.py
Expand Up @@ -44,7 +44,9 @@ def __init__(self, mu, sigma, df=1, index=None, columns=None):
self.index = index
self.columns = columns

self._mu, self._sigma, self._df = self._get_bc_params()
self._mu, self._sigma, self._df = self._get_bc_params(
self.mu, self.sigma, self.df
)
shape = self._mu.shape

if index is None:
Expand All @@ -55,16 +57,6 @@ def __init__(self, mu, sigma, df=1, index=None, columns=None):

super().__init__(index=index, columns=columns)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.sigma, self.df]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return bc[0], bc[1], bc[2]

def mean(self):
r"""Return expected value of the distribution.
Expand Down
2 changes: 1 addition & 1 deletion sktime/proba/tests/test_base_default_methods.py
Expand Up @@ -34,7 +34,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
self.index = index
self.columns = columns

self._mu, self._sigma = self._get_bc_params()
self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma)
shape = self._mu.shape

if index is None:
Expand Down
12 changes: 1 addition & 11 deletions sktime/proba/tfp.py
Expand Up @@ -52,7 +52,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
# and broadcast of parameters.
# move this functionality to the base class
# 0.19.0?
self._mu, self._sigma = self._get_bc_params()
self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma, dtype="float")
distr = tfd.Normal(loc=self._mu, scale=self._sigma)
shape = self._mu.shape

Expand All @@ -64,16 +64,6 @@ def __init__(self, mu, sigma, index=None, columns=None):

super().__init__(index=index, columns=columns, distr=distr)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.sigma]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return np.array(bc[0], dtype="float"), np.array(bc[1], dtype="float")

def energy(self, x=None):
r"""Energy of self, w.r.t. self or a constant frame x.
Expand Down

0 comments on commit 7affb72

Please sign in to comment.