Skip to content

Commit 6a59e92

Browse files
igorsugakfacebook-github-bot
authored andcommittedOct 18, 2024
[CODEMOD][pytorch] replace uses of np.ndarray with npt.NDArray (pytorch#3845)
Summary: X-link: pytorch/opacus#680 X-link: pytorch/captum#1387 X-link: pytorch/botorch#2584 This replaces uses of `numpy.ndarray` in type annotations with `numpy.typing.NDArray`. In Numpy-1.24.0+ `numpy.ndarray` is annotated as generic type. Without template parameters it triggers static analysis errors: ```counterexample Generic type `ndarray` expects 2 type parameters. ``` `numpy.typing.NDArray` is an alias that provides default template parameters. Differential Revision: D64619891
1 parent 79047bf commit 6a59e92

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed
 

‎test/torchaudio_unittest/prototype/functional/dsp_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import numpy.typing as npt
23

34

45
def oscillator_bank(
@@ -43,8 +44,8 @@ def freq_ir(magnitudes):
4344

4445

4546
def exp_sigmoid(
46-
input: np.ndarray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
47-
) -> np.ndarray:
47+
input: npt.NDArray, exponent: float = 10.0, max_value: float = 2.0, threshold: float = 1e-7
48+
) -> npt.NDArray:
4849
"""Exponential Sigmoid pointwise nonlinearity (Numpy version).
4950
Implements the equation:
5051
``max_value`` * sigmoid(``input``) ** (log(``exponent``)) + ``threshold``

0 commit comments

Comments
 (0)
Failed to load comments.