Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,21 +309,25 @@ TEST_F(AtenXlaTensorTest, TestMaxPool2D) {
for (int padding = 0; padding <= 1; ++padding) {
// Test ceil_mode=true through the CPU interop.
for (bool ceil_mode : {false, true}) {
at::Tensor output =
at::max_pool2d(input, /*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding}, /*dilation=*/{1, 1},
/*ceil_mode=*/ceil_mode);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::max_pool2d(
xla_input,
/*kernel_size=*/{kernel_size, kernel_size},
// Test dilation through the CPU interop.
for (int dilation = 1; dilation <= 2; ++dilation) {
at::Tensor output = at::max_pool2d(
input, /*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding}, /*dilation=*/{1, 1},
/*padding=*/{padding, padding}, /*dilation=*/{dilation, dilation},
/*ceil_mode=*/ceil_mode);
AllClose(output, xla_output);
});
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output =
at::max_pool2d(xla_input,
/*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*dilation=*/{dilation, dilation},
/*ceil_mode=*/ceil_mode);
AllClose(output, xla_output);
});
}
}
}
}
Expand All @@ -336,21 +340,26 @@ TEST_F(AtenXlaTensorTest, TestMaxPool2DNonSquare) {
for (int padding = 0; padding <= 1; ++padding) {
// Test ceil_mode=true through the CPU interop.
for (bool ceil_mode : {false, true}) {
at::Tensor output = at::max_pool2d(
input, /*kernel_size=*/{kernel_size, kernel_size + 1},
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1},
/*ceil_mode=*/ceil_mode);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output = at::max_pool2d(
xla_input,
/*kernel_size=*/{kernel_size, kernel_size + 1},
// Test dilation through the CPU interop.
for (int dilation = 1; dilation <= 2; ++dilation) {
at::Tensor output = at::max_pool2d(
input, /*kernel_size=*/{kernel_size, kernel_size + 1},
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1}, /*dilation=*/{1, 1},
/*padding=*/{padding, padding + 1},
/*dilation=*/{dilation, dilation},
/*ceil_mode=*/ceil_mode);
AllClose(output, xla_output);
});
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output =
at::max_pool2d(xla_input,
/*kernel_size=*/{kernel_size, kernel_size + 1},
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*dilation=*/{dilation, dilation},
/*ceil_mode=*/ceil_mode);
AllClose(output, xla_output);
});
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla_client/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#include <functional>
#include <memory>
#include <vector>
#include <set>
#include <vector>

#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
Expand Down
19 changes: 13 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"

namespace torch_xla {
namespace {

// Returns true if dilation is non-trivial (not 1) in at least one dimension.
bool IsNonTrivialDilation(at::IntList dilation) {
return std::any_of(
dilation.begin(), dilation.end(),
[](const int64_t dim_dilation) { return dim_dilation != 1; });
}

} // namespace

bool AtenXlaType::s_use_full_conv_precision_ = false;

Expand Down Expand Up @@ -68,11 +78,8 @@ at::Tensor AtenXlaType::conv2d(const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& bias,
at::IntList stride, at::IntList padding,
at::IntList dilation, int64_t groups) const {
bool has_dilation =
std::any_of(dilation.begin(), dilation.end(),
[](const int64_t dim_dilation) { return dim_dilation != 1; });
// Dilated or grouped convolutions aren't lowered to XLA yet.
if (has_dilation || groups != 1) {
if (IsNonTrivialDilation(dilation) || groups != 1) {
return AtenXlaTypeBase::conv2d(input, weight, bias, stride, padding,
dilation, groups);
}
Expand Down Expand Up @@ -118,8 +125,8 @@ at::Tensor AtenXlaType::max_pool2d(const at::Tensor& self,
at::IntList kernel_size, at::IntList stride,
at::IntList padding, at::IntList dilation,
bool ceil_mode) const {
// Lowering when ceil_mode is set not supported yet.
if (ceil_mode) {
// Lowering when dilation is non-trivial or ceil_mode is set not supported.
if (ceil_mode || IsNonTrivialDilation(dilation)) {
return AtenXlaTypeBase::max_pool2d(self, kernel_size, stride, padding,
dilation, ceil_mode);
}
Expand Down