diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 4d0af26bd397b..3549359eb36bb 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -57,26 +57,27 @@ OperatorHandle Dispatcher::findSchemaOrThrow(const char* name, const char* overl return findSchema({name, overload_name}).value(); } -OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options) { +OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema) { const auto found = findSchema(schema.operator_name()); if (found != c10::nullopt) { if (found->schema() != schema) { TORCH_CHECK(false, "Tried to register multiple operators with the same name and the same overload name but different schemas: ", schema, " vs ", found->schema()); } - if (options.isDefaultAliasAnalysisKind()) { + if (schema.isDefaultAliasAnalysisKind()) { // just do nothing and let it pass. - } else if (found->options().isDefaultAliasAnalysisKind()) { - found->operatorIterator_->op.updateOptionsAliasAnalysis(options.aliasAnalysis()); + } else if (found->schema().isDefaultAliasAnalysisKind()) { + found->operatorIterator_->op.updateSchemaAliasAnalysis(schema.aliasAnalysis()); } else { + // TODO: This error message is crappy TORCH_CHECK( - found->options() == options, - "Tried to register multiple operators with the same schema but different options: ", toString(schema)); + found->schema().aliasAnalysis() == schema.aliasAnalysis(), + "Tried to register multiple operators with the same schema but different alias analysis kind: ", toString(schema)); } return *found; } OperatorName op_name = schema.operator_name(); - operators_.emplace_back(std::move(schema), std::move(options)); + operators_.emplace_back(std::move(schema)); OperatorHandle handle(--operators_.end()); operatorLookupTable_.write([&] (ska::flat_hash_map& operatorLookupTable) { operatorLookupTable.emplace(op_name, handle); @@ -85,13 +86,13 @@ OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, Operat return handle; } -std::pair Dispatcher::registerSchema(FunctionSchema schema, OperatorOptions options) { +std::pair Dispatcher::registerSchema(FunctionSchema schema) { // we need a lock to avoid concurrent writes std::lock_guard lock(mutex_); OperatorName op_name = schema.operator_name(); - auto op = findOrRegisterSchema_(std::move(schema), std::move(options)); + auto op = findOrRegisterSchema_(std::move(schema)); ++op.operatorIterator_->refcount; if (1 == op.operatorIterator_->refcount) { diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 7b08deb7ebf0a..0420aae63c57c 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -37,8 +37,8 @@ class SchemaRegistrationHandleRAII; class CAFFE2_API Dispatcher final { private: struct OperatorDef final { - explicit OperatorDef(FunctionSchema&& schema, OperatorOptions&& options) - : op(std::move(schema), std::move(options)), refcount(0) {} + explicit OperatorDef(FunctionSchema&& schema) + : op(std::move(schema)), refcount(0) {} impl::OperatorEntry op; size_t refcount; @@ -116,7 +116,7 @@ class CAFFE2_API Dispatcher final { * object that manages the lifetime of the registration. Once that * object is destructed, the kernel will be deregistered. */ - std::pair registerSchema(FunctionSchema schema, OperatorOptions options); + std::pair registerSchema(FunctionSchema schema); /** * Register a kernel to the dispatch table for an operator. @@ -152,7 +152,7 @@ class CAFFE2_API Dispatcher final { private: Dispatcher(); - OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options); + OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema); void deregisterSchema_(const OperatorHandle& op, const OperatorName& op_name); void deregisterBackendFallbackKernel_(DispatchKey dispatchKey); @@ -187,10 +187,6 @@ class CAFFE2_API OperatorHandle final { return operatorIterator_->op.schema(); } - const OperatorOptions& options() const { - return operatorIterator_->op.options(); - } - template Return callUnboxed(Args... args) const { return c10::Dispatcher::singleton().callUnboxed(*this, std::forward(args)...); diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 0001133b473ce..f6fc1954878ec 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -26,11 +26,10 @@ namespace { } } -OperatorEntry::OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options) +OperatorEntry::OperatorEntry(FunctionSchema&& schema) : schema_(std::move(schema)) , dispatchTable_(schema_) -, kernels_() -, options_(std::move(options)) { +, kernels_() { } void OperatorEntry::prepareForDeregistration() { diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index bce7a44c40bde..14b47916204c3 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -16,7 +16,7 @@ namespace impl { // and its dispatch table. This is not part of the public API. class OperatorEntry final { public: - explicit OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options); + explicit OperatorEntry(FunctionSchema&& schema); OperatorEntry(const OperatorEntry&) = delete; OperatorEntry(OperatorEntry&&) noexcept = delete; @@ -35,12 +35,8 @@ class OperatorEntry final { RegistrationHandleRAII registerKernel(c10::optional dispatch_key, KernelFunction kernel); - const OperatorOptions& options() { - return options_; - } - - void updateOptionsAliasAnalysis(AliasAnalysisKind a) { - options_.setAliasAnalysis(a); + void updateSchemaAliasAnalysis(AliasAnalysisKind a) { + schema_.setAliasAnalysis(a); } private: @@ -84,9 +80,6 @@ class OperatorEntry final { // currently not high-pri. ska::flat_hash_map, std::list> kernels_; - // Some metadata about the operator - OperatorOptions options_; - std::mutex kernelsMutex_; // protects kernels_ // This function re-establishes the invariant that dispatchTable diff --git a/aten/src/ATen/core/dispatch/OperatorOptions.h b/aten/src/ATen/core/dispatch/OperatorOptions.h index 0fe5eeafae560..5c87f93657ac1 100644 --- a/aten/src/ATen/core/dispatch/OperatorOptions.h +++ b/aten/src/ATen/core/dispatch/OperatorOptions.h @@ -3,9 +3,6 @@ #include namespace c10 { -namespace impl { -class OperatorEntry; -} enum class AliasAnalysisKind : uint8_t { INTERNAL_SPECIAL_CASE, @@ -30,32 +27,4 @@ inline const char* toString(AliasAnalysisKind aliasAnalysisKind) { : "UNKNOWN"; } -struct OperatorOptions final { -public: - bool isDefaultAliasAnalysisKind() const { - return aliasAnalysisKind_ == c10::nullopt; - } - - AliasAnalysisKind aliasAnalysis() const { - return !isDefaultAliasAnalysisKind() - ? *aliasAnalysisKind_ - : AliasAnalysisKind::CONSERVATIVE; - } - - void setAliasAnalysis(AliasAnalysisKind v) { - aliasAnalysisKind_ = v; - } - - friend bool operator==(const OperatorOptions& lhs, const OperatorOptions& rhs) { - return lhs.aliasAnalysisKind_ == rhs.aliasAnalysisKind_; - } - - friend bool operator!=(const OperatorOptions& lhs, const OperatorOptions& rhs) { - return !(lhs == rhs); - } - -private: - c10::optional aliasAnalysisKind_; -}; - } // namespace c10 diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 82fc260683722..01a2c04df4325 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace c10 { @@ -69,6 +70,7 @@ struct Argument { bool is_inferred_type() const { return is_inferred_type_; } + std::string formatTypeMismatchMsg(const std::string& actual_type) const { std::string inferred_type_hint; if (is_inferred_type()) { @@ -188,6 +190,13 @@ struct FunctionSchema { // arguments are not checked by schema bool is_vararg_; bool is_varret_; + + // if no alias information is directly specified, what kind of "default" + // alias information should we infer? + // NB: due to alias analysis kind merging, this may be nullopt. Eventually + // this should always be set no matter what + c10::optional alias_kind_; + void checkArg(const IValue& value, const Argument& argument, optional pos) const; void checkSchema() const { @@ -301,6 +310,18 @@ struct FunctionSchema { return false; } + + // TODO remove the mutation here + bool isDefaultAliasAnalysisKind() const { + return !alias_kind_; + } + AliasAnalysisKind aliasAnalysis() const { + return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE); + } + void setAliasAnalysis(AliasAnalysisKind v) { + alias_kind_ = v; + } + // can a function with this schema be substituted for a function of rhs's // schema and have the program typecheck? // as_method - if true, treat this schema as a method and ignore diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index 9e1c87afdab22..0c94a9f8f8242 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -12,8 +12,8 @@ static_assert(std::is_nothrow_move_assignable dispatch_key, c10::optional kernel) - : op_(Dispatcher::singleton().registerSchema(std::move(schema), std::move(operatorOptions))), kernel_registration_handle_(c10::nullopt) { + explicit OperatorRegistrar(FunctionSchema&& schema, c10::optional dispatch_key, c10::optional kernel) + : op_(Dispatcher::singleton().registerSchema(std::move(schema))), kernel_registration_handle_(c10::nullopt) { if (kernel.has_value()) { TORCH_INTERNAL_ASSERT(kernel->isValid()); kernel_registration_handle_ = Dispatcher::singleton().registerKernel(op_.second, dispatch_key, std::move(*kernel)); @@ -123,37 +123,34 @@ void RegisterOperators::checkNoDuplicateKernels_(const Options& options) { void RegisterOperators::registerOp_(Options&& options) { FunctionSchema schema = std::move(*options.schemaOrName_).right(); - OperatorName op_name = schema.operator_name(); - auto operatorOptions = makeOperatorOptions_(options); + // HACK: bong in the alias analysis kind from the legacy API directly + // into schema + if (options.aliasAnalysisKind_.has_value()) { + schema.setAliasAnalysis(*options.aliasAnalysisKind_); + } + + OperatorName op_name = schema.operator_name(); if (0 == options.kernels.size()) { - registerSchemaOnly_(std::move(schema), std::move(operatorOptions)); + registerSchemaOnly_(std::move(schema)); } else { for (auto& kernel : options.kernels) { - registerSchemaAndKernel_(schema, std::move(kernel), std::move(operatorOptions)); + registerSchemaAndKernel_(schema, std::move(kernel)); } } TORCH_INTERNAL_ASSERT(c10::Dispatcher::singleton().findSchema(op_name).has_value()); } -OperatorOptions RegisterOperators::makeOperatorOptions_(const RegisterOperators::Options& options) { - OperatorOptions result; - if (options.aliasAnalysisKind_.has_value()) { - result.setAliasAnalysis(*options.aliasAnalysisKind_); - } - return result; -} - -void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel, OperatorOptions&& operatorOptions) { +void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel) { TORCH_INTERNAL_ASSERT(kernel.func.isValid(), "Kernel must be set"); - registrars_.emplace_back(std::move(schema), std::move(operatorOptions), kernel.dispatch_key, std::move(kernel.func)); + registrars_.emplace_back(std::move(schema), kernel.dispatch_key, std::move(kernel.func)); } -void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& operatorOptions) { - registrars_.emplace_back(std::move(schema), std::move(operatorOptions), c10::nullopt, c10::nullopt); +void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema) { + registrars_.emplace_back(std::move(schema), c10::nullopt, c10::nullopt); } RegisterOperators::RegisterOperators() = default; diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index 83a62f12df114..c152b12c7d2e6 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -582,9 +582,8 @@ class CAFFE2_API RegisterOperators final { static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options); void checkNoDuplicateKernels_(const Options& options); void registerOp_(Options&& options); - void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config, OperatorOptions&& options); - void registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& options); - static OperatorOptions makeOperatorOptions_(const Options& options); + void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config); + void registerSchemaOnly_(FunctionSchema&& schema); class OperatorRegistrar; diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 42e36fad768cb..f282654f3bab3 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -46,7 +46,7 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithAliasAnalysisAfterRe auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); - EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION); + EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION); } { auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); @@ -54,7 +54,7 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithAliasAnalysisAfterRe auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); - EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION); + EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION); } } @@ -64,7 +64,7 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithSameAliasAnalysis_th auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); - EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION); + EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION); } TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithNoAliasAnalysis_thenCanBeCalled) { @@ -73,15 +73,15 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithNoAliasAnalysis_then auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); - EXPECT_TRUE(op->options().isDefaultAliasAnalysisKind()); - EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::CONSERVATIVE); + EXPECT_TRUE(op->schema().isDefaultAliasAnalysisKind()); + EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::CONSERVATIVE); } TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithDifferentAliasAnalysis_thenShouldThrow) { expectThrows([] { auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::CONSERVATIVE)); - }, "Tried to register multiple operators with the same schema but different options:"); + }, "Tried to register multiple operators with the same schema but different alias analysis kind:"); } TEST(OperatorRegistrationTest, whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled) { diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index e2ad5c8e4d9e0..65fb61eda7633 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -9,10 +9,8 @@ namespace torch { namespace jit { -inline c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } // Fixture to set up a graph and make assertions clearer diff --git a/test/cpp/jit/test_base.cpp b/test/cpp/jit/test_base.cpp index 8655586695d58..e197e31199aee 100644 --- a/test/cpp/jit/test_base.cpp +++ b/test/cpp/jit/test_base.cpp @@ -5,10 +5,8 @@ namespace torch { namespace jit { -inline c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } RegisterOperators reg({ diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 56351fe02f478..fd5feb631b25d 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -69,10 +69,8 @@ namespace torch { namespace jit { -inline c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +inline c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } diff --git a/test/cpp/jit/test_schema_matching.cpp b/test/cpp/jit/test_schema_matching.cpp index ecd3781f9f52d..157d52d6d577f 100644 --- a/test/cpp/jit/test_schema_matching.cpp +++ b/test/cpp/jit/test_schema_matching.cpp @@ -21,7 +21,7 @@ void testSchemaMatching() { pop(stack, list, a); push(stack, a); return 0; - }), + }, c10::AliasAnalysisKind::FROM_SCHEMA), }); script::Module m("m"); m.define(R"( @@ -57,7 +57,7 @@ void testSchemaMatching() { pop(stack, a, list); push(stack, a); return 0; - }), + }, AliasAnalysisKind::FROM_SCHEMA), }); script::Module m("m"); m.define(R"JIT( diff --git a/tools/jit/templates/register_aten_ops.cpp b/tools/jit/templates/register_aten_ops.cpp index 7633e66ee7f38..5cc2d32a994fa 100644 --- a/tools/jit/templates/register_aten_ops.cpp +++ b/tools/jit/templates/register_aten_ops.cpp @@ -80,10 +80,8 @@ std::array as_bool_array(const c10::List& list) { return res; } -c10::OperatorOptions atenOperatorOptions() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +c10::AliasAnalysisKind atenOperatorOptions() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } int (*DUMMY_OPERATION)(Stack&) = [](Stack& stack) -> int { diff --git a/torch/csrc/jit/codegen/fuser/fallback.cpp b/torch/csrc/jit/codegen/fuser/fallback.cpp index 0c55ee0b10892..3f354213a8a9f 100644 --- a/torch/csrc/jit/codegen/fuser/fallback.cpp +++ b/torch/csrc/jit/codegen/fuser/fallback.cpp @@ -14,10 +14,8 @@ namespace jit { namespace fuser { namespace { -c10::OperatorOptions aliasAnalysisIsSpecialCase() { - c10::OperatorOptions options; - options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE); - return options; +c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() { + return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } } // namespace diff --git a/torch/csrc/jit/ir/constants.cpp b/torch/csrc/jit/ir/constants.cpp index 4b53eba96cf26..dfb052dbbcd64 100644 --- a/torch/csrc/jit/ir/constants.cpp +++ b/torch/csrc/jit/ir/constants.cpp @@ -9,10 +9,8 @@ namespace torch { namespace jit { namespace { -c10::OperatorOptions aliasAnalysisInternalSpecialCase() { - c10::OperatorOptions options; - options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE); - return options; +c10::AliasAnalysisKind aliasAnalysisInternalSpecialCase() { + return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } } // namespace diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 8e1c1b85aaaca..fa62fdbb3eec6 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -17,10 +17,8 @@ namespace torch { namespace jit { namespace { -c10::OperatorOptions aliasAnalysisIsSpecialCase() { - c10::OperatorOptions options; - options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE); - return options; +c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() { + return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } } // namespace diff --git a/torch/csrc/jit/passes/decompose_ops.cpp b/torch/csrc/jit/passes/decompose_ops.cpp index a33d15af87fe9..5b4c531417620 100644 --- a/torch/csrc/jit/passes/decompose_ops.cpp +++ b/torch/csrc/jit/passes/decompose_ops.cpp @@ -11,10 +11,8 @@ namespace torch { namespace jit { namespace { -c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } } // namespace diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index b2afc03830234..43771e4b2d35f 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -128,10 +128,8 @@ Operation createTensorExprOp(const Node* node) { }; } -c10::OperatorOptions getAliasAnalysisOption(AliasAnalysisKind k) { - auto options = c10::OperatorOptions(); - options.setAliasAnalysis(k); - return options; +c10::AliasAnalysisKind getAliasAnalysisOption(AliasAnalysisKind k) { + return k; } RegisterOperators TensorExprOps({ diff --git a/torch/csrc/jit/python/python_interpreter.cpp b/torch/csrc/jit/python/python_interpreter.cpp index ecd345caa00c9..9f756108e69ff 100644 --- a/torch/csrc/jit/python/python_interpreter.cpp +++ b/torch/csrc/jit/python/python_interpreter.cpp @@ -72,10 +72,8 @@ Operation createPythonOperation(const Node* op_) { }; } -c10::OperatorOptions aliasAnalysisIsSpecialCase() { - c10::OperatorOptions options; - options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE); - return options; +c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() { + return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } RegisterOperators reg({Operator( diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index a42431294a5e5..2b9759cb2abee 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -53,10 +53,8 @@ namespace torch { namespace jit { namespace { -c10::OperatorOptions aliasAnalysisInternalSpecialCase() { - c10::OperatorOptions options; - options.setAliasAnalysis(AliasAnalysisKind::INTERNAL_SPECIAL_CASE); - return options; +c10::AliasAnalysisKind aliasAnalysisInternalSpecialCase() { + return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } } // namespace diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 048987cded3f2..5b1f6c94df2ee 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -66,26 +66,25 @@ struct TORCH_API Operator { Operator(c10::OperatorHandle opHandle, Operation operation) : schema_(std::make_shared(opHandle.schema())), op_(std::make_shared(std::move(operation))), - c10Handle_(opHandle), - options_(c10Handle_->options()) {} + c10Handle_(opHandle) {} Operator( const std::string& schema, int(*op)(Stack&), - c10::OperatorOptions options = c10::OperatorOptions()) + c10::AliasAnalysisKind alias_analysis) : schema_string_(schema), - op_(std::make_shared(std::move(op))), - options_(std::move(options)) {} + alias_analysis_(alias_analysis), + op_(std::make_shared(std::move(op))) {} Operator( const std::string& schema, OperationCreator op_creator, - c10::OperatorOptions options = c10::OperatorOptions()) + c10::AliasAnalysisKind alias_analysis) : schema_string_(schema), - op_creator_(std::move(op_creator)), - options_(std::move(options)) {} + alias_analysis_(alias_analysis), + op_creator_(std::move(op_creator)) {} // Helper constructor to register `op` to run // run for _every_ IR Node where n.kind() == name, regardless of arguments. @@ -94,10 +93,11 @@ struct TORCH_API Operator { Operator( Symbol name, OperationCreator op_creator, - c10::OperatorOptions options = c10::OperatorOptions()) + c10::AliasAnalysisKind alias_analysis) : schema_(std::make_shared(varArgSchemaWithName(name))), - op_creator_(std::move(op_creator)), - options_(std::move(options)) {} + op_creator_(std::move(op_creator)) { + schema_->setAliasAnalysis(alias_analysis); + } Operation getOperation(const Node* node = nullptr) const { if (op_) { @@ -113,7 +113,11 @@ struct TORCH_API Operator { if (!schema_) { schema_ = std::make_shared(parseSchema(schema_string_.value())); + if (alias_analysis_.has_value()) { + schema_->setAliasAnalysis(*alias_analysis_); + } schema_string_ = c10::nullopt; + alias_analysis_ = c10::nullopt; } return *schema_; } @@ -126,13 +130,13 @@ struct TORCH_API Operator { if (isC10Op()) { const FunctionSchema& schemaRef = schema(); TORCH_CHECK( - options_.aliasAnalysis() == AliasAnalysisKind::FROM_SCHEMA || + schemaRef.aliasAnalysis() == AliasAnalysisKind::FROM_SCHEMA || !schemaRef.hasAnyAliasInfo(), "In operator registration: Tried to register operator ", schemaRef, " with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA."); } - return options_.aliasAnalysis(); + return schema().aliasAnalysis(); } bool hasOperation() const { return op_ != nullptr; @@ -152,13 +156,13 @@ struct TORCH_API Operator { // assignment operator to be generated cannot use std::unique_ptr because // initializer lists of Operators end up copying the Operator mutable std::shared_ptr schema_; + mutable c10::optional alias_analysis_; // Essentially a variant. // NB: std::function has a default state (where it == nullptr). std::shared_ptr op_; OperationCreator op_creator_; c10::optional c10Handle_; - c10::OperatorOptions options_; }; TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); diff --git a/torch/csrc/jit/runtime/register_distributed_ops.cpp b/torch/csrc/jit/runtime/register_distributed_ops.cpp index deb28bf8c453f..58d6777a4d25e 100644 --- a/torch/csrc/jit/runtime/register_distributed_ops.cpp +++ b/torch/csrc/jit/runtime/register_distributed_ops.cpp @@ -30,16 +30,12 @@ at::Tensor optional_to_tensor(c10::optional v) { return v.has_value() ? *v : at::Tensor(); } -c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } -c10::OperatorOptions aliasAnalysisSpecialCase() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE); - return result; +c10::AliasAnalysisKind aliasAnalysisSpecialCase() { + return c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } RegisterOperators reg_rpc_ops({ diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 7513cbf28e28f..6d709c9a66f60 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -65,22 +65,16 @@ int noop(Stack& n) { return 0; } -c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } -c10::OperatorOptions aliasAnalysisConservative() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE); - return result; +c10::AliasAnalysisKind aliasAnalysisConservative() { + return c10::AliasAnalysisKind::CONSERVATIVE; } -c10::OperatorOptions aliasAnalysisSpecialCase() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE); - return result; +c10::AliasAnalysisKind aliasAnalysisSpecialCase() { + return c10::AliasAnalysisKind::INTERNAL_SPECIAL_CASE; } // using the rules from python_arg_parser FunctionParameter::check diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 4efa7abb78c3d..cb4eae99b1920 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -21,10 +21,8 @@ namespace jit { namespace { -c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) { diff --git a/torch/csrc/jit/runtime/register_string_ops.cpp b/torch/csrc/jit/runtime/register_string_ops.cpp index 3e244f241d219..3cd8aba542649 100644 --- a/torch/csrc/jit/runtime/register_string_ops.cpp +++ b/torch/csrc/jit/runtime/register_string_ops.cpp @@ -5,10 +5,8 @@ namespace torch { namespace jit { namespace { -c10::OperatorOptions aliasAnalysisFromSchema() { - c10::OperatorOptions result; - result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA); - return result; +c10::AliasAnalysisKind aliasAnalysisFromSchema() { + return c10::AliasAnalysisKind::FROM_SCHEMA; } // Convert an python index (which may be negative) into an index usable for a