Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Add disable option to 2 XLA CC tests #30217

Merged
merged 3 commits into from Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions tensorflow/compiler/xla/tests/fmax_fmin_test.cc
Expand Up @@ -19,14 +19,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/platform/test.h"

namespace xla {
namespace {

class FmaxSimpleTest : public ClientLibraryTestBase {};

TEST_F(FmaxSimpleTest, FmaxTenValues) {
XLA_TEST_F(FmaxSimpleTest, FmaxTenValues) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
auto x = ConstantR1<float>(
Expand All @@ -40,7 +41,7 @@ TEST_F(FmaxSimpleTest, FmaxTenValues) {
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}

TEST_F(FmaxSimpleTest, FmaxEdgeCases) {
XLA_TEST_F(FmaxSimpleTest, FmaxEdgeCases) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
XlaOp param0, param1;
Expand All @@ -62,7 +63,7 @@ TEST_F(FmaxSimpleTest, FmaxEdgeCases) {
ErrorSpec(0.0001));
}

TEST_F(FmaxSimpleTest, FminEdgeCases) {
XLA_TEST_F(FmaxSimpleTest, FminEdgeCases) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
XlaOp param0, param1;
Expand Down
40 changes: 20 additions & 20 deletions tensorflow/compiler/xla/tests/reduce_window_test.cc
Expand Up @@ -105,7 +105,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
XlaBuilder builder_;
};

TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
XLA_TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
const auto input = CreateConstantFromLiteral(
LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
const auto init_value =
Expand All @@ -122,7 +122,7 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
}

// Regression test for b/68964348.
TEST_P(ReduceWindowTest, R0ReduceWindow) {
XLA_TEST_P(ReduceWindowTest, R0ReduceWindow) {
const auto input =
CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
const auto init =
Expand All @@ -134,15 +134,15 @@ TEST_P(ReduceWindowTest, R0ReduceWindow) {
ErrorSpec(0.00001));
}

TEST_P(ReduceWindowTest, Min3In5Stride2) {
XLA_TEST_P(ReduceWindowTest, Min3In5Stride2) {
const auto input = CreateConstantFromLiteral(
LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, {3}, {2}, Padding::kValid);
ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
{}, ErrorSpec(0.00001));
}

TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
XLA_TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
const auto input = CreateConstantFromLiteral(
LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
Expand All @@ -165,7 +165,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, NonSquareSmall) {
XLA_TEST_P(ReduceWindowTest, NonSquareSmall) {
Array4D<float> input_array(1, 2, 2, 1);
input_array.FillRandom(2.f, 2.f);
const auto input = CreateConstantFromArray(input_array, &builder_);
Expand All @@ -180,7 +180,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) {
DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, MiddleDimsSmall) {
XLA_TEST_P(ReduceWindowTest, MiddleDimsSmall) {
Array4D<float> input_array(1, 3, 3, 1);
input_array.FillRandom(2.f, 2.f);
const auto input = CreateConstantFromArray(input_array, &builder_);
Expand All @@ -194,7 +194,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) {
DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, Along2ndMinorDim) {
XLA_TEST_P(ReduceWindowTest, Along2ndMinorDim) {
Array4D<float> input_array(3, 6, 7, 32);
input_array.FillRandom(2.f, 2.f);
const auto input = CreateConstantFromArray(input_array, &builder_);
Expand All @@ -211,7 +211,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) {
DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, AmongMajor2DimsAdd) {
XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsAdd) {
Array4D<float> input_array(4, 4, 6, 8);
input_array.FillWithMinorDimNum();
const auto input_data_handle =
Expand All @@ -233,7 +233,7 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsAdd) {
DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, AmongMajor2DimsMax) {
XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMax) {
Array4D<float> input_array(3, 3, 2, 1);
input_array.FillWithMinorDimNum();
const auto input_data_handle =
Expand All @@ -247,7 +247,7 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMax) {
ComputeAndCompare(&builder_, {}, DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
Array4D<float> input_array(9, 12, 4, 89);
input_array.FillRandom(2.f, 2.f);

Expand All @@ -272,7 +272,7 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {

// Tests the super windowing logic w.r.t handling prime number of windows in a
// major dimension with reduction.
TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
XLA_TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
Array4D<float> input_array(15, 15, 4, 128);
input_array.FillRandom(2.f, 4.f);

Expand All @@ -295,7 +295,7 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
XLA_TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
Array4D<float> input_array(19, 17, 8, 256);
input_array.FillWithMinorDimNum();

Expand Down Expand Up @@ -350,7 +350,7 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
{}, DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, R4UnitWindow) {
XLA_TEST_P(ReduceWindowTest, R4UnitWindow) {
Array4D<float> input_array(13, 12, 8, 15);
input_array.FillRandom(2.f, 2.f);
Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
Expand Down Expand Up @@ -471,7 +471,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
{input_data.get()}, DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
Array4D<float> input_array(6, 4, 10, 130);
input_array.FillRandom(2.0f);

Expand Down Expand Up @@ -538,7 +538,7 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) {
}

// Regression test for a bug that appeared in Inception (b/34784899).
TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
XLA_TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
Array2D<float> input_array(14, 14, 1.0f);
const auto input = CreateConstantFromArray(input_array, &builder_);
int win_len = 3;
Expand All @@ -548,7 +548,7 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
ComputeAndCompare(&builder_, {}, DefaultErrorSpec());
}

TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
XLA_TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
Array2D<float> input_array(6, 4, 1.0f);
XlaOp input = Broadcast(
CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
Expand Down Expand Up @@ -668,7 +668,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
}
};

TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
XLA_TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }

// base_bounds, window_bounds, strides, pad_low, pad_high
const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
Expand Down Expand Up @@ -1009,7 +1009,7 @@ class R3ReduceWindowTest : public ReduceWindowTestBase,
R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
};

TEST_P(R3ReduceWindowTest, DoIt) {
XLA_TEST_P(R3ReduceWindowTest, DoIt) {
XlaBuilder b(TestName());
const auto& param = ::testing::get<0>(GetParam());

Expand Down Expand Up @@ -1261,7 +1261,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
}
};

TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
XLA_TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }

INSTANTIATE_TEST_CASE_P(
R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
Expand Down Expand Up @@ -1409,7 +1409,7 @@ class R1ReduceWindowTest : public ReduceWindowTestBase,
R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
};

TEST_P(R1ReduceWindowTest, DoIt) {
XLA_TEST_P(R1ReduceWindowTest, DoIt) {
XlaBuilder b(TestName());
const auto& param = ::testing::get<0>(GetParam());
CHECK(param.reducer == kAdd || param.reducer == kMax);
Expand Down