diff --git a/pyttb/tensor.py b/pyttb/tensor.py index 38cca47a..5144d172 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -1366,7 +1366,7 @@ def reshape(self, shape: Shape) -> tensor: def scale( self, factor: np.ndarray | ttb.tensor, - dims: float | np.ndarray, + dims: OneDArray, ) -> tensor: """ Scale along specified dimensions for tensors. @@ -1398,11 +1398,6 @@ def scale( [ 1., 4., 7., 10.], [ 2., 5., 8., 11.]]) """ - if isinstance(dims, list): - dims = np.array(dims) - elif isinstance(dims, (float, int, np.generic)): - dims = np.array([dims]) - # TODO update tt_dimscheck overload so I don't need explicit # Nones to appease mypy dims, _ = tt_dimscheck(self.ndims, None, dims, None)