Skip to content

Commit

Permalink
Added ONNX STFT support, including unit tests. Addressed all CR
Browse files Browse the repository at this point in the history
comments.
  • Loading branch information
urinieto committed Feb 15, 2023
1 parent 65b9983 commit a95ffd7
Show file tree
Hide file tree
Showing 4 changed files with 495 additions and 8 deletions.
1 change: 1 addition & 0 deletions test/onnx/test_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def reason_flaky() -> str:
[
"ceil",
"sqrt",
"stft",
"t",
]
)
Expand Down
125 changes: 125 additions & 0 deletions test/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,131 @@ def test_dynamic_axes_unchange(self):
opset_version=12,
)

def test_stft_default(self):
"""Test STFT with default parameters"""
m1 = torch.randn((1, 32))
n_fft = 16
self.assertONNX(
lambda x: torch.stft(x, n_fft=n_fft, center=False, return_complex=False),
(m1,),
opset_version=17,
)

def test_stft_hop_length(self):
"""Test STFT with custom hop length"""
m1 = torch.randn((1, 32))
n_fft = 16
hop_length = 4
self.assertONNX(
lambda x: torch.stft(
x,
n_fft=n_fft,
center=False,
hop_length=hop_length,
return_complex=False,
),
(m1,),
opset_version=17,
)

def test_stft_non_divisible_hop_length(self):
"""Test STFT with non-divisible custom hop length"""
m1 = torch.randn((1, 32))
n_fft = 16
hop_length = 5
self.assertONNX(
lambda x: torch.stft(
x,
n_fft=n_fft,
center=False,
hop_length=hop_length,
return_complex=False,
),
(m1,),
opset_version=17,
)

def test_stft_window_int_same_size(self):
"""Test STFT with specific window length equals n_fft"""
m1 = torch.randn((1, 32))
n_fft = 16
win_length = 16
self.assertONNX(
lambda x: torch.stft(
x,
n_fft=n_fft,
center=False,
win_length=win_length,
return_complex=False,
),
(m1,),
opset_version=17,
)

def test_stft_window_int_different_size(self):
"""Test STFT with specific window length different than n_fft"""
m1 = torch.randn((1, 32))
n_fft = 16
win_length = 9
self.assertONNX(
lambda x: torch.stft(
x,
n_fft=n_fft,
center=False,
win_length=win_length,
return_complex=False,
),
(m1,),
opset_version=17,
)

def test_stft_window_custom(self):
"""Test STFT with a custom window"""
m1 = torch.randn((1, 32))
n_fft = 16
window = torch.hann_window(16)
self.assertONNX(
lambda x: torch.stft(
x, n_fft=n_fft, center=False, window=window, return_complex=False
),
(m1,),
opset_version=17,
)

def test_stft_one_dimension(self):
"""Test STFT with a single dimension"""
m1 = torch.randn((32))
n_fft = 16
self.assertONNX(
lambda x: torch.stft(x, n_fft=n_fft, center=False, return_complex=False),
(m1,),
opset_version=17,
)

def test_stft_normalize(self):
"""Test STFT with normalization"""
m1 = torch.randn((32))
n_fft = 16
self.assertONNX(
lambda x: torch.stft(
x, n_fft=n_fft, center=False, normalized=True, return_complex=False
),
(m1,),
opset_version=17,
)

def test_stft_not_onesided(self):
"""Test STFT without returning a single side"""
m1 = torch.randn((32))
n_fft = 16
self.assertONNX(
lambda x: torch.stft(
x, n_fft=n_fft, center=False, onesided=False, return_complex=False
),
(m1,),
opset_version=17,
)

def test_aten_embedding_1(self):
_onnx_opset_version = 12

Expand Down

0 comments on commit a95ffd7

Please sign in to comment.