Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ set(TORCH_XLA_TEST_SOURCES
test_ir.cpp
test_mayberef.cpp
test_tensor.cpp
torch_xla_test.cpp
)

add_executable(test_ptxla ${TORCH_XLA_TEST_SOURCES})
Expand Down
33 changes: 18 additions & 15 deletions test/cpp/test_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
#include "cpp_test_util.h"
#include "tensor.h"
#include "torch/csrc/autograd/variable.h"
#include "torch_xla_test.h"

namespace torch_xla {
namespace cpp_test {

TEST(TensorTest, TestAdd) {
using TensorTest = TorchXlaTest;

TEST_F(TensorTest, TestAdd) {
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat));
auto c = a.add(b, 1.0);
Expand All @@ -24,7 +27,7 @@ TEST(TensorTest, TestAdd) {
});
}

TEST(TensorTest, TestIntegerAdd) {
TEST_F(TensorTest, TestIntegerAdd) {
std::vector<at::ScalarType> types(
{at::kByte, at::kChar, at::kShort, at::kInt, at::kLong});

Expand All @@ -43,7 +46,7 @@ TEST(TensorTest, TestIntegerAdd) {
});
}

TEST(TensorTest, TestSize) {
TEST_F(TensorTest, TestSize) {
at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat));
int rank = input.dim();
ForEachDevice([&](const Device& device) {
Expand All @@ -54,7 +57,7 @@ TEST(TensorTest, TestSize) {
});
}

TEST(TensorTest, TestRelu) {
TEST_F(TensorTest, TestRelu) {
at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat));
auto output = input.relu();
ForEachDevice([&](const Device& device) {
Expand All @@ -64,7 +67,7 @@ TEST(TensorTest, TestRelu) {
});
}

TEST(TensorTest, TestThreshold) {
TEST_F(TensorTest, TestThreshold) {
at::Tensor input = at::rand({2, 1, 4, 6}, at::TensorOptions(at::kFloat));
float threshold = 0.4;
float value = 20;
Expand All @@ -76,7 +79,7 @@ TEST(TensorTest, TestThreshold) {
});
}

TEST(TensorTest, TestAddMatMul) {
TEST_F(TensorTest, TestAddMatMul) {
int in_channels = 32;
int out_channels = 320;
int labels = 50;
Expand All @@ -97,7 +100,7 @@ TEST(TensorTest, TestAddMatMul) {
});
}

TEST(TensorTest, TestTranspose) {
TEST_F(TensorTest, TestTranspose) {
at::Tensor input = at::rand({2, 3}, at::TensorOptions(at::kFloat));
auto output = input.t();
ForEachDevice([&](const Device& device) {
Expand All @@ -107,7 +110,7 @@ TEST(TensorTest, TestTranspose) {
});
}

TEST(TensorTest, TestView) {
TEST_F(TensorTest, TestView) {
at::Tensor input = at::rand({32, 20, 4, 4}, at::TensorOptions(at::kFloat));
auto output = input.view({-1, 320});
ForEachDevice([&](const Device& device) {
Expand All @@ -117,7 +120,7 @@ TEST(TensorTest, TestView) {
});
}

TEST(TensorTest, TestLogSoftmax) {
TEST_F(TensorTest, TestLogSoftmax) {
at::Tensor input = at::rand({5, 3, 4, 2}, at::TensorOptions(at::kFloat));
ForEachDevice([&](const Device& device) {
auto dev_input = XLATensor::Create(input, device, /*requires_grad=*/false);
Expand All @@ -129,7 +132,7 @@ TEST(TensorTest, TestLogSoftmax) {
});
}

TEST(TensorTest, TestMaxPool2D) {
TEST_F(TensorTest, TestMaxPool2D) {
at::Tensor input = at::rand({1, 64, 112, 112}, at::TensorOptions(at::kFloat));
int kernel_size = 3;
for (int stride = 1; stride <= 2; ++stride) {
Expand All @@ -152,7 +155,7 @@ TEST(TensorTest, TestMaxPool2D) {
}
}

TEST(TensorTest, TestMaxPool2DNonSquare) {
TEST_F(TensorTest, TestMaxPool2DNonSquare) {
at::Tensor input = at::rand({1, 64, 112, 112}, at::TensorOptions(at::kFloat));
int kernel_size = 4;
for (int stride = 1; stride <= 2; ++stride) {
Expand All @@ -175,7 +178,7 @@ TEST(TensorTest, TestMaxPool2DNonSquare) {
}
}

TEST(TensorTest, TestAvgPool2D) {
TEST_F(TensorTest, TestAvgPool2D) {
at::Tensor input = at::rand({4, 1, 28, 28}, at::TensorOptions(at::kFloat));
int kernel_size = 2;
for (int stride = 1; stride <= 2; ++stride) {
Expand All @@ -200,7 +203,7 @@ TEST(TensorTest, TestAvgPool2D) {
}
}

TEST(TensorTest, TestAvgPool2DNonSquare) {
TEST_F(TensorTest, TestAvgPool2DNonSquare) {
at::Tensor input = at::rand({4, 1, 28, 28}, at::TensorOptions(at::kFloat));
int kernel_size = 4;
for (int stride = 1; stride <= 2; ++stride) {
Expand All @@ -226,7 +229,7 @@ TEST(TensorTest, TestAvgPool2DNonSquare) {
}
}

TEST(TensorTest, TestConv2D) {
TEST_F(TensorTest, TestConv2D) {
int in_channels = 3;
int out_channels = 7;
int kernel_size = 5;
Expand Down Expand Up @@ -270,7 +273,7 @@ TEST(TensorTest, TestConv2D) {
}
}

TEST(TensorTest, TestConv2DNonSquare) {
TEST_F(TensorTest, TestConv2DNonSquare) {
int in_channels = 3;
int out_channels = 7;
int kernel_size = 5;
Expand Down
10 changes: 10 additions & 0 deletions test/cpp/torch_xla_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include "torch_xla_test.h"
#include <ATen/ATen.h>

namespace torch_xla {
namespace cpp_test {

void TorchXlaTest::SetUp() { at::manual_seed(42); }

} // namespace cpp_test
} // namespace torch_xla
14 changes: 14 additions & 0 deletions test/cpp/torch_xla_test.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <gtest/gtest.h>

namespace torch_xla {
namespace cpp_test {

class TorchXlaTest : public ::testing::Test {
protected:
void SetUp() override;
};

} // namespace cpp_test
} // namespace torch_xla