Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 134 additions & 169 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5357,130 +5357,6 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DNoBatch) {
}
}

TEST_F(AtenXlaTensorTest, TestConv2D) {
int in_channels = 3;
int out_channels = 7;
int kernel_size = 5;
torch::Tensor input = torch::rand({4, in_channels, 28, 28},
torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({out_channels, in_channels, kernel_size, kernel_size},
torch::TensorOptions(torch::kFloat));
torch::Tensor bias =
torch::rand({out_channels}, torch::TensorOptions(torch::kFloat));
torch::Tensor bias_undef;
for (int stride = 1; stride <= 3; ++stride) {
for (int padding = 0; padding <= 2; ++padding) {
for (bool with_bias : {true, false}) {
// Test dilation through the CPU interop.
for (int dilation = 1; dilation <= 2; ++dilation) {
torch::Tensor output =
torch::conv2d(input, weight, with_bias ? bias : bias_undef,
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*dilation=*/{dilation, dilation});
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
torch::Tensor xla_bias = CopyToDevice(bias, device);
torch::Tensor xla_output = torch::conv2d(
xla_input, xla_weight, with_bias ? xla_bias : bias_undef,
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*dilation=*/{dilation, dilation});
AllClose(output, xla_output);
});
}
};
}
}
}

TEST_F(AtenXlaTensorTest, TestTransposedConv2D) {
int in_channels = 3;
int out_channels = 7;
int kernel_size = 5;
torch::Tensor input = torch::rand({4, out_channels, 28, 28},
torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({out_channels, in_channels, kernel_size, kernel_size},
torch::TensorOptions(torch::kFloat));
torch::Tensor bias =
torch::rand({in_channels}, torch::TensorOptions(torch::kFloat));
torch::Tensor bias_undef;
for (int stride = 1; stride <= 3; ++stride) {
for (int padding = 0; padding <= 2; ++padding) {
for (int dilation = 1; dilation <= 2; ++dilation) {
for (int output_padding = 0;
output_padding < std::min(stride, dilation); ++output_padding) {
for (bool with_bias : {true, false}) {
// Test dilation through the CPU interop.
torch::Tensor output = torch::conv_transpose2d(
input, weight, with_bias ? bias : bias_undef,
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*output_padding=*/output_padding,
/*groups=*/1,
/*dilation=*/{dilation, dilation});
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
torch::Tensor xla_bias = CopyToDevice(bias, device);
torch::Tensor xla_output = torch::conv_transpose2d(
xla_input, xla_weight, with_bias ? xla_bias : bias_undef,
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*output_padding=*/output_padding,
/*groups=*/1,
/*dilation=*/{dilation, dilation});
AllClose(output, xla_output);
});
}
};
}
}
}
}

TEST_F(AtenXlaTensorTest, TestConv2DNonSquare) {
int in_channels = 3;
int out_channels = 7;
int kernel_size = 5;
torch::Tensor input = torch::rand({4, in_channels, 28, 28},
torch::TensorOptions(torch::kFloat));
torch::Tensor weight =
torch::rand({out_channels, in_channels, kernel_size, kernel_size},
torch::TensorOptions(torch::kFloat));
torch::Tensor bias =
torch::rand({out_channels}, torch::TensorOptions(torch::kFloat));
torch::Tensor bias_undef;
for (int stride = 1; stride <= 3; ++stride) {
for (int padding = 0; padding <= 2; ++padding) {
for (bool with_bias : {true, false}) {
// Test dilation through the CPU interop.
for (int dilation = 1; dilation <= 2; ++dilation) {
torch::Tensor output =
torch::conv2d(input, weight, with_bias ? bias : bias_undef,
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*dilation=*/{dilation, dilation});
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_weight = CopyToDevice(weight, device);
torch::Tensor xla_bias = CopyToDevice(bias, device);
torch::Tensor xla_output = torch::conv2d(
xla_input, xla_weight, with_bias ? xla_bias : bias_undef,
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*dilation=*/{dilation, dilation});
AllClose(output, xla_output);
});
}
}
}
}
}

TEST_F(AtenXlaTensorTest, TestNllLoss) {
int batch = 6;
int classes = 2;
Expand Down Expand Up @@ -6872,45 +6748,47 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DNoBatchBackward) {
}

TEST_F(AtenXlaTensorTest, TestConv2DBackward) {
int in_channels = 3;
int out_channels = 7;
int in_channels = 9;
int out_channels = 3;
int kernel_size = 5;
for (int stride = 1; stride <= 3; ++stride) {
for (int padding = 0; padding <= 2; ++padding) {
for (bool with_bias : {true, false}) {
// Test dilation through the CPU interop.
for (int dilation = 1; dilation <= 2; ++dilation) {
auto testfn =
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::conv2d(inputs[0], inputs[1], inputs[2],
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*dilation=*/{dilation, dilation});
};
for (int groups : {1, 3}) {
auto testfn =
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::conv2d(inputs[0], inputs[1], inputs[2],
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*dilation=*/{dilation, dilation}, groups);
};

ForEachDevice([&](const torch::Device& device) {
torch::Tensor bias =
with_bias ? torch::rand({out_channels},
torch::TensorOptions(torch::kFloat))
: torch::Tensor();
TestBackward(
{torch::rand(
{4, in_channels, 32, 32},
torch::TensorOptions(torch::kFloat).requires_grad(true)),
torch::rand(
{out_channels, in_channels, kernel_size, kernel_size},
torch::TensorOptions(torch::kFloat).requires_grad(true)),
bias},
device, testfn);
});
}
};
ForEachDevice([&](const torch::Device& device) {
torch::Tensor bias =
with_bias ? torch::rand({out_channels},
torch::TensorOptions(torch::kFloat))
: torch::Tensor();
TestBackward(
{torch::rand(
{4, in_channels, 16, 16},
torch::TensorOptions(torch::kFloat).requires_grad(true)),
torch::rand(
{out_channels, in_channels / groups, kernel_size,
kernel_size},
torch::TensorOptions(torch::kFloat).requires_grad(true)),
bias},
device, testfn);
});
}
};
}
}
}
}

