Skip to content

Commit

Permalink
Use tuples throughout _get_bc_params
Browse files Browse the repository at this point in the history
Update distributions to pass tuples to _get_bc_params
  • Loading branch information
Alex-JG3 committed Aug 29, 2023
1 parent 9569688 commit ae7f366
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
14 changes: 7 additions & 7 deletions sktime/proba/base.py
Expand Up @@ -158,17 +158,17 @@ def _method_error_msg(self, method="this method", severity="warn", fill_in=None)
else:
return msg

def _get_bc_params(self, to_broadcast, dtype=None):
"""Fully broadcast parameters of self, given param shapes and index, columns."""
length = len(to_broadcast)
def _get_bc_params(self, *args, dtype=None):
"""Fully broadcast tuple of parameters given param shapes and index, columns."""
number_of_params = len(args)
if hasattr(self, "index") and self.index is not None:
to_broadcast += [self.index.to_numpy().reshape(-1, 1)]
args += tuple(self.index.to_numpy().reshape(-1, 1))
if hasattr(self, "columns") and self.columns is not None:
to_broadcast += [self.columns.to_numpy()]
bc = np.broadcast_arrays(*to_broadcast)
args += tuple(self.columns.to_numpy())
bc = np.broadcast_arrays(*args)
if dtype is not None:
bc = [array.astype(dtype) for array in bc]
return bc[:length]
return bc[:number_of_params]

def pdf(self, x):
r"""Probability density function.
Expand Down
2 changes: 1 addition & 1 deletion sktime/proba/normal.py
Expand Up @@ -45,7 +45,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
# and broadcast of parameters.
# move this functionality to the base class
# 0.19.0?
self._mu, self._sigma = self._get_bc_params([self.mu, self.sigma])
self._mu, self._sigma = self._get_bc_params(*(self.mu, self.sigma))
shape = self._mu.shape

if index is None:
Expand Down
2 changes: 1 addition & 1 deletion sktime/proba/t.py
Expand Up @@ -45,7 +45,7 @@ def __init__(self, mu, sigma, df=1, index=None, columns=None):
self.columns = columns

self._mu, self._sigma, self._df = self._get_bc_params(
[self.mu, self.sigma, self.df]
*(self.mu, self.sigma, self.df)
)
shape = self._mu.shape

Expand Down
2 changes: 1 addition & 1 deletion sktime/proba/tests/test_base_default_methods.py
Expand Up @@ -34,7 +34,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
self.index = index
self.columns = columns

self._mu, self._sigma = self._get_bc_params([self.mu, self.sigma])
self._mu, self._sigma = self._get_bc_params(*(self.mu, self.sigma))
shape = self._mu.shape

if index is None:
Expand Down
2 changes: 1 addition & 1 deletion sktime/proba/tfp.py
Expand Up @@ -53,7 +53,7 @@ def __init__(self, mu, sigma, index=None, columns=None):
# move this functionality to the base class
# 0.19.0?
self._mu, self._sigma = self._get_bc_params(
[self.mu, self.sigma], dtype="float"
*(self.mu, self.sigma), dtype="float"
)
distr = tfd.Normal(loc=self._mu, scale=self._sigma)
shape = self._mu.shape
Expand Down

0 comments on commit ae7f366

Please sign in to comment.