Skip to content

Commit

Permalink
Simplify helper function (#514)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Apr 6, 2020
1 parent 3695a0e commit bc1ffb1
Showing 1 changed file with 5 additions and 28 deletions.
33 changes: 5 additions & 28 deletions test/test_batch_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,7 @@
from common_utils import AudioBackendScope, BACKENDS


def _test_batch_shape(functional, tensor, *args, **kwargs):

kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol

if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol

def _test_batch_shape(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
# Single then transform then batch

torch.random.manual_seed(42)
Expand All @@ -36,24 +24,13 @@ def _test_batch_shape(functional, tensor, *args, **kwargs):
computed = functional(tensors.clone(), *args, **kwargs)

assert expected.shape == computed.shape, (expected.shape, computed.shape)
assert torch.allclose(expected, computed, **kwargs_compare)
assert torch.allclose(expected, computed, atol=atol, rtol=rtol)

return tensors, expected


def _test_batch(functional, tensor, *args, **kwargs):
tensors, expected = _test_batch_shape(functional, tensor, *args, **kwargs)

kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol

if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol
def _test_batch(functional, tensor, *args, atol=1e-8, rtol=1e-5, **kwargs):
tensors, expected = _test_batch_shape(functional, tensor, *args, atol=atol, rtol=rtol, **kwargs)

# 3-Batch then transform

Expand All @@ -67,7 +44,7 @@ def _test_batch(functional, tensor, *args, **kwargs):
computed = functional(tensors.clone(), *args, **kwargs)

assert expected.shape == computed.shape, (expected.shape, computed.shape)
assert torch.allclose(expected, computed, **kwargs_compare)
assert torch.allclose(expected, computed, atol=atol, rtol=rtol)


class TestFunctional(unittest.TestCase):
Expand Down

0 comments on commit bc1ffb1

Please sign in to comment.