Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOT Autograd - Functionalization pass with torch.dispatch #88

Closed
anijain2305 opened this issue Mar 22, 2022 · 5 comments
Closed

AOT Autograd - Functionalization pass with torch.dispatch #88

anijain2305 opened this issue Mar 22, 2022 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@anijain2305
Copy link
Contributor

anijain2305 commented Mar 22, 2022

Run functionalization to resolve mutation related errors in AOT Autograd.

@anijain2305
Copy link
Contributor Author

anijain2305 commented Apr 15, 2022

WIP PR in functorch - pytorch/functorch#703

cc @bdhirsh

Issues seen in functionalization by running it on Torchbenhc

  • zero_
  • index
  • _fused_moving_avg_obs_fq_helper
  • tensor_constants
  • fill_
  • getitem

Repros

For zero_


def fn1(tangents_1):
    new_empty = torch.ops.aten.new_empty(tangents_1, [1, 3, 2, 10])
    zero_ = torch.ops.aten.zero_(new_empty)
    return (zero_, None)

For tensor_constant


class FxModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('_tensor_constant1', torch.empty([512], dtype=torch.int64))


    def forward(self):
        _tensor_constant1 = self._tensor_constant1
        return (_tensor_constant1, ) 


print("Starting 4")
mod = FxModule()
def fn4():
    return mod()

ref = fn4()
res = functionalize(fn4)()
for t1, t2 in zip(ref, res):
    assert torch.allclose(t1, t2)

For fill_



def fn5(primals):
    fill_ = torch.ops.aten.fill_(primals, 1.0);  new_empty = None
    return fill_


inputs = [torch.empty(torch.Size([10, 204, 320])),]
ref = fn5(*inputs)
res = functionalize(fn5)(*inputs) 

For getitem

def fn6(inp0):
    linspace = torch.linspace(0.125, 0.875, 7)
    getitem_2 = linspace[inp0]
    return getitem_2



inp0 = torch.zeros(torch.Size([6]), dtype=torch.int64)
ref = fn6(inp0)
res = functionalize(fn6)(inp0)

For index

import torch
from functorch.experimental import functionalize

sizes =  [torch.Size([32, 128]), torch.Size([32])]
dtypes =  [torch.float32, torch.int64]


def fn(primals_1, primals_2):
    view = torch.ops.aten.view(primals_2, [1, -1]);  primals_2 = None
    select = torch.ops.aten.select(view, 0, 0);  view = None
    index = torch.ops.aten.index(primals_1, [select]);  primals_1 = select = None
    return index


inputs = [torch.ones(size=size, dtype=dtype) for (size, dtype) in zip(sizes, dtypes)]

ref = fn(*inputs)
print(ref)

res = functionalize(fn)(*inputs)
print(res)

bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 15, 2022
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 17, 2022
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 20, 2022
…en op"

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 20, 2022
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 20, 2022
…en op"

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 20, 2022
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 20, 2022
…en op"

Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 20, 2022
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 20, 2022
Addresses `fill_` issue in pytorch/torchdynamo#88

Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized.

I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 20, 2022
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88.

### The bug
we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`).

It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  return ty == c10::getTypePtr<c10::optional<at::Tensor>>();
```

That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code:
```
template <typename T, typename U>
bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
  return (void*)x.get() == (void*)y.get();
}
``` 

Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects).

You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871).

When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>();
  // compare pointers, but if that fails compare the actual type objects
  return expected_ty == ty || *expected_ty == *ty;
```

So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template).

I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library.

We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*.

I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file.

### Testing?

I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`).

I confirmed that with this patch, this code runs successfully (it would break before)

```
import torch
from functorch import make_fx
from functorch.experimental import functionalize

def f(x, y):
    return x[y]

t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long))
print("Functionalized:\n", t1.graph)
```


### Generalizing this fix to other container types?

This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 21, 2022
…en op"

Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 21, 2022
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 21, 2022
Addresses `fill_` issue in pytorch/torchdynamo#88

Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized.

I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 21, 2022
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88.

### The bug
we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`).

It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  return ty == c10::getTypePtr<c10::optional<at::Tensor>>();
```

That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code:
```
template <typename T, typename U>
bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
  return (void*)x.get() == (void*)y.get();
}
``` 

Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects).

You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871).

When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>();
  // compare pointers, but if that fails compare the actual type objects
  return expected_ty == ty || *expected_ty == *ty;
```

So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template).

I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library.

We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*.

I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file.

### Testing?

I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`).

I confirmed that with this patch, this code runs successfully (it would break before)

```
import torch
from functorch import make_fx
from functorch.experimental import functionalize

def f(x, y):
    return x[y]

t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long))
print("Functionalized:\n", t1.graph)
```


### Generalizing this fix to other container types?

This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 21, 2022
…en op"

Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 21, 2022
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 21, 2022
Addresses `fill_` issue in pytorch/torchdynamo#88

Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized.

I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 22, 2022
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88.

### The bug
we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`).

It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  return ty == c10::getTypePtr<c10::optional<at::Tensor>>();
```

That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code:
```
template <typename T, typename U>
bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
  return (void*)x.get() == (void*)y.get();
}
``` 

Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects).

You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871).

When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>();
  // compare pointers, but if that fails compare the actual type objects
  return expected_ty == ty || *expected_ty == *ty;
```

So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template).

I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library.

We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*.

I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file.

### Testing?

I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`).

