From 6de82f7d229e8d4d4973f544b05967db27747941 Mon Sep 17 00:00:00 2001 From: mikeiovine Date: Tue, 30 Nov 2021 09:05:04 -0800 Subject: [PATCH] [SR] Improve set_inputs This diff includes a variety of improvements to `set_inputs` to unify behavior with `torch::jit::Module`: 1. Eliminate code duplication between rvalue/lvalue overloads 2. Add type checks 3. Make input length check a `TORCH_CHECK` instead of a debug check - we have to fail when the wrong number of inputs are passed. 4. `schema` now always includes `self`, even if we release `module_`. This is consistent with `torch::jit::Module`. Differential Revision: [D32711705](https://our.internmc.facebook.com/intern/diff/D32711705/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D32711705/)! [ghstack-poisoned] --- .../static_runtime/test_static_module.cc | 64 +++++++ torch/csrc/jit/runtime/static/impl.cpp | 159 +++++++++--------- torch/csrc/jit/runtime/static/impl.h | 20 ++- 3 files changed, 163 insertions(+), 80 deletions(-) diff --git a/benchmarks/static_runtime/test_static_module.cc b/benchmarks/static_runtime/test_static_module.cc index 975d733b9caa..93ca3206f1c9 100644 --- a/benchmarks/static_runtime/test_static_module.cc +++ b/benchmarks/static_runtime/test_static_module.cc @@ -1159,3 +1159,67 @@ TEST(ManagedTensorRanges, OverlappingLifetimesOutputs) { EXPECT_TRUE(ranges.lifetimesOverlap(b, output)); } + +namespace { +void testStaticModuleThrows( + const std::string& src, + const std::vector& args, + const std::unordered_map& kwargs) { + auto static_module = makeStaticModuleFromScript(src); + EXPECT_THROW(static_module(args, kwargs), c10::Error); +} +} // namespace + +TEST(StaticModule, IncorrectTypesPassed) { + const std::string args_bool_script = R"JIT( + def forward(self, x: bool): + return x + )JIT"; + testStaticModuleThrows(args_bool_script, {at::randn({1})}, {}); + + const std::string args_tensor_script = R"JIT( + def forward(self, x: Tensor): + return x + )JIT"; + testStaticModuleThrows(args_tensor_script, {false}, {}); + + const std::string kwargs_int_script = R"JIT( + def forward(self, x: bool = True): + return x + )JIT"; + testStaticModuleThrows(kwargs_int_script, {}, {{"x", at::randn({1})}}); + + const std::string kwargs_tensor_script = R"JIT( + def forward(self, x: Tensor = torch.randn((1, ))): + return x + )JIT"; + testStaticModuleThrows(kwargs_tensor_script, {}, {{"x", 1.0}}); +} + +TEST(StaticModule, TooManyArgs) { + const std::string args_src = R"JIT( + def forward(self, x: int): + return x + )JIT"; + testStaticModuleThrows(args_src, {0, 1}, {}); + + const std::string kwargs_src = R"JIT( + def forward(self, x: int = 1): + return x + )JIT"; + testStaticModuleThrows(kwargs_src, {}, {{"y", 0}, {"x", 1}}); +} + +TEST(StaticModule, NotEnoughArgs) { + const std::string args_src = R"JIT( + def forward(self, x: int): + return x + )JIT"; + testStaticModuleThrows(args_src, {}, {}); + + const std::string kwargs_src = R"JIT( + def forward(self, *, x: int): + return x + )JIT"; + testStaticModuleThrows(kwargs_src, {}, {}); +} diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 2e4c8cce13da..58b49d1f1c04 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -128,7 +128,7 @@ void OptimizeGraph( } // remove unused input 0 from graph -bool RemoveSelfFromGraphInput(std::shared_ptr& graph) { +bool removeSelfFromGraphInput(std::shared_ptr& graph) { if (graph->inputs().at(0)->type()->is_module()) { if (graph->inputs().at(0)->hasUses()) { return false; @@ -138,13 +138,6 @@ bool RemoveSelfFromGraphInput(std::shared_ptr& graph) { return true; } -// remove "self" from function schema -c10::FunctionSchema RemoveSelfFromSchema(const c10::FunctionSchema& s) { - TORCH_CHECK(s.arguments().size() >= 1 && s.arguments()[0].name() == "self"); - std::vector args({s.arguments().begin() + 1, s.arguments().end()}); - return s.cloneWithArguments(args); -} - std::vector valueVecFromFastSet(const FastSet& s) { std::vector result; result.reserve(s.size()); @@ -808,7 +801,8 @@ StaticModule::StaticModule( const StaticModuleOptions& opts) : opts_(opts), graph_(std::move(graph_and_module.first)), - module_(std::move(graph_and_module.second)) { + module_(std::move(graph_and_module.second)), + num_inputs_(graph_->inputs().size()) { // check opt flags if (opts.manage_output_tensors) { TORCH_CHECK( @@ -824,11 +818,12 @@ StaticModule::StaticModule( // handle schema if (module_.has_value()) { Method method = module_->get_method("forward"); - if (RemoveSelfFromGraphInput(graph_)) { - schema_ = RemoveSelfFromSchema(method.function().getSchema()); + schema_ = method.function().getSchema(); + const auto num_schema_args = schema_->arguments().size(); + DCHECK(num_schema_args > 0); + if (removeSelfFromGraphInput(graph_)) { module_ = c10::nullopt; - } else { - schema_ = method.function().getSchema(); + num_inputs_ = num_schema_args - 1; } } @@ -1008,7 +1003,7 @@ size_t StaticModule::num_outputs() const { } size_t StaticModule::num_inputs() const { - return graph_->inputs().size(); + return num_inputs_; } StaticRuntime& StaticModule::runtime() { @@ -1071,81 +1066,93 @@ StaticRuntime::StaticRuntime(const StaticModule& sm) StaticRuntime::~StaticRuntime() = default; -void StaticRuntime::set_inputs( - const std::vector& args, - const std::unordered_map& kwargs) { - if (!kwargs.empty()) { - // This is not ideal - TORCH_CHECK( - static_module_.schema(), - "Schema is not available. Consider creating the Static Runtime " - "with StaticModule(const torch::jit::Module& m) instead."); - std::vector stack; - stack.reserve(static_module_.num_inputs()); - if (static_module_.first_input_is_self()) { - stack.emplace_back(static_module_.module()._ivalue()); - } - stack.insert(stack.end(), args.begin(), args.end()); +void StaticRuntime::set_arg(const size_t idx, std::vector&& args) { + const bool first_input_is_self = static_module_.first_input_is_self(); + DCHECK(idx < args.size()); + Input(idx + first_input_is_self) = std::move(args[idx]); +} - static_module_.schema()->checkAndNormalizeInputs(stack, kwargs); - DCHECK_EQ(static_module_.num_inputs(), stack.size()); - for (const auto i : c10::irange(stack.size())) { - Input(i) = std::move(stack[i]); - } - } else { - if (static_module_.first_input_is_self()) { - Input(0) = static_module_.module()._ivalue(); - DCHECK_EQ(static_module_.num_inputs(), args.size() + 1); - for (const auto i : c10::irange(args.size())) { - Input(i + 1) = args[i]; - } - } else { - DCHECK_EQ(static_module_.num_inputs(), args.size()); - for (const auto i : c10::irange(args.size())) { - Input(i) = args[i]; - } - } +void StaticRuntime::set_arg(const size_t idx, const std::vector& args) { + const bool first_input_is_self = static_module_.first_input_is_self(); + DCHECK(idx < args.size()); + Input(idx + first_input_is_self) = args[idx]; +} + +void StaticRuntime::set_arg(const size_t idx, const IValue& arg) { + const bool first_input_is_self = static_module_.first_input_is_self(); + Input(idx + first_input_is_self) = arg; +} + +namespace { +void check_type(const Argument& schema_arg, const IValue& arg) { + // Fast path for most common case + if (arg.isTensor() && schema_arg.type()->castRaw()) { + return; } + TORCH_CHECK(arg.type()->isSubtypeOf(schema_arg.type())); } +} // namespace +template void StaticRuntime::set_inputs( - std::vector&& args, + IValueList&& args, const std::unordered_map& kwargs) { - if (!kwargs.empty()) { - // This is not ideal + const bool first_input_is_self = static_module_.first_input_is_self(); + const auto total_num_inputs = + args.size() + kwargs.size() + first_input_is_self; + TORCH_CHECK(total_num_inputs == static_module_.num_inputs()); + + const auto& schema = static_module_.schema(); + if (first_input_is_self) { + Input(0) = static_module_.module()._ivalue(); + } + + if (C10_UNLIKELY(!schema)) { TORCH_CHECK( - static_module_.schema(), - "Schema is not available. Consider creating the Static Runtime " + kwargs.empty(), + "Schema is not available, but StaticRuntime got kwargs. " + "Consider creating the Static Runtime instance " "with StaticModule(const torch::jit::Module& m) instead."); - std::vector stack; - stack.reserve(static_module_.num_inputs()); - if (static_module_.first_input_is_self()) { - stack.emplace_back(static_module_.module()._ivalue()); + for (size_t i = 0; i < args.size(); ++i) { + set_arg(i, std::forward(args)); } - stack.insert( - stack.end(), - std::make_move_iterator(args.begin()), - std::make_move_iterator(args.end())); + return; + } + + const auto& schema_args = schema->arguments(); + size_t consumed_kwargs = 0; + DCHECK(schema_args.size() > 0); + + for (size_t i = 0; i < schema_args.size() - 1; ++i) { + // Start at 1 since the schema always contains `self`. + const auto& schema_arg = schema_args[i + 1]; - static_module_.schema()->checkAndNormalizeInputs(stack, kwargs); - DCHECK_EQ(static_module_.num_inputs(), stack.size()); - for (const auto i : c10::irange(stack.size())) { - Input(i) = std::move(stack[i]); + if (i < args.size()) { + check_type(schema_arg, args[i]); + set_arg(i, std::forward(args)); + continue; } - } else { - if (static_module_.first_input_is_self()) { - Input(0) = static_module_.module()._ivalue(); - DCHECK_EQ(static_module_.num_inputs(), args.size() + 1); - for (const auto i : c10::irange(args.size())) { - Input(i + 1) = std::move(args[i]); - } - } else { - DCHECK_EQ(static_module_.num_inputs(), args.size()); - for (const auto i : c10::irange(args.size())) { - Input(i) = std::move(args[i]); - } + + auto it = kwargs.find(schema_arg.name()); + if (it != kwargs.end()) { + check_type(schema_arg, it->second); + set_arg(i, it->second); + ++consumed_kwargs; + continue; } + + auto maybe_default_val = schema_arg.default_value(); + if (maybe_default_val) { + set_arg(i, *maybe_default_val); + continue; + } + + TORCH_CHECK( + false, "Static runtime is missing required kwarg ", schema_arg.name()); } + TORCH_CHECK( + consumed_kwargs == kwargs.size() && + args.size() + consumed_kwargs == schema_args.size() - 1); } void StaticRuntime::create_memory_planner() { diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 252abd48ae14..d2b3c98d76fc 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -332,6 +332,12 @@ class TORCH_API StaticModule { value_to_same_storage_values_; FastSet node_is_optimizable_container_type_; + + // Includes self if module_ != nullopt. + // Note that we might have num_inputs_ == 0 even if the schema has a `self` + // argument. In this case, `self` isn't used in the graph, but the schema + // includes it anyways to be consistent with the JIT interpreter. + size_t num_inputs_; }; class TORCH_API StaticRuntime { @@ -446,13 +452,19 @@ class TORCH_API StaticRuntime { const std::unordered_map& kwargs); // helper method for copying input args/kwargs into inputs_ + template void set_inputs( - const std::vector& args, - const std::unordered_map& kwargs); - void set_inputs( - std::vector&& args, + IValueList&& args, const std::unordered_map& kwargs); + // Set Input(idx) to args[idx]. Invoked by set_inputs. Copies or moves + // depending on overload. + void set_arg(const size_t idx, std::vector&& args); + void set_arg(const size_t idx, const std::vector& args); + + // Set Input(idx) to arg. Always copies. Used for kwargs. + void set_arg(const size_t idx, const IValue& arg); + void verify_and_correct_memory_overlap(ProcessedNode& n); // clean up owning refs of input IValues