From 9044dafb852b99c349120dd789e104abeffdb72a Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Wed, 24 Apr 2024 09:35:55 +0000 Subject: [PATCH] Expose tensor check from guard for reusing [ghstack-poisoned] --- torch/csrc/dynamo/guards.cpp | 244 ++++++++++++++++------------------- torch/csrc/dynamo/guards.h | 65 ++++++++++ 2 files changed, 178 insertions(+), 131 deletions(-) diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index ac8fddfbd37b..929251d558e0 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -47,149 +47,131 @@ typedef struct { } \ self.insert_leaf_guard(name); -namespace { - -struct LocalState { - // TLS state that changes operators - c10::impl::LocalDispatchKeySet dispatch_modifier; - bool grad_mode_enabled; - - at::DispatchKeySet apply(at::DispatchKeySet ks) const { - return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; - } - - LocalState() - : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), - grad_mode_enabled(at::GradMode::is_enabled()) {} -}; +TensorCheck::TensorCheck( + const LocalState& state, + PyTypeObject* pt, + const at::Tensor& v, + std::vector> dynamic_dims_sizes, + std::vector> dynamic_dims_strides) + : pytype(pt), + dispatch_key_(state.apply(v.key_set()).raw_repr()), + dtype_(v.dtype().toScalarType()), + device_index_(v.device().index()), + requires_grad_(v.requires_grad()), + sizes_(std::move(dynamic_dims_sizes)), + strides_(std::move(dynamic_dims_strides)), + dim_(static_cast(sizes_.size())) { + // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should + // we just treat this as optional? +} -class TensorCheck { - public: - TensorCheck( - const LocalState& state, - PyTypeObject* pt, - const at::Tensor& v, - std::vector> dynamic_dims_sizes, - std::vector> dynamic_dims_strides) - : pytype(pt), - dispatch_key_(state.apply(v.key_set()).raw_repr()), - dtype_(v.dtype().toScalarType()), - device_index_(v.device().index()), - requires_grad_(v.requires_grad()), - sizes_(std::move(dynamic_dims_sizes)), - strides_(std::move(dynamic_dims_strides)), - dim_(static_cast(sizes_.size())) { - // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should - // we just treat this as optional? - } - - // See note in guards.py [Note - On Export Tensor Guards] - // Logic parallel to here must be maintained in python - bool check(const LocalState& state, const at::Tensor& v) { - if (dispatch_key_ != state.apply(v.key_set()).raw_repr() || - dtype_ != v.dtype().toScalarType() || - device_index_ != v.device().index() || - requires_grad_ != v.requires_grad()) { - return false; - } - auto ndim = v.ndimension(); - if (ndim != dim_) { - return false; - } - const auto& sizes = v.sym_sizes(); - const auto& strides = v.sym_strides(); - for (auto i : c10::irange(ndim)) { - auto known_size = sizes_[i]; - auto known_stride = strides_[i]; - if (known_size.has_value()) { - if (known_size.value() != sizes[i]) { - return false; - } +TensorCheck::TensorCheck( + const LocalState& state, + PyTypeObject* pt, + uint64_t dispatch_key, + at::ScalarType dtype, + at::DeviceIndex device_index, + std::vector> dynamic_dims_sizes, + std::vector> dynamic_dims_strides) + : pytype(pt), + dispatch_key_(dispatch_key), + dtype_(dtype), + device_index_(device_index), + requires_grad_(false), + sizes_(std::move(dynamic_dims_sizes)), + strides_(std::move(dynamic_dims_strides)), + dim_(static_cast(sizes_.size())) {} + +// See note in guards.py [Note - On Export Tensor Guards] +// Logic parallel to here must be maintained in python +bool TensorCheck::check(const LocalState& state, const at::Tensor& v) { + if (dispatch_key_ != state.apply(v.key_set()).raw_repr() || + dtype_ != v.dtype().toScalarType() || + device_index_ != v.device().index() || + requires_grad_ != v.requires_grad()) { + return false; + } + auto ndim = v.ndimension(); + if (ndim != dim_) { + return false; + } + const auto& sizes = v.sym_sizes(); + const auto& strides = v.sym_strides(); + for (auto i : c10::irange(ndim)) { + auto known_size = sizes_[i]; + auto known_stride = strides_[i]; + if (known_size.has_value()) { + if (known_size.value() != sizes[i]) { + return false; } - if (known_stride.has_value()) { - if (known_stride.value() != strides[i]) { - return false; - } + } + if (known_stride.has_value()) { + if (known_stride.value() != strides[i]) { + return false; } } - return true; } + return true; +} - std::string check_verbose( - const LocalState& state, - const at::Tensor& v, - const std::string& tensor_name) { - std::stringstream fail_reason; - fail_reason << "tensor '" << tensor_name << "' "; - if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) { - // return fmt::format("tensor dispatch key mismatch. expected {}, actual - // {}", dispatch_key_, state.apply(v.key_set()).raw_repr()); - fail_reason << "dispatch key set mismatch. expected " - << c10::DispatchKeySet( - c10::DispatchKeySet::RAW, dispatch_key_) - << ", actual " << state.apply(v.key_set()); - return fail_reason.str(); - } else if (dtype_ != v.dtype().toScalarType()) { - // return fmt::format("tensor dtype mismatch. expected {}, actual {}", - // dtype_, v.dtype().toScalarType()); - fail_reason << "dtype mismatch. expected " << dtype_ << ", actual " - << v.dtype().toScalarType(); - return fail_reason.str(); - } else if (device_index_ != v.device().index()) { - fail_reason - << "Tensor device index mismatch. Expected device index to be " - << device_index_ << ", actual " << v.device().index(); - return fail_reason.str(); - } else if (requires_grad_ != v.requires_grad()) { - // return fmt::format("tensor requires_grad mismatch. expected {}", - // requires_grad_); - fail_reason << "requires_grad mismatch. expected requires_grad=" - << requires_grad_; +std::string TensorCheck::check_verbose( + const LocalState& state, + const at::Tensor& v, + const std::string& tensor_name) { + std::stringstream fail_reason; + fail_reason << "tensor '" << tensor_name << "' "; + if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) { + // return fmt::format("tensor dispatch key mismatch. expected {}, actual + // {}", dispatch_key_, state.apply(v.key_set()).raw_repr()); + fail_reason << "dispatch key set mismatch. expected " + << c10::DispatchKeySet(c10::DispatchKeySet::RAW, dispatch_key_) + << ", actual " << state.apply(v.key_set()); + return fail_reason.str(); + } else if (dtype_ != v.dtype().toScalarType()) { + // return fmt::format("tensor dtype mismatch. expected {}, actual {}", + // dtype_, v.dtype().toScalarType()); + fail_reason << "dtype mismatch. expected " << dtype_ << ", actual " + << v.dtype().toScalarType(); + return fail_reason.str(); + } else if (device_index_ != v.device().index()) { + fail_reason << "Tensor device index mismatch. Expected device index to be " + << device_index_ << ", actual " << v.device().index(); + return fail_reason.str(); + } else if (requires_grad_ != v.requires_grad()) { + // return fmt::format("tensor requires_grad mismatch. expected {}", + // requires_grad_); + fail_reason << "requires_grad mismatch. expected requires_grad=" + << requires_grad_; + return fail_reason.str(); + } + auto ndim = v.ndimension(); + if (ndim != dim_) { + // return fmt::format("tensor rank mismatch. expected {}, actual {}", + // sizes_.size(), ndim); + fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual " + << ndim; + return fail_reason.str(); + } + const auto& sizes = v.sym_sizes(); + const auto& strides = v.sym_strides(); + for (auto i : c10::irange(ndim)) { + auto known_size = sizes_[i]; + auto known_stride = strides_[i]; + if (known_size.has_value() && (known_size.value() != sizes[i])) { + fail_reason << "size mismatch at index " << i << ". expected " + << known_size.value() << ", actual " << sizes[i]; return fail_reason.str(); } - auto ndim = v.ndimension(); - if (ndim != dim_) { - // return fmt::format("tensor rank mismatch. expected {}, actual {}", - // sizes_.size(), ndim); - fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual " - << ndim; + if (known_stride.has_value() && known_stride.value() != strides[i]) { + fail_reason << "stride mismatch at index " << i << ". expected " + << known_stride.value() << ", actual " << strides[i]; return fail_reason.str(); } - const auto& sizes = v.sym_sizes(); - const auto& strides = v.sym_strides(); - for (auto i : c10::irange(ndim)) { - auto known_size = sizes_[i]; - auto known_stride = strides_[i]; - if (known_size.has_value() && (known_size.value() != sizes[i])) { - fail_reason << "size mismatch at index " << i << ". expected " - << known_size.value() << ", actual " << sizes[i]; - return fail_reason.str(); - } - if (known_stride.has_value() && known_stride.value() != strides[i]) { - fail_reason << "stride mismatch at index " << i << ". expected " - << known_stride.value() << ", actual " << strides[i]; - return fail_reason.str(); - } - } - return ""; } + return ""; +} - PyTypeObject* pytype; - - private: - uint64_t dispatch_key_; // DispatchKeySet includes device/layout - at::ScalarType dtype_; - // Note(voz): While dispatch_key_ is sufficiently representative of a device - // In that keys are more granular AND device specific - they do not - // necessarily capture device indices correctly. - at::DeviceIndex device_index_; - bool requires_grad_; - // NB: These are unset if dynamic shapes is enabled. - std::vector> sizes_; - std::vector> strides_; - // Not strictly required for dense tensors, but nested tensors need it. - int64_t dim_; -}; +namespace { typedef std::vector ChecksList; diff --git a/torch/csrc/dynamo/guards.h b/torch/csrc/dynamo/guards.h index f47a52775c83..0e13601a7118 100644 --- a/torch/csrc/dynamo/guards.h +++ b/torch/csrc/dynamo/guards.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include @@ -8,3 +9,67 @@ PyObject* torch_c_dynamo_guards_init(); // not visible there. void* convert_to_root_guard_manager(py::object root); bool run_root_guard_manager(void* root, PyObject* f_locals); + +struct LocalState { + // TLS state that changes operators + c10::impl::LocalDispatchKeySet dispatch_modifier; + c10::DispatchKeySet override_dispatch_key_set; + bool grad_mode_enabled; + + at::DispatchKeySet apply(at::DispatchKeySet ks) const { + if (override_dispatch_key_set.empty()) { + return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; + } else { + return override_dispatch_key_set; + } + } + + LocalState() + : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), + grad_mode_enabled(at::GradMode::is_enabled()) {} + + void overrideDispatchKeySet(c10::DispatchKeySet ks) { + override_dispatch_key_set = ks; + } +}; + +class TensorCheck { + public: + TensorCheck( + const LocalState& state, + PyTypeObject* pt, + const at::Tensor& v, + std::vector> dynamic_dims_sizes, + std::vector> dynamic_dims_strides); + + TensorCheck( + const LocalState& state, + PyTypeObject* pt, + uint64_t dispatch_key, + at::ScalarType dtype, + at::DeviceIndex device_index, + std::vector> dynamic_dims_sizes, + std::vector> dynamic_dims_strides); + + bool check(const LocalState& state, const at::Tensor& v); + std::string check_verbose( + const LocalState& state, + const at::Tensor& v, + const std::string& tensor_name); + + PyTypeObject* pytype; + + private: + uint64_t dispatch_key_; // DispatchKeySet includes device/layout + at::ScalarType dtype_; + // Note(voz): While dispatch_key_ is sufficiently representative of a device + // In that keys are more granular AND device specific - they do not + // necessarily capture device indices correctly. + at::DeviceIndex device_index_; + bool requires_grad_; + // NB: These are unset if dynamic shapes is enabled. + std::vector> sizes_; + std::vector> strides_; + // Not strictly required for dense tensors, but nested tensors need it. + int64_t dim_; +};