I confirmed that with this patch, this code runs successfully (it would break before)

```
import torch
from functorch import make_fx
from functorch.experimental import functionalize

def f(x, y):
    return x[y]

t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long))
print("Functionalized:\n", t1.graph)
```


### Generalizing this fix to other container types?

This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
…en op"

Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
Addresses `fill_` issue in pytorch/torchdynamo#88

Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized.

I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88.

### The bug
we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`).

It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  return ty == c10::getTypePtr<c10::optional<at::Tensor>>();
```

That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code:
```
template <typename T, typename U>
bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
  return (void*)x.get() == (void*)y.get();
}
``` 

Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects).

You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871).

When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>();
  // compare pointers, but if that fails compare the actual type objects
  return expected_ty == ty || *expected_ty == *ty;
```

So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template).

I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library.

We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*.

I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file.

### Testing?

I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`).

I confirmed that with this patch, this code runs successfully (it would break before)

```
import torch
from functorch import make_fx
from functorch.experimental import functionalize

def f(x, y):
    return x[y]

t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long))
print("Functionalized:\n", t1.graph)
```


### Generalizing this fix to other container types?

This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
…en op"

Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
Addresses `fill_` issue in pytorch/torchdynamo#88

Adding out-of-place `fill.Tensor` and `fill.Scalar` ops, that way `fill_()` can be properly functionalized.

I ended up giving `fill` a derivative formula, because I think that we want to consider it a "base op" as part of tracing. The decomposition I wrote for it just calls back into `fill_()`, so we don't want to run that decomposition as part of tracing.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
This should fix `at::index.Tensor` for functionalization and address pytorch/torchdynamo#88.

### The bug
we have a bunch of code that expects all of our `c10::Type` objects to have a unique singleton instance, but it turns out that this isn't true for container types (like `c10::optional`).

It turns out that the `IValue::isOptionalTensorList` function I added earlier in #75716 doesn't always work:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  return ty == c10::getTypePtr<c10::optional<at::Tensor>>();
```

That equality check calls [this](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type_base.h#L586) code:
```
template <typename T, typename U>
bool operator==(const SingletonOrSharedTypePtr<T>& x, const SingletonOrSharedTypePtr<U>& y) {
  return (void*)x.get() == (void*)y.get();
}
``` 

Every `c10::Type` can be compared, but it also has its own singleton instance to make equality checks cheaper (just check that the two singleton instances are the same, instead of comparing the full type objects).

You can call `c10::getTypePtr<T>()`, and get back the singleton instance of the pointer to that type. For `optional<T>`, that singleton instance lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871).

When I was debugging, I noticed that `isOptionalTensorList` was returning false because the two pointers being compared were different, but the actual type objects were equal. I was able to repro this with `functionalize()`, but I couldn't repro it directly with test in core. Changing to this code worked:
```
  const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
  const auto& expected_ty == c10::getTypePtr<c10::optional<at::Tensor>>();
  // compare pointers, but if that fails compare the actual type objects
  return expected_ty == ty || *expected_ty == *ty;
```

So why do we have more than one "static singleton" instance of the same type object? The singleton instance for `c10::optional` lives [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/jit_type.h#L1871), and is defined in a header file (it has to be because it's a template).

I think that's because "function local statics are duplicated across DSO's". We have a similar comment about the dispatcher singleton and why it needs to live in a .cpp file [here](https://github.com/pytorch/pytorch/blob/e20793b05426284d973ac25e563046c037b4e4b2/aten/src/ATen/core/dispatch/Dispatcher.h#L95). Basically, since functorch and pytorch core live in two separate `.so` file, we'll end up with a new static singleton instance for each library.

We can't just move the singleton into a cpp though, since the function is templated - we want one singleton instance *per `optional<T>` type*.

I ended up doing it by converting each `T` into a `TypePtr` object, and keeping a mapping from `TypePtr` objects of the inner type to the static singleton instances in the .cpp file.

### Testing?

I couldn't figure out how to repro this failure in core, since I think the `functionalize()` failure came from the fact that we're loading multiple libraries that we're invoking the `c10::getTypePtr` call from (`libtorch_cpu.so` and `functorch/_C.so`).

I confirmed that with this patch, this code runs successfully (it would break before)

```
import torch
from functorch import make_fx
from functorch.experimental import functionalize

def f(x, y):
    return x[y]

t1 = make_fx(functionalize(f))(torch.arange(3), torch.ones(2, dtype=torch.long))
print("Functionalized:\n", t1.graph)
```


### Generalizing this fix to other container types?

This bug probably affects the other container `c10::Type`s, like List/Dict. I put this up as a PoC first, but if this seems like a reasonable fix then I can use the same fix for the other container types too.




[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
…l-only args"

Our JIT data model currently allows for a class of schemas that:
(1) mutate some of their inputs (based on the aliasing info)
(2) potentially return *new* outputs (unrelated to the mutated inputs)
(3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out=

This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment))

The majority of the work in this PR consisted of:

(1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out`

(2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=)

(3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization).

Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like:

```
    ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) {

      at::Tensor self_;
      if (at::functionalization::impl::isFunctionalTensor(self)) {
        at::functionalization::impl::sync(self);
        self_ = at::functionalization::impl::from_functional_tensor(self);
      } else {
        self_ = self;
      }

      at::Tensor observer_on_;
      if (at::functionalization::impl::isFunctionalTensor(observer_on)) {
        at::functionalization::impl::sync(observer_on);
        observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on);
      } else {
        observer_on_ = observer_on;
      }

      at::Tensor fake_quant_on_;
      if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) {
        at::functionalization::impl::sync(fake_quant_on);
        fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on);
      } else {
        fake_quant_on_ = fake_quant_on;
      }

      at::Tensor running_min_;
      if (at::functionalization::impl::isFunctionalTensor(running_min)) {
        at::functionalization::impl::sync(running_min);
        running_min_ = at::functionalization::impl::from_functional_tensor(running_min);
      } else {
        running_min_ = running_min;
      }

      at::Tensor running_max_;
      if (at::functionalization::impl::isFunctionalTensor(running_max)) {
        at::functionalization::impl::sync(running_max);
        running_max_ = at::functionalization::impl::from_functional_tensor(running_max);
      } else {
        running_max_ = running_max;
      }

      at::Tensor scale_;
      if (at::functionalization::impl::isFunctionalTensor(scale)) {
        at::functionalization::impl::sync(scale);
        scale_ = at::functionalization::impl::from_functional_tensor(scale);
      } else {
        scale_ = scale;
      }

      at::Tensor zero_point_;
      if (at::functionalization::impl::isFunctionalTensor(zero_point)) {
        at::functionalization::impl::sync(zero_point);
        zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point);
      } else {
        zero_point_ = zero_point;
      }
      if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) {
        if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) {
          // case 1: trying to mutate a non functional tensor with a functional tensor is an error
          TORCH_INTERNAL_ASSERT(false,
           "mutating a non-functional tensor with a functional tensor is not allowed.",
           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
        } else {
          // case 2: arguments are not functional tensors, so we no-op and redispatch.
          at::AutoDispatchSkipFunctionalize guard;
          ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
          auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output));
          auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output));
          return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);;
        }
      } else {
        ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output;
        {
          at::AutoDispatchSkipFunctionalize guard;
          tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
        }
        at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output));
        at::functionalization::impl::commit_update(running_min);
        at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output));
        at::functionalization::impl::commit_update(running_max);
        at::functionalization::impl::replace_(scale, std::get<2>(tmp_output));
        at::functionalization::impl::commit_update(scale);
        at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output));
        at::functionalization::impl::commit_update(zero_point);
        auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output));
        auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output));
        return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);
      }
```






[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 25, 2022
…args"

Our JIT data model currently allows for a class of schemas that:
(1) mutate some of their inputs (based on the aliasing info)
(2) potentially return *new* outputs (unrelated to the mutated inputs)
(3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out=

This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment))

The majority of the work in this PR consisted of:

(1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out`

(2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=)

(3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization).

Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like:

```
    ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) {

      at::Tensor self_;
      if (at::functionalization::impl::isFunctionalTensor(self)) {
        at::functionalization::impl::sync(self);
        self_ = at::functionalization::impl::from_functional_tensor(self);
      } else {
        self_ = self;
      }

      at::Tensor observer_on_;
      if (at::functionalization::impl::isFunctionalTensor(observer_on)) {
        at::functionalization::impl::sync(observer_on);
        observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on);
      } else {
        observer_on_ = observer_on;
      }

      at::Tensor fake_quant_on_;
      if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) {
        at::functionalization::impl::sync(fake_quant_on);
        fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on);
      } else {
        fake_quant_on_ = fake_quant_on;
      }

      at::Tensor running_min_;
      if (at::functionalization::impl::isFunctionalTensor(running_min)) {
        at::functionalization::impl::sync(running_min);
        running_min_ = at::functionalization::impl::from_functional_tensor(running_min);
      } else {
        running_min_ = running_min;
      }

      at::Tensor running_max_;
      if (at::functionalization::impl::isFunctionalTensor(running_max)) {
        at::functionalization::impl::sync(running_max);
        running_max_ = at::functionalization::impl::from_functional_tensor(running_max);
      } else {
        running_max_ = running_max;
      }

      at::Tensor scale_;
      if (at::functionalization::impl::isFunctionalTensor(scale)) {
        at::functionalization::impl::sync(scale);
        scale_ = at::functionalization::impl::from_functional_tensor(scale);
      } else {
        scale_ = scale;
      }

      at::Tensor zero_point_;
      if (at::functionalization::impl::isFunctionalTensor(zero_point)) {
        at::functionalization::impl::sync(zero_point);
        zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point);
      } else {
        zero_point_ = zero_point;
      }
      if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) {
        if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) {
          // case 1: trying to mutate a non functional tensor with a functional tensor is an error
          TORCH_INTERNAL_ASSERT(false,
           "mutating a non-functional tensor with a functional tensor is not allowed.",
           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
        } else {
          // case 2: arguments are not functional tensors, so we no-op and redispatch.
          at::AutoDispatchSkipFunctionalize guard;
          ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
          auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output));
          auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output));
          return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);;
        }
      } else {
        ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output;
        {
          at::AutoDispatchSkipFunctionalize guard;
          tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
        }
        at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output));
        at::functionalization::impl::commit_update(running_min);
        at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output));
        at::functionalization::impl::commit_update(running_max);
        at::functionalization::impl::replace_(scale, std::get<2>(tmp_output));
        at::functionalization::impl::commit_update(scale);
        at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output));
        at::functionalization::impl::commit_update(zero_point);
        auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output));
        auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output));
        return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);
      }
```






[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 26, 2022
…args"

Our JIT data model currently allows for a class of schemas that:
(1) mutate some of their inputs (based on the aliasing info)
(2) potentially return *new* outputs (unrelated to the mutated inputs)
(3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out=

This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment))

The majority of the work in this PR consisted of:

(1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out`

(2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=)

(3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization).

Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like:

```
    ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) {

      at::Tensor self_;
      if (at::functionalization::impl::isFunctionalTensor(self)) {
        at::functionalization::impl::sync(self);
        self_ = at::functionalization::impl::from_functional_tensor(self);
      } else {
        self_ = self;
      }

      at::Tensor observer_on_;
      if (at::functionalization::impl::isFunctionalTensor(observer_on)) {
        at::functionalization::impl::sync(observer_on);
        observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on);
      } else {
        observer_on_ = observer_on;
      }

      at::Tensor fake_quant_on_;
      if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) {
        at::functionalization::impl::sync(fake_quant_on);
        fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on);
      } else {
        fake_quant_on_ = fake_quant_on;
      }

      at::Tensor running_min_;
      if (at::functionalization::impl::isFunctionalTensor(running_min)) {
        at::functionalization::impl::sync(running_min);
        running_min_ = at::functionalization::impl::from_functional_tensor(running_min);
      } else {
        running_min_ = running_min;
      }

      at::Tensor running_max_;
      if (at::functionalization::impl::isFunctionalTensor(running_max)) {
        at::functionalization::impl::sync(running_max);
        running_max_ = at::functionalization::impl::from_functional_tensor(running_max);
      } else {
        running_max_ = running_max;
      }

      at::Tensor scale_;
      if (at::functionalization::impl::isFunctionalTensor(scale)) {
        at::functionalization::impl::sync(scale);
        scale_ = at::functionalization::impl::from_functional_tensor(scale);
      } else {
        scale_ = scale;
      }

      at::Tensor zero_point_;
      if (at::functionalization::impl::isFunctionalTensor(zero_point)) {
        at::functionalization::impl::sync(zero_point);
        zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point);
      } else {
        zero_point_ = zero_point;
      }
      if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) {
        if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) {
          // case 1: trying to mutate a non functional tensor with a functional tensor is an error
          TORCH_INTERNAL_ASSERT(false,
           "mutating a non-functional tensor with a functional tensor is not allowed.",
           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
        } else {
          // case 2: arguments are not functional tensors, so we no-op and redispatch.
          at::AutoDispatchSkipFunctionalize guard;
          ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
          auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output));
          auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output));
          return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);;
        }
      } else {
        ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output;
        {
          at::AutoDispatchSkipFunctionalize guard;
          tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
        }
        at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output));
        at::functionalization::impl::commit_update(running_min);
        at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output));
        at::functionalization::impl::commit_update(running_max);
        at::functionalization::impl::replace_(scale, std::get<2>(tmp_output));
        at::functionalization::impl::commit_update(scale);
        at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output));
        at::functionalization::impl::commit_update(zero_point);
        auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output));
        auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output));
        return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);
      }
```






[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 26, 2022
…args"

Our JIT data model currently allows for a class of schemas that:
(1) mutate some of their inputs (based on the aliasing info)
(2) potentially return *new* outputs (unrelated to the mutated inputs)
(3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out=

This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment))

The majority of the work in this PR consisted of:

(1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out`

(2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=)

(3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization).

Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like:

```
    ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) {

      at::Tensor self_;
      if (at::functionalization::impl::isFunctionalTensor(self)) {
        at::functionalization::impl::sync(self);
        self_ = at::functionalization::impl::from_functional_tensor(self);
      } else {
        self_ = self;
      }

      at::Tensor observer_on_;
      if (at::functionalization::impl::isFunctionalTensor(observer_on)) {
        at::functionalization::impl::sync(observer_on);
        observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on);
      } else {
        observer_on_ = observer_on;
      }

      at::Tensor fake_quant_on_;
      if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) {
        at::functionalization::impl::sync(fake_quant_on);
        fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on);
      } else {
        fake_quant_on_ = fake_quant_on;
      }

      at::Tensor running_min_;
      if (at::functionalization::impl::isFunctionalTensor(running_min)) {
        at::functionalization::impl::sync(running_min);
        running_min_ = at::functionalization::impl::from_functional_tensor(running_min);
      } else {
        running_min_ = running_min;
      }

      at::Tensor running_max_;
      if (at::functionalization::impl::isFunctionalTensor(running_max)) {
        at::functionalization::impl::sync(running_max);
        running_max_ = at::functionalization::impl::from_functional_tensor(running_max);
      } else {
        running_max_ = running_max;
      }

      at::Tensor scale_;
      if (at::functionalization::impl::isFunctionalTensor(scale)) {
        at::functionalization::impl::sync(scale);
        scale_ = at::functionalization::impl::from_functional_tensor(scale);
      } else {
        scale_ = scale;
      }

      at::Tensor zero_point_;
      if (at::functionalization::impl::isFunctionalTensor(zero_point)) {
        at::functionalization::impl::sync(zero_point);
        zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point);
      } else {
        zero_point_ = zero_point;
      }
      if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) {
        if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) {
          // case 1: trying to mutate a non functional tensor with a functional tensor is an error
          TORCH_INTERNAL_ASSERT(false,
           "mutating a non-functional tensor with a functional tensor is not allowed.",
           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
        } else {
          // case 2: arguments are not functional tensors, so we no-op and redispatch.
          at::AutoDispatchSkipFunctionalize guard;
          ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
          auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output));
          auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output));
          return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);;
        }
      } else {
        ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output;
        {
          at::AutoDispatchSkipFunctionalize guard;
          tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
        }
        at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output));
        at::functionalization::impl::commit_update(running_min);
        at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output));
        at::functionalization::impl::commit_update(running_max);
        at::functionalization::impl::replace_(scale, std::get<2>(tmp_output));
        at::functionalization::impl::commit_update(scale);
        at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output));
        at::functionalization::impl::commit_update(zero_point);
        auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output));
        auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output));
        return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);
      }
```






[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue Apr 27, 2022
…args"

Our JIT data model currently allows for a class of schemas that:
(1) mutate some of their inputs (based on the aliasing info)
(2) potentially return *new* outputs (unrelated to the mutated inputs)
(3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out=

This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment))

The majority of the work in this PR consisted of:

(1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out`

(2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=)

(3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization).

Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like:

```
    ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) {

      at::Tensor self_;
      if (at::functionalization::impl::isFunctionalTensor(self)) {
        at::functionalization::impl::sync(self);
        self_ = at::functionalization::impl::from_functional_tensor(self);
      } else {
        self_ = self;
      }

      at::Tensor observer_on_;
      if (at::functionalization::impl::isFunctionalTensor(observer_on)) {
        at::functionalization::impl::sync(observer_on);
        observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on);
      } else {
        observer_on_ = observer_on;
      }

      at::Tensor fake_quant_on_;
      if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) {
        at::functionalization::impl::sync(fake_quant_on);
        fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on);
      } else {
        fake_quant_on_ = fake_quant_on;
      }

      at::Tensor running_min_;
      if (at::functionalization::impl::isFunctionalTensor(running_min)) {
        at::functionalization::impl::sync(running_min);
        running_min_ = at::functionalization::impl::from_functional_tensor(running_min);
      } else {
        running_min_ = running_min;
      }

      at::Tensor running_max_;
      if (at::functionalization::impl::isFunctionalTensor(running_max)) {
        at::functionalization::impl::sync(running_max);
        running_max_ = at::functionalization::impl::from_functional_tensor(running_max);
      } else {
        running_max_ = running_max;
      }

      at::Tensor scale_;
      if (at::functionalization::impl::isFunctionalTensor(scale)) {
        at::functionalization::impl::sync(scale);
        scale_ = at::functionalization::impl::from_functional_tensor(scale);
      } else {
        scale_ = scale;
      }

      at::Tensor zero_point_;
      if (at::functionalization::impl::isFunctionalTensor(zero_point)) {
        at::functionalization::impl::sync(zero_point);
        zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point);
      } else {
        zero_point_ = zero_point;
      }
      if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) {
        if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) {
          // case 1: trying to mutate a non functional tensor with a functional tensor is an error
          TORCH_INTERNAL_ASSERT(false,
           "mutating a non-functional tensor with a functional tensor is not allowed.",
           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
        } else {
          // case 2: arguments are not functional tensors, so we no-op and redispatch.
          at::AutoDispatchSkipFunctionalize guard;
          ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
          auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output));
          auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output));
          return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);;
        }
      } else {
        ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output;
        {
          at::AutoDispatchSkipFunctionalize guard;
          tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
        }
        at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output));
        at::functionalization::impl::commit_update(running_min);
        at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output));
        at::functionalization::impl::commit_update(running_max);
        at::functionalization::impl::replace_(scale, std::get<2>(tmp_output));
        at::functionalization::impl::commit_update(scale);
        at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output));
        at::functionalization::impl::commit_update(zero_point);
        auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output));
        auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output));
        return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);
      }
```






[ghstack-poisoned]
@ciciewang
Copy link

@bdhirsh is working on it right now.

bdhirsh added a commit to pytorch/pytorch that referenced this issue May 10, 2022
…positional-only args"

Our JIT data model currently allows for a class of schemas that:
(1) mutate some of their inputs (based on the aliasing info)
(2) potentially return *new* outputs (unrelated to the mutated inputs)
(3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out=

This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment))

The majority of the work in this PR consisted of:

(1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out`

(2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=)

(3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization).

Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like:

```
    ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) {

      at::Tensor self_;
      if (at::functionalization::impl::isFunctionalTensor(self)) {
        at::functionalization::impl::sync(self);
        self_ = at::functionalization::impl::from_functional_tensor(self);
      } else {
        self_ = self;
      }

      at::Tensor observer_on_;
      if (at::functionalization::impl::isFunctionalTensor(observer_on)) {
        at::functionalization::impl::sync(observer_on);
        observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on);
      } else {
        observer_on_ = observer_on;
      }

      at::Tensor fake_quant_on_;
      if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) {
        at::functionalization::impl::sync(fake_quant_on);
        fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on);
      } else {
        fake_quant_on_ = fake_quant_on;
      }

      at::Tensor running_min_;
      if (at::functionalization::impl::isFunctionalTensor(running_min)) {
        at::functionalization::impl::sync(running_min);
        running_min_ = at::functionalization::impl::from_functional_tensor(running_min);
      } else {
        running_min_ = running_min;
      }

      at::Tensor running_max_;
      if (at::functionalization::impl::isFunctionalTensor(running_max)) {
        at::functionalization::impl::sync(running_max);
        running_max_ = at::functionalization::impl::from_functional_tensor(running_max);
      } else {
        running_max_ = running_max;
      }

      at::Tensor scale_;
      if (at::functionalization::impl::isFunctionalTensor(scale)) {
        at::functionalization::impl::sync(scale);
        scale_ = at::functionalization::impl::from_functional_tensor(scale);
      } else {
        scale_ = scale;
      }

      at::Tensor zero_point_;
      if (at::functionalization::impl::isFunctionalTensor(zero_point)) {
        at::functionalization::impl::sync(zero_point);
        zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point);
      } else {
        zero_point_ = zero_point;
      }
      if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) {
        if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) {
          // case 1: trying to mutate a non functional tensor with a functional tensor is an error
          TORCH_INTERNAL_ASSERT(false,
           "mutating a non-functional tensor with a functional tensor is not allowed.",
           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
        } else {
          // case 2: arguments are not functional tensors, so we no-op and redispatch.
          at::AutoDispatchSkipFunctionalize guard;
          ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
          auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output));
          auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output));
          return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);;
        }
      } else {
        ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output;
        {
          at::AutoDispatchSkipFunctionalize guard;
          tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
        }
        at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output));
        at::functionalization::impl::commit_update(running_min);
        at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output));
        at::functionalization::impl::commit_update(running_max);
        at::functionalization::impl::replace_(scale, std::get<2>(tmp_output));
        at::functionalization::impl::commit_update(scale);
        at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output));
        at::functionalization::impl::commit_update(zero_point);
        auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output));
        auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output));
        return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);
      }
```






[ghstack-poisoned]
bdhirsh added a commit to pytorch/pytorch that referenced this issue May 10, 2022
…args"

Our JIT data model currently allows for a class of schemas that:
(1) mutate some of their inputs (based on the aliasing info)
(2) potentially return *new* outputs (unrelated to the mutated inputs)
(3) the mutated inputs are not `self` or `out` kwargs, so the op is neither inplace nor out=

This PR adds support to be able to functionalize that class of ops, and also adds support for `_fused_moving_avg_obs_fq_helper` to ensure that it all works. (This op is needed for torchdynamo, as its used in some resnet models on torchbench. See pytorch/torchdynamo#88 (comment))

The majority of the work in this PR consisted of:

(1) Getting the functionalization codegen to detect "schemas that have any mutable args", instead of special-casing directly on `SchemaKind.inplace/out`

(2) Ensuring that we properly group mutable ops with their corresponding functional variants properly (and like the above, you can't rely on `SchemaKind` anymore because the mutable op is neither inplace nor out=)

(3) Removing some assumptions that the codegen made about mutable ops. For example, I used to assume it was always ok to return the `self` or `out=` args - but you can't always do that. Mutable ops are allowed to mutate their inputs by side effect, and return totally different output tensors (that then need to be wrapped by functionalization).

Here's what the codegen'd kernel for `_fused_moving_avg_obs_fq_helper` looks like:

```
    ::std::tuple<at::Tensor,at::Tensor> _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) {

      at::Tensor self_;
      if (at::functionalization::impl::isFunctionalTensor(self)) {
        at::functionalization::impl::sync(self);
        self_ = at::functionalization::impl::from_functional_tensor(self);
      } else {
        self_ = self;
      }

      at::Tensor observer_on_;
      if (at::functionalization::impl::isFunctionalTensor(observer_on)) {
        at::functionalization::impl::sync(observer_on);
        observer_on_ = at::functionalization::impl::from_functional_tensor(observer_on);
      } else {
        observer_on_ = observer_on;
      }

      at::Tensor fake_quant_on_;
      if (at::functionalization::impl::isFunctionalTensor(fake_quant_on)) {
        at::functionalization::impl::sync(fake_quant_on);
        fake_quant_on_ = at::functionalization::impl::from_functional_tensor(fake_quant_on);
      } else {
        fake_quant_on_ = fake_quant_on;
      }

      at::Tensor running_min_;
      if (at::functionalization::impl::isFunctionalTensor(running_min)) {
        at::functionalization::impl::sync(running_min);
        running_min_ = at::functionalization::impl::from_functional_tensor(running_min);
      } else {
        running_min_ = running_min;
      }

      at::Tensor running_max_;
      if (at::functionalization::impl::isFunctionalTensor(running_max)) {
        at::functionalization::impl::sync(running_max);
        running_max_ = at::functionalization::impl::from_functional_tensor(running_max);
      } else {
        running_max_ = running_max;
      }

      at::Tensor scale_;
      if (at::functionalization::impl::isFunctionalTensor(scale)) {
        at::functionalization::impl::sync(scale);
        scale_ = at::functionalization::impl::from_functional_tensor(scale);
      } else {
        scale_ = scale;
      }

      at::Tensor zero_point_;
      if (at::functionalization::impl::isFunctionalTensor(zero_point)) {
        at::functionalization::impl::sync(zero_point);
        zero_point_ = at::functionalization::impl::from_functional_tensor(zero_point);
      } else {
        zero_point_ = zero_point;
      }
      if (!(true && at::functionalization::impl::isFunctionalTensor(running_min) && at::functionalization::impl::isFunctionalTensor(running_max) && at::functionalization::impl::isFunctionalTensor(scale) && at::functionalization::impl::isFunctionalTensor(zero_point))) {
        if ((false || at::functionalization::impl::isFunctionalTensor(self) || at::functionalization::impl::isFunctionalTensor(observer_on) || at::functionalization::impl::isFunctionalTensor(fake_quant_on))) {
          // case 1: trying to mutate a non functional tensor with a functional tensor is an error
          TORCH_INTERNAL_ASSERT(false,
           "mutating a non-functional tensor with a functional tensor is not allowed.",
           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
        } else {
          // case 2: arguments are not functional tensors, so we no-op and redispatch.
          at::AutoDispatchSkipFunctionalize guard;
          ::std::tuple<at::Tensor,at::Tensor> tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
          auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<0>(tmp_output));
          auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<1>(tmp_output));
          return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);;
        }
      } else {
        ::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor> tmp_output;
        {
          at::AutoDispatchSkipFunctionalize guard;
          tmp_output = at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self_, observer_on_, fake_quant_on_, running_min_, running_max_, scale_, zero_point_, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant);
        }
        at::functionalization::impl::replace_(running_min, std::get<0>(tmp_output));
        at::functionalization::impl::commit_update(running_min);
        at::functionalization::impl::replace_(running_max, std::get<1>(tmp_output));
        at::functionalization::impl::commit_update(running_max);
        at::functionalization::impl::replace_(scale, std::get<2>(tmp_output));
        at::functionalization::impl::commit_update(scale);
        at::functionalization::impl::replace_(zero_point, std::get<3>(tmp_output));
        at::functionalization::impl::commit_update(zero_point);
        auto output_0 = at::functionalization::impl::to_functional_tensor(std::get<4>(tmp_output));
        auto output_1 = at::functionalization::impl::to_functional_tensor(std::get<5>(tmp_output));
        return ::std::tuple<at::Tensor,at::Tensor>(output_0, output_1);
      }
