diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp index 8a07f595f81d..ee783c710a0d 100644 --- a/torch/csrc/dynamo/cache_entry.cpp +++ b/torch/csrc/dynamo/cache_entry.cpp @@ -10,7 +10,8 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) { this->backend = backend; // TODO - clean this up when enable_cpp_guard_manager is True by default if (py::hasattr(this->check_fn, "root")) { - this->root_mgr = convert_to_root_guard_manager(this->check_fn.attr("root")); + this->root_mgr = torch::dynamo::convert_to_root_guard_manager( + this->check_fn.attr("root")); } } diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 2e4110f42930..7c9b4be0009b 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -97,7 +97,8 @@ PyObject* lookup( // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is // True by default if (cache_entry.root_mgr != nullptr) { - valid = run_root_guard_manager(cache_entry.root_mgr, f_locals); + valid = torch::dynamo::run_root_guard_manager( + cache_entry.root_mgr, f_locals); } else { valid = cache_entry.check_fn(locals).cast(); } diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index d61ac4219a85..cedaad3e3557 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -40,6 +40,8 @@ typedef struct { #endif // IS_PYTHON_3_12_PLUS +namespace torch::dynamo { + // Macro to skip addition of duplicate guards like EQUALS_MATCH #define SKIP_IF_GUARD_ALREADY_PRESENT(name) \ if (self.is_leaf_guard_present(name)) { \ @@ -47,149 +49,131 @@ typedef struct { } \ self.insert_leaf_guard(name); -namespace { - -struct LocalState { - // TLS state that changes operators - c10::impl::LocalDispatchKeySet dispatch_modifier; - bool grad_mode_enabled; - - at::DispatchKeySet apply(at::DispatchKeySet ks) const { - return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; - } - - LocalState() - : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), - grad_mode_enabled(at::GradMode::is_enabled()) {} -}; +TensorCheck::TensorCheck( + const LocalState& state, + PyTypeObject* pt, + const at::Tensor& v, + std::vector> dynamic_dims_sizes, + std::vector> dynamic_dims_strides) + : pytype(pt), + dispatch_key_(state.apply(v.key_set()).raw_repr()), + dtype_(v.dtype().toScalarType()), + device_index_(v.device().index()), + requires_grad_(v.requires_grad()), + sizes_(std::move(dynamic_dims_sizes)), + strides_(std::move(dynamic_dims_strides)), + dim_(static_cast(sizes_.size())) { + // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should + // we just treat this as optional? +} -class TensorCheck { - public: - TensorCheck( - const LocalState& state, - PyTypeObject* pt, - const at::Tensor& v, - std::vector> dynamic_dims_sizes, - std::vector> dynamic_dims_strides) - : pytype(pt), - dispatch_key_(state.apply(v.key_set()).raw_repr()), - dtype_(v.dtype().toScalarType()), - device_index_(v.device().index()), - requires_grad_(v.requires_grad()), - sizes_(std::move(dynamic_dims_sizes)), - strides_(std::move(dynamic_dims_strides)), - dim_(static_cast(sizes_.size())) { - // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should - // we just treat this as optional? - } - - // See note in guards.py [Note - On Export Tensor Guards] - // Logic parallel to here must be maintained in python - bool check(const LocalState& state, const at::Tensor& v) { - if (dispatch_key_ != state.apply(v.key_set()).raw_repr() || - dtype_ != v.dtype().toScalarType() || - device_index_ != v.device().index() || - requires_grad_ != v.requires_grad()) { - return false; - } - auto ndim = v.ndimension(); - if (ndim != dim_) { - return false; - } - const auto& sizes = v.sym_sizes(); - const auto& strides = v.sym_strides(); - for (auto i : c10::irange(ndim)) { - auto known_size = sizes_[i]; - auto known_stride = strides_[i]; - if (known_size.has_value()) { - if (known_size.value() != sizes[i]) { - return false; - } +TensorCheck::TensorCheck( + const LocalState& state, + PyTypeObject* pt, + uint64_t dispatch_key, + at::ScalarType dtype, + at::DeviceIndex device_index, + std::vector> dynamic_dims_sizes, + std::vector> dynamic_dims_strides) + : pytype(pt), + dispatch_key_(dispatch_key), + dtype_(dtype), + device_index_(device_index), + requires_grad_(false), + sizes_(std::move(dynamic_dims_sizes)), + strides_(std::move(dynamic_dims_strides)), + dim_(static_cast(sizes_.size())) {} + +// See note in guards.py [Note - On Export Tensor Guards] +// Logic parallel to here must be maintained in python +bool TensorCheck::check(const LocalState& state, const at::Tensor& v) { + if (dispatch_key_ != state.apply(v.key_set()).raw_repr() || + dtype_ != v.dtype().toScalarType() || + device_index_ != v.device().index() || + requires_grad_ != v.requires_grad()) { + return false; + } + auto ndim = v.ndimension(); + if (ndim != dim_) { + return false; + } + const auto& sizes = v.sym_sizes(); + const auto& strides = v.sym_strides(); + for (auto i : c10::irange(ndim)) { + auto known_size = sizes_[i]; + auto known_stride = strides_[i]; + if (known_size.has_value()) { + if (known_size.value() != sizes[i]) { + return false; } - if (known_stride.has_value()) { - if (known_stride.value() != strides[i]) { - return false; - } + } + if (known_stride.has_value()) { + if (known_stride.value() != strides[i]) { + return false; } } - return true; } + return true; +} - std::string check_verbose( - const LocalState& state, - const at::Tensor& v, - const std::string& tensor_name) { - std::stringstream fail_reason; - fail_reason << "tensor '" << tensor_name << "' "; - if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) { - // return fmt::format("tensor dispatch key mismatch. expected {}, actual - // {}", dispatch_key_, state.apply(v.key_set()).raw_repr()); - fail_reason << "dispatch key set mismatch. expected " - << c10::DispatchKeySet( - c10::DispatchKeySet::RAW, dispatch_key_) - << ", actual " << state.apply(v.key_set()); - return fail_reason.str(); - } else if (dtype_ != v.dtype().toScalarType()) { - // return fmt::format("tensor dtype mismatch. expected {}, actual {}", - // dtype_, v.dtype().toScalarType()); - fail_reason << "dtype mismatch. expected " << dtype_ << ", actual " - << v.dtype().toScalarType(); - return fail_reason.str(); - } else if (device_index_ != v.device().index()) { - fail_reason - << "Tensor device index mismatch. Expected device index to be " - << device_index_ << ", actual " << v.device().index(); - return fail_reason.str(); - } else if (requires_grad_ != v.requires_grad()) { - // return fmt::format("tensor requires_grad mismatch. expected {}", - // requires_grad_); - fail_reason << "requires_grad mismatch. expected requires_grad=" - << requires_grad_; +std::string TensorCheck::check_verbose( + const LocalState& state, + const at::Tensor& v, + const std::string& tensor_name) { + std::stringstream fail_reason; + fail_reason << "tensor '" << tensor_name << "' "; + if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) { + // return fmt::format("tensor dispatch key mismatch. expected {}, actual + // {}", dispatch_key_, state.apply(v.key_set()).raw_repr()); + fail_reason << "dispatch key set mismatch. expected " + << c10::DispatchKeySet(c10::DispatchKeySet::RAW, dispatch_key_) + << ", actual " << state.apply(v.key_set()); + return fail_reason.str(); + } else if (dtype_ != v.dtype().toScalarType()) { + // return fmt::format("tensor dtype mismatch. expected {}, actual {}", + // dtype_, v.dtype().toScalarType()); + fail_reason << "dtype mismatch. expected " << dtype_ << ", actual " + << v.dtype().toScalarType(); + return fail_reason.str(); + } else if (device_index_ != v.device().index()) { + fail_reason << "Tensor device index mismatch. Expected device index to be " + << device_index_ << ", actual " << v.device().index(); + return fail_reason.str(); + } else if (requires_grad_ != v.requires_grad()) { + // return fmt::format("tensor requires_grad mismatch. expected {}", + // requires_grad_); + fail_reason << "requires_grad mismatch. expected requires_grad=" + << requires_grad_; + return fail_reason.str(); + } + auto ndim = v.ndimension(); + if (ndim != dim_) { + // return fmt::format("tensor rank mismatch. expected {}, actual {}", + // sizes_.size(), ndim); + fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual " + << ndim; + return fail_reason.str(); + } + const auto& sizes = v.sym_sizes(); + const auto& strides = v.sym_strides(); + for (auto i : c10::irange(ndim)) { + auto known_size = sizes_[i]; + auto known_stride = strides_[i]; + if (known_size.has_value() && (known_size.value() != sizes[i])) { + fail_reason << "size mismatch at index " << i << ". expected " + << known_size.value() << ", actual " << sizes[i]; return fail_reason.str(); } - auto ndim = v.ndimension(); - if (ndim != dim_) { - // return fmt::format("tensor rank mismatch. expected {}, actual {}", - // sizes_.size(), ndim); - fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual " - << ndim; + if (known_stride.has_value() && known_stride.value() != strides[i]) { + fail_reason << "stride mismatch at index " << i << ". expected " + << known_stride.value() << ", actual " << strides[i]; return fail_reason.str(); } - const auto& sizes = v.sym_sizes(); - const auto& strides = v.sym_strides(); - for (auto i : c10::irange(ndim)) { - auto known_size = sizes_[i]; - auto known_stride = strides_[i]; - if (known_size.has_value() && (known_size.value() != sizes[i])) { - fail_reason << "size mismatch at index " << i << ". expected " - << known_size.value() << ", actual " << sizes[i]; - return fail_reason.str(); - } - if (known_stride.has_value() && known_stride.value() != strides[i]) { - fail_reason << "stride mismatch at index " << i << ". expected " - << known_stride.value() << ", actual " << strides[i]; - return fail_reason.str(); - } - } - return ""; } + return ""; +} - PyTypeObject* pytype; - - private: - uint64_t dispatch_key_; // DispatchKeySet includes device/layout - at::ScalarType dtype_; - // Note(voz): While dispatch_key_ is sufficiently representative of a device - // In that keys are more granular AND device specific - they do not - // necessarily capture device indices correctly. - at::DeviceIndex device_index_; - bool requires_grad_; - // NB: These are unset if dynamic shapes is enabled. - std::vector> sizes_; - std::vector> strides_; - // Not strictly required for dense tensors, but nested tensors need it. - int64_t dim_; -}; +namespace { typedef std::vector ChecksList; @@ -3949,3 +3933,5 @@ PyObject* torch_c_dynamo_guards_init() { return m; } + +} // namespace torch::dynamo diff --git a/torch/csrc/dynamo/guards.h b/torch/csrc/dynamo/guards.h index f47a52775c83..26accf742181 100644 --- a/torch/csrc/dynamo/guards.h +++ b/torch/csrc/dynamo/guards.h @@ -1,10 +1,79 @@ #pragma once +#include #include #include +namespace torch::dynamo { + PyObject* torch_c_dynamo_guards_init(); // interfaces for extra_state and eval_frame.c because RootGuardManager class is // not visible there. void* convert_to_root_guard_manager(py::object root); bool run_root_guard_manager(void* root, PyObject* f_locals); + +struct LocalState { + // TLS state that changes operators + c10::impl::LocalDispatchKeySet dispatch_modifier; + c10::DispatchKeySet override_dispatch_key_set; + bool grad_mode_enabled; + + at::DispatchKeySet apply(at::DispatchKeySet ks) const { + if (override_dispatch_key_set.empty()) { + return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; + } else { + return override_dispatch_key_set; + } + } + + LocalState() + : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), + grad_mode_enabled(at::GradMode::is_enabled()) {} + + void overrideDispatchKeySet(c10::DispatchKeySet ks) { + override_dispatch_key_set = ks; + } +}; + +class TensorCheck { + public: + TensorCheck( + const LocalState& state, + PyTypeObject* pt, + const at::Tensor& v, + std::vector> dynamic_dims_sizes, + std::vector> dynamic_dims_strides); + + TensorCheck( + const LocalState& state, + PyTypeObject* pt, + uint64_t dispatch_key, + at::ScalarType dtype, + at::DeviceIndex device_index, + std::vector> dynamic_dims_sizes, + std::vector> dynamic_dims_strides); + + bool check(const LocalState& state, const at::Tensor& v); + std::string check_verbose( + const LocalState& state, + const at::Tensor& v, + const std::string& tensor_name); + + PyTypeObject* pytype; + + private: + uint64_t dispatch_key_; // DispatchKeySet includes device/layout + at::ScalarType dtype_; + // Note(voz): While dispatch_key_ is sufficiently representative of a device + // In that keys are more granular AND device specific - they do not + // necessarily capture device indices correctly. + at::DeviceIndex device_index_; + bool requires_grad_; + // NB: These are unset if dynamic shapes is enabled. + std::vector> sizes_; + std::vector> strides_; + // Not strictly required for dense tensors, but nested tensors need it. + int64_t dim_; +}; + +} // namespace torch::dynamo