Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .torch_commit_id

This file was deleted.

174 changes: 21 additions & 153 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <ATen/ATen.h>
#include <ATen/LegacyTHFunctions.h>
#include <ATen/NativeFunctions.h>
#include <gtest/gtest.h>
#include <torch/csrc/autograd/function.h>
Expand Down Expand Up @@ -381,18 +380,6 @@ TEST_F(AtenXlaTensorTest, TestNeInplace) {
});
}

TEST_F(AtenXlaTensorTest, TestThNe) {
at::Tensor a = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor c = at::_th_ne(a, b);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = bridge::CreateXlaTensor(b, device);
at::Tensor xla_c = at::_th_ne(xla_a, xla_b);
AllClose(c, xla_c);
});
}

TEST_F(AtenXlaTensorTest, TestEq) {
at::Tensor a = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = a.clone();
Expand All @@ -419,43 +406,6 @@ TEST_F(AtenXlaTensorTest, TestEqInplace) {
});
}

TEST_F(AtenXlaTensorTest, TestThEq) {
at::Tensor a = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = a.clone();
at::Tensor c = at::_th_eq(a, b);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = bridge::CreateXlaTensor(b, device);
at::Tensor xla_c = at::_th_eq(xla_a, xla_b);
AllClose(c, xla_c);
});
}

TEST_F(AtenXlaTensorTest, TestThEqScalar) {
at::Tensor a = at::full({}, 1.2, at::TensorOptions(at::kFloat));
at::Tensor b = at::_th_eq(a, 1.2);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::_th_eq(xla_a, 1.2);
AllClose(b, xla_b);
});
}

TEST_F(AtenXlaTensorTest, TestThEqAutograd) {
at::Tensor a = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = a.clone();
at::Tensor c = at::_th_eq(torch::autograd::make_variable(a, false),
torch::autograd::make_variable(b, false));
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = torch::autograd::make_variable(
bridge::CreateXlaTensor(a, device), false);
at::Tensor xla_b = torch::autograd::make_variable(
bridge::CreateXlaTensor(b, device), false);
at::Tensor xla_c = at::_th_eq(xla_a, xla_b);
EXPECT_TRUE(EqualValues(c, xla_c));
});
}

TEST_F(AtenXlaTensorTest, TestGe) {
at::Tensor a = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = a.clone();
Expand Down Expand Up @@ -510,18 +460,6 @@ TEST_F(AtenXlaTensorTest, TestLeInplace) {
});
}

TEST_F(AtenXlaTensorTest, TestThLe) {
at::Tensor a = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = a.clone();
at::Tensor c = at::_th_le(a, b);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = bridge::CreateXlaTensor(b, device);
at::Tensor xla_c = at::_th_le(xla_a, xla_b);
AllClose(c, xla_c);
});
}

TEST_F(AtenXlaTensorTest, TestGt) {
at::Tensor a = at::rand({2, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = at::add(a.clone(), at::ones_like(a));
Expand Down Expand Up @@ -587,17 +525,6 @@ TEST_F(AtenXlaTensorTest, TestNeScalar) {
});
}

TEST_F(AtenXlaTensorTest, TestThNeScalar) {
at::Tensor input = at::ones({2, 3});
at::Scalar other(float(0));
at::Tensor result = at::_th_ne(input, other);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_result = at::_th_ne(xla_input, other);
AllClose(result, xla_result);
});
}

TEST_F(AtenXlaTensorTest, TestEqScalar) {
at::Tensor input = at::ones({2, 3});
at::Scalar other(float(1));
Expand Down Expand Up @@ -655,17 +582,6 @@ TEST_F(AtenXlaTensorTest, TestLeScalarInplace) {
});
}

TEST_F(AtenXlaTensorTest, TestThLeScalar) {
at::Tensor input = at::ones({2, 3});
at::Scalar other(float(1));
at::Tensor result = at::_th_le(input, other);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_result = at::_th_le(xla_input, other);
AllClose(result, xla_result);
});
}

TEST_F(AtenXlaTensorTest, TestGtScalar) {
at::Tensor input = at::ones({2, 3});
at::Scalar other(float(0.5));
Expand Down Expand Up @@ -4459,21 +4375,21 @@ TEST_F(AtenXlaTensorTest, TestMaxPool3DIncompleteAttributes) {
for (bool ceil_mode : {false, true}) {
// Test dilation through the CPU interop.
for (int dilation = 1; dilation <= 2; ++dilation) {
at::Tensor output =
at::max_pool3d(input, /*kernel_size=*/{kernel_size},
/*stride=*/{},
/*padding=*/{padding},
/*dilation=*/{dilation, dilation, dilation},
/*ceil_mode=*/ceil_mode);
at::Tensor output = at::max_pool3d(
input, /*kernel_size=*/{kernel_size, kernel_size, kernel_size},
/*stride=*/{},
/*padding=*/{padding},
/*dilation=*/{dilation, dilation, dilation},
/*ceil_mode=*/ceil_mode);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_output =
at::max_pool3d(xla_input,
/*kernel_size=*/{kernel_size},
/*stride=*/{},
/*padding=*/{padding},
/*dilation=*/{dilation, dilation, dilation},
/*ceil_mode=*/ceil_mode);
at::Tensor xla_output = at::max_pool3d(
xla_input,
/*kernel_size=*/{kernel_size, kernel_size, kernel_size},
/*stride=*/{},
/*padding=*/{padding},
/*dilation=*/{dilation, dilation, dilation},
/*ceil_mode=*/ceil_mode);
AllClose(output, xla_output);
});
}
Expand Down Expand Up @@ -5928,56 +5844,8 @@ TEST_F(AtenXlaTensorTest, TestBitwiseXorScalarInPlace) {
});
}

