Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 98 additions & 4 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,16 @@ TEST_F(AtenXlaTensorTest, TestMean) {
});
}

TEST_F(AtenXlaTensorTest, TestMeanCast) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::mean(a, at::kDouble);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::mean(xla_a, at::kDouble);
AllClose(b, xla_b);
});
}

TEST_F(AtenXlaTensorTest, TestMeanInDim) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
int rank = a.dim();
Expand All @@ -985,6 +995,18 @@ TEST_F(AtenXlaTensorTest, TestMeanInDims) {
}
}

TEST_F(AtenXlaTensorTest, TestMeanInDimsKeepCast) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
at::Tensor b = at::mean(a, dims, true, at::kDouble);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::mean(xla_a, dims, true, at::kDouble);
AllClose(b, xla_b);
});
}
}

TEST_F(AtenXlaTensorTest, TestSum) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::sum(a);
Expand All @@ -995,6 +1017,16 @@ TEST_F(AtenXlaTensorTest, TestSum) {
});
}

TEST_F(AtenXlaTensorTest, TestSumCast) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::sum(a, at::kDouble);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::sum(xla_a, at::kDouble);
AllClose(b, xla_b);
});
}

TEST_F(AtenXlaTensorTest, TestSumU8) {
at::Tensor a = at::ones({256}, at::TensorOptions(at::kByte));
at::Tensor b = at::sum(a);
Expand Down Expand Up @@ -1042,6 +1074,18 @@ TEST_F(AtenXlaTensorTest, TestSumInDimsKeep) {
}
}

TEST_F(AtenXlaTensorTest, TestSumInDimsKeepCast) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
for (auto dims : std::vector<std::vector<int64_t>>{{0, 1}, {-3, -2}}) {
at::Tensor b = at::sum(a, dims, /*keepdim=*/true, at::kDouble);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::sum(xla_a, dims, /*keepdim=*/true, at::kDouble);
AllClose(b, xla_b);
});
}
}

TEST_F(AtenXlaTensorTest, TestMaxInDim) {
at::Tensor input = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
int rank = input.dim();
Expand Down Expand Up @@ -1435,6 +1479,16 @@ TEST_F(AtenXlaTensorTest, TestProd) {
});
}

TEST_F(AtenXlaTensorTest, TestProdCast) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::prod(a, at::kDouble);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::prod(xla_a, at::kDouble);
AllClose(b, xla_b);
});
}

TEST_F(AtenXlaTensorTest, TestProdInDim) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
int rank = a.dim();
Expand All @@ -1448,6 +1502,19 @@ TEST_F(AtenXlaTensorTest, TestProdInDim) {
}
}

TEST_F(AtenXlaTensorTest, TestProdInDimKeepCast) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
int rank = a.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor b = at::prod(a, dim, /*keepdim=*/true, at::kDouble);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::prod(xla_a, dim, /*keepdim=*/true, at::kDouble);
AllClose(b, xla_b);
});
}
}

TEST_F(AtenXlaTensorTest, TestProdInDimKeep) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
int rank = a.dim();
Expand Down Expand Up @@ -1478,10 +1545,11 @@ TEST_F(AtenXlaTensorTest, TestCumSumCast) {
at::Tensor input = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor result = at::cumsum(input, dim, at::ScalarType::Int);
at::Tensor result = at::cumsum(input, dim, at::kDouble);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_result = at::cumsum(xla_input, dim, at::ScalarType::Int);
at::Tensor xla_result = at::cumsum(xla_input, dim, at::kDouble);
std::cout << result.dtype() << " " << xla_result.dtype() << std::endl;
EXPECT_TRUE(EqualValues(result, xla_result));
});
}
Expand Down Expand Up @@ -1531,10 +1599,10 @@ TEST_F(AtenXlaTensorTest, TestCumProdCast) {
at::mul(at::rand({4, 3, 4}, at::TensorOptions(at::kFloat)), 10);
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor result = at::cumprod(input, dim, at::ScalarType::Int);
at::Tensor result = at::cumprod(input, dim, at::kDouble);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_result = at::cumprod(xla_input, dim, at::ScalarType::Int);
at::Tensor xla_result = at::cumprod(xla_input, dim, at::kDouble);
EXPECT_TRUE(EqualValues(result, xla_result));
});
}
Expand Down Expand Up @@ -4307,6 +4375,19 @@ TEST_F(AtenXlaTensorTest, TestLogSoftmax) {
});
}

TEST_F(AtenXlaTensorTest, TestLogSoftmaxCast) {
at::Tensor input = at::rand({5, 3, 4, 2}, at::TensorOptions(at::kFloat));
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor output = at::log_softmax(input, dim, at::kDouble);
at::Tensor xla_output = at::log_softmax(xla_input, dim, at::kDouble);
AllClose(output, xla_output, /*rtol=*/1e-3);
}
});
}

TEST_F(AtenXlaTensorTest, TestSoftmax) {
at::Tensor input = at::rand({10, 8, 24, 16}, at::TensorOptions(at::kFloat));
ForEachDevice([&](const Device& device) {
Expand All @@ -4320,6 +4401,19 @@ TEST_F(AtenXlaTensorTest, TestSoftmax) {
});
}

TEST_F(AtenXlaTensorTest, TestSoftmaxCast) {
at::Tensor input = at::rand({10, 8, 24, 16}, at::TensorOptions(at::kFloat));
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor output = at::softmax(input, dim, at::kDouble);
at::Tensor xla_output = at::softmax(xla_input, dim, at::kDouble);
AllClose(output, xla_output, /*rtol=*/1e-3);
}
});
}

TEST_F(AtenXlaTensorTest, TestSoftmaxWrapper) {
at::Tensor input = at::rand({10, 8, 24, 16}, at::TensorOptions(at::kFloat));
ForEachDevice([&](const Device& device) {
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ TEST_F(TensorTest, TestLogSoftmax) {
auto dev_input = XLATensor::Create(input, device);
for (int dim = 0; dim < input.dim(); ++dim) {
auto output = input.log_softmax(dim);
auto dev_output = XLATensor::log_softmax(dev_input, dim);
auto dev_output = XLATensor::log_softmax(dev_input, dim, c10::nullopt);
AllClose(output, dev_output, /*rtol=*/1e-3);
}
});
Expand Down
Loading