Skip to content

Commit

Permalink
fix the issue described by #108223
Browse files Browse the repository at this point in the history
  • Loading branch information
FFFrog committed Sep 8, 2023
1 parent b193f29 commit f03fd5b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
24 changes: 24 additions & 0 deletions test/cpp/api/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,27 @@ TEST_F(RNNTest, UsePackedSequenceAsInput) {
std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
}
}

TEST_F(RNNTest, CheckErrorInfos) {
{
auto options = torch::nn::RNNOptions(1, 0).num_layers(1);
ASSERT_THROWS_WITH(RNN(options), "hidden_size must be greater than zero");

options = torch::nn::RNNOptions(1, 1).num_layers(0);
ASSERT_THROWS_WITH(RNN(options), "num_layers must be greater than zero");
}
{
auto options = torch::nn::LSTMOptions(1, 0).num_layers(1);
ASSERT_THROWS_WITH(LSTM(options), "hidden_size must be greater than zero");

options = torch::nn::LSTMOptions(1, 1).num_layers(0);
ASSERT_THROWS_WITH(LSTM(options), "num_layers must be greater than zero");
}
{
auto options = torch::nn::GRUOptions(1, 0).num_layers(1);
ASSERT_THROWS_WITH(GRU(options), "hidden_size must be greater than zero");

options = torch::nn::GRUOptions(1, 1).num_layers(0);
ASSERT_THROWS_WITH(GRU(options), "num_layers must be greater than zero");
}
}
8 changes: 8 additions & 0 deletions torch/csrc/api/src/nn/modules/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ void RNNImplBase<Derived>::reset() {
options_base.num_layers());
}

TORCH_CHECK(
options_base.hidden_size() > 0,
"hidden_size must be greater than zero");

TORCH_CHECK(
options_base.num_layers() > 0,
"num_layers must be greater than zero");

TORCH_CHECK(
0 <= options_base.proj_size() &&
options_base.proj_size() < options_base.hidden_size(),
Expand Down
2 changes: 2 additions & 0 deletions torch/nn/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(self, mode: str, input_size: int, hidden_size: int,
raise TypeError(f"hidden_size should be of type int, got: {type(hidden_size).__name__}")
if hidden_size <= 0:
raise ValueError("hidden_size must be greater than zero")
if num_layers <= 0:
raise ValueError("num_layers must be greater than zero")
if proj_size < 0:
raise ValueError("proj_size should be a positive integer or zero to disable projections")
if proj_size >= hidden_size:
Expand Down
20 changes: 20 additions & 0 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2524,6 +2524,23 @@ def module_error_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_g
return samples


def module_error_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [
ErrorModuleInput(
ModuleInput(constructor_input=FunctionInput(10, 0, 1)),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="hidden_size must be greater than zero"
),
ErrorModuleInput(
ModuleInput(constructor_input=FunctionInput(10, 10, 0)),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="num_layers must be greater than zero"
),
]
return samples

def module_error_inputs_torch_nn_Pad1d(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down Expand Up @@ -3441,19 +3458,22 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad
ModuleInfo(torch.nn.RNN,
train_and_eval_differ=True,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
decorators=rnn_gru_lstm_module_info_decorators
),
ModuleInfo(torch.nn.GRU,
train_and_eval_differ=True,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
decorators=rnn_gru_lstm_module_info_decorators),
ModuleInfo(torch.nn.LSTM,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_LSTM,
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
skips=(
# LSTM with projections is not currently supported with MPS
DecorateInfo(skipMPS),),
Expand Down

0 comments on commit f03fd5b

Please sign in to comment.