Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cudnn nhwc support] #23861

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/core/Tensor.h
Expand Up @@ -185,7 +185,7 @@ class CAFFE2_API Tensor {
}

at::MemoryFormat suggest_memory_format() const {
if (impl_->is_strides_like_channels_last()) {
if (!impl_->is_contiguous() && impl_->is_strides_like_channels_last()) {
apaszke marked this conversation as resolved.
Show resolved Hide resolved
return at::MemoryFormat::ChannelsLast;
}
return at::MemoryFormat::Contiguous;
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/cudnn/Descriptors.cpp
Expand Up @@ -110,7 +110,7 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
throw std::runtime_error("cuDNN supports only up to " STR(CUDNN_DIM_MAX) " dimensions");
#undef _STR
#undef STR
if (!t.is_contiguous()) {
if (!t.is_contiguous(t.suggest_memory_format())) {
// NB: It is possible for this test to be insufficient, because the
// Tensor passed in to set the filter descriptor may not be the actual
// Tensor whose data pointer is passed to cuDNN. Nevertheless,
Expand All @@ -125,7 +125,11 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
size[i] = (int) 1;
}
dim = std::max(dim, pad);
set(getDataType(t), (int) dim, size);
cudnnTensorFormat_t filter_format = CUDNN_TENSOR_NCHW;
if (t.suggest_memory_format() == at::MemoryFormat::ChannelsLast) {
filter_format = CUDNN_TENSOR_NHWC;
}
set(getDataType(t), (int) dim, size, filter_format);
}

}}
4 changes: 2 additions & 2 deletions aten/src/ATen/cudnn/Descriptors.h
Expand Up @@ -139,8 +139,8 @@ class FilterDescriptor
void set(const at::Tensor &t, int64_t pad = 0);

private:
void set(cudnnDataType_t dataType, int dim, int* size) {
AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, CUDNN_TENSOR_NCHW, dim, size));
void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
}
};

Expand Down
22 changes: 10 additions & 12 deletions aten/src/ATen/native/Convolution.cpp
Expand Up @@ -503,9 +503,6 @@ at::Tensor _convolution(

const bool input_is_mkldnn = input_r.is_mkldnn();
auto input = input_r;
if (!input_is_mkldnn) {
input = input.contiguous();
}
auto weight = weight_r;
auto bias = bias_r;
auto k = weight.ndimension();
Expand Down Expand Up @@ -547,15 +544,15 @@ at::Tensor _convolution(
auto dilation = params.dilation;
if (params.use_cudnn_depthwise(input, weight)) {
output = at::cudnn_convolution(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);

} else if (params.use_miopen(input)){
output = at::miopen_depthwise_convolution(
input, weight, bias,
input.contiguous(), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::thnn_conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation);
output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation);
}
} else if (params.use_cudnn(input)) {
TORCH_CHECK(input.type() == weight.type(),
Expand All @@ -567,11 +564,11 @@ at::Tensor _convolution(

if (params.transposed) {
output = at::cudnn_convolution_transpose(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::cudnn_convolution(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
}
} else if (params.use_miopen(input)) {
Expand All @@ -584,11 +581,11 @@ at::Tensor _convolution(

if (params.transposed) {
output = at::miopen_convolution_transpose(
input, weight, bias,
input.contiguous(), weight, bias,
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::miopen_convolution(
input, weight, bias,
input.contiguous(), weight, bias,
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
}
} else if (params.use_mkldnn(input)) {
Expand All @@ -600,7 +597,7 @@ at::Tensor _convolution(
"Input type (", input.type().toString(), ") and bias type (", bias.type().toString(),
") should be the same");
if (!input_is_mkldnn) {
output = at::mkldnn_convolution(input, weight.contiguous(), bias.defined() ? bias.contiguous() : bias,
output = at::mkldnn_convolution(input.contiguous(), weight.contiguous(), bias.defined() ? bias.contiguous() : bias,
params.padding, params.stride, params.dilation, params.groups);
} else {
// do not call contiguous on mkldnn tensor
Expand All @@ -611,9 +608,10 @@ at::Tensor _convolution(
} else {
if (params.groups == 1) {
output = at::_convolution_nogroup(
input, weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
input.contiguous(), weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
} else {
std::vector<Tensor> outputs(params.groups);
input = input.contiguous();
for (int g = 0; g < params.groups; ++g) {
auto input_g = subtensor(input, 1, params.groups, g);
auto weight_g = subtensor(weight, 0, params.groups, g);
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/Normalization.cpp
Expand Up @@ -372,7 +372,7 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
if (use_cudnn && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()) {
return std::tuple_cat(
at::cudnn_batch_norm(
input.contiguous(), weight.contiguous(),
input.contiguous(input.suggest_memory_format()), weight.contiguous(),
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
bias.contiguous(),
running_mean.defined() ? running_mean.contiguous() : running_mean,
running_var.defined() ? running_var.contiguous() : running_var,
Expand Down Expand Up @@ -415,6 +415,8 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
if (impl_index == 0) {
return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
} else if (impl_index == 1) {
// TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
// format conversion is done inside cudnn_batch_norm_backward instead
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
} else if (impl_index == 2) {
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
Expand Down
21 changes: 15 additions & 6 deletions aten/src/ATen/native/cudnn/BatchNorm.cpp
Expand Up @@ -74,7 +74,10 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm(
}
checkAllSameType(c, {weight, bias, running_mean, running_var});
// TODO: is weight required to be contiguous?
checkAllContiguous(c, {input, weight, bias, running_mean, running_var});
checkAllContiguous(c, {weight, bias, running_mean, running_var});
// TODO: TensorArg check should start handle memory format
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));

checkDimRange(c, input, 2, 6 /* exclusive */);
auto num_features = input->size(1);
for (auto t : {weight, bias, running_mean, running_var}) {
Expand All @@ -94,7 +97,7 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm(
// video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
}

auto output_t = at::empty(input->sizes(), input->options());
auto output_t = at::empty_like(*input, input->options(), input->suggest_memory_format());
TensorArg output{ output_t, "output", 0 };

auto handle = getCudnnHandle();
Expand Down Expand Up @@ -153,8 +156,10 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
const Tensor& save_mean_t, const Tensor& save_var_t,
double epsilon)
{
// TODO: Is it worth it to have a contiguous call or maybe we should go with
// whatever format is given here.
TensorArg input{ input_t, "input", 1 },
grad_output{ grad_output_t, "grad_output", 2 },
grad_output{ grad_output_t.contiguous(input_t.suggest_memory_format()), "grad_output", 2 },
weight{ weight_t, "weight", 3 },
save_mean{ save_mean_t, "save_mean", 4 },
save_var{ save_var_t, "save_var", 5 };
Expand All @@ -171,7 +176,10 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
checkAllSameType(c, {input, grad_output});
checkAllSameType(c, {weight, save_mean, save_var});
// TODO: is weight required to be contiguous?
checkAllContiguous(c, {input, grad_output, save_mean, save_var});
checkAllContiguous(c, {save_mean, save_var});
// TODO: TensorArg check should start handle memory format
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
TORCH_CHECK(grad_output->is_contiguous(grad_output->suggest_memory_format()));
checkDimRange(c, input, 2, 6 /* exclusive */);
checkSameSize(c, input, grad_output);
auto num_features = input->size(1);
Expand All @@ -190,14 +198,15 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
mode = CUDNN_BATCHNORM_SPATIAL;
}

auto grad_input_t = at::empty(input->sizes(), input->options());
auto grad_input_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
auto grad_weight_t = at::empty(weight->sizes(), weight->options());
auto grad_bias_t = at::empty(weight->sizes(), weight->options());

auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);

TensorDescriptor idesc{ *input, 4 }; // input, output, grad_output descriptor
TensorDescriptor odesc{ *grad_output, 4 }; // input, output, grad_output descriptor
TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, save_mean, etc.

Constant one(dataType, 1);
Expand All @@ -206,7 +215,7 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
AT_CUDNN_CHECK(cudnnBatchNormalizationBackward(
handle, mode, &one, &zero, &one, &zero,
idesc.desc(), input->data_ptr(),
idesc.desc(), grad_output->data_ptr(),
odesc.desc(), grad_output->data_ptr(),
idesc.desc(), grad_input_t.data_ptr(),
wdesc.desc(), weight->data_ptr(),
grad_weight_t.data_ptr(),
Expand Down
15 changes: 8 additions & 7 deletions aten/src/ATen/native/cudnn/Conv.cpp
Expand Up @@ -902,14 +902,15 @@ Tensor cudnn_convolution_forward(
auto output_t = at::empty(
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation, groups),
input->options());
input->options(),
input->suggest_memory_format());

// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{ output_t, "result", 0 };
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);

// See #4500
Tensor weight_contig = weight->contiguous();
Tensor weight_contig = weight->contiguous(input->suggest_memory_format());

raw_cudnn_convolution_forward_out(
*output, *input, weight_contig,
Expand Down Expand Up @@ -956,7 +957,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_transpose_backwar
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, std::array<bool,3> output_mask) {

Tensor grad_output = grad_output_t.contiguous();
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());

Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
Expand Down Expand Up @@ -1035,14 +1036,14 @@ Tensor cudnn_convolution_backward_input(
checkAllSameType(c, {grad_output, weight});
checkAllSameGPU(c, {grad_output, weight});

auto grad_input_t = at::empty(input_size, grad_output->options());
auto grad_input_t = at::empty(input_size, grad_output->options(), grad_output->suggest_memory_format());

// Avoid "grad_input" when this is being used as transposed convolution
TensorArg grad_input{ grad_input_t, "result", 0 };
convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);

// See #4500
Tensor weight_contig = weight->contiguous();
Tensor weight_contig = weight->contiguous(grad_output->suggest_memory_format());

raw_cudnn_convolution_backward_input_out(
*grad_input, *grad_output, weight_contig,
Expand Down Expand Up @@ -1082,7 +1083,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_backward(
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, std::array<bool,3> output_mask) {

Tensor grad_output = grad_output_t.contiguous();
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());

Tensor grad_input, grad_weight, grad_bias;
if (output_mask[0]) {
Expand Down Expand Up @@ -1165,7 +1166,7 @@ Tensor cudnn_convolution_backward_weight(
checkAllSameType(c, {grad_output, input});
checkAllSameGPU(c, {grad_output, input});

auto grad_weight_t = at::empty(weight_size, grad_output->options());
auto grad_weight_t = at::empty(weight_size, grad_output->options(), grad_output->suggest_memory_format());

// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/templates/Tensor.h
Expand Up @@ -185,7 +185,7 @@ class CAFFE2_API Tensor {
}

at::MemoryFormat suggest_memory_format() const {
if (impl_->is_strides_like_channels_last()) {
if (!impl_->is_contiguous() && impl_->is_strides_like_channels_last()) {
return at::MemoryFormat::ChannelsLast;
}
return at::MemoryFormat::Contiguous;
Expand Down
54 changes: 54 additions & 0 deletions test/test_nn.py
Expand Up @@ -7292,6 +7292,34 @@ def func(root):
gradcheck(func, [v])
gradgradcheck(func, [v])

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
@skipIfRocm
def test_batchnorm_cudnn_nhwc(self):
input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda", requires_grad=True).contiguous(memory_format=torch.channels_last)
input.retain_grad()
grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda").contiguous(memory_format=torch.channels_last)
bn = nn.BatchNorm2d(8).cuda().float()
bn.weight.data.uniform_()
bn.bias.data.uniform_()

ref_input = input.detach().clone().contiguous().requires_grad_(True)
ref_grad = grad.detach().clone().contiguous()
ref_bn = nn.BatchNorm2d(8).cuda().float()
ref_bn.load_state_dict(bn.state_dict())

out = bn(input)
out.backward(grad)
ref_out = ref_bn(ref_input)
ref_out.backward(ref_grad)

self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(ref_out.is_contiguous())
self.assertEqual(out, ref_out)
self.assertEqual(bn.weight.grad, ref_bn.weight.grad)
self.assertEqual(bn.bias.grad, ref_bn.bias.grad)
self.assertEqual(input.grad, ref_input.grad)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_batchnorm_cudnn_half(self):
# THNN
Expand Down Expand Up @@ -8682,6 +8710,32 @@ def test_cudnn_noncontiguous_weight(self):
self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2),
F.conv1d(input, weights2, bias=None, stride=2, dilation=2))

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
@skipIfRocm
def test_conv_cudnn_nhwc(self):
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True).contiguous(memory_format=torch.channels_last)
input.retain_grad()
grad = torch.rand(2, 4, 2, 2, dtype=torch.float32, device="cuda").contiguous(memory_format=torch.channels_last)
conv = nn.Conv2d(8, 4, 3).cuda().float()

ref_input = input.detach().clone().contiguous().requires_grad_(True)
ref_grad = grad.detach().clone().contiguous()
ref_conv = nn.Conv2d(8, 4, 3).cuda().float()
ref_conv.load_state_dict(conv.state_dict())

out = conv(input)
out.backward(grad)
ref_out = ref_conv(ref_input)
ref_out.backward(ref_grad)

self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(ref_out.is_contiguous())
self.assertEqual(out, ref_out)
self.assertEqual(conv.weight.grad, ref_conv.weight.grad)
self.assertEqual(conv.bias.grad, ref_conv.bias.grad)
self.assertEqual(input.grad, ref_input.grad)

jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@repeat_test_for_types(DOUBLE_TENSORTYPES)
def test_conv_double_backward_cuda(self, dtype=torch.double):
Expand Down