Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions aten/src/ATen/core/dispatch/Dispatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,27 @@ OperatorHandle Dispatcher::findSchemaOrThrow(const char* name, const char* overl
return findSchema({name, overload_name}).value();
}

OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options) {
OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema) {
const auto found = findSchema(schema.operator_name());
if (found != c10::nullopt) {
if (found->schema() != schema) {
TORCH_CHECK(false, "Tried to register multiple operators with the same name and the same overload name but different schemas: ", schema, " vs ", found->schema());
}
if (options.isDefaultAliasAnalysisKind()) {
if (schema.isDefaultAliasAnalysisKind()) {
// just do nothing and let it pass.
} else if (found->options().isDefaultAliasAnalysisKind()) {
found->operatorIterator_->op.updateOptionsAliasAnalysis(options.aliasAnalysis());
} else if (found->schema().isDefaultAliasAnalysisKind()) {
found->operatorIterator_->op.updateSchemaAliasAnalysis(schema.aliasAnalysis());
} else {
// TODO: This error message is crappy
TORCH_CHECK(
found->options() == options,
"Tried to register multiple operators with the same schema but different options: ", toString(schema));
found->schema().aliasAnalysis() == schema.aliasAnalysis(),
"Tried to register multiple operators with the same schema but different alias analysis kind: ", toString(schema));
}
return *found;
}

OperatorName op_name = schema.operator_name();
operators_.emplace_back(std::move(schema), std::move(options));
operators_.emplace_back(std::move(schema));
OperatorHandle handle(--operators_.end());
operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
operatorLookupTable.emplace(op_name, handle);
Expand All @@ -85,13 +86,13 @@ OperatorHandle Dispatcher::findOrRegisterSchema_(FunctionSchema&& schema, Operat
return handle;
}

std::pair<RegistrationHandleRAII, OperatorHandle> Dispatcher::registerSchema(FunctionSchema schema, OperatorOptions options) {
std::pair<RegistrationHandleRAII, OperatorHandle> Dispatcher::registerSchema(FunctionSchema schema) {
// we need a lock to avoid concurrent writes
std::lock_guard<std::mutex> lock(mutex_);

OperatorName op_name = schema.operator_name();

auto op = findOrRegisterSchema_(std::move(schema), std::move(options));
auto op = findOrRegisterSchema_(std::move(schema));

++op.operatorIterator_->refcount;
if (1 == op.operatorIterator_->refcount) {
Expand Down
12 changes: 4 additions & 8 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class SchemaRegistrationHandleRAII;
class CAFFE2_API Dispatcher final {
private:
struct OperatorDef final {
explicit OperatorDef(FunctionSchema&& schema, OperatorOptions&& options)
: op(std::move(schema), std::move(options)), refcount(0) {}
explicit OperatorDef(FunctionSchema&& schema)
: op(std::move(schema)), refcount(0) {}

impl::OperatorEntry op;
size_t refcount;
Expand Down Expand Up @@ -116,7 +116,7 @@ class CAFFE2_API Dispatcher final {
* object that manages the lifetime of the registration. Once that
* object is destructed, the kernel will be deregistered.
*/
std::pair<RegistrationHandleRAII, OperatorHandle> registerSchema(FunctionSchema schema, OperatorOptions options);
std::pair<RegistrationHandleRAII, OperatorHandle> registerSchema(FunctionSchema schema);

/**
* Register a kernel to the dispatch table for an operator.
Expand Down Expand Up @@ -152,7 +152,7 @@ class CAFFE2_API Dispatcher final {
private:
Dispatcher();

OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options);
OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);

void deregisterSchema_(const OperatorHandle& op, const OperatorName& op_name);
void deregisterBackendFallbackKernel_(DispatchKey dispatchKey);
Expand Down Expand Up @@ -187,10 +187,6 @@ class CAFFE2_API OperatorHandle final {
return operatorIterator_->op.schema();
}

const OperatorOptions& options() const {
return operatorIterator_->op.options();
}

template<class Return, class... Args>
Return callUnboxed(Args... args) const {
return c10::Dispatcher::singleton().callUnboxed<Return, Args...>(*this, std::forward<Args>(args)...);
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ namespace {
}
}

OperatorEntry::OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options)
OperatorEntry::OperatorEntry(FunctionSchema&& schema)
: schema_(std::move(schema))
, dispatchTable_(schema_)
, kernels_()
, options_(std::move(options)) {
, kernels_() {
}

void OperatorEntry::prepareForDeregistration() {
Expand Down
13 changes: 3 additions & 10 deletions aten/src/ATen/core/dispatch/OperatorEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace impl {
// and its dispatch table. This is not part of the public API.
class OperatorEntry final {
public:
explicit OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options);
explicit OperatorEntry(FunctionSchema&& schema);

OperatorEntry(const OperatorEntry&) = delete;
OperatorEntry(OperatorEntry&&) noexcept = delete;
Expand All @@ -35,12 +35,8 @@ class OperatorEntry final {

RegistrationHandleRAII registerKernel(c10::optional<DispatchKey> dispatch_key, KernelFunction kernel);

const OperatorOptions& options() {
return options_;
}

void updateOptionsAliasAnalysis(AliasAnalysisKind a) {
options_.setAliasAnalysis(a);
void updateSchemaAliasAnalysis(AliasAnalysisKind a) {
schema_.setAliasAnalysis(a);
}

private:
Expand Down Expand Up @@ -84,9 +80,6 @@ class OperatorEntry final {
// currently not high-pri.
ska::flat_hash_map<c10::optional<DispatchKey>, std::list<KernelFunction>> kernels_;

// Some metadata about the operator
OperatorOptions options_;

std::mutex kernelsMutex_; // protects kernels_

// This function re-establishes the invariant that dispatchTable
Expand Down
31 changes: 0 additions & 31 deletions aten/src/ATen/core/dispatch/OperatorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
#include <cstdint>

namespace c10 {
namespace impl {
class OperatorEntry;
}

enum class AliasAnalysisKind : uint8_t {
INTERNAL_SPECIAL_CASE,
Expand All @@ -30,32 +27,4 @@ inline const char* toString(AliasAnalysisKind aliasAnalysisKind) {
: "UNKNOWN";
}

struct OperatorOptions final {
public:
bool isDefaultAliasAnalysisKind() const {
return aliasAnalysisKind_ == c10::nullopt;
}

AliasAnalysisKind aliasAnalysis() const {
return !isDefaultAliasAnalysisKind()
? *aliasAnalysisKind_
: AliasAnalysisKind::CONSERVATIVE;
}

void setAliasAnalysis(AliasAnalysisKind v) {
aliasAnalysisKind_ = v;
}

friend bool operator==(const OperatorOptions& lhs, const OperatorOptions& rhs) {
return lhs.aliasAnalysisKind_ == rhs.aliasAnalysisKind_;
}

friend bool operator!=(const OperatorOptions& lhs, const OperatorOptions& rhs) {
return !(lhs == rhs);
}

private:
c10::optional<AliasAnalysisKind> aliasAnalysisKind_;
};

} // namespace c10
21 changes: 21 additions & 0 deletions aten/src/ATen/core/function_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/alias_info.h>
#include <ATen/core/operator_name.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <unordered_map>

namespace c10 {
Expand Down Expand Up @@ -69,6 +70,7 @@ struct Argument {
bool is_inferred_type() const {
return is_inferred_type_;
}

std::string formatTypeMismatchMsg(const std::string& actual_type) const {
std::string inferred_type_hint;
if (is_inferred_type()) {
Expand Down Expand Up @@ -188,6 +190,13 @@ struct FunctionSchema {
// arguments are not checked by schema
bool is_vararg_;
bool is_varret_;

// if no alias information is directly specified, what kind of "default"
// alias information should we infer?
// NB: due to alias analysis kind merging, this may be nullopt. Eventually
// this should always be set no matter what
c10::optional<AliasAnalysisKind> alias_kind_;

void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const;

void checkSchema() const {
Expand Down Expand Up @@ -301,6 +310,18 @@ struct FunctionSchema {
return false;
}


// TODO remove the mutation here
bool isDefaultAliasAnalysisKind() const {
return !alias_kind_;
}
AliasAnalysisKind aliasAnalysis() const {
return alias_kind_.value_or(AliasAnalysisKind::CONSERVATIVE);
}
void setAliasAnalysis(AliasAnalysisKind v) {
alias_kind_ = v;
}

// can a function with this schema be substituted for a function of rhs's
// schema and have the program typecheck?
// as_method - if true, treat this schema as a method and ignore
Expand Down
33 changes: 15 additions & 18 deletions aten/src/ATen/core/op_registration/op_registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ static_assert(std::is_nothrow_move_assignable<c10::optional<RegistrationHandleRA
// table deregisters it in the destructor.
class RegisterOperators::OperatorRegistrar final {
public:
explicit OperatorRegistrar(FunctionSchema&& schema, OperatorOptions&& operatorOptions, c10::optional<DispatchKey> dispatch_key, c10::optional<KernelFunction> kernel)
: op_(Dispatcher::singleton().registerSchema(std::move(schema), std::move(operatorOptions))), kernel_registration_handle_(c10::nullopt) {
explicit OperatorRegistrar(FunctionSchema&& schema, c10::optional<DispatchKey> dispatch_key, c10::optional<KernelFunction> kernel)
: op_(Dispatcher::singleton().registerSchema(std::move(schema))), kernel_registration_handle_(c10::nullopt) {
if (kernel.has_value()) {
TORCH_INTERNAL_ASSERT(kernel->isValid());
kernel_registration_handle_ = Dispatcher::singleton().registerKernel(op_.second, dispatch_key, std::move(*kernel));
Expand Down Expand Up @@ -123,37 +123,34 @@ void RegisterOperators::checkNoDuplicateKernels_(const Options& options) {

void RegisterOperators::registerOp_(Options&& options) {
FunctionSchema schema = std::move(*options.schemaOrName_).right();
OperatorName op_name = schema.operator_name();

auto operatorOptions = makeOperatorOptions_(options);
// HACK: bong in the alias analysis kind from the legacy API directly
// into schema
if (options.aliasAnalysisKind_.has_value()) {
schema.setAliasAnalysis(*options.aliasAnalysisKind_);
}

OperatorName op_name = schema.operator_name();

if (0 == options.kernels.size()) {
registerSchemaOnly_(std::move(schema), std::move(operatorOptions));
registerSchemaOnly_(std::move(schema));
} else {
for (auto& kernel : options.kernels) {
registerSchemaAndKernel_(schema, std::move(kernel), std::move(operatorOptions));
registerSchemaAndKernel_(schema, std::move(kernel));
}
}

TORCH_INTERNAL_ASSERT(c10::Dispatcher::singleton().findSchema(op_name).has_value());
}

OperatorOptions RegisterOperators::makeOperatorOptions_(const RegisterOperators::Options& options) {
OperatorOptions result;
if (options.aliasAnalysisKind_.has_value()) {
result.setAliasAnalysis(*options.aliasAnalysisKind_);
}
return result;
}

void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel, OperatorOptions&& operatorOptions) {
void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel) {
TORCH_INTERNAL_ASSERT(kernel.func.isValid(), "Kernel must be set");

registrars_.emplace_back(std::move(schema), std::move(operatorOptions), kernel.dispatch_key, std::move(kernel.func));
registrars_.emplace_back(std::move(schema), kernel.dispatch_key, std::move(kernel.func));
}

void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& operatorOptions) {
registrars_.emplace_back(std::move(schema), std::move(operatorOptions), c10::nullopt, c10::nullopt);
void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema) {
registrars_.emplace_back(std::move(schema), c10::nullopt, c10::nullopt);
}

RegisterOperators::RegisterOperators() = default;
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/core/op_registration/op_registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,8 @@ class CAFFE2_API RegisterOperators final {
static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options);
void checkNoDuplicateKernels_(const Options& options);
void registerOp_(Options&& options);
void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config, OperatorOptions&& options);
void registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& options);
static OperatorOptions makeOperatorOptions_(const Options& options);
void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config);
void registerSchemaOnly_(FunctionSchema&& schema);

class OperatorRegistrar;

Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithAliasAnalysisAfterRe

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());
EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
}
{
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::DispatchKey::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION));
auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::DispatchKey::CPUTensorId));

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());
EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
}
}

Expand All @@ -64,7 +64,7 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithSameAliasAnalysis_th

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());
EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION);
}

TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithNoAliasAnalysis_thenCanBeCalled) {
Expand All @@ -73,15 +73,15 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithNoAliasAnalysis_then

auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
ASSERT_TRUE(op.has_value());
EXPECT_TRUE(op->options().isDefaultAliasAnalysisKind());
EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::CONSERVATIVE);
EXPECT_TRUE(op->schema().isDefaultAliasAnalysisKind());
EXPECT_EQ(op->schema().aliasAnalysis(), at::AliasAnalysisKind::CONSERVATIVE);
}

TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithDifferentAliasAnalysis_thenShouldThrow) {
expectThrows<c10::Error>([] {
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::DispatchKey::CPUTensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION));
auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<DummyKernel>(c10::DispatchKey::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::CONSERVATIVE));
}, "Tried to register multiple operators with the same schema but different options:");
}, "Tried to register multiple operators with the same schema but different alias analysis kind:");
}

TEST(OperatorRegistrationTest, whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled) {
Expand Down
6 changes: 2 additions & 4 deletions test/cpp/jit/test_alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
namespace torch {
namespace jit {

inline c10::OperatorOptions aliasAnalysisFromSchema() {
c10::OperatorOptions result;
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
return result;
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}

// Fixture to set up a graph and make assertions clearer
Expand Down
6 changes: 2 additions & 4 deletions test/cpp/jit/test_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

namespace torch {
namespace jit {
inline c10::OperatorOptions aliasAnalysisFromSchema() {
c10::OperatorOptions result;
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
return result;
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}

RegisterOperators reg({
Expand Down
6 changes: 2 additions & 4 deletions test/cpp/jit/test_misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,8 @@

namespace torch {
namespace jit {
inline c10::OperatorOptions aliasAnalysisFromSchema() {
c10::OperatorOptions result;
result.setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
return result;
inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}


Expand Down
Loading