```






[ghstack-poisoned]
@chekangliang chekangliang added the enhancement New feature or request label May 14, 2022
@anijain2305
Copy link
Contributor Author

anijain2305 commented May 25, 2022

New issues - @bdhirsh

  • index

The repro is as follows

import torch
from torch.nn import *
from functorch.experimental import functionalize

class FxModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('_tensor_constant0', torch.zeros([4, 1], dtype=torch.int64))
        self.register_buffer('_tensor_constant1', torch.zeros([4, 359613], dtype=torch.int64))

    
    
    def forward(self, primals_1, primals_2):
        view = torch.ops.aten.view(primals_2, [4, 359613]);  primals_2 = None
        split_with_sizes = torch.ops.aten.split_with_sizes(view, [269952, 67488, 16872, 4218, 1083], 1)
        getitem = split_with_sizes[0]
        getitem_1 = split_with_sizes[1]
        getitem_2 = split_with_sizes[2]
        getitem_3 = split_with_sizes[3]
        getitem_4 = split_with_sizes[4];  split_with_sizes = None
        topk = torch.ops.aten.topk(getitem, 1000, 1);  getitem = None
        getitem_5 = topk[0]
        getitem_6 = topk[1];  topk = None
        add = torch.ops.aten.add(getitem_6, 0);  getitem_6 = None
        topk_1 = torch.ops.aten.topk(getitem_1, 1000, 1);  getitem_1 = None
        getitem_7 = topk_1[0]
        getitem_8 = topk_1[1];  topk_1 = None
        add_1 = torch.ops.aten.add(getitem_8, 269952);  getitem_8 = None
        topk_2 = torch.ops.aten.topk(getitem_2, 1000, 1);  getitem_2 = None
        getitem_9 = topk_2[0]
        getitem_10 = topk_2[1];  topk_2 = None
        add_2 = torch.ops.aten.add(getitem_10, 337440);  getitem_10 = None
        topk_3 = torch.ops.aten.topk(getitem_3, 1000, 1);  getitem_3 = None
        getitem_11 = topk_3[0]
        getitem_12 = topk_3[1];  topk_3 = None
        add_3 = torch.ops.aten.add(getitem_12, 354312);  getitem_12 = None
        topk_4 = torch.ops.aten.topk(getitem_4, 1000, 1);  getitem_4 = None
        getitem_13 = topk_4[0]
        getitem_14 = topk_4[1];  topk_4 = None
        add_4 = torch.ops.aten.add(getitem_14, 358530);  getitem_14 = None
        cat = torch.ops.aten.cat([add, add_1, add_2, add_3, add_4], 1);  add = add_1 = add_2 = add_3 = add_4 = None
        _tensor_constant0 = self._tensor_constant0
        index = torch.ops.aten.index(view, [_tensor_constant0, cat]);  view = _tensor_constant0 = None
        _tensor_constant1 = self._tensor_constant1
        _tensor_constant0_1 = self._tensor_constant0
        index_1 = torch.ops.aten.index(_tensor_constant1, [_tensor_constant0_1, cat]);  _tensor_constant1 = _tensor_constant0_1 = None
        _tensor_constant0_2 = self._tensor_constant0
        index_2 = torch.ops.aten.index(primals_1, [_tensor_constant0_2, cat]);  primals_1 = _tensor_constant0_2 = cat = None
        sigmoid = torch.ops.aten.sigmoid(index);  index = None
        select = torch.ops.aten.select(index_2, 0, 0)
        select_1 = torch.ops.aten.select(index_2, 0, 1)
        select_2 = torch.ops.aten.select(index_2, 0, 2)
        select_3 = torch.ops.aten.select(index_2, 0, 3);  index_2 = None
        select_4 = torch.ops.aten.select(sigmoid, 0, 0)
        select_5 = torch.ops.aten.select(sigmoid, 0, 1)
        select_6 = torch.ops.aten.select(sigmoid, 0, 2)
        select_7 = torch.ops.aten.select(sigmoid, 0, 3);  sigmoid = None
        select_8 = torch.ops.aten.select(index_1, 0, 0)
        select_9 = torch.ops.aten.select(index_1, 0, 1)
        select_10 = torch.ops.aten.select(index_1, 0, 2)
        select_11 = torch.ops.aten.select(index_1, 0, 3);  index_1 = None
        slice_1 = torch.ops.aten.slice(select, 1, 0, 9223372036854775807, 2)
        slice_2 = torch.ops.aten.slice(select, 1, 1, 9223372036854775807, 2);  select = None
        clamp = torch.ops.aten.clamp(slice_1, 0, 1199);  slice_1 = None
        clamp_1 = torch.ops.aten.clamp(slice_2, 0, 799);  slice_2 = None
        stack = torch.ops.aten.stack([clamp, clamp_1], 2);  clamp = clamp_1 = None
        view_1 = torch.ops.aten.view(stack, [5000, 4]);  stack = None
        return [select_1, select_5, select_9, select_2, select_6, select_10, select_3, select_7, select_11, view_1, select_4, select_8, None, None]
        

primals_sizes = [torch.Size([4, 359613, 4]), torch.Size([1438452, 1])]
primals_dtypes = [torch.float32, torch.float32]



mod = FxModule().to(device="cpu")
primals = [torch.empty(size, dtype=dtype, device="cpu") for (size, dtype) in zip(primals_sizes, primals_dtypes)]

inputs = primals
print(mod(*inputs))



res = functionalize(mod)(*inputs)
print(res)
import torch
from functorch.experimental import functionalize


def fn(x, y):
    zero_ = torch.ops.aten.zero_(x)
    as_strided = torch.ops.aten.as_strided(zero_, [16, 64, 128, 128], [1048576, 16384, 128, 1], 0)
    copy__34 = torch.ops.aten.copy_(as_strided, y);  as_strided = getitem_68 = None
    as_strided_1 = torch.ops.aten.as_strided(zero_, [1, 1024, 128, 128], [16777216, 16384, 128, 1], 0);  zero_ = None
    new_empty_strided = torch.ops.aten.new_empty_strided(as_strided_1, [1, 1024, 128, 128], [16777216, 16384, 128, 1])
    return new_empty_strided

sizes = [(1, 1024, 128, 128), (16, 64, 128, 128)]
dtypes = [torch.float32, torch.float32]

inputs = [torch.randn(size=size, dtype=dtype) for (size, dtype) in zip(sizes, dtypes)]

ref = fn(*inputs)
res = functionalize(fn)(*inputs)

@ezyang
Copy link
Contributor

ezyang commented Jul 23, 2022

@bdhirsh index still triggers assert error

RuntimeError: functional_count == 0 || nonfunctional_count == 0 INTERNAL ASSERT FAILED at "/raid/ezyang/pytorch-scratch2/aten/src/ATe
n/FunctionalTensorWrapper.cpp":549, please report a bug to PyTorch. Functionalization encountered a list of tensors where some are fu
nctionaland some are not, which is not currently unsupported.   

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Jul 29, 2022
…sorlists (#82326)

There's an existing assert in functionalization that's probably too restrictive - when you pass a list of tensors to an op that has a mix of functional and nonfunctional tensors, we should just selectively unwrap the functional tensors and call the op rather than erroring.

I added a test for it in `test_functionalization.py` - it looks like this behavior can also show up when tracing with `make_fx()`, when constants get baked in as module properties, which don't get wrapped up when you try to functionalize the module's forward function.

Should fix the last of pytorch/torchdynamo#88 (comment)

Pull Request resolved: #82326
Approved by: https://github.com/ezyang
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Aug 1, 2022
…sorlists (#82326) (#82326)

Summary:
There's an existing assert in functionalization that's probably too restrictive - when you pass a list of tensors to an op that has a mix of functional and nonfunctional tensors, we should just selectively unwrap the functional tensors and call the op rather than erroring.

I added a test for it in `test_functionalization.py` - it looks like this behavior can also show up when tracing with `make_fx()`, when constants get baked in as module properties, which don't get wrapped up when you try to functionalize the module's forward function.

Should fix the last of pytorch/torchdynamo#88 (comment)

Pull Request resolved: #82326
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/7eed83e01677b9eeab08cb2dd82fce3f0f50c3b3

Reviewed By: osalpekar

Differential Revision: D38306837

Pulled By: bdhirsh

fbshipit-source-id: 3c44a13447c9ca14ba02ad7d20dc7fd59c61e465
@anijain2305
Copy link
Contributor Author

Closed in favor of pytorch/pytorch#93621

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
No open projects
Development

No branches or pull requests

6 participants