Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions aten/src/ATen/native/ConvUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,20 +362,24 @@ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Ten
return false;
}

bool can_use_miopen_channels_last_2d = false;
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
// See #64427
static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
static bool suggest_nhwc = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC;

auto input_memory_format = input.suggest_memory_format();
auto weight_memory_format = weight.suggest_memory_format();
auto weight_ndim = weight.ndimension();

can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast) )
);
bool can_use_miopen_channels_last_2d = suggest_nhwc && (weight_ndim == 4) && (
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast)
);

bool can_use_miopen_channels_last_3d = false;
bool can_use_miopen_channels_last_3d = suggest_nhwc && (weight_ndim == 5) && (
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
);

return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ static inline at::MemoryFormat determine_backend_memory_format(
if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
"Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}
break;
case ConvBackend::Mkldnn:
Expand Down
17 changes: 14 additions & 3 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,20 +520,31 @@ BatchNormBackend _select_batch_norm_backend(
return BatchNormBackend::Cudnn;
}

// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM once ROCm officially supports NHWC in MIOpen
// See https://github.com/pytorch/pytorch/issues/64427.
// non static variable is used to be able to change environment variable in runtime for testing
// enabled by default for ROCm >= 7.0.0 with miopen 3.5
int miopen_version = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() : 0;
bool is_miopen_3_4 = miopen_version >= 30400; // ROCm 6.4
bool is_miopen_3_5 = miopen_version >= 30500; // ROCm 7.0
bool PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM").value_or(is_miopen_3_5);

if (
detail::getCUDAHooks().compiledWithMIOpen()
&& cudnn_enabled
&& input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.dim() >= 3
&& input.scalar_type() != at::kDouble
&& (detail::getCUDAHooks().versionMIOpen() >= 30400 || input.scalar_type() != at::kBFloat16)
&& (is_miopen_3_4 || input.scalar_type() != at::kBFloat16)
&& weight.scalar_type() == at::kFloat // only FP32 weight for FP32 or FP16/BF16(mixed) input
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
&& (input.suggest_memory_format() == MemoryFormat::Contiguous
|| (is_miopen_3_5 && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM &&
(input.suggest_memory_format() == MemoryFormat::ChannelsLast
|| input.suggest_memory_format() == MemoryFormat::ChannelsLast3d)))
) {
return BatchNormBackend::Miopen;
}
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/miopen/Conv_miopen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ Tensor miopen_convolution_forward(

auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*input, *weight)) {
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}

Tensor output_t = at::detail::empty_cuda(
Expand Down Expand Up @@ -870,7 +870,7 @@ Tensor miopen_depthwise_convolution_forward(

auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*input, *weight)) {
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}

Tensor output_t = at::detail::empty_cuda(
Expand Down Expand Up @@ -1070,7 +1070,7 @@ Tensor miopen_depthwise_convolution_backward_weight(

auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*input, *grad_output)) {
memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
memory_format = (input->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}

Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
Expand Down Expand Up @@ -1123,7 +1123,7 @@ Tensor miopen_convolution_backward_weight(

auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*input, *grad_output)) {
memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
memory_format = (input->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}

Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
Expand Down Expand Up @@ -1276,7 +1276,7 @@ Tensor miopen_convolution_backward_input(

auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*grad_output, *weight)) {
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}

Tensor grad_input_t = at::detail::empty_cuda(
Expand Down Expand Up @@ -1383,7 +1383,7 @@ Tensor miopen_depthwise_convolution_backward_input(

auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*grad_output, *weight)) {
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
}

Tensor grad_input_t = at::detail::empty_cuda(
Expand Down
Loading