-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AOT Autograd - Functionalization pass with torch.dispatch #88
Comments
WIP PR in functorch - pytorch/functorch#703 cc @bdhirsh Issues seen in functionalization by running it on Torchbenhc
Repros For
For
For
For
For
|
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 [ghstack-poisoned]
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 [ghstack-poisoned]
…en op" This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
…en op" This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
…en op" Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Addresses `fill_` issue in pytorch/torchdynamo#88 Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized. I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing. [ghstack-poisoned]
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88. ### The bug we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`). It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; return ty == c10::getTypePtr<c10::optional<at::Tensor>>(); ``` That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code: ``` template <typename T, typename U> bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { return (void*)x.get() == (void*)y.get(); } ``` Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects). You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871). When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>(); // compare pointers, but if that fails compare the actual type objects return expected_ty == ty || *expected_ty == *ty; ``` So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template). I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library. We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*. I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file. ### Testing? I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`). I confirmed that with this patch, this code runs successfully (it would break before) ``` import torch from functorch import make_fx from functorch.experimental import functionalize def f(x, y): return x[y] t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long)) print("Functionalized:\n", t1.graph) ``` ### Generalizing this fix to other container types? This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too. [ghstack-poisoned]
…en op" Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Addresses `fill_` issue in pytorch/torchdynamo#88 Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized. I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing. [ghstack-poisoned]
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88. ### The bug we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`). It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; return ty == c10::getTypePtr<c10::optional<at::Tensor>>(); ``` That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code: ``` template <typename T, typename U> bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { return (void*)x.get() == (void*)y.get(); } ``` Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects). You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871). When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>(); // compare pointers, but if that fails compare the actual type objects return expected_ty == ty || *expected_ty == *ty; ``` So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template). I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library. We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*. I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file. ### Testing? I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`). I confirmed that with this patch, this code runs successfully (it would break before) ``` import torch from functorch import make_fx from functorch.experimental import functionalize def f(x, y): return x[y] t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long)) print("Functionalized:\n", t1.graph) ``` ### Generalizing this fix to other container types? This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too. [ghstack-poisoned]
…en op" Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Addresses `fill_` issue in pytorch/torchdynamo#88 Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized. I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing. [ghstack-poisoned]
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88. ### The bug we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`). It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; return ty == c10::getTypePtr<c10::optional<at::Tensor>>(); ``` That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code: ``` template <typename T, typename U> bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { return (void*)x.get() == (void*)y.get(); } ``` Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects). You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871). When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>(); // compare pointers, but if that fails compare the actual type objects return expected_ty == ty || *expected_ty == *ty; ``` So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template). I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library. We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*. I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file. ### Testing? I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`). I confirmed that with this patch, this code runs successfully (it would break before) ``` import torch from functorch import make_fx from functorch.experimental import functionalize def f(x, y): return x[y] t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long)) print("Functionalized:\n", t1.graph) ``` ### Generalizing this fix to other container types? This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too. [ghstack-poisoned]
…en op" Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Addresses `fill_` issue in pytorch/torchdynamo#88 Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized. I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing. [ghstack-poisoned]
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88. ### The bug we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`). It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; return ty == c10::getTypePtr<c10::optional<at::Tensor>>(); ``` That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code: ``` template <typename T, typename U> bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { return (void*)x.get() == (void*)y.get(); } ``` Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects). You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871). When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>(); // compare pointers, but if that fails compare the actual type objects return expected_ty == ty || *expected_ty == *ty; ``` So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template). I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library. We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*. I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file. ### Testing? I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`). I confirmed that with this patch, this code runs successfully (it would break before) ``` import torch from functorch import make_fx from functorch.experimental import functionalize def f(x, y): return x[y] t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long)) print("Functionalized:\n", t1.graph) ``` ### Generalizing this fix to other container types? This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too. [ghstack-poisoned]
…en op" Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Fixes pytorch/functorch#705 This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`. It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`. We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding. From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op. This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88 Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378) [ghstack-poisoned]
Addresses `fill_` issue in pytorch/torchdynamo#88 Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized. I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing. [ghstack-poisoned]
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88. ### The bug we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`). It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; return ty == c10::getTypePtr<c10::optional<at::Tensor>>(); ``` That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code: ``` template <typename T, typename U> bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) { return (void*)x.get() == (void*)y.get(); } ``` Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects). You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871). When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked: ``` const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType; const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>(); // compare pointers, but if that fails compare the actual type objects return expected_ty == ty || *expected_ty == *ty; ``` So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template). I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library. We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*. I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file. ### Testing? I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`). I confirmed that with this patch, this code runs successfully (it would break before) ``` import torch from functorch import make_fx from functorch.experimental import functionalize def f(x, y): return x[y] t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long)) print("Functionalized:\n", t1.graph) ``` ### Generalizing this fix to other container types? This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too. [ghstack-poisoned]
…l-only args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
…args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
…args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
…args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
…args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
@bdhirsh is working on it right now. |
…positional-only args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
…args" Our JIT data model currently allows for a class of schemas that: (1) mutate some of their inputs (based on the aliasing info) (2) potentially return *new* outputs (unrelated to the mutated inputs) (3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out= This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment)) The majority of the work in this PR consisted of: (1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out` (2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=) (3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization). Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like: ``` ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { at::Tensor self_; if (at::functionalization::impl::isFunctionalTensor(self)) { at::functionalization::impl::sync(self); self_ = at::functionalization::impl::from_functional_tensor(self); } else { self_ = self; } at::Tensor observer_on_; if (at::functionalization::impl::isFunctionalTensor(observer_on)) { at::functionalization::impl::sync(observer_on); observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on); } else { observer_on_ = observer_on; } at::Tensor fake_quant_on_; if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) { at::functionalization::impl::sync(fake_quant_on); fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on); } else { fake_quant_on_ = fake_quant_on; } at::Tensor running_min_; if (at::functionalization::impl::isFunctionalTensor(running_min)) { at::functionalization::impl::sync(running_min); running_min_ = at::functionalization::impl::from_functional_tensor(running_min); } else { running_min_ = running_min; } at::Tensor running_max_; if (at::functionalization::impl::isFunctionalTensor(running_max)) { at::functionalization::impl::sync(running_max); running_max_ = at::functionalization::impl::from_functional_tensor(running_max); } else { running_max_ = running_max; } at::Tensor scale_; if (at::functionalization::impl::isFunctionalTensor(scale)) { at::functionalization::impl::sync(scale); scale_ = at::functionalization::impl::from_functional_tensor(scale); } else { scale_ = scale; } at::Tensor zero_point_; if (at::functionalization::impl::isFunctionalTensor(zero_point)) { at::functionalization::impl::sync(zero_point); zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point); } else { zero_point_ = zero_point; } if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) { if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) { // case 1: trying to mutate a non functional tensor with a functional tensor is an error TORCH_INTERNAL_ASSERT(false, "mutating a non-functional tensor with a functional tensor is not allowed.", " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); } else { // case 2: arguments are not functional tensors, so we no-op and redispatch. at::AutoDispatchSkipFunctionalize guard; ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);; } } else { ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output; { at::AutoDispatchSkipFunctionalize guard; tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); } at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output)); at::functionalization::impl::commit_update(running_min); at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output)); at::functionalization::impl::commit_update(running_max); at::functionalization::impl::replace_(scale, std::get<2>(tmp_output)); at::functionalization::impl::commit_update(scale); at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output)); at::functionalization::impl::commit_update(zero_point); auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output)); auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output)); return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1); } ``` [ghstack-poisoned]
New issues - @bdhirsh
The repro is as follows
|
@bdhirsh index still triggers assert error
|
…sorlists (#82326) There's an existing assert in functionalization that's probably too restrictive - when you pass a list of tensors to an op that has a mix of functional and nonfunctional tensors, we should just selectively unwrap the functional tensors and call the op rather than erroring. I added a test for it in `test_functionalization.py` - it looks like this behavior can also show up when tracing with `make_fx()`, when constants get baked in as module properties, which don't get wrapped up when you try to functionalize the module's forward function. Should fix the last of pytorch/torchdynamo#88 (comment) Pull Request resolved: #82326 Approved by: https://github.com/ezyang
…sorlists (#82326) (#82326) Summary: There's an existing assert in functionalization that's probably too restrictive - when you pass a list of tensors to an op that has a mix of functional and nonfunctional tensors, we should just selectively unwrap the functional tensors and call the op rather than erroring. I added a test for it in `test_functionalization.py` - it looks like this behavior can also show up when tracing with `make_fx()`, when constants get baked in as module properties, which don't get wrapped up when you try to functionalize the module's forward function. Should fix the last of pytorch/torchdynamo#88 (comment) Pull Request resolved: #82326 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/7eed83e01677b9eeab08cb2dd82fce3f0f50c3b3 Reviewed By: osalpekar Differential Revision: D38306837 Pulled By: bdhirsh fbshipit-source-id: 3c44a13447c9ca14ba02ad7d20dc7fd59c61e465
Closed in favor of pytorch/pytorch#93621 |
Run functionalization to resolve mutation related errors in AOT Autograd.
The text was updated successfully, but these errors were encountered: