Skip to content
Merged
285 changes: 285 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,78 @@ TEST_F(AtenXlaTensorTest, TestFrobeniusNormInDims) {
}
}

TEST_F(AtenXlaTensorTest, TestGroupNorm) {
int num_channels = 6;
at::Tensor input =
at::rand({20, num_channels, 10, 10}, at::TensorOptions(at::kFloat));
at::Tensor weight = at::rand({num_channels}, at::TensorOptions(at::kFloat));
at::Tensor bias = at::rand({num_channels}, at::TensorOptions(at::kFloat));
double eps = 1e-05;
for (int num_groups : {3, 6, 1}) {
at::Tensor output = at::group_norm(input, num_groups, weight, bias, eps,
/*cudnn_enabled=*/false);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_weight = bridge::CreateXlaTensor(weight, device);
at::Tensor xla_bias = bridge::CreateXlaTensor(bias, device);
at::Tensor xla_output =
at::group_norm(xla_input, num_groups, xla_weight, xla_bias, eps,
/*cudnn_enabled=*/false);
AllClose(output, xla_output, /*rtol=*/1e-3, /*atol=*/1e-5);
});
}
}

TEST_F(AtenXlaTensorTest, TestInstanceNorm) {
int batch = 5;
int num_channels = 20;
at::Tensor input =
at::rand({batch, num_channels, 10, 10}, at::TensorOptions(at::kFloat));
at::Tensor weight = at::rand({num_channels}, at::TensorOptions(at::kFloat));
at::Tensor bias = at::rand({num_channels}, at::TensorOptions(at::kFloat));
at::Tensor running_mean =
at::zeros({num_channels}, at::TensorOptions(at::kFloat));
at::Tensor running_var =
at::ones({num_channels}, at::TensorOptions(at::kFloat));
double momentum = 0.1;
double eps = 1e-05;
at::Tensor output = at::instance_norm(
input, weight, bias, running_mean, running_var,
/*use_input_stats=*/true, momentum, eps, /*cudnn_enabled=*/false);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_weight = bridge::CreateXlaTensor(weight, device);
at::Tensor xla_bias = bridge::CreateXlaTensor(bias, device);
at::Tensor xla_running_mean = bridge::CreateXlaTensor(running_mean, device);
at::Tensor xla_running_var = bridge::CreateXlaTensor(running_var, device);
at::Tensor xla_output = at::instance_norm(
xla_input, xla_weight, xla_bias, xla_running_mean, xla_running_var,
/*use_input_stats=*/true, momentum, eps, /*cudnn_enabled=*/false);
AllClose(output, xla_output, /*rtol=*/1e-3, /*atol=*/1e-5);
});
}

TEST_F(AtenXlaTensorTest, TestLayerNorm) {
int num_channels = 5;
std::vector<int64_t> normalized_shape = {10, 10};
at::Tensor input =
at::rand({20, num_channels, 10, 10}, at::TensorOptions(at::kFloat));
at::Tensor weight = at::rand(normalized_shape, at::TensorOptions(at::kFloat));
at::Tensor bias = at::rand(normalized_shape, at::TensorOptions(at::kFloat));
double eps = 1e-05;
at::Tensor output = at::layer_norm(input, normalized_shape, weight, bias, eps,
/*cudnn_enabled=*/false);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_weight = bridge::CreateXlaTensor(weight, device);
at::Tensor xla_bias = bridge::CreateXlaTensor(bias, device);
at::Tensor xla_output =
at::layer_norm(xla_input, normalized_shape, xla_weight, xla_bias, eps,
/*cudnn_enabled=*/false);
AllClose(output, xla_output, /*rtol=*/1e-3, /*atol=*/1e-5);
});
}

TEST_F(AtenXlaTensorTest, TestNuclearNorm) {
at::Tensor a = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = at::nuclear_norm(a);
Expand All @@ -1157,6 +1229,171 @@ TEST_F(AtenXlaTensorTest, TestNuclearNorm) {
}
}

TEST_F(AtenXlaTensorTest, TestPairwiseDistance) {
at::Tensor x1 = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor x2 = at::rand({4, 3}, at::TensorOptions(at::kFloat));
double eps = 1e-6;
for (bool keepdim : {false, true}) {
for (double p : {1, 2, 3, 4}) {
ForEachDevice([&](const Device& device) {
at::Tensor output = at::pairwise_distance(x1, x2, p, eps, keepdim);
at::Tensor xla_x1 = bridge::CreateXlaTensor(x1, device);
at::Tensor xla_x2 = bridge::CreateXlaTensor(x2, device);
at::Tensor xla_output =
at::pairwise_distance(xla_x1, xla_x2, p, eps, keepdim);
AllClose(output, xla_output);
});
}
}
}

