Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta)
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
gradgradcheck)
Expand Down Expand Up @@ -766,6 +766,26 @@ def test_device_ctx_init(self, device, dtype, module_info, training):
assert_metadata_eq(self.assertEqual, p_meta, p_cpu)


@modules([module for module in module_db if module.module_error_inputs_func is not None])
def test_errors(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
error_inputs = module_info.module_error_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False, training=training)
for error_input in error_inputs:
module_input = error_input.module_error_input
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
if error_input.error_on == ModuleErrorEnum.CONSTRUCTION_ERROR:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
m = module_cls(*c_args, **c_kwargs)
elif error_input.error_on == ModuleErrorEnum.FORWARD_ERROR:
m = module_cls(*c_args, **c_kwargs)
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
m(*fw_args, **fw_kwargs)
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")


instantiate_device_type_tests(TestModule, globals(), allow_mps=True)

if __name__ == '__main__':
Expand Down
16 changes: 0 additions & 16 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,22 +2876,6 @@ def test_RNN_cell(self):

hx.sum().backward()

def test_RNN_cell_forward_input_size(self):
input = torch.randn(3, 11)
hx = torch.randn(3, 20)
for module in (nn.RNNCell, nn.GRUCell):
cell = module(10, 20)
self.assertRaises(Exception, lambda: cell(input, hx))

def test_RNN_cell_forward_hidden_size(self):
input = torch.randn(3, 10)
hx = torch.randn(3, 21)
cell_shared_param = (10, 20)
for cell in (nn.RNNCell(*cell_shared_param, nonlinearity="relu"),
nn.RNNCell(*cell_shared_param, nonlinearity="tanh"),
nn.GRUCell(*cell_shared_param)):
self.assertRaises(Exception, lambda: cell(input, hx))

def test_RNN_cell_forward_zero_hidden_size(self):
input = torch.randn(3, 10)
hx = torch.randn(3, 0)
Expand Down
76 changes: 73 additions & 3 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,30 @@ def copy_reference_fn(m, *args, **kwargs):

self.reference_fn = copy_reference_fn

class ModuleErrorEnum(Enum):
""" Enumerates when error is raised when testing modules. """
CONSTRUCTION_ERROR = 0
FORWARD_ERROR = 1

class ErrorModuleInput:
"""
A ModuleInput that will cause the operation to throw an error plus information
about the resulting error.
"""

__slots__ = ["module_error_input", "error_on", "error_type", "error_regex"]

def __init__(self,
module_error_input,
*,
error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
error_type=RuntimeError,
error_regex):
self.module_error_input = module_error_input
self.error_on = error_on
self.error_type = error_type
self.error_regex = error_regex


class ModuleInfo:
""" Module information to be used in testing. """
Expand All @@ -182,6 +206,7 @@ def __init__(self,
module_memformat_affects_out=False, # whether converting module to channels last will generate
# channels last output
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
module_error_inputs_func=None, # Function to generate module inputs that error
):
self.module_cls = module_cls
self.module_inputs_func = module_inputs_func
Expand All @@ -191,6 +216,7 @@ def __init__(self,
self.gradcheck_nondet_tol = gradcheck_nondet_tol
self.module_memformat_affects_out = module_memformat_affects_out
self.train_and_eval_differ = train_and_eval_differ
self.module_error_inputs_func = module_error_inputs_func

def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
result = [set_single_threaded_if_parallel_tbb]
Expand All @@ -210,6 +236,7 @@ def name(self):
def formatted_name(self):
return self.name.replace('.', '_')

# Start of module inputs functions.

def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down Expand Up @@ -2206,9 +2233,6 @@ def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_gr
),
]




# All these operators share similar issues on cuDNN and MIOpen
rnn_gru_lstm_module_info_decorators = (
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
Expand Down Expand Up @@ -2243,6 +2267,50 @@ def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_gr
)
)

# Start of module error inputs functions.

def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, 20),
forward_input=FunctionInput(make_input(3, 11), make_input(3, 20)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="input has inconsistent input_size: got 11 expected 10"
),
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, 20),
forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
),
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, 20, 'relu'),
forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
),
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, 20, 'tanh'),
forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
),
]
return samples

# Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [
ModuleInfo(torch.nn.AdaptiveAvgPool1d,
Expand Down Expand Up @@ -2912,11 +2980,13 @@ def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_gr
),
ModuleInfo(torch.nn.RNNCell,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
),
ModuleInfo(torch.nn.GRUCell,
module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell,
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
),
Expand Down