Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cudnn nhwc support] #23861

Closed
wants to merge 15 commits into from
Closed

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Aug 6, 2019

Added nhwc support for:

  1. cudnn_batch_norm & cudnn_batch_norm_backward
  2. cudnn_convolution_forward & cudnn_convolution_backward
  3. cudnn_convolution_transpose & cudnn_convolution_transpose_backward

patching suggest_memory_format for convolution

suggest_memory_format has ambiguous meaning for two cases:

  1. tensor with NCHW where C = 1.
    we could use stride of C as a hint to tell the intended memory format.
  2. tensor with NCHW where H == W == 1.
    there's no way to identify the intended memory format from strides.

Currently we fallback to NCHW whenever we see contiguous tensor. Hence avoiding
ambiguity for some of the special cases.

Added cudnn nhwc support for:
1. batch norm
2. convolution
3. convolution_transpose
@pytorchbot pytorchbot added module: cudnn Related to torch.backends.cudnn, and CuDNN support module: operators labels Aug 6, 2019
@jjsjann123
Copy link
Collaborator Author

This is to support #23403 Passing my local tests and triggering correct nhwc kernels.
cc'ing @csarofeen @VitalyFedyunin @ptrblck for visibility

Saw some breaking conv tests. Will handle that in this PR as well.
Also need to put in tests / code cleaning.

suggest_memory_format has ambiguous meaning for two cases:
1. tensor with NCHW where C = 1.
   we could use stride of C as a hint to tell the intended memory format.
2. tensor with NCHW where H == W == 1.
   there's no way to identify the intended memory format from strides.

Currently we fallback to NCHW whenever we see contiguous tensor. Hence avoiding
ambiguity for some of the special cases.
@pytorchbot pytorchbot added the module: nn Related to torch.nn label Aug 7, 2019
@jjsjann123
Copy link
Collaborator Author

Code should be good to review. Can't get any useful info from previous failing tests.

@li-roy li-roy added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 7, 2019
Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still reviewing cpp parts, but better tests will simplify process.

test/test_nn.py Outdated Show resolved Hide resolved
test/test_nn.py Outdated Show resolved Hide resolved
// b. Tensor with both spatial size == 1
// It causes mismatch memory format for data & filter in convolution. Hence we
// check for contiguous here to fallback to NCHW in those cases.
if (!t.is_contiguous() && t.suggest_memory_format() == at::MemoryFormat::ChannelsLast) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check !t.is_contiguous() as it is mutually exclusive with t.suggest_memory_format() == at::MemoryFormat::ChannelsLast

Copy link
Collaborator Author

@jjsjann123 jjsjann123 Aug 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the special case as the comment above suggests.
is_contiguous()=True and suggest_memory_format()==ChannelsLast are not mutually exclusive identical (I believe you were actually saying this) in the two cases listed above.

I agree that it would be a better ideal to not have it here and put it inside TensorImpl instead, we could have the tag mutually exclusive here: https://github.com/pytorch/pytorch/blob/master/c10/core/TensorImpl.h#L1514-L1515

Copy link
Collaborator Author

@jjsjann123 jjsjann123 Aug 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an oversight or I missed something, we are not copying is_channels_last_contiguous_ & is_channels_last_:

pytorch/c10/core/TensorImpl.h

Lines 1524 to 1546 in 32efb43

static void copy_tensor_metadata(
const TensorImpl* src_impl,
TensorImpl* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
dest_impl->storage_ = src_impl->storage_;
dest_impl->sizes_ = src_impl->sizes_;
dest_impl->strides_ = src_impl->strides_;
dest_impl->storage_offset_ = src_impl->storage_offset_;
dest_impl->data_type_ = src_impl->data_type_;
dest_impl->device_opt_ = src_impl->device_opt_;
dest_impl->type_id_ = src_impl->type_id_;
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
dest_impl->reserved_ = src_impl->reserved_;
dest_impl->set_version_counter(version_counter);
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
#ifdef BUILD_NAMEDTENSOR
if (src_impl->named_tensor_meta_ != nullptr) {
dest_impl->named_tensor_meta_ = src_impl->named_tensor_meta_->clone();
}
#endif
}
.
But even after patching that my fix is still not working. :/ For some reason, calling contiguous(channels_last) on a non-contiguous (NC11) tensor ended up with both flags to be True. Tracing this down the code path, I am stepping on empty_like, which I know you are working on in a different thread. I'll double check the issue I mentioned earlier after your #23899 .

Put a second thought on this, the flags are not necessarily exclusive. I fallback to your suggestion earlier to update fallback to ChannelsLast in suggest_memory_format instead (just like my code is doing right now)
But this means that we cannot represent NC11 kernel in NHWC flag. I'll open an issue to track this in case we want to revisit our design later.

@pytorchbot pytorchbot added the module: internals Related to internal abstractions in c10 and ATen label Aug 8, 2019
@jjsjann123
Copy link
Collaborator Author

ROCM doesn't seem to support nhwc output :/
Should/how do I disable nhwc on rocm tests?

apaszke
apaszke previously requested changes Aug 9, 2019
aten/src/ATen/core/Tensor.h Outdated Show resolved Hide resolved
@VitalyFedyunin
Copy link
Contributor

Looks good, let's hold merging it until we land controls from #23899

@VitalyFedyunin
Copy link
Contributor

Fails during backward pass within #25102 branch here:

Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlgPerf);

With inputs