TEST_F(AtenXlaTensorTest, TestTransposedConv2DBackward) {
int in_channels = 2;
int in_channels = 9;
int out_channels = 3;
int kernel_size = 5;
for (int stride = 1; stride <= 2; ++stride) {
Expand All @@ -6919,29 +6797,72 @@ TEST_F(AtenXlaTensorTest, TestTransposedConv2DBackward) {
for (int output_padding = 0;
output_padding < std::min(stride, dilation); ++output_padding) {
for (bool with_bias : {true, false}) {
// Test dilation through the CPU interop.
for (int groups : {1, 3}) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs)
-> torch::Tensor {
return torch::conv_transpose2d(
inputs[0], inputs[1], inputs[2],
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*output_padding=*/output_padding,
/*groups=*/groups,
/*dilation=*/{dilation, dilation + 1});
};
ForEachDevice([&](const torch::Device& device) {
torch::Tensor input = torch::rand(
{4, out_channels, 14, 14},
torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor weight = torch::rand(
{out_channels, in_channels / groups, kernel_size,
kernel_size},
torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor bias =
with_bias ? torch::rand({in_channels},
torch::TensorOptions(torch::kFloat)
.requires_grad(true))
: torch::Tensor();
TestBackward({input, weight, bias}, device, testfn);
});
}
};
}
}
}
}
}

TEST_F(AtenXlaTensorTest, TestConv3DBackward) {
int in_channels = 9;
int out_channels = 3;
int kernel_size = 5;
for (int stride = 1; stride <= 3; ++stride) {
for (int padding = 1; padding <= 2; ++padding) {
for (bool with_bias : {true, false}) {
for (int dilation = 1; dilation <= 2; ++dilation) {
for (int groups : {1, 3}) {
auto testfn =
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::conv_transpose2d(inputs[0], inputs[1], inputs[2],
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*output_padding=*/output_padding,
/*groups=*/1,
/*dilation=*/{dilation, dilation});
return torch::conv3d(inputs[0], inputs[1], inputs[2],
/*stride=*/{stride, stride, stride},
/*padding=*/{padding, padding, padding},
/*dilation=*/{dilation, dilation, dilation},
groups);
};

ForEachDevice([&](const torch::Device& device) {
torch::Tensor input = torch::rand(
{4, out_channels, 14, 14},
torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor weight = torch::rand(
{out_channels, in_channels, kernel_size, kernel_size},
torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor bias =
with_bias ? torch::rand({in_channels},
torch::TensorOptions(torch::kFloat)
.requires_grad(true))
with_bias ? torch::rand({out_channels},
torch::TensorOptions(torch::kDouble))
: torch::Tensor();
TestBackward({input, weight, bias}, device, testfn);
TestBackward({torch::rand({4, in_channels, 14, 14, 14},
torch::TensorOptions(torch::kDouble)
.requires_grad(true)),
torch::rand({out_channels, in_channels / groups,
kernel_size, kernel_size, kernel_size},
torch::TensorOptions(torch::kDouble)
.requires_grad(true)),
bias},
device, testfn);
});
}
};
Expand All @@ -6950,6 +6871,50 @@ TEST_F(AtenXlaTensorTest, TestTransposedConv2DBackward) {
}
}

TEST_F(AtenXlaTensorTest, TestTransposedConv3DBackward) {
int in_channels = 9;
int out_channels = 3;
int kernel_size = 5;
for (int stride = 1; stride <= 2; ++stride) {
for (int padding = 0; padding <= 1; ++padding) {
for (int dilation = 1; dilation <= 2; ++dilation) {
for (int output_padding = 0;
output_padding < std::min(stride, dilation); ++output_padding) {
for (bool with_bias : {true, false}) {
for (int groups : {1, 3}) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs)
-> torch::Tensor {
return torch::conv_transpose3d(
inputs[0], inputs[1], inputs[2],
/*stride=*/{stride, stride + 1, stride},
/*padding=*/{padding, padding + 1, stride},
/*output_padding=*/output_padding,
/*groups=*/groups,
/*dilation=*/{dilation, dilation + 1, dilation});
};
ForEachDevice([&](const torch::Device& device) {
torch::Tensor input = torch::rand(
{4, out_channels, 14, 14, 14},
torch::TensorOptions(torch::kDouble).requires_grad(true));
torch::Tensor weight = torch::rand(
{out_channels, in_channels / groups, kernel_size,
kernel_size, kernel_size},
torch::TensorOptions(torch::kDouble).requires_grad(true));
torch::Tensor bias =
with_bias ? torch::rand({in_channels},
torch::TensorOptions(torch::kDouble)
.requires_grad(true))
: torch::Tensor();
TestBackward({input, weight, bias}, device, testfn);
});
}
};
}
}
}
}
}

TEST_F(AtenXlaTensorTest, TestMaxPool2DBackward) {
int kernel_size = 3;
for (int stride = 1; stride <= 2; ++stride) {
Expand Down
Loading