Skip to content

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Aug 10, 2021

This PR adds the barebones of a generic functionalization pass through the dispatcher, that works for both functorch and XLA. It's based off of @ailzhang's prototype, which lives here, and is nicely documented here). That prototype in turn is heavily inspired from pytorch/xla's implementation of functionalization! (See the core parts here).

I won't go into the details of the aliasing removal, which you can see in the original prototype. Instead I'll focus more on the API to integrate with it (as a composable transform, and as a backend like XLA), the mutation removal logic, and the followups.

I have corresponding test branches in functorch and pt/xla.

Working examples
I used two basic examples to help figure out how to make changes, using the corresponding changes in functorch/xla that I linked above. I'm planning on testing some more complicated functorch examples soon:

import torch
from functorch import make_fx, grad, vmap, functionalize

def f(x):  # tests that inputs are still successfully mutated
    tmp = torch.ones(4)
    y = x.view(4)
    y.add_(tmp)
    return x

def f(x):  # test the free variable mutation case, which currently breaks in functorch
    tmp = torch.ones(4)
    tmp.add_(x)
    return tmp

batched_input = torch.ones(2, 4)
vmap(functionalize((f)))(batched_input)
vmap(functionalize((f2)))(batched_input)
import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()
a = torch.ones(2, 2, device=device)
b = a.view(4)
a.add_(1) # successfully mutates b

API for functorch / backends

The main difference between the two is that functorch wants a composable transform (calling it functionalize() for now), which involves some notion of levels and wrapping. Backends like XLA only want the direct mutation / view removal.

I've factored those pieces into a FunctionalTensorImplBase (inherits from TensorImpl), and FunctionalTensorWrapper (inherits from the base). The base class deals with mutation and view removal, while the subclass knows about levels and wrapping tensors. backend tensors like XLA will also inherit from FunctionalTensorImplBase to get opted into the functionalization pass.

functionalize() is implemented similar to vmap/grad in the functorch repo, and just involves wrapping inputs/outputs in a FunctionalTensorWrapper of the appropriate level (and making DynamicLayerFrontMode aware of the Functionalization key). The (unpolished) functorch-side changes currently live in a branch here. FunctionalTensorWrapper also overrides all virtual methods to call into its wrapped value - I'm not sure if this is overkill, but the goal is that aside from the functionalization pass, everything else in pytorch shouldn't see the wrapper. We can do that for dispatcher-aware operators by explicitly unwrapping, but calling e.g. t.sizes() or t.numel() should also unwrap the tensor first.

I moved FunctionalTensorWrapper into the functorch branch, so core is unaware of it. The invariant is that when the functionalization kernels redispatch on an operator, the returned tensor is always a FunctionalTensorImplBase subclass. For backends like XLA, this happens automatically (since we get back an XLATensor, which inherits from FunctionalTensorImplBase). For functorch, the DynamicLayerBackFallback kernel is responsible for unwrapping inputs and wrapping up the outputs in a FunctionalTensorWrapper . This is necessary in order to properly record the current composable transform level in the wrapper (although if the dynamic leveling dispatcher logic eventually moves in core, then this can move in core too).

Integration with backends is currently implemented by updating your TensorImpl subclass to inherit from FunctionalTensorImplBase instead of TensorImpl, and adding an overridden replace_() method, which tells us how to "re-use" your tensor. I have a branch with XLA-side changes here (it's very small, and still a POC - it also doesn't remove any xla-side alias handling yet, since we need to add support for more views in core first).

Mutation removal logic

There's a new at::replace_ op in native_functions.yaml. The only reason it needs to be in the yaml is because I think it needs to show up in traces from functorch (functorch would like "functionalized" traces. We can't do that directly, but it should be pretty easy to convert a graph with a bunch of replace_ calls in it to its functionalized version).

The idea behind replace_ is that it knows how to swap out the data backing a tensor with fresh data. For cpu/cuda (TensorImpl), that corresponds to replacing a storage pointer (and associated metadata). For xla (XLATensorImpl), that corresponds to swapping out IR.

However, the idea behind the FunctionalTensorWrapper wrapper is that we can't just swap out metadata on a TensorImpl for functorch. Functorch has to deal with mixed inputs: e.g. cpu_tensor.add_(batched_tensor) should "promote" the input to be a batched tensor. That corresponds to swapping out cpu_tensor's TensorImpl object with the BatchedTensorImpl result of summing the two together.

FunctionalTensorWrapper does this by storing a full Tensor, and fully swapping it out for a new tensor when the corresponding new and old TensorImpl* types don't match up.

Right now this is implemented with two different methods:

  • replace(TensorImpl* other_impl). This is a virtual method on TensorImpl, that knows how to do the metadata swap as long as the TensorImpl argument is of the same type (also implemented for BatchedTensorImpl and TensorWrapper in functorch)
  • replace(Tensor& other). This is a composite kernel exposed in native_functions.yaml so we can trace it. It calls into the virtual method

I made the dispatcher-aware replace(Tensor) op a composite kernel, which required registering fall throughs to each functorch pass in order to avoid going to their boxed kernels. Another option would have been to not do that and instead have a single replace_() operator, with different kernels registered for functionalize/vmap/grad/cpu/xla. I mostly opted for the current approach because it was easier for me to picture what replacing a FunctionalTensorWrapper means as repeatedly unwrapping and replacing the internals, vs. doing a bunch of redispatch trips through the dispatcher (and it would also be slower).

[functorch] should the Alias object hold the wrapper or the unwrapped tensor?
(See the comments for details - this is resolved. I'm adding new {view}_copy operator variants for each view op, that knows to skip the functionalization pass).
When you create a view, we fork off the original tensor into an Alias object, that both the original tensor and the view tensor are now aware of. Right now, I have that alias object hold a clone of the original tensor (which in the functorch case is a FunctionalTensorWrapper wrapper). I did that because it made the view-handling logic more consistent (we don't need to special-case functorch), but there's a problem where we need to perform a bunch of view operations on the alias tensor when we sync. If the alias is a FunctionalTensorWrapper wrapper, we'll call back into the functionalization pass machinery and infinite loop (creating an alias for the alias). That's true for xla too, but we can just add an exclude guard before calling sync_() in the xla case. We can't do that for functorch, since the DynamicLayer logic overwrites TLS at each layer.

One way to get around this would have been to create separate {view}_copy operators for every view op, and register fallthroughs for them in the functionalization pass. Instead, I added a bit on the FunctionalTensorImplBase class to tell us if it's actually stored inside of an alias, so the functionalization pass will know to skip the view machinery if the bit is set.

I tried to call out other major details in specific comments in the code.

Codegen output

Inplace ops are all codegen'd (at least, the ones that have a functional version). View ops need to be added one-by-one, so I've only added at::view for now. The codegen output looks like this for view:

at::Tensor view(c10::DispatchKeySet ks, const at::Tensor & self, at::IntArrayRef size) {
        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.key_set().has(c10::DispatchKey::Functionalize));
        at::Tensor out;
        {
          at::AutoDispatchBelowFunctionalize guard;
          auto tmp_output = at::redispatch::view(ks & c10::after_func_keyset, self, size);
          out = tmp_output.clone();
        }
        // See Note [Marking Alias Tensors]
        if (!at::functionalization::impl::is_alias_tensor(out)) {
          // TODO we'll probably want a separate function for each view op that gets creates the corresponding ViewMeta.
          at::ViewMeta view_meta = at::functionalization::impl::get_meta_view(self, size);
          at::functionalization::impl::set_view_meta(out, self, view_meta);
        }
        return out;
}

And for an example inplace op, the codegen looks like this:

at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
        {
            at::AutoDispatchBelowFunctionalize guard;
            auto tmp_output = at::redispatch::add(ks & c10::after_func_keyset, self, other, alpha);
            self.replace_(tmp_output);
            TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.key_set().has(c10::DispatchKey::Functionalize));
            at::functionalization::impl::maybe_add_update(self);
        }
        return self;
}

Printing tensors
In the XLA case, printing tensors should just work - the call to _tensor_str() eventually just calls to(‘cpu’), which hits the functionalization fallback and syncs everything.

In the functorch case, functorch overrides the _tensor_str function to first recursively unwrap the tensor wrappers before calling the original version of _tensor_str(), so I added logic in functorch to be aware of unwrapping FunctionalTensorWrapper objects, and sync before unwrapping. I also have the functionalize() pass call sync_() on every tensor input after the pass completes, to ensure that inputs get mutated correctly.

The handling for when to wrap up outputs is a little fragile though, because of factory functions + printing. For factory functions, we need to make sure that we still wrap output tensors even if there are no input tensor arguments. For the printing, we need to make sure not to wrap arguments. This case is distinguished by the fact that non of the tensor input arguments are wrapped. That means that stuff can break if any factory functions are called inside of _tensor_str(), but it looks like that currently isn't the case.

Other stuff

  • The functionalization pass has a boxed fallback (in VariableFallback.cpp), but it shouldn’t be too hard to move that to codegen, since all it does is sync all of the input tensors.
  • For any unsupported view/mutation ops (either view ops that aren’t implemented yet, or mutation ops that don’t currently have an out-of-place equivalent), I codegen a kernel that’s pretty similar to the boxed fallback - it just syncs the inputs and redispatches. It also prints a warning, to help figure out which mutation ops we need to add out-of-place versions for.
  • I defined a bunch of helpers functions in the at::functionalization::impl namespace, mostly as utility functions to make the codegen easier
  • Right now I have an enum for each view op. The Alias contains a stack of updates, and each update contains a stack of ViewMetas, explaining what view ops were run on the base tensor to get to the view before the mutation occurred. The logic to sync mutations across aliases involves replaying the views in reverse, to figure out what the base tensor looks like after every mutation. The enums are a little bit ugly, but I’m not sure of a more significantly elegant way to represent them. I also listed out the full set of enum values, but that might be too presumptive.
  • Each view op needs to store some extra info in order to replay it in reverse - this will probably need to be implemented separately per view - right now I have an at::functionalization::impl::get_meta_{view}(…) function that knows which information to store (called by the codegen), and I’m planning on trying to implement similar functions for the other view ops.

Followups

  • I haven't carefully tested a bunch of use cases with functorch yet (like nested called to functionalize)
  • Add support for more view ops (probably not all of them... but the important ones. The handful that xla implements are probably a good starting point).
  • In particular, I have a feeling the codegen will change a little as more view ops are added. For example, torch.split() is a view op that returns multiple output tensors, which will all alias the same input tensor. Need to make sure that the codegen handles that gracefully. There are also a bunch of view ops that are both views and mutations, like transpose_ and as_strided_. That's probably gonna require extra codegen.
  • audit the pass for perf (unnecessary tensor clones and refcount bumps)
  • Think about a version of the pass that's just mutation removal, or just alias removal? One option is to add separate keys for AliasRemovalOnly and MutationRemovalOnly, and factor the codegen well enough that it can be re-used. Another would be to split out the current alias and mutation removal bits into two passes, although that'll require an extra dispatcher trip everywhere.

Stack from ghstack:

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 10, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 76f0961 (more details on the Dr. CI page):


  • 16/16 failures possibly* introduced in this PR
    • 1/16 non-scanned failure(s)

🕵️ 12 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build win-vs2019-cpu-py3 / test (default, 2, 2, windows.4xlarge) (1/12)

Step: "Run test scripts" (full log | diagnosis details | 🔁 rerun)

2021-08-14T02:10:50.2759777Z AssertionError: Th...eturned by torch._overrides.get_ignored_functions.
2021-08-14T02:10:48.4742843Z Generated XML report: test-reports\dist-gloo\test_vmap\TEST-TestVmapAPI-20210814021045.xml
2021-08-14T02:10:48.4744155Z Generated XML report: test-reports\dist-gloo\test_vmap\TEST-TestVmapBatchedGradientCPU-20210814021045.xml
2021-08-14T02:10:48.4745470Z Generated XML report: test-reports\dist-gloo\test_vmap\TEST-TestVmapOperators-20210814021045.xml
2021-08-14T02:10:48.6418814Z Running test_overrides ... [2021-08-14 02:10:48.634960]
2021-08-14T02:10:48.6419686Z Executing ['coverage', 'run', '--parallel-mode', '--source=torch', 'test_overrides.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2021-08-14 02:10:48.634960]
2021-08-14T02:10:50.2755205Z Traceback (most recent call last):
2021-08-14T02:10:50.2756462Z   File "test_overrides.py", line 348, in <module>
2021-08-14T02:10:50.2756977Z     generate_tensor_like_torch_implementations()
2021-08-14T02:10:50.2757768Z   File "test_overrides.py", line 336, in generate_tensor_like_torch_implementations
2021-08-14T02:10:50.2758458Z     assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
2021-08-14T02:10:50.2759777Z AssertionError: The following functions are not tested for __torch_function__ support, please ensure there is an entry in the dict returned by torch._overrides.get_testing_overrides for this function or if a __torch_function__ override does not make sense, add an entry to the tuple returned by torch._overrides.get_ignored_functions.
2021-08-14T02:10:50.2760797Z 
2021-08-14T02:10:50.2761197Z ["<class 'torch.Tensor'>.replace_", "<class 'torch.Tensor'>.sync_"]
2021-08-14T02:10:51.1727661Z Traceback (most recent call last):
2021-08-14T02:10:51.1728466Z   File "run_test.py", line 1091, in <module>
2021-08-14T02:10:51.1728821Z     main()
2021-08-14T02:10:51.1729340Z   File "run_test.py", line 1070, in main
2021-08-14T02:10:51.1729776Z     raise RuntimeError(err_message)
2021-08-14T02:10:51.1730285Z RuntimeError: test_overrides failed!
2021-08-14T02:10:51.3937052Z 
2021-08-14T02:10:51.3937869Z (base) C:\actions-runner\_work\pytorch\pytorch\pytorch-1129406720\test>popd

See GitHub Actions build linux-bionic-py3.8-gcc9-coverage / test (default, 1, 2, linux.2xlarge) (2/12)

Step: "Test PyTorch" (full log | diagnosis details | 🔁 rerun)

2021-08-14T00:59:31.5573460Z AssertionError: Fa...ot true : Tensor.replace_ is missing documentation
2021-08-14T00:59:31.5567455Z   test_cuda_vitals_gpu_only_cpu (__main__.TestVitalSignsCudaCPU) ... skip (0.001s)
2021-08-14T00:59:31.5568213Z 
2021-08-14T00:59:31.5568648Z ======================================================================
2021-08-14T00:59:31.5569021Z FAIL [0.014s]: test_doc (__main__.TestTorch)
2021-08-14T00:59:31.5569994Z ----------------------------------------------------------------------
2021-08-14T00:59:31.5570479Z Traceback (most recent call last):
2021-08-14T00:59:31.5571115Z   File "test_torch.py", line 200, in test_doc
2021-08-14T00:59:31.5571554Z     test_namespace(torch.randn(1),
2021-08-14T00:59:31.5572007Z   File "test_torch.py", line 195, in test_namespace
2021-08-14T00:59:31.5572740Z     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
2021-08-14T00:59:31.5573460Z AssertionError: False is not true : Tensor.replace_ is missing documentation
2021-08-14T00:59:31.5573892Z 
2021-08-14T00:59:31.5574395Z ----------------------------------------------------------------------
2021-08-14T00:59:31.5574806Z Ran 661 tests in 37.179s
2021-08-14T00:59:31.5575013Z 
2021-08-14T00:59:31.5575333Z FAILED (failures=1, skipped=44)
2021-08-14T00:59:31.5575572Z 
2021-08-14T00:59:31.5575905Z Generating XML reports...
2021-08-14T00:59:31.5588280Z Generated XML report: test-reports/python-unittest/test_torch/TEST-TestBasicVitalSigns-20210814005854.xml
2021-08-14T00:59:31.5979257Z Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorch-20210814005854.xml
2021-08-14T00:59:31.6950860Z Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorchDeviceTypeCPU-20210814005854.xml

See GitHub Actions build linux-xenial-py3.6-gcc7-bazel-test / build-and-test (3/12)

Step: "Build PyTorch" (full log | diagnosis details | 🔁 rerun)

2021-08-14T00:29:02.4429108Z ModuleNotFoundError: No module named 'tools.autograd'
2021-08-14T00:29:02.4335618Z �[31m�[1mERROR: �[0m/var/lib/jenkins/workspace/BUILD.bazel:121:1: Executing genrule //:generated_cpp failed (Exit 1) bash failed: error executing command /bin/bash -c ... (remaining 1 argument(s) skipped)
2021-08-14T00:29:02.4336972Z 
2021-08-14T00:29:02.4338191Z Use --sandbox_debug to see verbose messages from the sandbox
2021-08-14T00:29:02.4407772Z Traceback (most recent call last):
2021-08-14T00:29:02.4410946Z   File "/var/lib/jenkins/.cache/bazel/_bazel_jenkins/fdf6d09bf4b4f04a71e2a7dfceb40620/sandbox/processwrapper-sandbox/33/execroot/pytorch/bazel-out/host/bin/gen.runfiles/pytorch/tools/setup_helpers/gen.py", line 9, in <module>
2021-08-14T00:29:02.4413124Z     import tools.codegen.gen
2021-08-14T00:29:02.4415700Z   File "/var/lib/jenkins/.cache/bazel/_bazel_jenkins/fdf6d09bf4b4f04a71e2a7dfceb40620/sandbox/processwrapper-sandbox/33/execroot/pytorch/bazel-out/host/bin/gen.runfiles/pytorch/tools/codegen/gen.py", line 36, in <module>
2021-08-14T00:29:02.4418120Z     from tools.codegen.gen_functionalization_type import Functionalize
2021-08-14T00:29:02.4425262Z   File "/var/lib/jenkins/.cache/bazel/_bazel_jenkins/fdf6d09bf4b4f04a71e2a7dfceb40620/sandbox/processwrapper-sandbox/33/execroot/pytorch/bazel-out/host/bin/gen.runfiles/pytorch/tools/codegen/gen_functionalization_type.py", line 14, in <module>
2021-08-14T00:29:02.4428208Z     from tools.autograd.gen_inplace_or_view_type import (
2021-08-14T00:29:02.4429108Z ModuleNotFoundError: No module named 'tools.autograd'
2021-08-14T00:29:02.4429694Z ----------------
2021-08-14T00:29:02.4431854Z Note: The failure of target //:gen (with exit code 1) may have been caused by the fact that it is running under Python 3 instead of Python 2. Examine the error to determine if that appears to be the problem. Since this target is built in the host configuration, the only way to change its version is to set --host_force_python=PY2, which affects the entire build.
2021-08-14T00:29:02.4434440Z 
2021-08-14T00:29:02.4436384Z If this error started occurring in Bazel 0.27 and later, it may be because the Python toolchain now enforces that targets analyzed as PY2 and PY3 run under a Python 2 and Python 3 interpreter, respectively. See https://github.com/bazelbuild/bazel/issues/7899 for more information.
2021-08-14T00:29:02.4438859Z ----------------
2021-08-14T00:29:02.5208207Z Target //:torch failed to build
2021-08-14T00:29:02.5213780Z Use --verbose_failures to see the command lines of failed build steps.
2021-08-14T00:29:02.5221265Z �[31m�[1mERROR: �[0m/var/lib/jenkins/workspace/BUILD.bazel:564:1 Executing genrule //:generated_cpp failed (Exit 1) bash failed: error executing command /bin/bash -c ... (remaining 1 argument(s) skipped)
2021-08-14T00:29:02.5222593Z 
2021-08-14T00:29:02.5224300Z Use --sandbox_debug to see verbose messages from the sandbox

See GitHub Actions build win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge) (4/12)

Step: "Run test scripts" (full log | diagnosis details | 🔁 rerun)

2021-08-14T02:05:52.4312983Z AssertionError: Fa...ot true : Tensor.replace_ is missing documentation
2021-08-14T02:05:52.4306189Z   test_cuda_vitals_gpu_only_cpu (__main__.TestVitalSignsCudaCPU) ... skip (0.002s)
2021-08-14T02:05:52.4307003Z 
2021-08-14T02:05:52.4307276Z ======================================================================
2021-08-14T02:05:52.4307628Z FAIL [0.031s]: test_doc (__main__.TestTorch)
2021-08-14T02:05:52.4308048Z ----------------------------------------------------------------------
2021-08-14T02:05:52.4308472Z Traceback (most recent call last):
2021-08-14T02:05:52.4310878Z   File "test_torch.py", line 200, in test_doc
2021-08-14T02:05:52.4311300Z     test_namespace(torch.randn(1),
2021-08-14T02:05:52.4311810Z   File "test_torch.py", line 195, in test_namespace
2021-08-14T02:05:52.4312336Z     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
2021-08-14T02:05:52.4312983Z AssertionError: False is not true : Tensor.replace_ is missing documentation
2021-08-14T02:05:52.4313359Z 
2021-08-14T02:05:52.9126141Z ----------------------------------------------------------------------
2021-08-14T02:05:52.9126622Z Ran 661 tests in 15.282s
2021-08-14T02:05:52.9126804Z 
2021-08-14T02:05:52.9127101Z FAILED (failures=1, skipped=44)
2021-08-14T02:05:52.9127326Z 
2021-08-14T02:05:52.9127606Z Generating XML reports...
2021-08-14T02:05:52.9128424Z Generated XML report: test-reports\python-unittest\test_torch\TEST-TestBasicVitalSigns-20210814020537.xml
2021-08-14T02:05:52.9129511Z Generated XML report: test-reports\python-unittest\test_torch\TEST-TestTorch-20210814020537.xml
2021-08-14T02:05:52.9130657Z Generated XML report: test-reports\python-unittest\test_torch\TEST-TestTorchDeviceTypeCPU-20210814020537.xml

See CircleCI build pytorch_macos_10_13_py3_test (5/12)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

Aug 14 00:54:29 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 00:54:29   test_cuda_vitals_gpu_only_cpu (__main__.TestVitalSignsCudaCPU) ... skip (0.001s)
Aug 14 00:54:29 
Aug 14 00:54:29 ======================================================================
Aug 14 00:54:29 FAIL [0.010s]: test_doc (__main__.TestTorch)
Aug 14 00:54:29 ----------------------------------------------------------------------
Aug 14 00:54:29 Traceback (most recent call last):
Aug 14 00:54:29   File "test_torch.py", line 217, in test_doc
Aug 14 00:54:29     'unsafe_split_with_sizes',
Aug 14 00:54:29   File "test_torch.py", line 195, in test_namespace
Aug 14 00:54:29     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
Aug 14 00:54:29 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 00:54:29 
Aug 14 00:54:29 ----------------------------------------------------------------------
Aug 14 00:54:29 Ran 661 tests in 28.202s
Aug 14 00:54:29 
Aug 14 00:54:29 FAILED (failures=1, skipped=44)
Aug 14 00:54:29 
Aug 14 00:54:29 Generating XML reports...
Aug 14 00:54:29 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestBasicVitalSigns-20210814005401.xml
Aug 14 00:54:29 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorch-20210814005401.xml
Aug 14 00:54:29 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorchDeviceTypeCPU-20210814005401.xml

See CircleCI build pytorch_linux_xenial_py3_clang7_asan_test1 (6/12)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 14 01:21:05 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 01:21:05   test_cuda_vitals_gpu_only_cpu (__main__.TestVitalSignsCudaCPU) ... skip (0.001s)
Aug 14 01:21:05 
Aug 14 01:21:05 ======================================================================
Aug 14 01:21:05 FAIL [0.020s]: test_doc (__main__.TestTorch)
Aug 14 01:21:05 ----------------------------------------------------------------------
Aug 14 01:21:05 Traceback (most recent call last):
Aug 14 01:21:05   File "test_torch.py", line 217, in test_doc
Aug 14 01:21:05     'unsafe_split_with_sizes',
Aug 14 01:21:05   File "test_torch.py", line 195, in test_namespace
Aug 14 01:21:05     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
Aug 14 01:21:05 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 01:21:05 
Aug 14 01:21:05 ----------------------------------------------------------------------
Aug 14 01:21:05 Ran 661 tests in 63.272s
Aug 14 01:21:05 
Aug 14 01:21:05 FAILED (failures=1, skipped=44)
Aug 14 01:21:05 
Aug 14 01:21:05 Generating XML reports...
Aug 14 01:21:05 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestBasicVitalSigns-20210814012001.xml
Aug 14 01:21:05 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorch-20210814012001.xml
Aug 14 01:21:05 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorchDeviceTypeCPU-20210814012001.xml

See CircleCI build pytorch_linux_xenial_py3_6_gcc5_4_test (7/12)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 14 02:13:19 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 02:13:19   test_cuda_vitals_gpu_only_cpu (__main__.TestVitalSignsCudaCPU) ... skip (0.001s)
Aug 14 02:13:19 
Aug 14 02:13:19 ======================================================================
Aug 14 02:13:19 FAIL [0.009s]: test_doc (__main__.TestTorch)
Aug 14 02:13:19 ----------------------------------------------------------------------
Aug 14 02:13:19 Traceback (most recent call last):
Aug 14 02:13:19   File "test_torch.py", line 217, in test_doc
Aug 14 02:13:19     'unsafe_split_with_sizes',
Aug 14 02:13:19   File "test_torch.py", line 195, in test_namespace
Aug 14 02:13:19     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
Aug 14 02:13:19 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 02:13:19 
Aug 14 02:13:19 ----------------------------------------------------------------------
Aug 14 02:13:19 Ran 661 tests in 20.404s
Aug 14 02:13:19 
Aug 14 02:13:19 FAILED (failures=1, skipped=44)
Aug 14 02:13:19 
Aug 14 02:13:19 Generating XML reports...
Aug 14 02:13:19 Generated XML report: test-reports/dist-gloo/test_torch/TEST-TestBasicVitalSigns-20210814021258.xml
Aug 14 02:13:19 Generated XML report: test-reports/dist-gloo/test_torch/TEST-TestTorch-20210814021258.xml
Aug 14 02:13:19 Generated XML report: test-reports/dist-gloo/test_torch/TEST-TestTorchDeviceTypeCPU-20210814021258.xml

See CircleCI build pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_test2 (8/12)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 14 02:46:20 AssertionError: The following f...eturned by torch._overrides.get_ignored_functions.
Aug 14 02:46:19 
Aug 14 02:46:19 Generating XML reports...
Aug 14 02:46:19 Generated XML report: test-reports/dist-gloo/test_complex/TEST-TestComplexTensorCUDA-20210814024617.xml
Aug 14 02:46:19 Running test_overrides ... [2021-08-14 02:46:19.584564]
Aug 14 02:46:19 Executing ['/opt/conda/bin/python', 'test_overrides.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2021-08-14 02:46:19.584679]
Aug 14 02:46:20 Traceback (most recent call last):
Aug 14 02:46:20   File "/var/lib/jenkins/workspace/test/test_overrides.py", line 348, in <module>
Aug 14 02:46:20     generate_tensor_like_torch_implementations()
Aug 14 02:46:20   File "/var/lib/jenkins/workspace/test/test_overrides.py", line 336, in generate_tensor_like_torch_implementations
Aug 14 02:46:20     assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
Aug 14 02:46:20 AssertionError: The following functions are not tested for __torch_function__ support, please ensure there is an entry in the dict returned by torch._overrides.get_testing_overrides for this function or if a __torch_function__ override does not make sense, add an entry to the tuple returned by torch._overrides.get_ignored_functions.
Aug 14 02:46:20 
Aug 14 02:46:20 ["<class 'torch.Tensor'>.replace_", "<class 'torch.Tensor'>.sync_"]
Aug 14 02:46:20 Traceback (most recent call last):
Aug 14 02:46:20   File "/var/lib/jenkins/workspace/test/run_test.py", line 1091, in <module>
Aug 14 02:46:20     main()
Aug 14 02:46:20   File "/var/lib/jenkins/workspace/test/run_test.py", line 1070, in main
Aug 14 02:46:20     raise RuntimeError(err_message)
Aug 14 02:46:20 RuntimeError: test_overrides failed!
Aug 14 02:46:21 
Aug 14 02:46:21 real	86m59.235s

See CircleCI build pytorch_linux_bionic_cuda10_2_cudnn7_py3_9_gcc7_test1 (9/12)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 14 01:22:45 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 01:22:45 [TORCH_VITAL] CUDA.used		 true
Aug 14 01:22:45 
Aug 14 01:22:45 ======================================================================
Aug 14 01:22:45 FAIL [0.006s]: test_doc (__main__.TestTorch)
Aug 14 01:22:45 ----------------------------------------------------------------------
Aug 14 01:22:45 Traceback (most recent call last):
Aug 14 01:22:45   File "/var/lib/jenkins/workspace/test/test_torch.py", line 200, in test_doc
Aug 14 01:22:45     test_namespace(torch.randn(1),
Aug 14 01:22:45   File "/var/lib/jenkins/workspace/test/test_torch.py", line 195, in test_namespace
Aug 14 01:22:45     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
Aug 14 01:22:45 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 01:22:45 
Aug 14 01:22:45 ----------------------------------------------------------------------
Aug 14 01:22:45 Ran 693 tests in 89.210s
Aug 14 01:22:45 
Aug 14 01:22:45 FAILED (failures=1, skipped=31)
Aug 14 01:22:45 
Aug 14 01:22:45 Generating XML reports...
Aug 14 01:22:45 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestBasicVitalSigns-20210814012116.xml
Aug 14 01:22:45 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestDevicePrecisionCUDA-20210814012116.xml
Aug 14 01:22:45 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorch-20210814012116.xml

See CircleCI build pytorch_linux_bionic_py3_6_clang9_noarch_test (10/12)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 14 01:16:59 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 01:16:59   test_cuda_vitals_gpu_only_meta (__main__.TestVitalSignsCudaMETA) ... skip (0.001s)
Aug 14 01:16:59 
Aug 14 01:16:59 ======================================================================
Aug 14 01:16:59 FAIL [0.007s]: test_doc (__main__.TestTorch)
Aug 14 01:16:59 ----------------------------------------------------------------------
Aug 14 01:16:59 Traceback (most recent call last):
Aug 14 01:16:59   File "test_torch.py", line 217, in test_doc
Aug 14 01:16:59     'unsafe_split_with_sizes',
Aug 14 01:16:59   File "test_torch.py", line 195, in test_namespace
Aug 14 01:16:59     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
Aug 14 01:16:59 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 01:16:59 
Aug 14 01:16:59 ----------------------------------------------------------------------
Aug 14 01:16:59 Ran 1121 tests in 21.889s
Aug 14 01:16:59 
Aug 14 01:16:59 FAILED (failures=1, skipped=429)
Aug 14 01:16:59 
Aug 14 01:16:59 Generating XML reports...
Aug 14 01:16:59 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestBasicVitalSigns-20210814011637.xml
Aug 14 01:16:59 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorch-20210814011637.xml
Aug 14 01:16:59 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorchDeviceTypeCPU-20210814011637.xml

See CircleCI build pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test1 (11/12)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 14 02:05:22 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 02:05:22 [TORCH_VITAL] CUDA.used		 true
Aug 14 02:05:22 
Aug 14 02:05:22 ======================================================================
Aug 14 02:05:22 FAIL [0.007s]: test_doc (__main__.TestTorch)
Aug 14 02:05:22 ----------------------------------------------------------------------
Aug 14 02:05:22 Traceback (most recent call last):
Aug 14 02:05:22   File "test_torch.py", line 217, in test_doc
Aug 14 02:05:22     'unsafe_split_with_sizes',
Aug 14 02:05:22   File "test_torch.py", line 195, in test_namespace
Aug 14 02:05:22     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
Aug 14 02:05:22 AssertionError: False is not true : Tensor.replace_ is missing documentation
Aug 14 02:05:22 
Aug 14 02:05:22 ----------------------------------------------------------------------
Aug 14 02:05:22 Ran 693 tests in 95.651s
Aug 14 02:05:22 
Aug 14 02:05:22 FAILED (failures=1, skipped=31)
Aug 14 02:05:22 
Aug 14 02:05:22 Generating XML reports...
Aug 14 02:05:22 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestBasicVitalSigns-20210814020346.xml
Aug 14 02:05:22 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestDevicePrecisionCUDA-20210814020346.xml
Aug 14 02:05:22 Generated XML report: test-reports/python-unittest/test_torch/TEST-TestTorch-20210814020346.xml

See CircleCI build pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test2 (12/12)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Aug 14 03:34:25 AssertionError: The following f...eturned by torch._overrides.get_ignored_functions.
Aug 14 03:34:24 
Aug 14 03:34:24 Generating XML reports...
Aug 14 03:34:24 Generated XML report: test-reports/dist-gloo/test_namedtuple_return_api/TEST-TestNamedTupleAPI-20210814033422.xml
Aug 14 03:34:24 Running test_overrides ... [2021-08-14 03:34:24.292683]
Aug 14 03:34:24 Executing ['/opt/conda/bin/python', 'test_overrides.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2021-08-14 03:34:24.292765]
Aug 14 03:34:25 Traceback (most recent call last):
Aug 14 03:34:25   File "test_overrides.py", line 348, in <module>
Aug 14 03:34:25     generate_tensor_like_torch_implementations()
Aug 14 03:34:25   File "test_overrides.py", line 336, in generate_tensor_like_torch_implementations
Aug 14 03:34:25     assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
Aug 14 03:34:25 AssertionError: The following functions are not tested for __torch_function__ support, please ensure there is an entry in the dict returned by torch._overrides.get_testing_overrides for this function or if a __torch_function__ override does not make sense, add an entry to the tuple returned by torch._overrides.get_ignored_functions.
Aug 14 03:34:25 
Aug 14 03:34:25 ["<class 'torch.Tensor'>.replace_", "<class 'torch.Tensor'>.sync_"]
Aug 14 03:34:25 Traceback (most recent call last):
Aug 14 03:34:25   File "test/run_test.py", line 1091, in <module>
Aug 14 03:34:25     main()
Aug 14 03:34:25   File "test/run_test.py", line 1070, in main
Aug 14 03:34:25     raise RuntimeError(err_message)
Aug 14 03:34:25 RuntimeError: test_overrides failed!
Aug 14 03:34:26 + cleanup
Aug 14 03:34:26 + retcode=1

3 failures not recognized by patterns:

Job Step Action
GitHub Actions Lint / flake8-py3 Fail if there were any warnings 🔁 rerun
GitHub Actions Lint / quick-checks Ensure correct trailing newlines 🔁 rerun
GitHub Actions Lint / clang-format Run clang-format 🔁 rerun

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

bdhirsh added a commit that referenced this pull request Aug 10, 2021
ghstack-source-id: e2a2349
Pull Request resolved: #63048
First cut. Description coming soon.




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Aug 11, 2021
ghstack-source-id: 51c37b6
Pull Request resolved: #63048
@ezyang
Copy link
Contributor

ezyang commented Aug 11, 2021

waiting on desc before review


def get_view_info(fn: NativeFunctionWithDifferentiabilityInfo) -> Optional[str]:
f = fn.func
def get_view_info(f: NativeFunction) -> Optional[str]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably should have factored this into a separate PR, but this was just a small change to allow me to re-use get_view_info in the functionalization codegen. The extra DifferentiabilityInfo isn't used anywhere.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never to late to stak

First cut. Description coming soon.




[ghstack-poisoned]
This PR adds the barebones of a generic functionalization pass through the dispatcher, that works for both functorch and XLA. It's based off of @ailzhang's prototype (which lives [here](ailzhang@83f647e), and is nicely documented [here](https://gist.github.com/ailzhang/75af24db042ec5e101a6fa4fef1122c3#implementation-details))).

I won't go into the details of the aliasing removal, which you can see in the original prototype. Instead I'll focus more on the API to integrate with it (as a composable transform, and as a backend like XLA), the mutation removal logic, and the followups.

I have corresponding test branches in [functorch](pytorch/functorch@c6c5e49#diff-da7b94c8791e30d89b2b2a6641bcbc8ad9282554cf4fac3784666f3aab9ea09bR300) and [pt/xla](pytorch/xla@db13eae#diff-1edb9f3bb5aa676852199f34cbb2fb731175df85e4bf59c0976df389d01386e3R14).

**Working examples**
I used two basic examples to help figure out how to make changes, using the corresponding changes in functorch/xla that I linked above. I'm planning on testing some more complicated functorch examples soon:

```
import torch
from functorch import make_fx, grad, vmap, functionalize

def f(x):  # tests that inputs are still successfully mutated
    tmp = torch.ones(4)
    y = x.view(4)
    y.add_(tmp)
    return x

def f(x):  # test the free variable mutation case, which currently breaks in functorch
      tmp = torch.ones(4)
      tmp.add_(x)
      return tmp

  batched_input = torch.ones(2, 4)
  vmap(functionalize((f)))(batched_input)
  vmap(functionalize((f2)))(batched_input)
```

```
import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()
a = torch.ones(2, 2, device=device)
b = a.view(4)
a.add_(1) # successfully mutates b
```

**API for functorch / backends**

The main difference between the two is that functorch wants a composable transform (calling it `functionalize()` for now), which involves some notion of levels and wrapping. Backends like XLA only want the direct mutation / view removal.

I've factored those pieces into a `FunctionalTensorImplBase` (inherits from `TensorImpl`), and `FunctionalTensorImpl` (inherits from the base). The base class deals with mutation and view removal, while the subclass knows about levels and wrapping tensors. backend tensors like XLA will also inherit from `FunctionalTensorImplBase` to get opted into the functionalization pass.

`functionalize()` is implemented similar to `vmap`/`grad` in the functorch repo, and just involves wrapping inputs/outputs in a `FunctionalTensorImpl` of the appropriate level (and making `DynamicLayerFrontMode` aware of the `Functionalization` key). The (unpolished) functorch-side changes currently live in a branch [here](pytorch/functorch@c6c5e49#diff-da7b94c8791e30d89b2b2a6641bcbc8ad9282554cf4fac3784666f3aab9ea09bR300). `FunctionalTensorImpl` also overrides all virtual methods to call into its wrapped value - I'm not sure if this is overkill, but the goal is that aside from the functionalization pass, everything else in pytorch shouldn't see the wrapper. We can do that for dispatcher-aware operators by explicitly unwrapping, but calling e.g. `t.sizes()` or `t.numel()` should also unwrap the tensor first.

I moved `FunctionalTensorImpl` into the functorch branch, so core is unaware of it. The invariant is that when the functionalization kernels redispatch on an operator, the returned tensor is always a `FunctionalTensorImplBase` subclass. For backends like XLA, this happens automatically (since we get back an XLATensor, which inherits from `FunctionalTensorImplBase`). For functorch, the `DynamicLayerBackFallback` kernel is responsible for unwrapping inputs and wrapping up the outputs in a `FunctionalTensorImpl`. This is necessary in order to properly record the current composable transform level in the wrapper (although if the dynamic leveling dispatcher logic eventually moves in core, then this can move in core too).

Integration with backends is currently implemented by updating your TensorImpl subclass to inherit from `FunctionalTensorImplBase` instead of `TensorImpl`, and adding an overridden `replace_()` method, which tells us how to "re-use" your tensor. I have a branch with XLA-side changes [here](pytorch/xla@db13eae#diff-1edb9f3bb5aa676852199f34cbb2fb731175df85e4bf59c0976df389d01386e3R14) (it's very small, and still a POC - it also doesn't remove any xla-side alias handling yet, since we need to add support for more views in core first).

**Mutation removal logic**

There's a new `at::replace_` op in native_functions.yaml. The only reason it needs to be in the yaml is because I think it needs to show up in traces from functorch (functorch would like "functionalized" traces. We can't do that directly, but it should be pretty easy to convert a graph with a bunch of `replace_` calls in it to its functionalized version).

The idea behind `replace_` is that it knows how to swap out the data backing a tensor with fresh data. For cpu/cuda (`TensorImpl`), that corresponds to replacing a storage pointer (and associated metadata). For xla (`XLATensorImpl`), that corresponds to swapping out IR.

However, the idea behind the `FunctionalTensorImpl` wrapper is that we can't just swap out metadata on a TensorImpl for functorch. Functorch has to deal with mixed inputs: e.g. `cpu_tensor.add_(batched_tensor)` should "promote" the input to be a batched tensor. That corresponds to swapping out `cpu_tensor`'s `TensorImpl` object with the `BatchedTensorImpl` result of summing the two together.

`FunctionalTensorImpl` does this by storing a full `Tensor`, and fully swapping it out for a new tensor when the corresponding new and old `TensorImpl*` types don't match up.

Right now this is implemented with two different methods:
- `replace(TensorImpl* other_impl)`. This is a virtual method on `TensorImpl`, that knows how to do the metadata swap as long as the TensorImpl argument is of the same type (also implemented for `BatchedTensorImpl` and `TensorWrapper` in functorch)
- `replace(Tensor& other)`. This is a virtual method on `FunctionalTensorImplBase` that knows how to swap arbitrary tensors. the `FunctionalTensorImpl` implementation does the work of swapping the tensors completely if their types or different, or otherwise calling `replace_(other_impl)` to do an inplace swap if the types are the same.

I made the dispatcher-aware `replace(Tensor)` op a composite kernel, which required registering fall throughs to each functorch pass in order to avoid going to their boxed kernels. Another option would have been to not do that and instead have a single `replace_()` operator, with different kernels registered for functionalize/vmap/grad/cpu/xla. I mostly opted for the current approach because it was easier for me to picture what replacing a `FunctionalTensorImpl` means as repeatedly unwrapping and replacing the internals, vs. doing a bunch of redispatch trips through the dispatcher (and it would also be slower).

**[functorch] should the Alias object hold the wrapper or the unwrapped tensor?**
When you create a view, we fork off the original tensor into an `Alias` object, that both the original tensor and the view tensor are now aware of. Right now, I have that alias object hold a clone of the original tensor (which in the functorch case is a `FunctionalTensorImpl` wrapper). I did that because it made the view-handling logic more consistent (we don't need to special-case functorch), but there's a problem where we need to perform a bunch of view operations on the alias tensor when we sync. If the alias is a `FunctionalTensorImpl` wrapper, we'll call back into the functionalization pass machinery and infinite loop (creating an alias for the alias). One way to get around this would have been to create separate `{view}_copy` operators for every view op, and register fallthroughs for them in the functionalization pass. Instead, I added a bit on the `FunctionalTensorImplBase` class to tell us if it's actually stored inside of an alias, so the functionalization pass will know to skip the view machinery if the bit is set.

I tried to call out other major details in specific comments in the code.

**Codegen output**

Inplace ops are all codegen'd (at least, the ones that have a functional version). View ops need to be added one-by-one, so I've only added `at::view` for now. The codegen output looks like this for `view`:
```
at::Tensor view(c10::DispatchKeySet ks, const at::Tensor & self, at::IntArrayRef size) {
        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.key_set().has(c10::DispatchKey::Functionalize));
        at::Tensor out;
        {
          at::AutoDispatchBelowFunctionalize guard;
          auto tmp_output = at::redispatch::view(ks & c10::after_func_keyset, self, size);
          out = tmp_output.clone();
        }
        // See Note [Marking Alias Tensors]
        if (!at::functionalization::impl::is_alias_tensor(out)) {
          // TODO we'll probably want a separate function for each view op that gets creates the corresponding ViewMeta.
          at::ViewMeta view_meta = at::functionalization::impl::get_meta_view(self, size);
          at::functionalization::impl::set_view_meta(out, self, view_meta);
        }
        return out;
}
```

And for an example inplace op, the codegen looks like this:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
        {
            at::AutoDispatchBelowFunctionalize guard;
            auto tmp_output = at::redispatch::add(ks & c10::after_func_keyset, self, other, alpha);
            self.replace_(tmp_output);
            TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.key_set().has(c10::DispatchKey::Functionalize));
            at::functionalization::impl::maybe_add_update(self);
        }
        return self;
}
```

**Printing tensors**
In the XLA case, printing tensors should just work - the call to `_tensor_str()` eventually just calls `to(‘cpu’)`, which hits the functionalization fallback and syncs everything.

In the functorch case, functorch overrides the `_tensor_str` function to first recursively unwrap the tensor wrappers before calling the original version of `_tensor_str()`, so I added logic in functorch to be aware of unwrapping FunctionalTensorWrapper objects. I also have the `functionalize()` pass call `sync_()` on every tensor input after the pass completes, to ensure that inputs get mutated correctly.

The handling for when to wrap up outputs is a little fragile though, because of factory functions + printing. For factory functions, we need to make sure that we still wrap output tensors even if there are no input tensor arguments. For the printing, we need to make sure not to wrap arguments. This case is distinguished by the fact that non of the tensor input arguments are wrapped. That means that stuff can break if any factory functions are called inside of `_tensor_str()`, but it looks like that currently isn't the case.

**Other stuff**
- The functionalization pass has a boxed fallback (in `VariableFallback.cpp`), but it shouldn’t be too hard to move that to codegen, since all it does is sync all of the input tensors.
- For any unsupported view/mutation ops (either view ops that aren’t implemented yet, or mutation ops that don’t currently have an out-of-place equivalent), I codegen a kernel that’s pretty similar to the boxed fallback - it just syncs the inputs and redispatches. It also prints a warning, to help figure out which mutation ops we need to add out-of-place versions for.
- I defined a bunch of helpers functions in the `at::functionalization::impl` namespace, mostly as utility functions to make the codegen easier
- Right now I have an enum for each view op. The Alias contains a stack of updates, and each update contains a stack of `ViewMetas`, explaining what view ops were run on the base tensor to get to the view before the mutation occurred. The logic to sync mutations across aliases involves replaying the views in reverse, to figure out what the base tensor looks like after every mutation. The enums are a little bit ugly, but I’m not sure of a more significantly elegant way to represent them. I also listed out the full set of enum values, but that might be too presumptive.
- Each view op needs to store some extra info in order to replay it in reverse - this will probably need to be implemented separately per view - right now I have an `at::functionalization::impl::get_meta_{view}(…)` function that knows which information to store (called by the codegen), and I’m planning on trying to implement similar functions for the other view ops.


**Followups**

- I haven't carefully tested a bunch of use cases with functorch yet (like nested called to functionalize)
- Add support for more view ops (probably not all of them... but the important ones. The handful that xla implements are probably a good starting point).
- In particular, I have a feeling the codegen will change a little as more view ops are added. For example, `torch.split()` is a view op that returns multiple output tensors, which will all alias the same input tensor. Need to make sure that the codegen handles that gracefully. There are also a bunch of view ops that are both views and mutations, like `transpose_` and `as_strided_`. That's probably gonna require extra codegen.
- audit the pass for perf (unnecessary tensor clones and refcount bumps)





[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Aug 13, 2021
ghstack-source-id: 8733c62
Pull Request resolved: #63048
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Aug 13, 2021

I'm starting to test some more cases and clean up CI failures, but I think the PR is in a state where it's ready for review. cc @ezyang

This PR adds the barebones of a generic functionalization pass through the dispatcher, that works for both functorch and XLA. It's based off of @ailzhang's prototype (which lives [here](ailzhang@83f647e), and is nicely documented [here](https://gist.github.com/ailzhang/75af24db042ec5e101a6fa4fef1122c3#implementation-details))).

I won't go into the details of the aliasing removal, which you can see in the original prototype. Instead I'll focus more on the API to integrate with it (as a composable transform, and as a backend like XLA), the mutation removal logic, and the followups.

I have corresponding test branches in [functorch](pytorch/functorch@c6c5e49#diff-da7b94c8791e30d89b2b2a6641bcbc8ad9282554cf4fac3784666f3aab9ea09bR300) and [pt/xla](pytorch/xla@db13eae#diff-1edb9f3bb5aa676852199f34cbb2fb731175df85e4bf59c0976df389d01386e3R14).

**Working examples**
I used two basic examples to help figure out how to make changes, using the corresponding changes in functorch/xla that I linked above. I'm planning on testing some more complicated functorch examples soon:

```
import torch
from functorch import make_fx, grad, vmap, functionalize

def f(x):  # tests that inputs are still successfully mutated
    tmp = torch.ones(4)
    y = x.view(4)
    y.add_(tmp)
    return x

def f(x):  # test the free variable mutation case, which currently breaks in functorch
    tmp = torch.ones(4)
    tmp.add_(x)
    return tmp

batched_input = torch.ones(2, 4)
vmap(functionalize((f)))(batched_input)
vmap(functionalize((f2)))(batched_input)
```

```
import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()
a = torch.ones(2, 2, device=device)
b = a.view(4)
a.add_(1) # successfully mutates b
```

**API for functorch / backends**

The main difference between the two is that functorch wants a composable transform (calling it `functionalize()` for now), which involves some notion of levels and wrapping. Backends like XLA only want the direct mutation / view removal.

I've factored those pieces into a `FunctionalTensorImplBase` (inherits from `TensorImpl`), and `FunctionalTensorWrapper` (inherits from the base). The base class deals with mutation and view removal, while the subclass knows about levels and wrapping tensors. backend tensors like XLA will also inherit from `FunctionalTensorImplBase` to get opted into the functionalization pass.

`functionalize()` is implemented similar to `vmap`/`grad` in the functorch repo, and just involves wrapping inputs/outputs in a `FunctionalTensorWrapper ` of the appropriate level (and making `DynamicLayerFrontMode` aware of the `Functionalization` key). The (unpolished) functorch-side changes currently live in a branch [here](pytorch/functorch@c6c5e49#diff-da7b94c8791e30d89b2b2a6641bcbc8ad9282554cf4fac3784666f3aab9ea09bR300). `FunctionalTensorWrapper ` also overrides all virtual methods to call into its wrapped value - I'm not sure if this is overkill, but the goal is that aside from the functionalization pass, everything else in pytorch shouldn't see the wrapper. We can do that for dispatcher-aware operators by explicitly unwrapping, but calling e.g. `t.sizes()` or `t.numel()` should also unwrap the tensor first.

I moved `FunctionalTensorWrapper ` into the functorch branch, so core is unaware of it. The invariant is that when the functionalization kernels redispatch on an operator, the returned tensor is always a `FunctionalTensorImplBase` subclass. For backends like XLA, this happens automatically (since we get back an XLATensor, which inherits from `FunctionalTensorImplBase`). For functorch, the `DynamicLayerBackFallback` kernel is responsible for unwrapping inputs and wrapping up the outputs in a `FunctionalTensorWrapper `. This is necessary in order to properly record the current composable transform level in the wrapper (although if the dynamic leveling dispatcher logic eventually moves in core, then this can move in core too).

Integration with backends is currently implemented by updating your TensorImpl subclass to inherit from `FunctionalTensorImplBase` instead of `TensorImpl`, and adding an overridden `replace_()` method, which tells us how to "re-use" your tensor. I have a branch with XLA-side changes [here](pytorch/xla@db13eae#diff-1edb9f3bb5aa676852199f34cbb2fb731175df85e4bf59c0976df389d01386e3R14) (it's very small, and still a POC - it also doesn't remove any xla-side alias handling yet, since we need to add support for more views in core first).

**Mutation removal logic**

There's a new `at::replace_` op in native_functions.yaml. The only reason it needs to be in the yaml is because I think it needs to show up in traces from functorch (functorch would like "functionalized" traces. We can't do that directly, but it should be pretty easy to convert a graph with a bunch of `replace_` calls in it to its functionalized version).

The idea behind `replace_` is that it knows how to swap out the data backing a tensor with fresh data. For cpu/cuda (`TensorImpl`), that corresponds to replacing a storage pointer (and associated metadata). For xla (`XLATensorImpl`), that corresponds to swapping out IR.

However, the idea behind the `FunctionalTensorWrapper ` wrapper is that we can't just swap out metadata on a TensorImpl for functorch. Functorch has to deal with mixed inputs: e.g. `cpu_tensor.add_(batched_tensor)` should "promote" the input to be a batched tensor. That corresponds to swapping out `cpu_tensor`'s `TensorImpl` object with the `BatchedTensorImpl` result of summing the two together.

`FunctionalTensorWrapper ` does this by storing a full `Tensor`, and fully swapping it out for a new tensor when the corresponding new and old `TensorImpl*` types don't match up.

Right now this is implemented with two different methods:
- `replace(TensorImpl* other_impl)`. This is a virtual method on `TensorImpl`, that knows how to do the metadata swap as long as the TensorImpl argument is of the same type (also implemented for `BatchedTensorImpl` and `TensorWrapper` in functorch)
- `replace(Tensor& other)`. This is a virtual method on `FunctionalTensorImplBase` that knows how to swap arbitrary tensors. the `FunctionalTensorWrapper ` implementation does the work of swapping the tensors completely if their types or different, or otherwise calling `replace_(other_impl)` to do an inplace swap if the types are the same.

I made the dispatcher-aware `replace(Tensor)` op a composite kernel, which required registering fall throughs to each functorch pass in order to avoid going to their boxed kernels. Another option would have been to not do that and instead have a single `replace_()` operator, with different kernels registered for functionalize/vmap/grad/cpu/xla. I mostly opted for the current approach because it was easier for me to picture what replacing a `FunctionalTensorWrapper ` means as repeatedly unwrapping and replacing the internals, vs. doing a bunch of redispatch trips through the dispatcher (and it would also be slower).

**[functorch] should the Alias object hold the wrapper or the unwrapped tensor?**
When you create a view, we fork off the original tensor into an `Alias` object, that both the original tensor and the view tensor are now aware of. Right now, I have that alias object hold a clone of the original tensor (which in the functorch case is a `FunctionalTensorWrapper ` wrapper). I did that because it made the view-handling logic more consistent (we don't need to special-case functorch), but there's a problem where we need to perform a bunch of view operations on the alias tensor when we sync. If the alias is a `FunctionalTensorWrapper ` wrapper, we'll call back into the functionalization pass machinery and infinite loop (creating an alias for the alias). That's true for xla too, but we can just add an exclude guard before calling sync_() in the xla case. We can't do that for functorch, since the DynamicLayer logic overwrites TLS at each layer.

One way to get around this would have been to create separate `{view}_copy` operators for every view op, and register fallthroughs for them in the functionalization pass. Instead, I added a bit on the `FunctionalTensorImplBase` class to tell us if it's actually stored inside of an alias, so the functionalization pass will know to skip the view machinery if the bit is set.

I tried to call out other major details in specific comments in the code.

**Codegen output**

Inplace ops are all codegen'd (at least, the ones that have a functional version). View ops need to be added one-by-one, so I've only added `at::view` for now. The codegen output looks like this for `view`:
```
at::Tensor view(c10::DispatchKeySet ks, const at::Tensor & self, at::IntArrayRef size) {
        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.key_set().has(c10::DispatchKey::Functionalize));
        at::Tensor out;
        {
          at::AutoDispatchBelowFunctionalize guard;
          auto tmp_output = at::redispatch::view(ks & c10::after_func_keyset, self, size);
          out = tmp_output.clone();
        }
        // See Note [Marking Alias Tensors]
        if (!at::functionalization::impl::is_alias_tensor(out)) {
          // TODO we'll probably want a separate function for each view op that gets creates the corresponding ViewMeta.
          at::ViewMeta view_meta = at::functionalization::impl::get_meta_view(self, size);
          at::functionalization::impl::set_view_meta(out, self, view_meta);
        }
        return out;
}
```

And for an example inplace op, the codegen looks like this:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
        {
            at::AutoDispatchBelowFunctionalize guard;
            auto tmp_output = at::redispatch::add(ks & c10::after_func_keyset, self, other, alpha);
            self.replace_(tmp_output);
            TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.key_set().has(c10::DispatchKey::Functionalize));
            at::functionalization::impl::maybe_add_update(self);
        }
        return self;
}
```

**Printing tensors**
In the XLA case, printing tensors should just work - the call to `_tensor_str()` eventually just calls `to(‘cpu’)`, which hits the functionalization fallback and syncs everything.

In the functorch case, functorch overrides the `_tensor_str` function to first recursively unwrap the tensor wrappers before calling the original version of `_tensor_str()`, so I added logic in functorch to be aware of unwrapping FunctionalTensorWrapper objects, and sync before unwrapping. I also have the `functionalize()` pass call `sync_()` on every tensor input after the pass completes, to ensure that inputs get mutated correctly.

The handling for when to wrap up outputs is a little fragile though, because of factory functions + printing. For factory functions, we need to make sure that we still wrap output tensors even if there are no input tensor arguments. For the printing, we need to make sure not to wrap arguments. This case is distinguished by the fact that non of the tensor input arguments are wrapped. That means that stuff can break if any factory functions are called inside of `_tensor_str()`, but it looks like that currently isn't the case.

**Other stuff**
- The functionalization pass has a boxed fallback (in `VariableFallback.cpp`), but it shouldn’t be too hard to move that to codegen, since all it does is sync all of the input tensors.
- For any unsupported view/mutation ops (either view ops that aren’t implemented yet, or mutation ops that don’t currently have an out-of-place equivalent), I codegen a kernel that’s pretty similar to the boxed fallback - it just syncs the inputs and redispatches. It also prints a warning, to help figure out which mutation ops we need to add out-of-place versions for.
- I defined a bunch of helpers functions in the `at::functionalization::impl` namespace, mostly as utility functions to make the codegen easier
- Right now I have an enum for each view op. The Alias contains a stack of updates, and each update contains a stack of `ViewMetas`, explaining what view ops were run on the base tensor to get to the view before the mutation occurred. The logic to sync mutations across aliases involves replaying the views in reverse, to figure out what the base tensor looks like after every mutation. The enums are a little bit ugly, but I’m not sure of a more significantly elegant way to represent them. I also listed out the full set of enum values, but that might be too presumptive.
- Each view op needs to store some extra info in order to replay it in reverse - this will probably need to be implemented separately per view - right now I have an `at::functionalization::impl::get_meta_{view}(…)` function that knows which information to store (called by the codegen), and I’m planning on trying to implement similar functions for the other view ops.


**Followups**

- I haven't carefully tested a bunch of use cases with functorch yet (like nested called to functionalize)
- Add support for more view ops (probably not all of them... but the important ones. The handful that xla implements are probably a good starting point).
- In particular, I have a feeling the codegen will change a little as more view ops are added. For example, `torch.split()` is a view op that returns multiple output tensors, which will all alias the same input tensor. Need to make sure that the codegen handles that gracefully. There are also a bunch of view ops that are both views and mutations, like `transpose_` and `as_strided_`. That's probably gonna require extra codegen.
- audit the pass for perf (unnecessary tensor clones and refcount bumps)
- Think about a version of the pass that's just mutation removal, or just alias removal? One option is to add separate keys for AliasRemovalOnly and MutationRemovalOnly, and factor the codegen well enough that it can be re-used. Another would be to split out the current alias and mutation removal bits into two passes, although that'll require an extra dispatcher trip everywhere.





[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Aug 14, 2021
ghstack-source-id: af734a5
Pull Request resolved: #63048
@ezyang
Copy link
Contributor

ezyang commented Aug 16, 2021

FunctionalTensorWrapper also overrides all virtual methods to call into its wrapped value - I'm not sure if this is overkill, but the goal is that aside from the functionalization pass, everything else in pytorch shouldn't see the wrapper.

The current policy is that we're supposed to replicate all metadata in the wrapped class, in which case it wouldn't be necessary to override the virtual methods. @swolchok has worked pretty hard to devirtualize these methods in fbcode and it would be best to make sure things work even if they are not virtual.

@ezyang
Copy link
Contributor

ezyang commented Aug 16, 2021

Integration with backends is currently implemented by updating your TensorImpl subclass to inherit from FunctionalTensorImplBase instead of TensorImpl, and adding an overridden replace_() method, which tells us how to "re-use" your tensor

Not sure why the virtual method is needed here, if there's also a dispatched at::replace_; you can just implement the dispatched operator instead?

@zou3519
Copy link
Contributor

zou3519 commented Aug 19, 2021

If the alias is a FunctionalTensorWrapper wrapper, we'll call back into the functionalization pass machinery and infinite loop (creating an alias for the alias). That's true for xla too, but we can just add an exclude guard before calling sync_() in the xla case. We can't do that for functorch, since the DynamicLayer logic overwrites TLS at each layer.

The DynamicLayer overwriting all TLS is a bug that we haven't fixed. It's not supposed to overwrite all TLS, it's supposed to only overrwrite the TLS that it needs to (e.g. the DispatchPastAD guard, the Batched guard, etc)

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Aug 19, 2021

Some changes I'm going to look into after talking to Ed:

(1) The "preserve input mutations" logic is actually wrong - it doesn't work if the input to the program that we functionalize is actually a view of any existing tensor - the storage-swapping logic that I have there is a little sketchy and breaks down in this case. Instead, at the end of the pass I can just (a) detect when a mutation to an input has occurred, and (b) copy_() the new value onto the input (which will properly affect all pre-existing views).

(2) Not having to special-case the is_alias_of() check to check if we're running the functionalization pass - ideally, impl_->storage().is_alias_of(other.storage()) should just work in all cases.

(3) Properly propagating stride information onto the view tensors, which I don't think pytorch/xla currently does (I'm going to look into pytorch/xla's version a little more closely to understand what the delta is).

TORCH_CHECK(false, "Tried to run the functionalization pass on an unsupported view: ", view_meta.view_type);
}
}
// We want the new tensor to have separate memory from the alias.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Comment on lines +266 to +271
void Alias::SyncUpdateOperations() {
for (auto& update_data: updates_) {
apply_update(update_data);
}
updates_.clear();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the updates have to be queued? Is there a reason why we can't apply the update directly after an in-place operation?

return c10::MaybeOwned<Tensor>::owned(__dispatch_contiguous(memory_format));
}
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we already have a class called ViewInfo that is stored on the AutogradMeta (

struct TORCH_API ViewInfo {
) this might lead to some confusion in naming.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I namespaced this into at::functionalization::ViewMeta in the new PR


struct ViewMeta {
// The names of all existing view operators.
enum class Type {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to do this?
This feels like it will make adding new views even harder?
Is there any reason why the strategy used by the autograd view tracking is not possible here?

  • Use size/stride info + as_strided for strided Tensors and non-cross-dtype views
  • For other things, use a lambda that captures all the necessary arguments for the current function.

The ViewMeta you have here will need to get specialized for every single view function we have to be able to capture its arguments no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new PR uses lambdas instead of a big enum, I think the amount of boilerplate per view op is (relatively) minimal but lmk what you think!

ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate("""\
if (${tensor_name}_storage_saved.has_value())
AT_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage()));
// AT_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new test is not valid because you won't detect if the input was actually modified inplace. Because you read it after it was (potentially) modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted in the new PR. In particular, I added a new FunctionalStorageImpl so calls to a.storage().is_alias_of(b.storage()) is always valid.

@codecov
Copy link

codecov bot commented Aug 25, 2021

Codecov Report

Merging #63048 (87a661c) into gh/bdhirsh/140/base (85a4c72) will increase coverage by 6.70%.
The diff coverage is n/a.

❗ Current head 87a661c differs from pull request most recent head 76f0961. Consider uploading reports for the commit 76f0961 to get more accurate results

@@                   Coverage Diff                   @@
##           gh/bdhirsh/140/base   #63048      +/-   ##
=======================================================
+ Coverage                60.47%   67.18%   +6.70%     
=======================================================
  Files                      684      695      +11     
  Lines                    88467    90528    +2061     
=======================================================
+ Hits                     53503    60820    +7317     
+ Misses                   34964    29708    -5256     

bdhirsh added a commit that referenced this pull request Sep 2, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me): the pass currently does THREE things, which are all needed by functorch.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59
* documentation breadcrumb 2: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
@ezyang
Copy link
Contributor

ezyang commented Sep 8, 2021

IIUC, this got obsoleted by #64432

bdhirsh added a commit that referenced this pull request Sep 14, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Sep 14, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Sep 15, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Sep 21, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Sep 22, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 5, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 22, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 22, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 25, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 26, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 27, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```


Differential Revision: [D31942093](https://our.internmc.facebook.com/intern/diff/D31942093)

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 27, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```


Differential Revision: [D31942093](https://our.internmc.facebook.com/intern/diff/D31942093)

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Oct 27, 2021
Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.


**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)


**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.


**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.


(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). 

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these 😢).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.



**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.


**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {


      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```


Differential Revision: [D31942093](https://our.internmc.facebook.com/intern/diff/D31942093)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Oct 28, 2021
Summary:
Pull Request resolved: #64432

Original PR description + feedback here: #63048

I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.

**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)

**Starting Points**

A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
  * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
  * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
  * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
  * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
  * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.

**Main changes from the original PR**

(1)  I use lambdas instead of a giant enum to handle all of the different views.

This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)

(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.

This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.

(3) `FunctionalTensorWrapper` objects accurately report stride information.

It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.

To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.

(4) `FunctionalTensorWrapper` objects accurately report aliasing information.

There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage).

One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.

Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?

(5) better docs :)

**View operator coverage**

(6) The functionalization pass now gets math-composite view ops for free.

I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.

(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these {emoji:1f622}).

From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation

(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.

These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.

The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).

I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).

I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.

Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
  * select
  * slice
  * diagonal
  * as_stridied
  * split
  * split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.

**Current State + Next Steps**

There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.

**Example Codegen Output**

View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {

      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      ::std::vector<at::Tensor> out;
      {
        at::AutoDispatchBelowFunctionalize guard;
        auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
        out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
        // I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
        // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
      }

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.split(split_size, dim)[mutated_view_idx];
        },
        [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
        }
      );
      at::functionalization::impl::set_view_meta(out, self, view_meta);

      at::AutoDispatchDirectlyToNative native_guard;
      ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
      at::functionalization::impl::set_strides(out, reference_tensor_output);
      return out;

}
```

Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {

      at::functionalization::impl::sync(self);
      at::functionalization::impl::sync(other);
      auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
      auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
      at::Tensor tmp_output;
      {
          at::AutoDispatchBelowFunctionalize guard;
          // The functionalization pass explicitly doesn't pass out= parameters to the redispatch
          tmp_output = at::redispatch::add(
            ks & c10::after_func_keyset, self_, other_, alpha);
      }

      self.replace_(tmp_output);
      at::functionalization::impl::maybe_add_update(self);
      return self;
}
```

View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {

      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
        [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
          return base.transpose(dim0, dim1);
        },
        [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
          return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
        }
      );
      at::functionalization::impl::mutate_view_meta(self, view_meta);
      // See  Note [Propagating strides in the functionalization pass]
      // Directly update the sizes/strides/storage_offset fields on self using the inplace call.
      // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
      // Its only job is to directly compute the output size/stride/storage_offset metadata.
      at::AutoDispatchDirectlyToNative native_guard;
      at::native::transpose_(self, dim0, dim1);
      return self;

}
```

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D31942093

Pulled By: bdhirsh

fbshipit-source-id: b95598dae35dd1842fa8b1d8d1448332f3afaadf
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants