Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cudnn nhwc support] #23861

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions aten/src/ATen/cudnn/Descriptors.cpp
Expand Up @@ -110,7 +110,7 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
throw std::runtime_error("cuDNN supports only up to " STR(CUDNN_DIM_MAX) " dimensions");
#undef _STR
#undef STR
if (!t.is_contiguous()) {
if (!t.is_contiguous(t.suggest_memory_format())) {
// NB: It is possible for this test to be insufficient, because the
// Tensor passed in to set the filter descriptor may not be the actual
// Tensor whose data pointer is passed to cuDNN. Nevertheless,
Expand All @@ -125,7 +125,11 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
size[i] = (int) 1;
}
dim = std::max(dim, pad);
set(getDataType(t), (int) dim, size);
cudnnTensorFormat_t filter_format = CUDNN_TENSOR_NCHW;
if (t.suggest_memory_format() == at::MemoryFormat::ChannelsLast) {
filter_format = CUDNN_TENSOR_NHWC;
}
set(getDataType(t), (int) dim, size, filter_format);
}

}}
4 changes: 2 additions & 2 deletions aten/src/ATen/cudnn/Descriptors.h
Expand Up @@ -140,8 +140,8 @@ class FilterDescriptor
void set(const at::Tensor &t, int64_t pad = 0);

private:
void set(cudnnDataType_t dataType, int dim, int* size) {
AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, CUDNN_TENSOR_NCHW, dim, size));
void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
}
};

Expand Down
22 changes: 10 additions & 12 deletions aten/src/ATen/native/Convolution.cpp
Expand Up @@ -539,9 +539,6 @@ at::Tensor _convolution(

const bool input_is_mkldnn = input_r.is_mkldnn();
auto input = input_r;
if (!input_is_mkldnn) {
input = input.contiguous();
}
auto weight = weight_r;
auto bias = bias_r;
auto k = weight.ndimension();
Expand Down Expand Up @@ -583,15 +580,15 @@ at::Tensor _convolution(
auto dilation = params.dilation;
if (params.use_cudnn_depthwise(input, weight)) {
output = at::cudnn_convolution(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);

} else if (params.use_miopen(input)){
output = at::miopen_depthwise_convolution(
input, weight, bias,
input.contiguous(), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::thnn_conv_depthwise2d(input, weight, kernel_size, bias, stride, padding, dilation);
output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation);
}
} else if (params.use_cudnn(input)) {
TORCH_CHECK(input.type() == weight.type(),
Expand All @@ -603,11 +600,11 @@ at::Tensor _convolution(

if (params.transposed) {
output = at::cudnn_convolution_transpose(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::cudnn_convolution(
input, weight, bias,
input.contiguous(input.suggest_memory_format()), weight, bias,
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
}
} else if (params.use_miopen(input)) {
Expand All @@ -620,11 +617,11 @@ at::Tensor _convolution(

if (params.transposed) {
output = at::miopen_convolution_transpose(
input, weight, bias,
input.contiguous(), weight, bias,
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
} else {
output = at::miopen_convolution(
input, weight, bias,
input.contiguous(), weight, bias,
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
}
} else if (params.use_mkldnn(input)) {
Expand All @@ -636,7 +633,7 @@ at::Tensor _convolution(
"Input type (", input.type().toString(), ") and bias type (", bias.type().toString(),
") should be the same");
if (!input_is_mkldnn) {
output = at::mkldnn_convolution(input, weight.contiguous(), bias.defined() ? bias.contiguous() : bias,
output = at::mkldnn_convolution(input.contiguous(), weight.contiguous(), bias.defined() ? bias.contiguous() : bias,
params.padding, params.stride, params.dilation, params.groups);
} else {
// do not call contiguous on mkldnn tensor
Expand All @@ -650,9 +647,10 @@ at::Tensor _convolution(
input.device().type(), input, weight, bias, params.padding, params.stride, params.groups);
} else if (params.groups == 1) {
output = at::_convolution_nogroup(
input, weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
input.contiguous(), weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
} else {
std::vector<Tensor> outputs(params.groups);
input = input.contiguous();
for (int g = 0; g < params.groups; ++g) {
auto input_g = subtensor(input, 1, params.groups, g);
auto weight_g = subtensor(weight, 0, params.groups, g);
Expand Down
14 changes: 10 additions & 4 deletions aten/src/ATen/native/Normalization.cpp
Expand Up @@ -354,7 +354,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(const Tensor
// of backends, while enabling it to keep the information about the used backend, so that it can
// use its corresponding backward implementation.
// XXX: The indices of backends need to be kept synchronized between this function and its _backward.
std::tuple<Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
Expand Down Expand Up @@ -390,14 +390,16 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
if (use_cudnn && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()) {
return std::tuple_cat(
at::cudnn_batch_norm(
input.contiguous(), weight.contiguous(),
input.contiguous(input.suggest_memory_format()), weight.contiguous(),
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
bias.contiguous(),
running_mean.defined() ? running_mean.contiguous() : running_mean,
running_var.defined() ? running_var.contiguous() : running_var,
training, momentum, eps),
std::make_tuple(1));
}

Tensor reserve = at::empty({0}, input.options().dtype(kByte));

bool use_miopen = (input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
Expand All @@ -415,12 +417,14 @@ std::tuple<Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
running_mean.defined() ? running_mean.contiguous() : running_mean,
running_var.defined() ? running_var.contiguous() : running_var,
training, momentum, eps),
std::tuple<Tensor>(reserve),
std::make_tuple(2));
}

return std::tuple_cat(
at::native_batch_norm(
input, weight, bias, running_mean, running_var, training, momentum, eps),
std::tuple<Tensor>(reserve),
std::make_tuple(0));
}

Expand All @@ -429,11 +433,13 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
const Tensor& input, const Tensor& grad_output, const Tensor& weight /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
const Tensor& save_mean /* optional */, const Tensor& save_var_transform /* optional */,
bool train, double epsilon, std::array<bool, 3> output_mask) {
bool train, double epsilon, std::array<bool, 3> output_mask, const Tensor &reservedSpace) {
if (impl_index == 0) {
return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
} else if (impl_index == 1) {
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
// TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
// format conversion is done inside cudnn_batch_norm_backward instead
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace);
} else if (impl_index == 2) {
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
}
Expand Down