weight strides [512, 1, 1, 1] sizes [2048, 512, 1, 1] contiguous 1 channels last 1
grad_output strides [100352, 1, 14336, 2048] sizes [64, 2048, 7, 7] contiguous 0 channels last 1
grad_input strides [25088, 1, 3584, 512] sizes [64, 512, 7, 7] contiguous 0 channels last 1
Padding [0, 0]
stride [1, 1]
dilation [1, 1]
groups 1
benchmark 1
deterministic 0

@jjsjann123
Copy link
Collaborator Author

Failed test seems to be related to some core changes. Looks like cuda.comm is calling is_contiguous on sparse tensorin. Doesn't see it on other changes, I'll merge master to see if it goes away.

@jjsjann123
Copy link
Collaborator Author

As @VitalyFedyunin mentioned earlier, BC breakage to support additional arguments required by cudnn Ex API is getting flagged by CI test:

Oct 30 07:17:39 processing existing schema:  aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, int)
Oct 30 07:17:39 Can NOT find backward compatible schemas after changes for schema aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, int) from the following candidates:
Oct 30 07:17:39 [
Oct 30 07:17:39 aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)

How do we want to proceed here?

@VitalyFedyunin
Copy link
Contributor

Let me verify it, adding arguments should not trigger BC breakage tests as it is (usually) only FC breaking.

@VitalyFedyunin
Copy link
Contributor

Can you please rebase to master

@VitalyFedyunin
Copy link
Contributor

Also you need to whitelist _batch_norm_impl_index and _batch_norm_impl_index_backward inside of check_backward_compatibility.py

@jjsjann123
Copy link
Collaborator Author

Will do.
BTW, the failure on calling is_contiguous on sparse tensor looks real. Is there a PR in your pipeline that's fixing it?

@jjsjann123
Copy link
Collaborator Author

Quick note here to myself.
Current implementation in Conv respects the layout of input. (the layout of input tensor determines the layout of output tensor).

I should update this behavior in a future PR to have the layout of weight to be dominant instead. The updated behavior would facilitate the conversion of ones model from NCHW to NHWC.

@VitalyFedyunin
Copy link
Contributor

should update this behavior in a future PR to have the layout of weight to be dominant instead. The updated behavior would facilitate the conversion of ones model from NCHW to NHWC.

I recommend to add expectedFailure test for it, which will cover edge case and explain what we are trying to archive.

@VitalyFedyunin VitalyFedyunin dismissed their stale review October 31, 2019 15:20

Sorry, wrong button, still need to figure out what is going on with Sparse and how it is even got affected

1. fixing BC compatiblity check date;
2. fixing is_contiguous call on sparse tensor;
3. added expectedFailure test to show intended behavior for Conv2d layer with
   mixed layout for input/weight;
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@VitalyFedyunin
Copy link
Contributor

Everything looks good, the only thing that we want to add is a TORCH_WARN_ONCE warning if we dispatched channels last to old cudnn function.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Nov 4, 2019
Summary:
Added nhwc support for:
1. cudnn_batch_norm & cudnn_batch_norm_backward
2. cudnn_convolution_forward & cudnn_convolution_backward
3. cudnn_convolution_transpose & cudnn_convolution_transpose_backward

patching suggest_memory_format for convolution

suggest_memory_format has ambiguous meaning for two cases:
1. tensor with NCHW where C = 1.
   we could use stride of C as a hint to tell the intended memory format.
2. tensor with NCHW where H == W == 1.
   there's no way to identify the intended memory format from strides.

Currently we fallback to NCHW whenever we see contiguous tensor. Hence avoiding
ambiguity for some of the special cases.
Pull Request resolved: pytorch/pytorch#23861

Differential Revision: D18263434

Pulled By: VitalyFedyunin

fbshipit-source-id: dd9f69576ec12fec879cd87a3d446931371360d9
@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in 8160f39.

@gchanan
Copy link
Contributor

gchanan commented Nov 5, 2019

@jjsjann123 can you not put the title of the PR in brackets "[]" -- it doesn't display in the github UI (e.g. https://github.com/pytorch/pytorch/commits/master), it only shows the commit number, which makes it much more difficult to figure out what changes are involved.

@jjsjann123
Copy link
Collaborator Author

@gchanan noted. I'll update that for my other inflight PR and do so for my future PRs.

VitalyFedyunin added a commit to VitalyFedyunin/pytorch that referenced this pull request Nov 6, 2019
```python
x = torch.randn(192,16,50).cuda()
x = x.permute(0,2,1).contiguous().permute(0,2,1)
m = torch.nn.Conv1d(
       in_channels=16,
       out_channels=32,
       kernel_size=2,
       bias=True,
  ).cuda()

m(x)
```

This reverts commit 8160f39.
facebook-github-bot pushed a commit that referenced this pull request Nov 7, 2019
Summary:
Broken case:

```python
x = torch.randn(192,16,50).cuda()
x = x.permute(0,2,1).contiguous().permute(0,2,1)
m = torch.nn.Conv1d(
       in_channels=16,
       out_channels=32,
       kernel_size=2,
       bias=True,
  ).cuda()

m(x)
```

This reverts commit 8160f39.
Pull Request resolved: #29329

Differential Revision: D18357674

Pulled By: VitalyFedyunin

fbshipit-source-id: cdd7e77e8dcbfc5f2ab3df54eb53ccfbf703b245
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cudnn Related to torch.backends.cudnn, and CuDNN support module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants