Skip to content

Commit

Permalink
add channels last support for thnn_conv2d (non-dilated)
Browse files Browse the repository at this point in the history
ghstack-source-id: 420d9cdfbcf948e6695a89018643071ac1409d51
Pull Request resolved: #49582
  • Loading branch information
mingfeima committed Dec 21, 2020
1 parent 4cd3b2b commit 4c6a701
Show file tree
Hide file tree
Showing 7 changed files with 506 additions and 45 deletions.
15 changes: 12 additions & 3 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,9 @@ static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) {
if (!tensor.defined()) {
return at::Tensor();
}
auto memory_format = tensor.suggest_memory_format();
int64_t n = tensor.sizes()[dim] / groups;
return tensor.narrow(dim, n * g, n).contiguous();
return tensor.narrow(dim, n * g, n).contiguous(memory_format);
}


Expand Down Expand Up @@ -776,12 +777,20 @@ at::Tensor _convolution(
params.stride,
params.padding);
} else if (input.device().type() == c10::DeviceType::CPU || input.device().type() == c10::DeviceType::CUDA) {
bool is_channels_last_supported = !params.transposed && (input.ndimension() == 4) &&
!params.use_nnpack(input, weight) && (input.device().type() == c10::DeviceType::CPU) &&
!params.is_dilated();
if (is_channels_last_supported) {
auto memory_format = input.suggest_memory_format();
input = input.contiguous(memory_format);
} else {
input = input.contiguous();
}
if (params.groups == 1) {
output = at::_convolution_nogroup(
input.contiguous(), weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
input, weight, bias, params.stride, params.padding, params.dilation, params.transposed, params.output_padding);
} else {
std::vector<Tensor> outputs(params.groups);
input = input.contiguous();
for (int g = 0; g < params.groups; ++g) {
auto input_g = subtensor(input, 1, params.groups, g);
auto weight_g = subtensor(weight, 0, params.groups, g);
Expand Down

0 comments on commit 4c6a701

Please sign in to comment.