Skip to content

Relax some limitations of InferenceMode. #54403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 26 commits into from

Conversation

ailzhang
Copy link
Contributor

@ailzhang ailzhang commented Mar 22, 2021

Stack from ghstack:

A few important points about InferenceMode behavior:

  1. All tensors created in InferenceMode are inference tensors except for view of normal tensors.
    • View ops produce output has the same is_inference_tensor property as their input.
      Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
      exactly the same as creating a view inside NoGradMode. And view of
      inference tensor outside InferenceMode produce inference tensor as output.
    • All inference tensors have requires_grad=False and is_leaf=True.
  2. All ops are allowed inside InferenceMode, faster than normal mode.
  3. Inference tensor cannot be saved for backward.
  4. Inference tensor doesn't have version counter.
  5. There's no way to change an existing tensor from normal to inference,
    or vice versa.
  6. Leaf inference tensor with requires_grad=true can still have
    gradients.

Differential Revision: D27316483

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 22, 2021

💊 CI failures summary and remediations

As of commit 972e766 (more details on the Dr. CI page):


  • 4/4 failures possibly* introduced in this PR
    • 2/4 non-scanned failure(s)

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Apr 09 04:26:07 ERROR [0.006s]: TestViewOpsXLA (unittest.loader._FailedTest)
Apr 09 04:26:05 + XLA_EXPERIMENTAL=nonzero:masked_select
Apr 09 04:26:05 + run_test python3 /var/lib/jenkins/workspace/xla/test/../../test/test_view_ops.py -v TestViewOpsXLA
Apr 09 04:26:05 + python3 /var/lib/jenkins/workspace/xla/test/../../test/test_view_ops.py -v TestViewOpsXLA
Apr 09 04:26:07 Test results will be stored in test-reports/python-unittest/.var.lib.jenkins.workspace.xla.test.......test.test_view_ops
Apr 09 04:26:07 
Apr 09 04:26:07 Running tests...
Apr 09 04:26:07 ----------------------------------------------------------------------
Apr 09 04:26:07   TestViewOpsXLA (unittest.loader._FailedTest) ... ERROR (0.006s)
Apr 09 04:26:07 
Apr 09 04:26:07 ======================================================================
Apr 09 04:26:07 ERROR [0.006s]: TestViewOpsXLA (unittest.loader._FailedTest)
Apr 09 04:26:07 ----------------------------------------------------------------------
Apr 09 04:26:07 AttributeError: module '__main__' has no attribute 'TestViewOpsXLA'
Apr 09 04:26:07 
Apr 09 04:26:07 ----------------------------------------------------------------------
Apr 09 04:26:07 Ran 1 test in 0.007s
Apr 09 04:26:07 
Apr 09 04:26:07 FAILED (errors=1)
Apr 09 04:26:07 
Apr 09 04:26:07 Generating XML reports...
Apr 09 04:26:07 Generated XML report: test-reports/python-unittest/.var.lib.jenkins.workspace.xla.test.......test.test_view_ops/TEST-unittest.loader._FailedTest-20210409042607.xml

See CircleCI build pytorch_linux_bionic_rocm3_9_py3_6_build (2/2)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Apr 09 02:43:47 Error generating file
Apr 09 02:43:47         AT_CUDA_CHECK(cub::DeviceSegmentedReduce::Reduce(
Apr 09 02:43:47                       ^
Apr 09 02:43:47 /var/lib/jenkins/workspace/aten/src/ATen/native/hip/SegmentReduce.hip:90:23: error: use of undeclared identifier 'cub'
Apr 09 02:43:47         AT_CUDA_CHECK(cub::DeviceSegmentedReduce::Reduce(
Apr 09 02:43:47                       ^
Apr 09 02:43:47 /var/lib/jenkins/workspace/aten/src/ATen/native/hip/SegmentReduce.hip:105:23: error: use of undeclared identifier 'cub'
Apr 09 02:43:47         AT_CUDA_CHECK(cub::DeviceSegmentedReduce::Reduce(
Apr 09 02:43:47                       ^
Apr 09 02:43:47 18 errors generated when compiling for gfx900.
Apr 09 02:43:47 CMake Error at torch_hip_generated_SegmentReduce.hip.o.cmake:192 (message):
Apr 09 02:43:47   Error generating file
Apr 09 02:43:47   /var/lib/jenkins/workspace/build/caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/hip/./torch_hip_generated_SegmentReduce.hip.o
Apr 09 02:43:47 
Apr 09 02:43:47 
Apr 09 02:43:47 caffe2/CMakeFiles/torch_hip.dir/build.make:1195: recipe for target 'caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/hip/torch_hip_generated_SegmentReduce.hip.o' failed
Apr 09 02:43:47 make[2]: *** [caffe2/CMakeFiles/torch_hip.dir/__/aten/src/ATen/native/hip/torch_hip_generated_SegmentReduce.hip.o] Error 1
Apr 09 02:43:47 make[2]: *** Waiting for unfinished jobs....
Apr 09 02:43:48 In file included from /var/lib/jenkins/workspace/aten/src/ATen/native/hip/ForeachUnaryOp.hip:4:
Apr 09 02:43:48 In file included from /var/lib/jenkins/workspace/aten/src/ATen/native/hip/ForeachFunctors.cuh:3:
Apr 09 02:43:48 In file included from /var/lib/jenkins/workspace/aten/src/ATen/native/hip/MultiTensorApply.cuh:6:
Apr 09 02:43:48 In file included from /var/lib/jenkins/workspace/aten/src/ATen/native/hip/Loops.cuh:18:

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

ailzhang pushed a commit that referenced this pull request Mar 22, 2021
doesn't bump version.

ghstack-source-id: 188f95c
Pull Request resolved: #54403
… long as it"

doesn't bump version.

[ghstack-poisoned]
ailzhang pushed a commit that referenced this pull request Mar 22, 2021
doesn't bump version.

ghstack-source-id: 25e1406
Pull Request resolved: #54403
@ailzhang ailzhang requested review from ezyang and bhosmer March 22, 2021 17:32
… long as it"

doesn't bump version.

[ghstack-poisoned]
ailzhang pushed a commit that referenced this pull request Mar 23, 2021
doesn't bump version.

ghstack-source-id: 22e561d
Pull Request resolved: #54403
@ezyang
Copy link
Contributor

ezyang commented Mar 24, 2021

I think there's a correctness problem with erroring on bumping, versus erroring on read. From correspondence:

It seems to me that mutating an inference tensor outside of inference mode is safe, as long as you haven't saved it for backwards.

But the vice versa isn't safe: you can't save an inference tensor for backwards, because you might reenter an inference region (where inplace mutation could occur without bumping the VC at all)

Also, zeroing out the VC in TensorImpl constructor seems a bit smelly to me, because that's not where the final setting of VC gets done anyway for views (it can't be, because we need to share th VC in that situation)

A few important points about InferenceMode behavior:
1. Inference tensor can only be created inside InferenceMode. But not
all tensors created in InferenceMode are inference tensors. Namely view
of normal tensor inside InferenceMode produce a normal tensor, which is
exactly the same as creating a view inside NoGradMode.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be inplace updated outside InferenceMode.
4. It's not allowed to take views of inference tensor outside
InferenceMode. (This can be relaxed if needed. See comment in the
generated InplaceOrView_x.cpp.)

```
// Theorectically we can allow view ops on inference tensor in normal mode
// as long as we mark the output also inference tensor in this kernel.
// But it'll break an invariant we currently have: inference tensor can
// only be created inside InferenceMode.
// This invariant makes inference tensor easier for users to understand,
// so we should only break it when there's a valid use case.
TORCH_CHECK(!self.unsafeGetTensorImpl()->is_inference_tensor(),
  "Calling view ops on inference tensor outside InferenceMode is not allowed, ",
  "consider doing it inside infernce mode to work around this error. ",
  "If you have a valid use case, please make a feature request to PyTorch.");
```

[ghstack-poisoned]
@ailzhang ailzhang changed the title Allow mixing normal and inference tensor in normal mode as long as it Relax some limitations of InferenceMode. Mar 24, 2021
A few important points about InferenceMode behavior:
1. Inference tensor can only be created inside InferenceMode. But not
all tensors created in InferenceMode are inference tensors. Namely view
of normal tensor inside InferenceMode produce a normal tensor, which is
exactly the same as creating a view inside NoGradMode.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be inplace updated outside InferenceMode.
4. It's not allowed to take views of inference tensor outside
InferenceMode. (This can be relaxed if needed. See comment in the
generated InplaceOrView_x.cpp.)

```
// Theorectically we can allow view ops on inference tensor in normal mode
// as long as we mark the output also inference tensor in this kernel.
// But it'll break an invariant we currently have: inference tensor can
// only be created inside InferenceMode.
// This invariant makes inference tensor easier for users to understand,
// so we should only break it when there's a valid use case.
TORCH_CHECK(!self.unsafeGetTensorImpl()->is_inference_tensor(),
  "Calling view ops on inference tensor outside InferenceMode is not allowed, ",
  "consider doing it inside infernce mode to work around this error. ",
  "If you have a valid use case, please make a feature request to PyTorch.");
```

[ghstack-poisoned]
A few important points about InferenceMode behavior:
1. Inference tensor can only be created inside InferenceMode. But not
all tensors created in InferenceMode are inference tensors. Namely view
of normal tensor inside InferenceMode produce a normal tensor, which is
exactly the same as creating a view inside NoGradMode.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be inplace updated outside InferenceMode.
4. It's not allowed to take views of inference tensor outside
InferenceMode. (This can be relaxed if needed. See comment in the
as_view() function in VariableTypeUtils.h.)

```
// Theorectically we can allow view ops on inference tensor in normal mode
// as long as we mark the output also inference tensor in this kernel.
// But it'll break an invariant we currently have: inference tensor can
// only be created inside InferenceMode.
// This invariant makes inference tensor easier for users to understand,
// so we should only break it when there's a valid use case.
TORCH_CHECK(!self.unsafeGetTensorImpl()->is_inference_tensor(),
  "Calling view ops on inference tensor outside InferenceMode is not allowed, ",
  "consider doing it inside infernce mode to work around this error. ",
  "If you have a valid use case, please make a feature request to PyTorch.");
```

[ghstack-poisoned]
ailzhang pushed a commit that referenced this pull request Mar 24, 2021
A few important points about InferenceMode behavior:
1. Inference tensor can only be created inside InferenceMode. But not
all tensors created in InferenceMode are inference tensors. Namely view
of normal tensor inside InferenceMode produce a normal tensor, which is
exactly the same as creating a view inside NoGradMode.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be inplace updated outside InferenceMode.
4. It's not allowed to take views of inference tensor outside
InferenceMode. (This can be relaxed if needed. See comment in the
as_view() function in VariableTypeUtils.h.)

```
// Theorectically we can allow view ops on inference tensor in normal mode
// as long as we mark the output also inference tensor in this kernel.
// But it'll break an invariant we currently have: inference tensor can
// only be created inside InferenceMode.
// This invariant makes inference tensor easier for users to understand,
// so we should only break it when there's a valid use case.
TORCH_CHECK(!self.unsafeGetTensorImpl()->is_inference_tensor(),
  "Calling view ops on inference tensor outside InferenceMode is not allowed, ",
  "consider doing it inside infernce mode to work around this error. ",
  "If you have a valid use case, please make a feature request to PyTorch.");
```

ghstack-source-id: c520df2
Pull Request resolved: #54403
@ailzhang ailzhang requested a review from ezyang March 24, 2021 07:25
A few important points about InferenceMode behavior:
1. Inference tensor can only be created inside InferenceMode. But not
all tensors created in InferenceMode are inference tensors. Namely view
of normal tensor inside InferenceMode produce a normal tensor, which is
exactly the same as creating a view inside NoGradMode.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be inplace updated outside InferenceMode.
4. It's not allowed to take views of inference tensor outside
InferenceMode. (This can be relaxed if needed. See comment in the
as_view() function in VariableTypeUtils.h.)

```
// Theorectically we can allow view ops on inference tensor in normal mode
// as long as we mark the output also inference tensor in this kernel.
// But it'll break an invariant we currently have: inference tensor can
// only be created inside InferenceMode.
// This invariant makes inference tensor easier for users to understand,
// so we should only break it when there's a valid use case.
TORCH_CHECK(!self.unsafeGetTensorImpl()->is_inference_tensor(),
  "Calling view ops on inference tensor outside InferenceMode is not allowed, ",
  "consider doing it inside infernce mode to work around this error. ",
  "If you have a valid use case, please make a feature request to PyTorch.");
```

[ghstack-poisoned]
A few important points about InferenceMode behavior:
1. Inference tensor can only be created inside InferenceMode. But not
all tensors created in InferenceMode are inference tensors.1. All tensors created in InferenceMode are inference tensors except
   for view ops.
   - view ops produce output has the same is_inference_tensor property
     as their input. Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
     exactly the same as creating a view inside NoGradMode. And view of
     inference tensor outside InferenceMode produce inference tensor as
output.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be saved for backward.

[ghstack-poisoned]
@ailzhang ailzhang changed the title Relax some limitations of InferenceMode. [WIP]Relax some limitations of InferenceMode. Mar 25, 2021
@@ -401,7 +415,9 @@ void TensorImpl::copy_tensor_metadata(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change);
dest_impl->set_version_counter(version_counter);
if (!dest_impl->is_inference_tensor()) {
dest_impl->set_version_counter(version_counter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this is another one for subsequent refactoring, but this also looks a bit suspicious, because if I'm copying the VC from source to dest, then if source was an inference tensor the VC is empty and it's harmless to set version counter here. This should be a consequence of https://github.com/pytorch/pytorch/pull/54403/files#r607472380

I think one reason this might not hold is because of this indirect call of copy tensor metadata

Tensor VariableHooks::variable_data(const Tensor& self) const {
  TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
  auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
    /*version_counter=*/0,
    /*allow_tensor_metadata_change=*/false);

so this is freshly creating a new version counter and would trigger the error here. This is kind of suspicious anyway because we don't actually want to allocate a version counter and then throw it out, so maybe just testing if self is an inference tensor in this function would solve this problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this one will be considered in the subsequent refactor, essentially we have some code calling shallow_copy_and_detach and they're all suspicious. E.g. https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/variable.h#L688 is called when you jit load a model inside InferenceMode and cause here to fail.
In as_view we also call shallow_copy_and_detach so the followup refactor will focus on rationalize uses shallow_copy_and_detach and the VariableVersions passed from there.

@@ -401,7 +415,9 @@ void TensorImpl::copy_tensor_metadata(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change);
dest_impl->set_version_counter(version_counter);
if (!dest_impl->is_inference_tensor()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is surprising to have some inference specific logic in here no?
Why can't this setter just be a no-op for such Tensors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea the reason I keep it here since this is the only one callsite left of set_version_counter() on inference tensor with an enabled target version counter and it requires a larger refactor to get rid of it. In the ideal end state, set_version_counter on inference tensor is valid as long as the target version_counter is also disabled. I'll add a TODO here!

torch::Tensor out = view_op(inference_tensor); // go through kernels: InplaceOrView, CPU
ASSERT_TRUE(is_inference_tensor(out));
ASSERT_FALSE(out.requires_grad());
ASSERT_FALSE(out.is_view());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is out not a view? Didn't we go in the InplaceOrView kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We go through InplaceOrView kernel but note it's a no-op if base is inference tensor.
I think this is better since it matches what you get from doing view of inference tensor inside InferenceMode (is_view=false). Since the output is inference tensor we don't use the ViewMeta anyway, so it's much cleaner to just not create it.
(another reason is if we had ViewMeta in autograd_meta, it suddenly make the tensor requires_grad which is super confusing to users.).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's ok to skip all the view logic for these.
This change is starting to get very intrusive and hard to think about...

(another reason is if we had ViewMeta in autograd_meta, it suddenly make the tensor requires_grad which is super confusing to users.).

Not sure about that. This is only true if the base does require gradients.

A few important points about InferenceMode behavior:
1. All tensors created in InferenceMode are inference tensors except for view of normal tensors.
   - View ops produce output has the same is_inference_tensor property as their input.
     Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
     exactly the same as creating a view inside NoGradMode. And view of
     inference tensor outside InferenceMode produce inference tensor as output.
   - All inference tensors have requires_grad=False and is_leaf=True.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be saved for backward.
4. Inference tensor doesn't have version counter.
5. There's no way to change an existing tensor from normal to inference,
or vice versa.
6. Leaf inference tensor with requires_grad=true can still have
gradients.

Differential Revision: [D27316483](https://our.internmc.facebook.com/intern/diff/D27316483)

[ghstack-poisoned]
A few important points about InferenceMode behavior:
1. All tensors created in InferenceMode are inference tensors except for view of normal tensors.
   - View ops produce output has the same is_inference_tensor property as their input.
     Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
     exactly the same as creating a view inside NoGradMode. And view of
     inference tensor outside InferenceMode produce inference tensor as output.
   - All inference tensors have requires_grad=False and is_leaf=True.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be saved for backward.
4. Inference tensor doesn't have version counter.
5. There's no way to change an existing tensor from normal to inference,
or vice versa.
6. Leaf inference tensor with requires_grad=true can still have
gradients.

Differential Revision: [D27316483](https://our.internmc.facebook.com/intern/diff/D27316483)

[ghstack-poisoned]
ailzhang pushed a commit that referenced this pull request Apr 7, 2021
A few important points about InferenceMode behavior:
1. All tensors created in InferenceMode are inference tensors except for view of normal tensors.
   - View ops produce output has the same is_inference_tensor property as their input.
     Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
     exactly the same as creating a view inside NoGradMode. And view of
     inference tensor outside InferenceMode produce inference tensor as output.
   - All inference tensors have requires_grad=False and is_leaf=True.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be saved for backward.
4. Inference tensor doesn't have version counter.
5. There's no way to change an existing tensor from normal to inference,
or vice versa.
6. Leaf inference tensor with requires_grad=true can still have
gradients.

ghstack-source-id: f74582a
Pull Request resolved: #54403
@ailzhang ailzhang requested review from ezyang and albanD April 7, 2021 23:25
Ailing Zhang added 2 commits April 7, 2021 23:29
A few important points about InferenceMode behavior:
1. All tensors created in InferenceMode are inference tensors except for view of normal tensors.
   - View ops produce output has the same is_inference_tensor property as their input.
     Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
     exactly the same as creating a view inside NoGradMode. And view of
     inference tensor outside InferenceMode produce inference tensor as output.
   - All inference tensors have requires_grad=False and is_leaf=True.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be saved for backward.
4. Inference tensor doesn't have version counter.
5. There's no way to change an existing tensor from normal to inference,
or vice versa.
6. Leaf inference tensor with requires_grad=true can still have
gradients.

Differential Revision: [D27316483](https://our.internmc.facebook.com/intern/diff/D27316483)

[ghstack-poisoned]
A few important points about InferenceMode behavior:
1. All tensors created in InferenceMode are inference tensors except for view of normal tensors.
   - View ops produce output has the same is_inference_tensor property as their input.
     Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
     exactly the same as creating a view inside NoGradMode. And view of
     inference tensor outside InferenceMode produce inference tensor as output.
   - All inference tensors have requires_grad=False and is_leaf=True.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be saved for backward.
4. Inference tensor doesn't have version counter.
5. There's no way to change an existing tensor from normal to inference,
or vice versa.
6. Leaf inference tensor with requires_grad=true can still have
gradients.

Differential Revision: [D27316483](https://our.internmc.facebook.com/intern/diff/D27316483)

[ghstack-poisoned]
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upcoming refactors will make this better, but I think this is good enough to go today

A few important points about InferenceMode behavior:
1. All tensors created in InferenceMode are inference tensors except for view of normal tensors.
   - View ops produce output has the same is_inference_tensor property as their input.
     Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
     exactly the same as creating a view inside NoGradMode. And view of
     inference tensor outside InferenceMode produce inference tensor as output.
   - All inference tensors have requires_grad=False and is_leaf=True.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be saved for backward.
4. Inference tensor doesn't have version counter.
5. There's no way to change an existing tensor from normal to inference,
or vice versa.
6. Leaf inference tensor with requires_grad=true can still have
gradients.

Differential Revision: [D27316483](https://our.internmc.facebook.com/intern/diff/D27316483)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

@ailzhang merged this pull request in 6842da6.

@facebook-github-bot facebook-github-bot deleted the gh/ailzhang/52/head branch April 13, 2021 14:16
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Pull Request resolved: pytorch#54403

A few important points about InferenceMode behavior:
1. All tensors created in InferenceMode are inference tensors except for view ops.
   - view ops produce output has the same is_inference_tensor property as their input.
     Namely view of normal tensor inside InferenceMode produce a normal tensor, which is
     exactly the same as creating a view inside NoGradMode. And view of
     inference tensor outside InferenceMode produce inference tensor as output.
2. All ops are allowed inside InferenceMode, faster than normal mode.
3. Inference tensor cannot be saved for backward.

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D27316483

Pulled By: ailzhang

fbshipit-source-id: e03248a66d42e2d43cfe7ccb61e49cc4afb2923b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants