Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Refactor of BaseDistribution and descendants - generalised distribution param broadcasting in base class #5176

Merged
merged 8 commits into from Sep 9, 2023
37 changes: 37 additions & 0 deletions sktime/proba/base.py
Expand Up @@ -158,6 +158,43 @@ def _method_error_msg(self, method="this method", severity="warn", fill_in=None)
else:
return msg

def _get_bc_params(self, *args, dtype=None):
"""Fully broadcast tuple of parameters given param shapes and index, columns.

Parameters
----------
args : float, int, array of floats, or array of ints (1D or 2D)
Distribution parameters that are to be made broadcastable. If no positional
arguments are provided, all parameters of `self` are used except for `index`
and `columns`.
dtype : str, optional
broadcasted arrays are cast to all have datatype `dtype`. If None, then no
datatype casting is done.

Returns
-------
Tuple of float or integer arrays
Each element of the tuple represents a different broadcastable distribution
parameter.
"""
number_of_params = len(args)
if number_of_params == 0:
# Handle case where no positional arguments are provided
params = self.get_params()
params.pop("index")
params.pop("columns")
args = tuple(params.values())
number_of_params = len(args)

if hasattr(self, "index") and self.index is not None:
args += (self.index.to_numpy().reshape(-1, 1),)
if hasattr(self, "columns") and self.columns is not None:
args += (self.columns.to_numpy(),)
bc = np.broadcast_arrays(*args)
if dtype is not None:
bc = [array.astype(dtype) for array in bc]
return bc[:number_of_params]

def pdf(self, x):
r"""Probability density function.

Expand Down
12 changes: 1 addition & 11 deletions sktime/proba/normal.py
Expand Up @@ -45,7 +45,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
# and broadcast of parameters.
# move this functionality to the base class
# 0.19.0?
self._mu, self._sigma = self._get_bc_params()
self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma)
shape = self._mu.shape

if index is None:
Expand All @@ -56,16 +56,6 @@ def __init__(self, mu, sigma, index=None, columns=None):

super().__init__(index=index, columns=columns)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.sigma]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return bc[0], bc[1]

def energy(self, x=None):
r"""Energy of self, w.r.t. self or a constant frame x.

Expand Down
14 changes: 3 additions & 11 deletions sktime/proba/t.py
Expand Up @@ -44,7 +44,9 @@ def __init__(self, mu, sigma, df=1, index=None, columns=None):
self.index = index
self.columns = columns

self._mu, self._sigma, self._df = self._get_bc_params()
self._mu, self._sigma, self._df = self._get_bc_params(
self.mu, self.sigma, self.df
)
shape = self._mu.shape

if index is None:
Expand All @@ -55,16 +57,6 @@ def __init__(self, mu, sigma, df=1, index=None, columns=None):

super().__init__(index=index, columns=columns)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.sigma, self.df]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return bc[0], bc[1], bc[2]

def mean(self):
r"""Return expected value of the distribution.

Expand Down
2 changes: 1 addition & 1 deletion sktime/proba/tests/test_base_default_methods.py
Expand Up @@ -34,7 +34,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
self.index = index
self.columns = columns

self._mu, self._sigma = self._get_bc_params()
self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma)
shape = self._mu.shape

if index is None:
Expand Down
12 changes: 1 addition & 11 deletions sktime/proba/tfp.py
Expand Up @@ -52,7 +52,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
# and broadcast of parameters.
# move this functionality to the base class
# 0.19.0?
self._mu, self._sigma = self._get_bc_params()
self._mu, self._sigma = self._get_bc_params(self.mu, self.sigma, dtype="float")
distr = tfd.Normal(loc=self._mu, scale=self._sigma)
shape = self._mu.shape

Expand All @@ -64,16 +64,6 @@ def __init__(self, mu, sigma, index=None, columns=None):

super().__init__(index=index, columns=columns, distr=distr)

def _get_bc_params(self):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
to_broadcast = [self.mu, self.sigma]
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
return np.array(bc[0], dtype="float"), np.array(bc[1], dtype="float")

def energy(self, x=None):
r"""Energy of self, w.r.t. self or a constant frame x.

Expand Down