TEST_F(AtenXlaTensorTest, TestBitwiseAndAutograd) {
at::Tensor lhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
at::TensorOptions(at::kInt));
at::Tensor rhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
at::TensorOptions(at::kInt));
at::Tensor result = at::legacy::th::__and__(lhs, rhs);
ForEachDevice([&](const Device& device) {
at::Tensor xla_lhs = torch::autograd::make_variable(
bridge::CreateXlaTensor(lhs, device), false);
at::Tensor xla_rhs = torch::autograd::make_variable(
bridge::CreateXlaTensor(rhs, device), false);
at::Tensor xla_result = at::legacy::th::__and__(xla_lhs, xla_rhs);
AllClose(result, xla_result);
});
}

TEST_F(AtenXlaTensorTest, TestBitwiseOrAutograd) {
at::Tensor lhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
at::TensorOptions(at::kInt));
at::Tensor rhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
at::TensorOptions(at::kInt));
at::Tensor result = at::legacy::th::__or__(lhs, rhs);
ForEachDevice([&](const Device& device) {
at::Tensor xla_lhs = torch::autograd::make_variable(
bridge::CreateXlaTensor(lhs, device), false);
at::Tensor xla_rhs = torch::autograd::make_variable(
bridge::CreateXlaTensor(rhs, device), false);
at::Tensor xla_result = at::legacy::th::__or__(xla_lhs, xla_rhs);
AllClose(result, xla_result);
});
}

TEST_F(AtenXlaTensorTest, TestBitwiseXorAutograd) {
at::Tensor lhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
at::TensorOptions(at::kInt));
at::Tensor rhs = at::randint(0, std::numeric_limits<int32_t>::max(), {4, 2},
at::TensorOptions(at::kInt));
at::Tensor result = at::legacy::th::__xor__(lhs, rhs);
ForEachDevice([&](const Device& device) {
at::Tensor xla_lhs = torch::autograd::make_variable(
bridge::CreateXlaTensor(lhs, device), false);
at::Tensor xla_rhs = torch::autograd::make_variable(
bridge::CreateXlaTensor(rhs, device), false);
at::Tensor xla_result = at::legacy::th::__xor__(xla_lhs, xla_rhs);
AllClose(result, xla_result);
});
}

TEST_F(AtenXlaTensorTest, TestLshift) {
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor input = at::ones({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor shift_amount = at::randint(16, input.sizes());
at::Tensor result = at::__lshift__(input, shift_amount);
ForEachDevice([&](const Device& device) {
Expand All @@ -5989,7 +5857,7 @@ TEST_F(AtenXlaTensorTest, TestLshift) {
}

TEST_F(AtenXlaTensorTest, TestLshiftInPlace) {
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor input = at::ones({4, 2}, at::TensorOptions(at::kFloat));
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input.clone(), device);
at::Tensor shift_amount = at::randint(16, input.sizes());
Expand All @@ -6002,7 +5870,7 @@ TEST_F(AtenXlaTensorTest, TestLshiftInPlace) {
}

TEST_F(AtenXlaTensorTest, TestLshiftScalar) {
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor input = at::ones({4, 2}, at::TensorOptions(at::kFloat));
at::Scalar shift_amount = 3;
at::Tensor result = at::__lshift__(input, shift_amount);
ForEachDevice([&](const Device& device) {
Expand All @@ -6013,7 +5881,7 @@ TEST_F(AtenXlaTensorTest, TestLshiftScalar) {
}

TEST_F(AtenXlaTensorTest, TestLshiftScalarInPlace) {
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor input = at::ones({4, 2}, at::TensorOptions(at::kFloat));
at::Scalar shift_amount = 3;
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input.clone(), device);
Expand All @@ -6025,7 +5893,7 @@ TEST_F(AtenXlaTensorTest, TestLshiftScalarInPlace) {
}

TEST_F(AtenXlaTensorTest, TestRshift) {
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor input = at::ones({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor shift_amount = at::randint(16, input.sizes());
at::Tensor result = at::__rshift__(input, shift_amount);
ForEachDevice([&](const Device& device) {
Expand All @@ -6037,7 +5905,7 @@ TEST_F(AtenXlaTensorTest, TestRshift) {
}

TEST_F(AtenXlaTensorTest, TestRshiftInPlace) {
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor input = at::ones({4, 2}, at::TensorOptions(at::kFloat));
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input.clone(), device);
at::Tensor shift_amount = at::randint(16, input.sizes());
Expand All @@ -6050,7 +5918,7 @@ TEST_F(AtenXlaTensorTest, TestRshiftInPlace) {
}

TEST_F(AtenXlaTensorTest, TestRshiftScalar) {
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor input = at::ones({4, 2}, at::TensorOptions(at::kFloat));
at::Scalar shift_amount = 3;
at::Tensor result = at::__rshift__(input, shift_amount);
ForEachDevice([&](const Device& device) {
Expand All @@ -6061,7 +5929,7 @@ TEST_F(AtenXlaTensorTest, TestRshiftScalar) {
}

TEST_F(AtenXlaTensorTest, TestRshiftScalarInPlace) {
at::Tensor input = at::randn({4, 2}, at::TensorOptions(at::kFloat));
at::Tensor input = at::ones({4, 2}, at::TensorOptions(at::kFloat));
at::Scalar shift_amount = 3;
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input.clone(), device);
Expand Down
Loading