Skip to content

Commit

Permalink
Ensure axis masking operations are not in-place (#1481)
Browse files Browse the repository at this point in the history
It was reported in #1478 that spectrogram masking operations were done in-place and modified the original input tensors. This PR fixes this behavior and adds tests to ensure that the input tensor is not changed.
  • Loading branch information
Caroline Chen committed May 3, 2021
1 parent b540e5d commit 7fd5fce
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
32 changes: 32 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,38 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()

@parameterized.expand(
list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2]))
)
def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis):
"""mask_along_axis should not alter original input Tensor
Test is run 5 times to bound the probability of no masking occurring to 1e-10
See https://github.com/pytorch/audio/issues/1478
"""
torch.random.manual_seed(42)
for _ in range(5):
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
specgram_copy = specgram.clone()
F.mask_along_axis(specgram, mask_param, mask_value, axis)

self.assertEqual(specgram, specgram_copy)

@parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3])))
def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis):
"""mask_along_axis_iid should not alter original input Tensor
Test is run 5 times to bound the probability of no masking occurring to 1e-10
See https://github.com/pytorch/audio/issues/1478
"""
torch.random.manual_seed(42)
for _ in range(5):
specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
specgrams_copy = specgrams.clone()
F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)

self.assertEqual(specgrams, specgrams_copy)


class FunctionalComplex(TestBaseMixin):
complex_dtype = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@


class Functional(TempDirMixin, TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices"""
"""Implements test for `functional` module that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)

path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

torch.random.manual_seed(40)
output = func(tensor)

torch.random.manual_seed(40)
ts_output = ts_func(tensor)

if shape_only:
ts_output = ts_output.shape
output = output.shape
Expand Down
17 changes: 9 additions & 8 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def mask_along_axis_iid(

# Per batch example masking
specgrams = specgrams.transpose(axis, -1)
specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
specgrams = specgrams.transpose(axis, -1)

return specgrams
Expand All @@ -772,24 +772,25 @@ def mask_along_axis(
Returns:
Tensor: Masked spectrogram of dimensions (channel, freq, time)
"""
if axis != 1 and axis != 2:
raise ValueError('Only Frequency and Time masking are supported')

# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))

value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)

mask_start = (min_value.long()).squeeze()
mask_end = (min_value.long() + value.long()).squeeze()
mask = torch.arange(0, specgram.shape[axis], device=specgram.device, dtype=specgram.dtype)
mask = (mask >= mask_start) & (mask < mask_end)
if axis == 1:
mask = mask.unsqueeze(-1)

assert mask_end - mask_start < mask_param
if axis == 1:
specgram[:, mask_start:mask_end] = mask_value
elif axis == 2:
specgram[:, :, mask_start:mask_end] = mask_value
else:
raise ValueError('Only Frequency and Time masking are supported')

specgram = specgram.masked_fill(mask, mask_value)

# unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
Expand Down

0 comments on commit 7fd5fce

Please sign in to comment.