TEST_F(AtenXlaTensorTest, TestCosineSimilarity) {
at::Tensor x1 = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor x2 = at::rand({4, 3}, at::TensorOptions(at::kFloat));
double eps = 1e-8;
int rank = x1.dim();
for (int dim = -rank; dim < rank; ++dim) {
ForEachDevice([&](const Device& device) {
at::Tensor output = at::cosine_similarity(x1, x2, dim, eps);
at::Tensor xla_x1 = bridge::CreateXlaTensor(x1, device);
at::Tensor xla_x2 = bridge::CreateXlaTensor(x2, device);
at::Tensor xla_output = at::cosine_similarity(xla_x1, xla_x2, dim, eps);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestCosineEmbeddingLoss) {
at::Tensor input1 = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor input2 = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor target = at::rand({4}, at::TensorOptions(at::kFloat));
for (Reduction::Reduction reduction : {Reduction::Mean, Reduction::Sum}) {
for (double margin : {0., 0.2}) {
ForEachDevice([&](const Device& device) {
at::Tensor output = at::cosine_embedding_loss(input1, input2, target,
margin, reduction);
at::Tensor xla_input1 = bridge::CreateXlaTensor(input1, device);
at::Tensor xla_input2 = bridge::CreateXlaTensor(input2, device);
at::Tensor xla_target = bridge::CreateXlaTensor(target, device);
at::Tensor xla_output = at::cosine_embedding_loss(
xla_input1, xla_input2, xla_target, margin, reduction);
AllClose(output, xla_output);
});
}
}
}

TEST_F(AtenXlaTensorTest, TestHingeEmbeddingLoss) {
at::Tensor input = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor target = at::rand({4, 3}, at::TensorOptions(at::kFloat));
for (Reduction::Reduction reduction : {Reduction::Mean, Reduction::Sum}) {
for (double margin : {0., 0.2}) {
ForEachDevice([&](const Device& device) {
at::Tensor output =
at::hinge_embedding_loss(input, target, margin, reduction);
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_target = bridge::CreateXlaTensor(target, device);
at::Tensor xla_output =
at::hinge_embedding_loss(xla_input, xla_target, margin, reduction);
AllClose(output, xla_output);
});
}
}
}

TEST_F(AtenXlaTensorTest, TestTripletMarginLoss) {
at::Tensor anchor = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor positive =
at::abs(at::rand({4, 3}, at::TensorOptions(at::kFloat)));
at::Tensor negative =
at::neg(at::abs(at::rand({4, 3}, at::TensorOptions(at::kFloat))));
double eps = 1e-6;
for (double margin : {0., 0.2}) {
for (double p : {1, 2, 3, 4}) {
for (bool swap : {false, true}) {
for (Reduction::Reduction reduction :
{Reduction::Mean, Reduction::Sum}) {
ForEachDevice([&](const Device& device) {
at::Tensor output = at::triplet_margin_loss(
anchor, positive, negative, margin, p, eps, swap, reduction);
at::Tensor xla_anchor = bridge::CreateXlaTensor(anchor, device);
at::Tensor xla_positive = bridge::CreateXlaTensor(positive, device);
at::Tensor xla_negative = bridge::CreateXlaTensor(negative, device);
at::Tensor xla_output =
at::triplet_margin_loss(xla_anchor, xla_positive, xla_negative,
margin, p, eps, swap, reduction);
AllClose(output, xla_output);
});
}
}
}
}
}

TEST_F(AtenXlaTensorTest, TestMarginRankingLoss) {
at::Tensor input1 = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor input2 = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor target = at::rand({4, 3}, at::TensorOptions(at::kFloat));
for (Reduction::Reduction reduction : {Reduction::Mean, Reduction::Sum}) {
for (double margin : {0., 0.2}) {
ForEachDevice([&](const Device& device) {
at::Tensor output =
at::margin_ranking_loss(input1, input2, target, margin, reduction);
at::Tensor xla_input1 = bridge::CreateXlaTensor(input1, device);
at::Tensor xla_input2 = bridge::CreateXlaTensor(input2, device);
at::Tensor xla_target = bridge::CreateXlaTensor(target, device);
at::Tensor xla_output = at::margin_ranking_loss(
xla_input1, xla_input2, xla_target, margin, reduction);
AllClose(output, xla_output);
});
}
}
}

TEST_F(AtenXlaTensorTest, TestBCEWithLogits) {
int batch = 10;
int classes = 5;
at::Tensor input = at::rand({batch, classes}, at::TensorOptions(at::kFloat));
at::Tensor target = at::rand({batch, classes}, at::TensorOptions(at::kFloat));
at::Tensor weight = at::rand({classes}, at::TensorOptions(at::kFloat));
at::Tensor pos_weight = at::rand({classes}, at::TensorOptions(at::kFloat));
at::Tensor undef;
for (Reduction::Reduction reduction : {Reduction::Mean, Reduction::Sum}) {
for (bool undef_weight : {false, true}) {
for (bool undef_pos_weight : {false, true}) {
ForEachDevice([&](const Device& device) {
at::Tensor output = at::binary_cross_entropy_with_logits(
input, target, undef_weight ? undef : weight,
undef_pos_weight ? undef : pos_weight, reduction);
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_target = bridge::CreateXlaTensor(target, device);
at::Tensor xla_weight =
undef_weight ? undef : bridge::CreateXlaTensor(weight, device);
at::Tensor xla_pos_weight =
undef_pos_weight ? undef
: bridge::CreateXlaTensor(pos_weight, device);
at::Tensor xla_output = at::binary_cross_entropy_with_logits(
xla_input, xla_target, xla_weight, xla_pos_weight, reduction);
});
}
}
}
}

TEST_F(AtenXlaTensorTest, TestKlDiv) {
at::Tensor input = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor target = at::rand({4, 3}, at::TensorOptions(at::kFloat));
for (Reduction::Reduction reduction : {Reduction::Mean, Reduction::Sum}) {
ForEachDevice([&](const Device& device) {
at::Tensor output = at::kl_div(input, target, reduction);
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_target = bridge::CreateXlaTensor(target, device);
at::Tensor xla_output = at::kl_div(xla_input, xla_target, reduction);
AllClose(output, xla_output);
});
}
}

TEST_F(AtenXlaTensorTest, TestProd) {
at::Tensor a = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::prod(a);
Expand Down Expand Up @@ -5749,5 +5986,53 @@ TEST_F(AtenXlaTensorTest, TestBatchNorm2DBackward) {
}
}

TEST_F(AtenXlaTensorTest, TestBCEWithLogitsBackward) {
int batch = 10;
int classes = 5;
at::Tensor undef;
for (Reduction::Reduction reduction :
{Reduction::None, Reduction::Mean, Reduction::Sum}) {
auto testfn = [&](const std::vector<at::Tensor>& inputs) -> at::Tensor {
return at::binary_cross_entropy_with_logits(
/*input=*/inputs[0], /*target=*/inputs[1], /*weight=*/inputs[2],
/*pos_weight=*/inputs[3],
/*reduction=*/reduction);
};
for (bool undef_weight : {false, true}) {
for (bool undef_pos_weight : {false, true}) {
at::Tensor input =
at::rand({batch, classes}, at::TensorOptions(at::kFloat));
at::Tensor target =
at::rand({batch, classes}, at::TensorOptions(at::kFloat));
at::Tensor weight =
undef_weight ? undef
: at::rand({classes}, at::TensorOptions(at::kFloat));
at::Tensor pos_weight =
undef_pos_weight
? undef
: at::rand({classes}, at::TensorOptions(at::kFloat));
ForEachDevice([&](const Device& device) {
TestBackward({input, target, weight, pos_weight}, device, testfn,
/*rtol=*/1e-5, /*atol=*/1e-7,
/*inputs_require_grad=*/{true, true, false, false});
});
}
}
}
}

TEST_F(AtenXlaTensorTest, TestKlDivBackward) {
at::Tensor input = at::rand({4, 3}, at::TensorOptions(at::kFloat));
at::Tensor target = at::rand({4, 3}, at::TensorOptions(at::kFloat));
for (Reduction::Reduction reduction : {Reduction::Mean, Reduction::Sum}) {
auto testfn = [&](const std::vector<at::Tensor>& inputs) -> at::Tensor {
return at::kl_div(/*self=*/inputs[0], /*target=*/inputs[1], reduction);
};
ForEachDevice([&](const Device& device) {
TestBackward({input, target}, device, testfn);
});
}
}

} // namespace cpp_test
} // namespace torch_xla
Loading