diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 5d0ea43c02a9..e7fd317066d9 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -5357,130 +5357,6 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DNoBatch) { } } -TEST_F(AtenXlaTensorTest, TestConv2D) { - int in_channels = 3; - int out_channels = 7; - int kernel_size = 5; - torch::Tensor input = torch::rand({4, in_channels, 28, 28}, - torch::TensorOptions(torch::kFloat)); - torch::Tensor weight = - torch::rand({out_channels, in_channels, kernel_size, kernel_size}, - torch::TensorOptions(torch::kFloat)); - torch::Tensor bias = - torch::rand({out_channels}, torch::TensorOptions(torch::kFloat)); - torch::Tensor bias_undef; - for (int stride = 1; stride <= 3; ++stride) { - for (int padding = 0; padding <= 2; ++padding) { - for (bool with_bias : {true, false}) { - // Test dilation through the CPU interop. - for (int dilation = 1; dilation <= 2; ++dilation) { - torch::Tensor output = - torch::conv2d(input, weight, with_bias ? bias : bias_undef, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*dilation=*/{dilation, dilation}); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_weight = CopyToDevice(weight, device); - torch::Tensor xla_bias = CopyToDevice(bias, device); - torch::Tensor xla_output = torch::conv2d( - xla_input, xla_weight, with_bias ? xla_bias : bias_undef, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*dilation=*/{dilation, dilation}); - AllClose(output, xla_output); - }); - } - }; - } - } -} - -TEST_F(AtenXlaTensorTest, TestTransposedConv2D) { - int in_channels = 3; - int out_channels = 7; - int kernel_size = 5; - torch::Tensor input = torch::rand({4, out_channels, 28, 28}, - torch::TensorOptions(torch::kFloat)); - torch::Tensor weight = - torch::rand({out_channels, in_channels, kernel_size, kernel_size}, - torch::TensorOptions(torch::kFloat)); - torch::Tensor bias = - torch::rand({in_channels}, torch::TensorOptions(torch::kFloat)); - torch::Tensor bias_undef; - for (int stride = 1; stride <= 3; ++stride) { - for (int padding = 0; padding <= 2; ++padding) { - for (int dilation = 1; dilation <= 2; ++dilation) { - for (int output_padding = 0; - output_padding < std::min(stride, dilation); ++output_padding) { - for (bool with_bias : {true, false}) { - // Test dilation through the CPU interop. - torch::Tensor output = torch::conv_transpose2d( - input, weight, with_bias ? bias : bias_undef, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*output_padding=*/output_padding, - /*groups=*/1, - /*dilation=*/{dilation, dilation}); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_weight = CopyToDevice(weight, device); - torch::Tensor xla_bias = CopyToDevice(bias, device); - torch::Tensor xla_output = torch::conv_transpose2d( - xla_input, xla_weight, with_bias ? xla_bias : bias_undef, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*output_padding=*/output_padding, - /*groups=*/1, - /*dilation=*/{dilation, dilation}); - AllClose(output, xla_output); - }); - } - }; - } - } - } -} - -TEST_F(AtenXlaTensorTest, TestConv2DNonSquare) { - int in_channels = 3; - int out_channels = 7; - int kernel_size = 5; - torch::Tensor input = torch::rand({4, in_channels, 28, 28}, - torch::TensorOptions(torch::kFloat)); - torch::Tensor weight = - torch::rand({out_channels, in_channels, kernel_size, kernel_size}, - torch::TensorOptions(torch::kFloat)); - torch::Tensor bias = - torch::rand({out_channels}, torch::TensorOptions(torch::kFloat)); - torch::Tensor bias_undef; - for (int stride = 1; stride <= 3; ++stride) { - for (int padding = 0; padding <= 2; ++padding) { - for (bool with_bias : {true, false}) { - // Test dilation through the CPU interop. - for (int dilation = 1; dilation <= 2; ++dilation) { - torch::Tensor output = - torch::conv2d(input, weight, with_bias ? bias : bias_undef, - /*stride=*/{stride, stride + 1}, - /*padding=*/{padding, padding + 1}, - /*dilation=*/{dilation, dilation}); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_input = CopyToDevice(input, device); - torch::Tensor xla_weight = CopyToDevice(weight, device); - torch::Tensor xla_bias = CopyToDevice(bias, device); - torch::Tensor xla_output = torch::conv2d( - xla_input, xla_weight, with_bias ? xla_bias : bias_undef, - /*stride=*/{stride, stride + 1}, - /*padding=*/{padding, padding + 1}, - /*dilation=*/{dilation, dilation}); - AllClose(output, xla_output); - }); - } - } - } - } -} - TEST_F(AtenXlaTensorTest, TestNllLoss) { int batch = 6; int classes = 2; @@ -6872,45 +6748,47 @@ TEST_F(AtenXlaTensorTest, TestAdaptiveAvgPool2DNoBatchBackward) { } TEST_F(AtenXlaTensorTest, TestConv2DBackward) { - int in_channels = 3; - int out_channels = 7; + int in_channels = 9; + int out_channels = 3; int kernel_size = 5; for (int stride = 1; stride <= 3; ++stride) { for (int padding = 0; padding <= 2; ++padding) { for (bool with_bias : {true, false}) { - // Test dilation through the CPU interop. for (int dilation = 1; dilation <= 2; ++dilation) { - auto testfn = - [&](const std::vector& inputs) -> torch::Tensor { - return torch::conv2d(inputs[0], inputs[1], inputs[2], - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*dilation=*/{dilation, dilation}); - }; + for (int groups : {1, 3}) { + auto testfn = + [&](const std::vector& inputs) -> torch::Tensor { + return torch::conv2d(inputs[0], inputs[1], inputs[2], + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, groups); + }; - ForEachDevice([&](const torch::Device& device) { - torch::Tensor bias = - with_bias ? torch::rand({out_channels}, - torch::TensorOptions(torch::kFloat)) - : torch::Tensor(); - TestBackward( - {torch::rand( - {4, in_channels, 32, 32}, - torch::TensorOptions(torch::kFloat).requires_grad(true)), - torch::rand( - {out_channels, in_channels, kernel_size, kernel_size}, - torch::TensorOptions(torch::kFloat).requires_grad(true)), - bias}, - device, testfn); - }); - } - }; + ForEachDevice([&](const torch::Device& device) { + torch::Tensor bias = + with_bias ? torch::rand({out_channels}, + torch::TensorOptions(torch::kFloat)) + : torch::Tensor(); + TestBackward( + {torch::rand( + {4, in_channels, 16, 16}, + torch::TensorOptions(torch::kFloat).requires_grad(true)), + torch::rand( + {out_channels, in_channels / groups, kernel_size, + kernel_size}, + torch::TensorOptions(torch::kFloat).requires_grad(true)), + bias}, + device, testfn); + }); + } + }; + } } } } TEST_F(AtenXlaTensorTest, TestTransposedConv2DBackward) { - int in_channels = 2; + int in_channels = 9; int out_channels = 3; int kernel_size = 5; for (int stride = 1; stride <= 2; ++stride) { @@ -6919,29 +6797,72 @@ TEST_F(AtenXlaTensorTest, TestTransposedConv2DBackward) { for (int output_padding = 0; output_padding < std::min(stride, dilation); ++output_padding) { for (bool with_bias : {true, false}) { - // Test dilation through the CPU interop. + for (int groups : {1, 3}) { + auto testfn = [&](const std::vector& inputs) + -> torch::Tensor { + return torch::conv_transpose2d( + inputs[0], inputs[1], inputs[2], + /*stride=*/{stride, stride + 1}, + /*padding=*/{padding, padding + 1}, + /*output_padding=*/output_padding, + /*groups=*/groups, + /*dilation=*/{dilation, dilation + 1}); + }; + ForEachDevice([&](const torch::Device& device) { + torch::Tensor input = torch::rand( + {4, out_channels, 14, 14}, + torch::TensorOptions(torch::kFloat).requires_grad(true)); + torch::Tensor weight = torch::rand( + {out_channels, in_channels / groups, kernel_size, + kernel_size}, + torch::TensorOptions(torch::kFloat).requires_grad(true)); + torch::Tensor bias = + with_bias ? torch::rand({in_channels}, + torch::TensorOptions(torch::kFloat) + .requires_grad(true)) + : torch::Tensor(); + TestBackward({input, weight, bias}, device, testfn); + }); + } + }; + } + } + } + } +} + +TEST_F(AtenXlaTensorTest, TestConv3DBackward) { + int in_channels = 9; + int out_channels = 3; + int kernel_size = 5; + for (int stride = 1; stride <= 3; ++stride) { + for (int padding = 1; padding <= 2; ++padding) { + for (bool with_bias : {true, false}) { + for (int dilation = 1; dilation <= 2; ++dilation) { + for (int groups : {1, 3}) { auto testfn = [&](const std::vector& inputs) -> torch::Tensor { - return torch::conv_transpose2d(inputs[0], inputs[1], inputs[2], - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}, - /*output_padding=*/output_padding, - /*groups=*/1, - /*dilation=*/{dilation, dilation}); + return torch::conv3d(inputs[0], inputs[1], inputs[2], + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{dilation, dilation, dilation}, + groups); }; + ForEachDevice([&](const torch::Device& device) { - torch::Tensor input = torch::rand( - {4, out_channels, 14, 14}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); - torch::Tensor weight = torch::rand( - {out_channels, in_channels, kernel_size, kernel_size}, - torch::TensorOptions(torch::kFloat).requires_grad(true)); torch::Tensor bias = - with_bias ? torch::rand({in_channels}, - torch::TensorOptions(torch::kFloat) - .requires_grad(true)) + with_bias ? torch::rand({out_channels}, + torch::TensorOptions(torch::kDouble)) : torch::Tensor(); - TestBackward({input, weight, bias}, device, testfn); + TestBackward({torch::rand({4, in_channels, 14, 14, 14}, + torch::TensorOptions(torch::kDouble) + .requires_grad(true)), + torch::rand({out_channels, in_channels / groups, + kernel_size, kernel_size, kernel_size}, + torch::TensorOptions(torch::kDouble) + .requires_grad(true)), + bias}, + device, testfn); }); } }; @@ -6950,6 +6871,50 @@ TEST_F(AtenXlaTensorTest, TestTransposedConv2DBackward) { } } +TEST_F(AtenXlaTensorTest, TestTransposedConv3DBackward) { + int in_channels = 9; + int out_channels = 3; + int kernel_size = 5; + for (int stride = 1; stride <= 2; ++stride) { + for (int padding = 0; padding <= 1; ++padding) { + for (int dilation = 1; dilation <= 2; ++dilation) { + for (int output_padding = 0; + output_padding < std::min(stride, dilation); ++output_padding) { + for (bool with_bias : {true, false}) { + for (int groups : {1, 3}) { + auto testfn = [&](const std::vector& inputs) + -> torch::Tensor { + return torch::conv_transpose3d( + inputs[0], inputs[1], inputs[2], + /*stride=*/{stride, stride + 1, stride}, + /*padding=*/{padding, padding + 1, stride}, + /*output_padding=*/output_padding, + /*groups=*/groups, + /*dilation=*/{dilation, dilation + 1, dilation}); + }; + ForEachDevice([&](const torch::Device& device) { + torch::Tensor input = torch::rand( + {4, out_channels, 14, 14, 14}, + torch::TensorOptions(torch::kDouble).requires_grad(true)); + torch::Tensor weight = torch::rand( + {out_channels, in_channels / groups, kernel_size, + kernel_size, kernel_size}, + torch::TensorOptions(torch::kDouble).requires_grad(true)); + torch::Tensor bias = + with_bias ? torch::rand({in_channels}, + torch::TensorOptions(torch::kDouble) + .requires_grad(true)) + : torch::Tensor(); + TestBackward({input, weight, bias}, device, testfn); + }); + } + }; + } + } + } + } +} + TEST_F(AtenXlaTensorTest, TestMaxPool2DBackward) { int kernel_size = 3; for (int stride = 1; stride <= 2; ++stride) { diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index 0c6e7bf3ae40..a4a65886b275 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -421,78 +421,279 @@ TEST_F(TensorTest, TestBatchNorm1D) { } TEST_F(TensorTest, TestConv2D) { - int in_channels = 3; - int out_channels = 7; + int in_channels = 9; + int out_channels = 3; int kernel_size = 5; at::Tensor input = - at::rand({4, in_channels, 28, 28}, at::TensorOptions(at::kFloat)); - at::Tensor weight = - at::rand({out_channels, in_channels, kernel_size, kernel_size}, - at::TensorOptions(at::kFloat)); + at::rand({4, in_channels, 32, 32}, at::TensorOptions(at::kFloat)); at::Tensor bias = at::rand({out_channels}, at::TensorOptions(at::kFloat)); at::Tensor no_bias; for (int stride = 1; stride <= 3; ++stride) { for (int padding = 0; padding <= 2; ++padding) { for (bool with_bias : {true, false}) { - at::Tensor output = - at::native::conv2d(input, weight, with_bias ? bias : no_bias, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}); - ForEachDevice([&](const Device& device) { - XLATensor dev_input = XLATensor::Create(input, device); - XLATensor dev_weight = XLATensor::Create(weight, device); - XLATensor dev_output; - if (with_bias) { - XLATensor dev_bias = XLATensor::Create(bias, device); - dev_output = XLATensor::conv2d(dev_input, dev_weight, dev_bias, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}); - } else { - dev_output = XLATensor::conv2d(dev_input, dev_weight, - /*stride=*/{stride, stride}, - /*padding=*/{padding, padding}); + for (int dilation = 1; dilation <= 2; ++dilation) { + for (int groups : {1, 3}) { + for (bool transposed : {true, false}) { + for (int output_padding = 0; + output_padding < std::min(stride, dilation); + ++output_padding) { + at::Tensor weight = + transposed ? at::rand({in_channels, out_channels / groups, + kernel_size, kernel_size}) + : at::rand({out_channels, in_channels / groups, + kernel_size, kernel_size}, + at::TensorOptions(at::kFloat)); + + at::Tensor output = at::native::_convolution( + input, weight, with_bias ? bias : no_bias, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, + /*transposed=*/transposed, + /*output_padding=*/{output_padding, output_padding}, + /*groups=*/groups, false, false, false); + ForEachDevice([&](const Device& device) { + XLATensor dev_input = XLATensor::Create(input, device); + XLATensor dev_weight = XLATensor::Create(weight, device); + XLATensor dev_output; + if (with_bias) { + XLATensor dev_bias = XLATensor::Create(bias, device); + dev_output = XLATensor::convolution_overrideable( + dev_input, dev_weight, dev_bias, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, + /*transposed=*/transposed, + /*output_padding=*/{output_padding, output_padding}, + /*groups=*/groups); + } else { + dev_output = XLATensor::convolution_overrideable( + dev_input, dev_weight, + /*stride=*/{stride, stride}, + /*padding=*/{padding, padding}, + /*dilation=*/{dilation, dilation}, + /*transposed=*/transposed, + /*output_padding=*/{output_padding, output_padding}, + /*groups=*/groups); + } + AllClose(output, dev_output); + }); + }; + } } - AllClose(output, dev_output); - }); - }; + } + } } } } TEST_F(TensorTest, TestConv2DNonSquare) { int in_channels = 3; - int out_channels = 7; + int out_channels = 6; int kernel_size = 5; at::Tensor input = - at::rand({4, in_channels, 28, 28}, at::TensorOptions(at::kFloat)); - at::Tensor weight = - at::rand({out_channels, in_channels, kernel_size, kernel_size}, - at::TensorOptions(at::kFloat)); + at::rand({4, in_channels, 26, 26}, at::TensorOptions(at::kFloat)); + at::Tensor bias = at::rand({out_channels}, at::TensorOptions(at::kFloat)); + at::Tensor no_bias; + + for (int stride = 1; stride <= 3; ++stride) { + for (int padding = 0; padding <= 0; ++padding) { + for (bool with_bias : {true, false}) { + for (int dilation = 1; dilation <= 2; ++dilation) { + for (int groups : {1, 3}) { + for (bool transposed : {true, false}) { + for (int output_padding = 0; + output_padding < std::min(stride, dilation); + ++output_padding) { + at::Tensor weight = + transposed ? at::rand({in_channels, out_channels / groups, + kernel_size, kernel_size}) + : at::rand({out_channels, in_channels / groups, + kernel_size, kernel_size}, + at::TensorOptions(at::kFloat)); + + at::Tensor output = at::native::_convolution( + input, weight, with_bias ? bias : no_bias, + /*stride=*/{stride, stride + 1}, + /*padding=*/{padding, padding + 1}, + /*dilation=*/{dilation, dilation + 1}, + /*transposed=*/transposed, + /*output_padding=*/{output_padding, output_padding + 1}, + /*groups=*/groups, false, false, false); + + ForEachDevice([&](const Device& device) { + XLATensor dev_input = XLATensor::Create(input, device); + XLATensor dev_weight = XLATensor::Create(weight, device); + XLATensor dev_output; + if (with_bias) { + XLATensor dev_bias = XLATensor::Create(bias, device); + dev_output = XLATensor::convolution_overrideable( + dev_input, dev_weight, dev_bias, + /*stride=*/{stride, stride + 1}, + /*padding=*/{padding, padding + 1}, + /*dilation=*/{dilation, dilation + 1}, + /*transposed=*/transposed, + /*output_padding=*/{output_padding, output_padding + 1}, + /*groups=*/groups); + + } else { + dev_output = XLATensor::convolution_overrideable( + dev_input, dev_weight, + /*stride=*/{stride, stride + 1}, + /*padding=*/{padding, padding + 1}, + /*dilation=*/{dilation, dilation + 1}, + /*transposed=*/transposed, + /*output_padding=*/{output_padding, output_padding + 1}, + /*groups=*/groups); + } + AllClose(output, dev_output); + }); + } + } + } + } + } + } + } +} + +TEST_F(TensorTest, TestConv3D) { + int in_channels = 9; + int out_channels = 3; + int kernel_size = 5; + at::Tensor input = + at::rand({4, in_channels, 28, 28, 28}, at::TensorOptions(at::kFloat)); at::Tensor bias = at::rand({out_channels}, at::TensorOptions(at::kFloat)); at::Tensor no_bias; for (int stride = 1; stride <= 3; ++stride) { for (int padding = 0; padding <= 2; ++padding) { for (bool with_bias : {true, false}) { - at::Tensor output = - at::native::conv2d(input, weight, with_bias ? bias : no_bias, - /*stride=*/{stride, stride + 1}, - /*padding=*/{padding, padding + 1}); - ForEachDevice([&](const Device& device) { - XLATensor dev_input = XLATensor::Create(input, device); - XLATensor dev_weight = XLATensor::Create(weight, device); - XLATensor dev_output; - if (with_bias) { - XLATensor dev_bias = XLATensor::Create(bias, device); - dev_output = XLATensor::conv2d(dev_input, dev_weight, dev_bias, - /*stride=*/{stride, stride + 1}, - /*padding=*/{padding, padding + 1}); - } else { - dev_output = XLATensor::conv2d(dev_input, dev_weight, - /*stride=*/{stride, stride + 1}, - /*padding=*/{padding, padding + 1}); + for (int dilation = 1; dilation <= 1; ++dilation) { + for (int groups : {1, 3}) { + for (bool transposed : {true, false}) { + for (int output_padding = 0; + output_padding < std::min(stride, dilation); + ++output_padding) { + at::Tensor weight = + transposed + ? at::rand({in_channels, out_channels / groups, + kernel_size, kernel_size, kernel_size}) + : at::rand({out_channels, in_channels / groups, + kernel_size, kernel_size, kernel_size}, + at::TensorOptions(at::kFloat)); + + at::Tensor output = at::native::_convolution( + input, weight, with_bias ? bias : no_bias, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{dilation, dilation, dilation}, + /*transposed=*/transposed, + /*output_padding=*/ + {output_padding, output_padding, output_padding}, + /*groups=*/groups, false, false, false); + ForEachDevice([&](const Device& device) { + XLATensor dev_input = XLATensor::Create(input, device); + XLATensor dev_weight = XLATensor::Create(weight, device); + XLATensor dev_output; + if (with_bias) { + XLATensor dev_bias = XLATensor::Create(bias, device); + dev_output = XLATensor::convolution_overrideable( + dev_input, dev_weight, dev_bias, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{dilation, dilation, dilation}, + /*transposed=*/transposed, + /*output_padding=*/ + {output_padding, output_padding, output_padding}, + /*groups=*/groups); + } else { + dev_output = XLATensor::convolution_overrideable( + dev_input, dev_weight, + /*stride=*/{stride, stride, stride}, + /*padding=*/{padding, padding, padding}, + /*dilation=*/{dilation, dilation, dilation}, + /*transposed=*/transposed, + /*output_padding=*/ + {output_padding, output_padding, output_padding}, + /*groups=*/groups); + } + AllClose(output, dev_output); + }); + }; + } } - AllClose(output, dev_output); - }); + } + } + } + } +} + +TEST_F(TensorTest, TestConv3DNonSquare) { + int in_channels = 9; + int out_channels = 3; + int kernel_size = 5; + at::Tensor input = + at::rand({4, in_channels, 28, 28, 28}, at::TensorOptions(at::kFloat)); + at::Tensor bias = at::rand({out_channels}, at::TensorOptions(at::kFloat)); + at::Tensor no_bias; + for (int stride = 1; stride <= 3; ++stride) { + for (int padding = 0; padding <= 2; ++padding) { + for (bool with_bias : {true, false}) { + for (int dilation = 1; dilation <= 1; ++dilation) { + for (int groups : {1, 3}) { + for (bool transposed : {true, false}) { + for (int output_padding = 0; + output_padding < std::min(stride, dilation); + ++output_padding) { + at::Tensor weight = + transposed + ? at::rand({in_channels, out_channels / groups, + kernel_size, kernel_size, kernel_size}) + : at::rand({out_channels, in_channels / groups, + kernel_size, kernel_size, kernel_size}, + at::TensorOptions(at::kFloat)); + + at::Tensor output = at::native::_convolution( + input, weight, with_bias ? bias : no_bias, + /*stride=*/{stride, stride + 1, stride + 1}, + /*padding=*/{padding, padding + 1, padding + 1}, + /*dilation=*/{dilation, dilation + 1, dilation + 1}, + /*transposed=*/transposed, + /*output_padding=*/ + {output_padding, output_padding + 1, output_padding}, + /*groups=*/groups, false, false, false); + ForEachDevice([&](const Device& device) { + XLATensor dev_input = XLATensor::Create(input, device); + XLATensor dev_weight = XLATensor::Create(weight, device); + XLATensor dev_output; + if (with_bias) { + XLATensor dev_bias = XLATensor::Create(bias, device); + dev_output = XLATensor::convolution_overrideable( + dev_input, dev_weight, dev_bias, + /*stride=*/{stride, stride + 1, stride + 1}, + /*padding=*/{padding, padding + 1, padding + 1}, + /*dilation=*/{dilation, dilation + 1, dilation + 1}, + /*transposed=*/transposed, + /*output_padding=*/ + {output_padding, output_padding + 1, output_padding}, + /*groups=*/groups); + } else { + dev_output = XLATensor::convolution_overrideable( + dev_input, dev_weight, + /*stride=*/{stride, stride + 1, stride + 1}, + /*padding=*/{padding, padding + 1, padding + 1}, + /*dilation=*/{dilation, dilation + 1, dilation + 1}, + /*transposed=*/transposed, + /*output_padding=*/ + {output_padding, output_padding + 1, output_padding}, + /*groups=*/groups); + } + AllClose(output, dev_output); + }); + }; + } + } + } } } } diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d66f51aa056c..41bd5304b4b9 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -824,51 +824,45 @@ at::Tensor AtenXlaType::contiguous(const at::Tensor& self, return self; } -at::Tensor AtenXlaType::conv2d(const at::Tensor& input, - const at::Tensor& weight, const at::Tensor& bias, - at::IntArrayRef stride, at::IntArrayRef padding, - at::IntArrayRef dilation, int64_t groups) { - // Dilated or grouped convolutions aren't lowered to XLA yet. - if (IsNonTrivialDilation(dilation) || groups != 1) { - return AtenXlaTypeDefault::conv2d(input, weight, bias, stride, padding, - dilation, groups); - } +// This functions covers the whole convolution lowering. +at::Tensor AtenXlaType::convolution_overrideable( + const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups) { if (bias.defined()) { - return bridge::AtenFromXlaTensor(XLATensor::conv2d( + return bridge::AtenFromXlaTensor(XLATensor::convolution_overrideable( bridge::GetXlaTensor(input), bridge::GetXlaTensor(weight), bridge::GetXlaTensor(bias), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding))); + XlaHelpers::I64List(padding), XlaHelpers::I64List(dilation), transposed, + XlaHelpers::I64List(output_padding), groups)); } else { - return bridge::AtenFromXlaTensor(XLATensor::conv2d( + return bridge::AtenFromXlaTensor(XLATensor::convolution_overrideable( bridge::GetXlaTensor(input), bridge::GetXlaTensor(weight), - XlaHelpers::I64List(stride), XlaHelpers::I64List(padding))); + XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), + XlaHelpers::I64List(dilation), transposed, + XlaHelpers::I64List(output_padding), groups)); } } -at::Tensor AtenXlaType::conv_transpose2d( - const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, - at::IntArrayRef stride, at::IntArrayRef padding, - at::IntArrayRef output_padding, int64_t groups, at::IntArrayRef dilation) { - // Dilated or grouped transposed convolutions aren't lowered to XLA yet. - if (IsNonTrivialPadding(output_padding) || IsNonTrivialDilation(dilation) || - groups != 1) { - return AtenXlaTypeDefault::conv_transpose2d( - input, weight, bias, stride, padding, output_padding, groups, dilation); - } - if (bias.defined()) { - return bridge::AtenFromXlaTensor(XLATensor::conv_transpose2d( - /*input=*/bridge::GetXlaTensor(input), - /*weight=*/bridge::GetXlaTensor(weight), - /*bias=*/bridge::GetXlaTensor(bias), - /*stride=*/XlaHelpers::I64List(stride), - /*padding=*/XlaHelpers::I64List(padding))); - } else { - return bridge::AtenFromXlaTensor(XLATensor::slow_conv_transpose2d( - /*input=*/bridge::GetXlaTensor(input), - /*weight=*/bridge::GetXlaTensor(weight), - /*stride=*/XlaHelpers::I64List(stride), - /*padding=*/XlaHelpers::I64List(padding))); - } +// This functions covers the whole convolution backward lowering. +std::tuple +AtenXlaType::convolution_backward_overrideable( + const at::Tensor& grad_output, const at::Tensor& input, + const at::Tensor& weight, at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, + int64_t groups, std::array output_mask) { + auto gradients = XLATensor::convolution_backward_overrideable( + bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(input), + bridge::GetXlaTensor(weight), XlaHelpers::I64List(stride), + XlaHelpers::I64List(padding), XlaHelpers::I64List(dilation), transposed, + XlaHelpers::I64List(output_padding), groups); + return std::make_tuple( + output_mask[0] ? bridge::AtenFromXlaTensor(std::get<0>(gradients)) + : at::Tensor(), + output_mask[1] ? bridge::AtenFromXlaTensor(std::get<1>(gradients)) + : at::Tensor(), + output_mask[2] ? bridge::AtenFromXlaTensor(std::get<2>(gradients)) + : at::Tensor()); } at::Tensor& AtenXlaType::copy_(at::Tensor& self, const at::Tensor& src, @@ -2423,34 +2417,6 @@ at::Tensor AtenXlaType::slice(const at::Tensor& self, int64_t dim, XLATensor::slice(bridge::GetXlaTensor(self), dim, start, end, step)); } -std::tuple -AtenXlaType::slow_conv_transpose2d_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& weight, at::IntArrayRef kernel_size, - at::IntArrayRef stride, at::IntArrayRef padding, - at::IntArrayRef output_padding, at::IntArrayRef dilation, - const at::Tensor& columns, const at::Tensor& ones, - std::array output_mask) { - // Dilated or grouped transposed convolutions aren't lowered to XLA yet. - if (IsNonTrivialPadding(output_padding) || IsNonTrivialDilation(dilation)) { - return AtenXlaTypeDefault::slow_conv_transpose2d_backward( - grad_output, self, weight, kernel_size, stride, padding, output_padding, - dilation, columns, ones, output_mask); - } - at::Tensor undefined; - auto gradients = XLATensor::slow_conv_transpose2d_backward( - bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self), - bridge::GetXlaTensor(weight), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding)); - return std::make_tuple( - output_mask[0] ? bridge::AtenFromXlaTensor(std::get<0>(gradients)) - : undefined, - output_mask[1] ? bridge::AtenFromXlaTensor(std::get<1>(gradients)) - : undefined, - output_mask[2] ? bridge::AtenFromXlaTensor(std::get<2>(gradients)) - : undefined); -} - at::Tensor AtenXlaType::smooth_l1_loss(const at::Tensor& self, const at::Tensor& target, int64_t reduction) { @@ -2672,42 +2638,6 @@ at::Tensor AtenXlaType::tensordot(const at::Tensor& self, return at::native::tensordot(self, other, dims_self, dims_other); } -std::tuple -AtenXlaType::thnn_conv2d_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& weight, at::IntArrayRef kernel_size, - at::IntArrayRef stride, at::IntArrayRef padding, const at::Tensor& finput, - const at::Tensor& fgrad_input, std::array output_mask) { - at::Tensor undefined; - auto gradients = XLATensor::conv2d_backward( - /*out_backprop=*/bridge::GetXlaTensor(grad_output), - /*input=*/bridge::GetXlaTensor(self), - /*weight=*/bridge::GetXlaTensor(weight), - /*stride=*/XlaHelpers::I64List(stride), - /*padding=*/XlaHelpers::I64List(padding)); - return std::make_tuple( - output_mask[0] ? bridge::AtenFromXlaTensor(std::get<0>(gradients)) - : undefined, - output_mask[1] ? bridge::AtenFromXlaTensor(std::get<1>(gradients)) - : undefined, - output_mask[2] ? bridge::AtenFromXlaTensor(std::get<2>(gradients)) - : undefined); -} - -std::tuple AtenXlaType::thnn_conv2d_forward( - const at::Tensor& self, const at::Tensor& weight, - at::IntArrayRef kernel_size, const at::Tensor& bias, at::IntArrayRef stride, - at::IntArrayRef padding) { - at::Tensor undefined = at::empty({}); - // TODO(asuhan): double check it's ok to return undefined for finput and - // fgrad_input. - return std::make_tuple( - conv2d(/*input=*/self, /*weight=*/weight, /*bias=*/bias, - /*stride=*/stride, /*padding=*/padding, /*dilation=*/{1, 1}, - /*groups=*/1), - undefined, undefined); -} - at::Tensor AtenXlaType::threshold(const at::Tensor& self, at::Scalar threshold, at::Scalar value) { return bridge::AtenFromXlaTensor(XLATensor::threshold( diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index af14e5b1ffc4..e267eed40ce7 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -292,15 +292,17 @@ class AtenXlaType { static at::Tensor contiguous(const at::Tensor& self, at::MemoryFormat memory_format); - static at::Tensor conv2d(const at::Tensor& input, const at::Tensor& weight, - const at::Tensor& bias, at::IntArrayRef stride, - at::IntArrayRef padding, at::IntArrayRef dilation, - int64_t groups); + static std::tuple + convolution_backward_overrideable( + const at::Tensor& grad_output, const at::Tensor& input, + const at::Tensor& weight, at::IntArrayRef stride, at::IntArrayRef padding, + at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, + int64_t groups, std::array output_mask); - static at::Tensor conv_transpose2d( + static at::Tensor convolution_overrideable( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, - at::IntArrayRef stride, at::IntArrayRef padding, - at::IntArrayRef output_padding, int64_t groups, at::IntArrayRef dilation); + at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, + bool transposed, at::IntArrayRef output_padding, int64_t groups); static at::Tensor& copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking); @@ -916,15 +918,6 @@ class AtenXlaType { static at::Tensor slice(const at::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step); - static std::tuple - slow_conv_transpose2d_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& weight, at::IntArrayRef kernel_size, - at::IntArrayRef stride, at::IntArrayRef padding, - at::IntArrayRef output_padding, at::IntArrayRef dilation, - const at::Tensor& columns, const at::Tensor& ones, - std::array output_mask); - static at::Tensor smooth_l1_loss(const at::Tensor& self, const at::Tensor& target, int64_t reduction); @@ -1019,17 +1012,6 @@ class AtenXlaType { at::IntArrayRef dims_self, at::IntArrayRef dims_other); - static std::tuple thnn_conv2d_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& weight, at::IntArrayRef kernel_size, - at::IntArrayRef stride, at::IntArrayRef padding, const at::Tensor& finput, - const at::Tensor& fgrad_input, std::array output_mask); - - static std::tuple thnn_conv2d_forward( - const at::Tensor& self, const at::Tensor& weight, - at::IntArrayRef kernel_size, const at::Tensor& bias, - at::IntArrayRef stride, at::IntArrayRef padding); - static at::Tensor threshold(const at::Tensor& self, at::Scalar threshold, at::Scalar value); diff --git a/torch_xla/csrc/convolution.cpp b/torch_xla/csrc/convolution.cpp index f4e211ea5377..b0028455229b 100644 --- a/torch_xla/csrc/convolution.cpp +++ b/torch_xla/csrc/convolution.cpp @@ -16,12 +16,13 @@ namespace { tensorflow::ConvOpAttrs MakeConvOpAttrs( tensorflow::gtl::ArraySlice spatial_stride, tensorflow::gtl::ArraySlice spatial_padding, - tensorflow::gtl::ArraySlice spatial_dilation) { + tensorflow::gtl::ArraySlice spatial_dilation, + bool depthwise) { int num_spatial_dims = spatial_stride.size(); XLA_CHECK_EQ(spatial_padding.size(), num_spatial_dims); XLA_CHECK_EQ(spatial_dilation.size(), num_spatial_dims); tensorflow::ConvOpAttrs conv_op_attrs; - conv_op_attrs.depthwise = false; + conv_op_attrs.depthwise = depthwise; conv_op_attrs.num_spatial_dims = num_spatial_dims; // Stride, dilation and padding must be set for the batch and feature in the // TF convolution metadata. Set them to 1 (stride and dilation) or 0 (padding) @@ -33,6 +34,8 @@ tensorflow::ConvOpAttrs MakeConvOpAttrs( std::copy(spatial_stride.begin(), spatial_stride.end(), std::back_inserter(conv_op_attrs.strides)); conv_op_attrs.padding = tensorflow::Padding::EXPLICIT; + // https://github.com/tensorflow/tensorflow/blob/ec81825aaf7e848d9f8ddffdf1e0d20aebe9172c/tensorflow/core/util/padding.cc#L40 + // explicit_padding requires to have (spatial_dims + 2) * 2 elements conv_op_attrs.explicit_paddings.resize(4); for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { conv_op_attrs.explicit_paddings.push_back(spatial_padding[spatial_dim]); @@ -42,41 +45,93 @@ tensorflow::ConvOpAttrs MakeConvOpAttrs( return conv_op_attrs; } -std::vector FilterTransposePermutation() { return {2, 3, 1, 0}; } +// Transpose filter shape to have [channel, batch] as last two dimensions. +// 4D case: (N, C, H, W) -> (H, W, C, N) +const std::vector& FilterTransposePermutation(const xla::int64 k) { + if (k == 4) { + static std::vector* permutation = + new std::vector({2, 3, 1, 0}); + return *permutation; + } else if (k == 5) { + static std::vector* permutation = + new std::vector({2, 3, 4, 1, 0}); + return *permutation; + } else { + XLA_ERROR() << "Invalid rank: " << k; + } +} + +// Bias broadcast based on output shape produces: +// (N, H, W) + (C,) = (N, H, W, C) +// This permutation does (N, H, W, C) -> (N, C, H, W) +const std::vector& BiasTransposePermutation(const xla::int64 k) { + if (k == 4) { + static std::vector* permutation = + new std::vector({0, 3, 1, 2}); + return *permutation; + } else if (k == 5) { + static std::vector* permutation = + new std::vector({0, 4, 1, 2, 3}); + return *permutation; + } else { + XLA_ERROR() << "Invalid rank: " << k; + } +} + +// Reduce bias from (N, C, H, W) to (C,) +const std::vector& BiasReduceDimensions(const xla::int64 k) { + if (k == 4) { + static std::vector* reduce_dim = + new std::vector({0, 2, 3}); + return *reduce_dim; + } else if (k == 5) { + static std::vector* reduce_dim = + new std::vector({0, 2, 3, 4}); + return *reduce_dim; + } else { + XLA_ERROR() << "Invalid rank: " << k; + } +} // Computes the input gradient for a convolution. -xla::XlaOp BuildThnnConv2dBackwardInput( +xla::XlaOp BuildConvBackwardInput( const xla::XlaOp& grad_output, const xla::XlaOp& kernel, const xla::Shape& input_shape, tensorflow::gtl::ArraySlice spatial_stride, - tensorflow::gtl::ArraySlice spatial_padding) { - tensorflow::ConvOpAttrs conv_op_attrs = - MakeConvOpAttrs(spatial_stride, spatial_padding, {1, 1}); + tensorflow::gtl::ArraySlice spatial_padding, + tensorflow::gtl::ArraySlice spatial_dilation, + xla::int64 groups) { + bool depthwise = groups == input_shape.dimensions(1); + tensorflow::ConvOpAttrs conv_op_attrs = MakeConvOpAttrs( + spatial_stride, spatial_padding, spatial_dilation, depthwise); xla::XlaOp kernel_transposed = - xla::Transpose(kernel, FilterTransposePermutation()); + xla::Transpose(kernel, FilterTransposePermutation(input_shape.rank())); xla::PrecisionConfig precision_config = XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision()); return ConsumeValue(tensorflow::MakeXlaBackpropInputConvOp( - "thnn_conv2d_backward", input_shape, kernel_transposed, grad_output, + "conv_backward", input_shape, kernel_transposed, grad_output, conv_op_attrs, &precision_config)); } // Computes the kernel gradient for a convolution. -xla::XlaOp BuildThnnConv2dBackwardWeight( +xla::XlaOp BuildConv2dBackwardWeight( const xla::XlaOp& grad_output, const xla::XlaOp& input, const xla::Shape& kernel_shape, tensorflow::gtl::ArraySlice spatial_stride, - tensorflow::gtl::ArraySlice spatial_padding) { - tensorflow::ConvOpAttrs conv_op_attrs = - MakeConvOpAttrs(spatial_stride, spatial_padding, {1, 1}); + tensorflow::gtl::ArraySlice spatial_padding, + tensorflow::gtl::ArraySlice spatial_dilation, + xla::int64 groups) { + bool depthwise = groups == XlaHelpers::ShapeOfXlaOp(input).dimensions(1); + tensorflow::ConvOpAttrs conv_op_attrs = MakeConvOpAttrs( + spatial_stride, spatial_padding, spatial_dilation, depthwise); auto inv_transpose_permutation = - xla::InversePermutation(FilterTransposePermutation()); + xla::InversePermutation(FilterTransposePermutation(kernel_shape.rank())); xla::Shape transposed_weight_shape = xla::ShapeUtil::PermuteDimensions( inv_transpose_permutation, kernel_shape); xla::PrecisionConfig precision_config = XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision()); xla::XlaOp conv = ConsumeValue(tensorflow::MakeXlaBackpropFilterConvOp( - "thnn_conv2d_backward", input, transposed_weight_shape, grad_output, + "conv2d_backward", input, transposed_weight_shape, grad_output, conv_op_attrs, &precision_config)); // Reorder the dimensions of the filter gradient to match the NCHW convention @@ -95,7 +150,7 @@ xla::XlaOp BuildGradBias(xla::XlaOp grad_output) { XlaHelpers::ScalarValue(0, grad_output_shape.element_type(), grad_output.builder()), XlaHelpers::CreateAddComputation(grad_output_shape.element_type()), - {0, 2, 3}); + BiasReduceDimensions(grad_output_shape.rank())); } std::vector> MakePadding( @@ -109,91 +164,137 @@ std::vector> MakePadding( } // namespace -xla::XlaOp BuildConvolution( +xla::XlaOp BuildTransposedConvolution( const xla::XlaOp& input, const xla::XlaOp& kernel, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - const auto dims_padding = MakePadding(padding); - xla::PrecisionConfig precision_config = - XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision()); - return xla::ConvWithGeneralPadding( - input, kernel, stride, dims_padding, - /*feature_group_count*/ 1, /*batch_group_count=*/1, &precision_config); + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups) { + xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); + xla::Shape kernel_shape = XlaHelpers::ShapeOfXlaOp(kernel); + xla::int64 num_spatial = input_shape.rank() - 2; + // We only support 2D or 3D convolution. + XLA_CHECK(num_spatial == 2 || num_spatial == 3) << num_spatial; + // Fold group into input_size feature dimension + xla::int64 feature_dim = kernel_shape.dimensions(1) * groups; + std::vector input_size{input_shape.dimensions(0), feature_dim}; + for (int spatial_dim = 0; spatial_dim < num_spatial; ++spatial_dim) { + input_size.push_back( + (input_shape.dimensions(2 + spatial_dim) - 1) * stride[spatial_dim] - + 2 * padding[spatial_dim] + + dilation[spatial_dim] * (kernel_shape.dimensions(2 + spatial_dim) - 1) + + output_padding[spatial_dim] + 1); + } + return BuildConvBackwardInput( + input, kernel, + xla::ShapeUtil::MakeShape(input_shape.element_type(), input_size), stride, + padding, dilation, /*groups=*/1); } -xla::XlaOp BuildConvolutionBias( +xla::XlaOp BuildTransposedConvolutionBias( const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - xla::XlaOp conv = BuildConvolution(input, kernel, stride, padding); + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups) { + xla::XlaOp conv = BuildTransposedConvolution( + input, kernel, stride, padding, dilation, output_padding, groups); auto broadcast_sizes = XlaHelpers::SizesOfXlaOp(conv); - XLA_CHECK_EQ(broadcast_sizes.size(), 4); // Remove the channels dimension. broadcast_sizes.erase(broadcast_sizes.begin() + 1); // Make the bias match the output dimensions. xla::XlaOp bias_broadcast = - xla::Transpose(xla::Broadcast(bias, broadcast_sizes), {0, 3, 1, 2}); + xla::Transpose(xla::Broadcast(bias, broadcast_sizes), + BiasTransposePermutation(broadcast_sizes.size() + 1)); return conv + bias_broadcast; } -Conv2DGrads BuildConv2dBackward( +ConvGrads BuildTransposedConvolutionBackward( const xla::XlaOp& grad_output, const xla::XlaOp& input, const xla::XlaOp& kernel, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - xla::XlaOp grad_input = BuildThnnConv2dBackwardInput( - grad_output, kernel, XlaHelpers::ShapeOfXlaOp(input), stride, padding); - xla::XlaOp grad_weight = BuildThnnConv2dBackwardWeight( - grad_output, input, XlaHelpers::ShapeOfXlaOp(kernel), stride, padding); + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups) { + xla::XlaOp grad_input = + BuildConvolutionOverrideable(grad_output, kernel, stride, padding, + dilation, false, output_padding, groups); + xla::XlaOp grad_weight = BuildConv2dBackwardWeight( + input, grad_output, XlaHelpers::ShapeOfXlaOp(kernel), stride, padding, + dilation, groups); xla::XlaOp grad_bias = BuildGradBias(grad_output); return {grad_input, grad_weight, grad_bias}; } -xla::XlaOp BuildTransposedConvolution( +xla::XlaOp BuildConvolutionOverrideable( const xla::XlaOp& input, const xla::XlaOp& kernel, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - xla::Shape input_shape = XlaHelpers::ShapeOfXlaOp(input); - xla::Shape kernel_shape = XlaHelpers::ShapeOfXlaOp(kernel); - std::vector input_size{input_shape.dimensions(0), - kernel_shape.dimensions(1)}; - for (int spatial_dim = 0; spatial_dim < 2; ++spatial_dim) { - input_size.push_back( - (input_shape.dimensions(2 + spatial_dim) - 1) * stride[spatial_dim] - - 2 * padding[spatial_dim] + kernel_shape.dimensions(2 + spatial_dim)); + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, bool transposed, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups) { + if (transposed) { + return BuildTransposedConvolution(input, kernel, stride, padding, dilation, + output_padding, groups); + } else { + auto dims_padding = MakePadding(padding); + xla::PrecisionConfig precision_config = + XlaHelpers::BuildPrecisionConfig(XlaHelpers::mat_mul_precision()); + return xla::ConvGeneralDilated( + input, kernel, stride, dims_padding, + /*lhs_dilation*/ {}, + /*rhs_dilation*/ dilation, + /*dimension_numbers*/ + xla::XlaBuilder::CreateDefaultConvDimensionNumbers(stride.size()), + /*feature_group_count*/ groups, + /*batch_group_count=*/1, &precision_config); } - return BuildThnnConv2dBackwardInput( - input, kernel, - xla::ShapeUtil::MakeShape(input_shape.element_type(), input_size), stride, - padding); } -xla::XlaOp BuildTransposedConvolutionBias( +xla::XlaOp BuildConvolutionOverrideableBias( const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - xla::XlaOp conv = BuildTransposedConvolution(input, kernel, stride, padding); + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, bool transposed, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups) { + xla::XlaOp conv = + BuildConvolutionOverrideable(input, kernel, stride, padding, dilation, + transposed, output_padding, groups); auto broadcast_sizes = XlaHelpers::SizesOfXlaOp(conv); - XLA_CHECK_EQ(broadcast_sizes.size(), 4); // Remove the channels dimension. broadcast_sizes.erase(broadcast_sizes.begin() + 1); // Make the bias match the output dimensions. xla::XlaOp bias_broadcast = - xla::Transpose(xla::Broadcast(bias, broadcast_sizes), {0, 3, 1, 2}); + xla::Transpose(xla::Broadcast(bias, broadcast_sizes), + BiasTransposePermutation(broadcast_sizes.size() + 1)); return conv + bias_broadcast; } -Conv2DGrads BuildTransposedConvolutionBackward( +ConvGrads BuildConvolutionBackwardOverrideable( const xla::XlaOp& grad_output, const xla::XlaOp& input, const xla::XlaOp& kernel, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - xla::XlaOp grad_input = - BuildConvolution(grad_output, kernel, stride, padding); - xla::XlaOp grad_weight = BuildThnnConv2dBackwardWeight( - input, grad_output, XlaHelpers::ShapeOfXlaOp(kernel), stride, padding); - xla::XlaOp grad_bias = BuildGradBias(grad_output); - return {grad_input, grad_weight, grad_bias}; + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, bool transposed, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups) { + if (transposed) { + return BuildTransposedConvolutionBackward(grad_output, input, kernel, + stride, padding, dilation, + output_padding, groups); + } else { + xla::XlaOp grad_input = BuildConvBackwardInput( + grad_output, kernel, XlaHelpers::ShapeOfXlaOp(input), stride, padding, + dilation, groups); + xla::XlaOp grad_weight = BuildConv2dBackwardWeight( + grad_output, input, XlaHelpers::ShapeOfXlaOp(kernel), stride, padding, + dilation, groups); + xla::XlaOp grad_bias = BuildGradBias(grad_output); + return {grad_input, grad_weight, grad_bias}; + } } - } // namespace torch_xla diff --git a/torch_xla/csrc/convolution.h b/torch_xla/csrc/convolution.h index ae08b33984e3..7220e63817fd 100644 --- a/torch_xla/csrc/convolution.h +++ b/torch_xla/csrc/convolution.h @@ -7,44 +7,37 @@ namespace torch_xla { // Computes the convolution of the given input and kernel with the given // precision, with the given stride and padding. -xla::XlaOp BuildConvolution( +xla::XlaOp BuildConvolutionOverrideable( const xla::XlaOp& input, const xla::XlaOp& kernel, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding); + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, bool transposed, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups); // Same as above, then broadcasts the bias and adds it to the result. -xla::XlaOp BuildConvolutionBias( +xla::XlaOp BuildConvolutionOverrideableBias( const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding); + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, bool transposed, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups); -xla::XlaOp BuildTransposedConvolution( - const xla::XlaOp& input, const xla::XlaOp& kernel, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding); - -xla::XlaOp BuildTransposedConvolutionBias( - const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding); - -struct Conv2DGrads { +struct ConvGrads { xla::XlaOp grad_input; xla::XlaOp grad_weight; xla::XlaOp grad_bias; }; // Computes the gradients for a convolution with the given stride and padding. -Conv2DGrads BuildConv2dBackward( - const xla::XlaOp& grad_output, const xla::XlaOp& input, - const xla::XlaOp& kernel, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding); - -Conv2DGrads BuildTransposedConvolutionBackward( +ConvGrads BuildConvolutionBackwardOverrideable( const xla::XlaOp& grad_output, const xla::XlaOp& input, const xla::XlaOp& kernel, tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding); + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, bool transposed, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups); } // namespace torch_xla diff --git a/torch_xla/csrc/ops/conv2d.cpp b/torch_xla/csrc/ops/conv2d.cpp deleted file mode 100644 index 09bba84a048f..000000000000 --- a/torch_xla/csrc/ops/conv2d.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include "torch_xla/csrc/ops/conv2d.h" - -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "torch_xla/csrc/convolution.h" -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/infer_output_shape.h" - -namespace torch_xla { -namespace ir { -namespace ops { -namespace { - -// The bias doesn't matter for shape inference. -xla::Shape NodeOutputShape( - const Value& input, const Value& weight, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - auto lower_for_shape_fn = - [stride, padding](tensorflow::gtl::ArraySlice operands) - -> xla::XlaOp { - XLA_CHECK(operands.size() == 2 || operands.size() == 3) - << "Unexpected number of operands: " << operands.size(); - return BuildConvolution(operands[0], operands[1], stride, padding); - }; - return InferOutputShape({input.shape(), weight.shape()}, lower_for_shape_fn); -} - -} // namespace - -Conv2d::Conv2d(const Value& input, const Value& weight, const Value& bias, - std::vector stride, std::vector padding) - : Node(ir::OpKind(at::aten::convolution), {input, weight, bias}, - [&]() { return NodeOutputShape(input, weight, stride, padding); }, - /*num_outputs=*/1, xla::util::MHash(stride, padding)), - stride_(std::move(stride)), - padding_(std::move(padding)) {} - -Conv2d::Conv2d(const Value& input, const Value& weight, - std::vector stride, std::vector padding) - : Node(ir::OpKind(at::aten::convolution), {input, weight}, - [&]() { return NodeOutputShape(input, weight, stride, padding); }, - /*num_outputs=*/1, xla::util::MHash(stride, padding)), - stride_(std::move(stride)), - padding_(std::move(padding)) {} - -NodePtr Conv2d::Clone(OpList operands) const { - return operands.size() == 3 - ? MakeNode(operands.at(0), operands.at(1), operands.at(2), - stride_, padding_) - : MakeNode(operands.at(0), operands.at(1), stride_, - padding_); -} - -XlaOpVector Conv2d::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp kernel = loctx->GetOutputOp(operand(1)); - xla::XlaOp output; - if (operands().size() == 3) { - xla::XlaOp bias = loctx->GetOutputOp(operand(2)); - output = BuildConvolutionBias(input, kernel, bias, stride_, padding_); - } else { - XLA_CHECK_EQ(operands().size(), 2); - output = BuildConvolution(input, kernel, stride_, padding_); - } - return ReturnOp(output, loctx); -} - -std::string Conv2d::ToString() const { - std::stringstream ss; - ss << Node::ToString() << ", stride=[" << absl::StrJoin(stride_, ", ") - << "], padding=[" << absl::StrJoin(padding_, ", ") << "]"; - return ss.str(); -} - -} // namespace ops -} // namespace ir -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/conv2d.h b/torch_xla/csrc/ops/conv2d.h deleted file mode 100644 index e031d092373b..000000000000 --- a/torch_xla/csrc/ops/conv2d.h +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { -namespace ir { -namespace ops { - -// IR node for 2D convolutions with or without bias. -class Conv2d : public Node { - public: - Conv2d(const Value& input, const Value& weight, const Value& bias, - std::vector stride, std::vector padding); - - Conv2d(const Value& input, const Value& weight, - std::vector stride, std::vector padding); - - NodePtr Clone(OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - const std::vector& stride() const { return stride_; } - - const std::vector& padding() const { return padding_; } - - private: - // The parameters of the convolution. - std::vector stride_; - std::vector padding_; -}; - -} // namespace ops -} // namespace ir -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/conv2d_backward.cpp b/torch_xla/csrc/ops/conv2d_backward.cpp deleted file mode 100644 index 171368af14e0..000000000000 --- a/torch_xla/csrc/ops/conv2d_backward.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include "torch_xla/csrc/ops/conv2d_backward.h" - -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "torch_xla/csrc/convolution.h" -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/infer_output_shape.h" - -namespace torch_xla { -namespace ir { -namespace ops { -namespace { - -xla::Shape NodeOutputShape( - const Value& grad_output, const Value& input, const Value& weight, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - auto lower_for_shape_fn = - [stride, padding](tensorflow::gtl::ArraySlice operands) - -> xla::XlaOp { - XLA_CHECK_EQ(operands.size(), 3) - << "Unexpected number of operands: " << operands.size(); - // The precision doesn't matter for shape inference. - Conv2DGrads grads = BuildConv2dBackward(operands[0], operands[1], - operands[2], stride, padding); - return xla::Tuple(operands[0].builder(), - {grads.grad_input, grads.grad_weight, grads.grad_bias}); - }; - return InferOutputShape({grad_output.shape(), input.shape(), weight.shape()}, - lower_for_shape_fn); -} - -} // namespace - -Conv2dBackward::Conv2dBackward(const Value& grad_output, const Value& input, - const Value& weight, - std::vector stride, - std::vector padding) - : Node(ir::OpKind(at::aten::thnn_conv2d_backward), - {grad_output, input, weight}, - [&]() { - return NodeOutputShape(grad_output, input, weight, stride, - padding); - }, - /*num_outputs=*/3, xla::util::MHash(stride, padding)), - stride_(std::move(stride)), - padding_(std::move(padding)) {} - -NodePtr Conv2dBackward::Clone(OpList operands) const { - return MakeNode(operands.at(0), operands.at(1), - operands.at(2), stride_, padding_); -} - -XlaOpVector Conv2dBackward::Lower(LoweringContext* loctx) const { - xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); - xla::XlaOp input = loctx->GetOutputOp(operand(1)); - xla::XlaOp weight = loctx->GetOutputOp(operand(2)); - auto grads = - BuildConv2dBackward(grad_output, input, weight, stride_, padding_); - return ReturnOps({std::move(grads.grad_input), std::move(grads.grad_weight), - std::move(grads.grad_bias)}, - loctx); -} - -std::string Conv2dBackward::ToString() const { - std::stringstream ss; - ss << Node::ToString() << ", stride=[" << absl::StrJoin(stride_, ", ") - << "], padding=[" << absl::StrJoin(padding_, ", ") << "]"; - return ss.str(); -} - -} // namespace ops -} // namespace ir -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/conv2d_backward.h b/torch_xla/csrc/ops/conv2d_backward.h deleted file mode 100644 index 52e0d507b4a1..000000000000 --- a/torch_xla/csrc/ops/conv2d_backward.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { -namespace ir { -namespace ops { - -class Conv2dBackward : public Node { - public: - Conv2dBackward(const Value& grad_output, const Value& input, - const Value& weight, std::vector stride, - std::vector padding); - - NodePtr Clone(OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - const std::vector& stride() const { return stride_; } - - const std::vector& padding() const { return padding_; } - - private: - // The parameters of the convolution. - std::vector stride_; - std::vector padding_; -}; - -} // namespace ops -} // namespace ir -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/conv_transpose2d.cpp b/torch_xla/csrc/ops/conv_transpose2d.cpp deleted file mode 100644 index fbb02f11d0fe..000000000000 --- a/torch_xla/csrc/ops/conv_transpose2d.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include "torch_xla/csrc/ops/conv_transpose2d.h" - -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "torch_xla/csrc/convolution.h" -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/infer_output_shape.h" - -namespace torch_xla { -namespace ir { -namespace ops { -namespace { - -// The bias doesn't matter for shape inference. -xla::Shape NodeOutputShape( - const Value& input, const Value& weight, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - auto lower_for_shape_fn = - [stride, padding](tensorflow::gtl::ArraySlice operands) - -> xla::XlaOp { - XLA_CHECK(operands.size() == 2 || operands.size() == 3) - << "Unexpected number of operands: " << operands.size(); - return BuildTransposedConvolution(operands[0], operands[1], stride, - padding); - }; - return InferOutputShape({input.shape(), weight.shape()}, lower_for_shape_fn); -} - -} // namespace - -ConvTranspose2d::ConvTranspose2d(const Value& input, const Value& weight, - const Value& bias, - std::vector stride, - std::vector padding) - : Node(ir::OpKind(at::aten::slow_conv_transpose2d), {input, weight, bias}, - [&]() { return NodeOutputShape(input, weight, stride, padding); }, - /*num_outputs=*/1, xla::util::MHash(stride, padding)), - stride_(std::move(stride)), - padding_(std::move(padding)) {} - -ConvTranspose2d::ConvTranspose2d(const Value& input, const Value& weight, - std::vector stride, - std::vector padding) - : Node(ir::OpKind(at::aten::convolution), {input, weight}, - [&]() { return NodeOutputShape(input, weight, stride, padding); }, - /*num_outputs=*/1, xla::util::MHash(stride, padding)), - stride_(std::move(stride)), - padding_(std::move(padding)) {} - -NodePtr ConvTranspose2d::Clone(OpList operands) const { - return operands.size() == 3 - ? MakeNode(operands.at(0), operands.at(1), - operands.at(2), stride_, padding_) - : MakeNode(operands.at(0), operands.at(1), - stride_, padding_); -} - -XlaOpVector ConvTranspose2d::Lower(LoweringContext* loctx) const { - xla::XlaOp input = loctx->GetOutputOp(operand(0)); - xla::XlaOp kernel = loctx->GetOutputOp(operand(1)); - xla::XlaOp output; - if (operands().size() == 3) { - xla::XlaOp bias = loctx->GetOutputOp(operand(2)); - output = - BuildTransposedConvolutionBias(input, kernel, bias, stride_, padding_); - } else { - XLA_CHECK_EQ(operands().size(), 2); - output = BuildTransposedConvolution(input, kernel, stride_, padding_); - } - return ReturnOp(output, loctx); -} - -std::string ConvTranspose2d::ToString() const { - std::stringstream ss; - ss << Node::ToString() << ", stride=[" << absl::StrJoin(stride_, ", ") - << "], padding=[" << absl::StrJoin(padding_, ", ") << "]"; - return ss.str(); -} - -} // namespace ops -} // namespace ir -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/conv_transpose2d.h b/torch_xla/csrc/ops/conv_transpose2d.h deleted file mode 100644 index ea1cfefded90..000000000000 --- a/torch_xla/csrc/ops/conv_transpose2d.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { -namespace ir { -namespace ops { - -class ConvTranspose2d : public Node { - public: - ConvTranspose2d(const Value& input, const Value& weight, const Value& bias, - std::vector stride, - std::vector padding); - - ConvTranspose2d(const Value& input, const Value& weight, - std::vector stride, - std::vector padding); - - NodePtr Clone(OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - const std::vector& stride() const { return stride_; } - - const std::vector& padding() const { return padding_; } - - private: - // The parameters of the convolution. - std::vector stride_; - std::vector padding_; -}; - -} // namespace ops -} // namespace ir -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp new file mode 100644 index 000000000000..7d3350bae8e5 --- /dev/null +++ b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp @@ -0,0 +1,93 @@ +#include "torch_xla/csrc/ops/convolution_backward_overrideable.h" + +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/convolution.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +xla::Shape NodeOutputShape( + const Value& grad_output, const Value& input, const Value& weight, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, bool transposed, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups) { + auto lower_for_shape_fn = + [stride, padding, dilation, transposed, output_padding, + groups](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + XLA_CHECK_EQ(operands.size(), 3); + // The precision doesn't matter for shape inference. + ConvGrads grads = BuildConvolutionBackwardOverrideable( + operands[0], operands[1], operands[2], stride, padding, dilation, + transposed, output_padding, groups); + return xla::Tuple(operands[0].builder(), + {grads.grad_input, grads.grad_weight, grads.grad_bias}); + }; + return InferOutputShape({grad_output.shape(), input.shape(), weight.shape()}, + lower_for_shape_fn); +} + +} // namespace + +ConvolutionBackwardOverrideable::ConvolutionBackwardOverrideable( + const Value& grad_output, const Value& input, const Value& weight, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, xla::int64 groups) + : Node(ir::OpKind(at::aten::convolution_backward_overrideable), + {grad_output, input, weight}, + [&]() { + return NodeOutputShape(grad_output, input, weight, stride, padding, + dilation, transposed, output_padding, + groups); + }, + /*num_outputs=*/3, + xla::util::MHash(stride, padding, transposed, output_padding, + groups)), + stride_(std::move(stride)), + padding_(std::move(padding)), + dilation_(std::move(dilation)), + transposed_(transposed), + output_padding_(std::move(output_padding)), + groups_(groups) {} + +NodePtr ConvolutionBackwardOverrideable::Clone(OpList operands) const { + return MakeNode( + operands.at(0), operands.at(1), operands.at(2), stride_, padding_, + dilation_, transposed_, output_padding_, groups_); +} + +XlaOpVector ConvolutionBackwardOverrideable::Lower( + LoweringContext* loctx) const { + xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp input = loctx->GetOutputOp(operand(1)); + xla::XlaOp weight = loctx->GetOutputOp(operand(2)); + auto grads = BuildConvolutionBackwardOverrideable( + grad_output, input, weight, stride_, padding_, dilation_, transposed_, + output_padding_, groups_); + return ReturnOps({std::move(grads.grad_input), std::move(grads.grad_weight), + std::move(grads.grad_bias)}, + loctx); +} + +std::string ConvolutionBackwardOverrideable::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", stride=[" << absl::StrJoin(stride_, ", ") + << "], padding=[" << absl::StrJoin(padding_, ", ") << "], dilation=[" + << absl::StrJoin(dilation_, ", ") << "], transpose=" << transposed_ + << ", output_padding=[" << absl::StrJoin(output_padding_, ", ") + << ", groups=" << groups_; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/convolution_backward_overrideable.h b/torch_xla/csrc/ops/convolution_backward_overrideable.h new file mode 100644 index 000000000000..3c92424dbe8a --- /dev/null +++ b/torch_xla/csrc/ops/convolution_backward_overrideable.h @@ -0,0 +1,50 @@ +#pragma once + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +class ConvolutionBackwardOverrideable : public Node { + public: + ConvolutionBackwardOverrideable( + const Value& grad_output, const Value& input, const Value& weight, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, xla::int64 groups); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::vector& stride() const { return stride_; } + + const std::vector& padding() const { return padding_; } + + const std::vector& dilation() const { return dilation_; } + + bool transposed() const { return transposed_; } + + const std::vector& output_padding() const { + return output_padding_; + } + + xla::int64 groups() const { return groups_; } + + private: + std::vector stride_; + std::vector padding_; + std::vector dilation_; + std::vector output_padding_; + bool transposed_; + xla::int64 groups_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/convolution_overrideable.cpp b/torch_xla/csrc/ops/convolution_overrideable.cpp new file mode 100644 index 000000000000..6084b8e8ab0f --- /dev/null +++ b/torch_xla/csrc/ops/convolution_overrideable.cpp @@ -0,0 +1,117 @@ +#include "torch_xla/csrc/ops/convolution_overrideable.h" + +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch_xla/csrc/convolution.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" + +namespace torch_xla { +namespace ir { +namespace ops { +namespace { + +// The bias doesn't matter for shape inference. +xla::Shape NodeOutputShape( + const Value& input, const Value& weight, + tensorflow::gtl::ArraySlice stride, + tensorflow::gtl::ArraySlice padding, + tensorflow::gtl::ArraySlice dilation, bool transposed, + tensorflow::gtl::ArraySlice output_padding, + xla::int64 groups) { + auto lower_for_shape_fn = + [stride, padding, dilation, output_padding, transposed, + groups](tensorflow::gtl::ArraySlice operands) + -> xla::XlaOp { + XLA_CHECK(operands.size() == 2 || operands.size() == 3); + return BuildConvolutionOverrideable(operands[0], operands[1], stride, + padding, dilation, transposed, + output_padding, groups); + }; + return InferOutputShape({input.shape(), weight.shape()}, lower_for_shape_fn); +} + +} // namespace + +ConvolutionOverrideable::ConvolutionOverrideable( + const Value& input, const Value& weight, const Value& bias, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, xla::int64 groups) + : Node(ir::OpKind(at::aten::convolution_overrideable), + {input, weight, bias}, + [&]() { + return NodeOutputShape(input, weight, stride, padding, dilation, + transposed, output_padding, groups); + }, + /*num_outputs=*/1, + xla::util::MHash(stride, padding, dilation, transposed, + output_padding, groups)), + stride_(std::move(stride)), + padding_(std::move(padding)), + dilation_(std::move(dilation)), + transposed_(transposed), + output_padding_(std::move(output_padding)), + groups_(groups) {} + +ConvolutionOverrideable::ConvolutionOverrideable( + const Value& input, const Value& weight, std::vector stride, + std::vector padding, std::vector dilation, + bool transposed, std::vector output_padding, xla::int64 groups) + : Node(ir::OpKind(at::aten::convolution_overrideable), {input, weight}, + [&]() { + return NodeOutputShape(input, weight, stride, padding, dilation, + transposed, output_padding, groups); + }, + /*num_outputs=*/1, + xla::util::MHash(stride, padding, dilation, transposed, + output_padding, groups)), + stride_(std::move(stride)), + padding_(std::move(padding)), + dilation_(std::move(dilation)), + transposed_(transposed), + output_padding_(std::move(output_padding)), + groups_(groups) {} + +NodePtr ConvolutionOverrideable::Clone(OpList operands) const { + return operands.size() == 3 + ? MakeNode( + operands.at(0), operands.at(1), operands.at(2), stride_, + padding_, dilation_, transposed_, output_padding_, groups_) + : MakeNode( + operands.at(0), operands.at(1), stride_, padding_, dilation_, + transposed_, output_padding_, groups_); +} + +XlaOpVector ConvolutionOverrideable::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp kernel = loctx->GetOutputOp(operand(1)); + xla::XlaOp output; + if (operands().size() == 3) { + xla::XlaOp bias = loctx->GetOutputOp(operand(2)); + output = BuildConvolutionOverrideableBias(input, kernel, bias, stride_, + padding_, dilation_, transposed_, + output_padding_, groups_); + } else { + XLA_CHECK_EQ(operands().size(), 2); + output = BuildConvolutionOverrideable(input, kernel, stride_, padding_, + dilation_, transposed_, + output_padding_, groups_); + } + return ReturnOp(output, loctx); +} + +std::string ConvolutionOverrideable::ToString() const { + std::stringstream ss; + ss << Node::ToString() << ", stride=[" << absl::StrJoin(stride_, ", ") + << "], padding=[" << absl::StrJoin(padding_, ", ") << "], dilation=[" + << absl::StrJoin(dilation_, ", ") << "], transpose=" << transposed_ + << ", output_padding=[" << absl::StrJoin(output_padding_, ", ") + << ", groups=" << groups_; + return ss.str(); +} + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/convolution_overrideable.h b/torch_xla/csrc/ops/convolution_overrideable.h new file mode 100644 index 000000000000..686911105f2f --- /dev/null +++ b/torch_xla/csrc/ops/convolution_overrideable.h @@ -0,0 +1,59 @@ +#pragma once + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { +namespace ir { +namespace ops { + +// IR node for 2D & 3D convolutions with or without bias. +class ConvolutionOverrideable : public Node { + public: + ConvolutionOverrideable(const Value& input, const Value& weight, + const Value& bias, std::vector stride, + std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, + xla::int64 groups); + + ConvolutionOverrideable(const Value& input, const Value& weight, + std::vector stride, + std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, + xla::int64 groups); + + NodePtr Clone(OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + std::string ToString() const override; + + const std::vector& stride() const { return stride_; } + + const std::vector& padding() const { return padding_; } + + const std::vector& dilation() const { return dilation_; } + + bool transposed() const { return transposed_; } + + const std::vector& output_padding() const { + return output_padding_; + } + + xla::int64 groups() const { return groups_; } + + private: + std::vector stride_; + std::vector padding_; + std::vector dilation_; + std::vector output_padding_; + bool transposed_; + xla::int64 groups_; +}; + +} // namespace ops +} // namespace ir +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/slow_conv_transpose2d_backward.cpp b/torch_xla/csrc/ops/slow_conv_transpose2d_backward.cpp deleted file mode 100644 index fc0363725658..000000000000 --- a/torch_xla/csrc/ops/slow_conv_transpose2d_backward.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include "torch_xla/csrc/ops/slow_conv_transpose2d_backward.h" - -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" -#include "torch_xla/csrc/convolution.h" -#include "torch_xla/csrc/lowering_context.h" -#include "torch_xla/csrc/ops/infer_output_shape.h" - -namespace torch_xla { -namespace ir { -namespace ops { -namespace { - -xla::Shape NodeOutputShape( - const Value& grad_output, const Value& input, const Value& weight, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice padding) { - auto lower_for_shape_fn = - [stride, padding](tensorflow::gtl::ArraySlice operands) - -> xla::XlaOp { - XLA_CHECK_EQ(operands.size(), 3); - Conv2DGrads grads = BuildTransposedConvolutionBackward( - operands[0], operands[1], operands[2], stride, padding); - return xla::Tuple(operands[0].builder(), - {grads.grad_input, grads.grad_weight, grads.grad_bias}); - }; - return InferOutputShape({grad_output.shape(), input.shape(), weight.shape()}, - lower_for_shape_fn); -} - -} // namespace - -ConvTranspose2dBackward::ConvTranspose2dBackward( - const Value& grad_output, const Value& input, const Value& weight, - std::vector stride, std::vector padding) - : Node(ir::OpKind(at::aten::slow_conv_transpose2d_backward), - {grad_output, input, weight}, - [&]() { - return NodeOutputShape(grad_output, input, weight, stride, - padding); - }, - /*num_outputs=*/3, xla::util::MHash(stride, padding)), - stride_(std::move(stride)), - padding_(std::move(padding)) {} - -NodePtr ConvTranspose2dBackward::Clone(OpList operands) const { - return MakeNode(operands.at(0), operands.at(1), - operands.at(2), stride_, padding_); -} - -XlaOpVector ConvTranspose2dBackward::Lower(LoweringContext* loctx) const { - xla::XlaOp grad_output = loctx->GetOutputOp(operand(0)); - xla::XlaOp input = loctx->GetOutputOp(operand(1)); - xla::XlaOp kernel = loctx->GetOutputOp(operand(2)); - Conv2DGrads grads = BuildTransposedConvolutionBackward( - grad_output, input, kernel, stride_, padding_); - return ReturnOps({std::move(grads.grad_input), std::move(grads.grad_weight), - std::move(grads.grad_bias)}, - loctx); -} - -std::string ConvTranspose2dBackward::ToString() const { - std::stringstream ss; - ss << Node::ToString() << ", stride=[" << absl::StrJoin(stride_, ", ") - << "], padding=[" << absl::StrJoin(padding_, ", ") << "]"; - return ss.str(); -} - -} // namespace ops -} // namespace ir -} // namespace torch_xla diff --git a/torch_xla/csrc/ops/slow_conv_transpose2d_backward.h b/torch_xla/csrc/ops/slow_conv_transpose2d_backward.h deleted file mode 100644 index 7fc28dcba172..000000000000 --- a/torch_xla/csrc/ops/slow_conv_transpose2d_backward.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "torch_xla/csrc/ir.h" - -namespace torch_xla { -namespace ir { -namespace ops { - -class ConvTranspose2dBackward : public Node { - public: - ConvTranspose2dBackward(const Value& grad_output, const Value& input, - const Value& weight, std::vector stride, - std::vector padding); - - NodePtr Clone(OpList operands) const override; - - XlaOpVector Lower(LoweringContext* loctx) const override; - - std::string ToString() const override; - - const std::vector& stride() const { return stride_; } - - const std::vector& padding() const { return padding_; } - - private: - std::vector stride_; - std::vector padding_; -}; - -} // namespace ops -} // namespace ir -} // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 4a3583ba0d4b..e1d99e985aa3 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -295,24 +295,25 @@ class XLATensor { const XLATensor& input, tensorflow::gtl::ArraySlice pad, at::Scalar value); - static XLATensor conv2d(const XLATensor& input, const XLATensor& weight, - const XLATensor& bias, std::vector stride, - std::vector padding); - - static XLATensor conv2d(const XLATensor& input, const XLATensor& weight, - std::vector stride, - std::vector padding); + static XLATensor convolution_overrideable( + const XLATensor& input, const XLATensor& weight, const XLATensor& bias, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, xla::int64 groups); - static std::tuple conv2d_backward( + static std::tuple + convolution_backward_overrideable( const XLATensor& out_backprop, const XLATensor& input, const XLATensor& weight, std::vector stride, - std::vector padding); + std::vector padding, std::vector dilation, + bool transposed, std::vector output_padding, + xla::int64 groups); - static XLATensor conv_transpose2d(const XLATensor& input, - const XLATensor& weight, - const XLATensor& bias, - std::vector stride, - std::vector padding); + static XLATensor convolution_overrideable( + const XLATensor& input, const XLATensor& weight, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, xla::int64 groups); static XLATensor cos(const XLATensor& input); static void cos_(XLATensor& input); @@ -740,18 +741,6 @@ class XLATensor { static XLATensor slice(const XLATensor& input, xla::int64 dim, xla::int64 start, xla::int64 end, xla::int64 step); - static XLATensor slow_conv_transpose2d(const XLATensor& input, - const XLATensor& weight, - std::vector stride, - std::vector padding); - - static std::tuple - slow_conv_transpose2d_backward(const XLATensor& out_backprop, - const XLATensor& input, - const XLATensor& weight, - std::vector stride, - std::vector padding); - // Computes a loss that uses a squared term if the absolute element-wise error // falls below 1 and an L1 term otherwise. static XLATensor smooth_l1_loss(const XLATensor& input, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7c84d5302261..85a4f96d1547 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -29,9 +29,8 @@ #include "torch_xla/csrc/ops/cholesky.h" #include "torch_xla/csrc/ops/constant.h" #include "torch_xla/csrc/ops/constant_pad_nd.h" -#include "torch_xla/csrc/ops/conv2d.h" -#include "torch_xla/csrc/ops/conv2d_backward.h" -#include "torch_xla/csrc/ops/conv_transpose2d.h" +#include "torch_xla/csrc/ops/convolution_backward_overrideable.h" +#include "torch_xla/csrc/ops/convolution_overrideable.h" #include "torch_xla/csrc/ops/cross_replica_sum.h" #include "torch_xla/csrc/ops/cumprod.h" #include "torch_xla/csrc/ops/cumsum.h" @@ -75,7 +74,6 @@ #include "torch_xla/csrc/ops/scatter.h" #include "torch_xla/csrc/ops/scatter_add.h" #include "torch_xla/csrc/ops/shrink_backward.h" -#include "torch_xla/csrc/ops/slow_conv_transpose2d_backward.h" #include "torch_xla/csrc/ops/softmax.h" #include "torch_xla/csrc/ops/softshrink.h" #include "torch_xla/csrc/ops/split.h" @@ -671,32 +669,41 @@ XLATensor XLATensor::constant_pad_nd( input.GetIrValue(), complete_pad, value)); } -XLATensor XLATensor::conv2d(const XLATensor& input, const XLATensor& weight, - const XLATensor& bias, - std::vector stride, - std::vector padding) { - ir::NodePtr ir_value = ir::MakeNode( +XLATensor XLATensor::convolution_overrideable( + const XLATensor& input, const XLATensor& weight, const XLATensor& bias, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, xla::int64 groups) { + ir::NodePtr ir_value = ir::MakeNode( input.GetIrValue(), weight.GetIrValue(), bias.GetIrValue(), - std::move(stride), std::move(padding)); + std::move(stride), std::move(padding), std::move(dilation), transposed, + std::move(output_padding), groups); return input.CreateFrom(ir_value); } -XLATensor XLATensor::conv2d(const XLATensor& input, const XLATensor& weight, - std::vector stride, - std::vector padding) { - ir::NodePtr ir_value = - ir::MakeNode(input.GetIrValue(), weight.GetIrValue(), - std::move(stride), std::move(padding)); +XLATensor XLATensor::convolution_overrideable( + const XLATensor& input, const XLATensor& weight, + std::vector stride, std::vector padding, + std::vector dilation, bool transposed, + std::vector output_padding, xla::int64 groups) { + ir::NodePtr ir_value = ir::MakeNode( + input.GetIrValue(), weight.GetIrValue(), std::move(stride), + std::move(padding), std::move(dilation), transposed, + std::move(output_padding), groups); return input.CreateFrom(ir_value); } -std::tuple XLATensor::conv2d_backward( +std::tuple +XLATensor::convolution_backward_overrideable( const XLATensor& out_backprop, const XLATensor& input, const XLATensor& weight, std::vector stride, - std::vector padding) { - ir::NodePtr node = ir::MakeNode( + std::vector padding, std::vector dilation, + bool transposed, std::vector output_padding, + xla::int64 groups) { + ir::NodePtr node = ir::MakeNode( out_backprop.GetIrValue(), input.GetIrValue(), weight.GetIrValue(), - std::move(stride), std::move(padding)); + std::move(stride), std::move(padding), std::move(dilation), transposed, + std::move(output_padding), groups); XLATensor grad_input = out_backprop.CreateFrom(ir::Value(node, 0)); XLATensor grad_weight = out_backprop.CreateFrom(ir::Value(node, 1)); XLATensor grad_bias = out_backprop.CreateFrom(ir::Value(node, 2)); @@ -704,17 +711,6 @@ std::tuple XLATensor::conv2d_backward( std::move(grad_bias)); } -XLATensor XLATensor::conv_transpose2d(const XLATensor& input, - const XLATensor& weight, - const XLATensor& bias, - std::vector stride, - std::vector padding) { - ir::NodePtr node = ir::MakeNode( - input.GetIrValue(), weight.GetIrValue(), bias.GetIrValue(), - std::move(stride), std::move(padding)); - return input.CreateFrom(node); -} - XLATensor XLATensor::cos(const XLATensor& input) { return input.CreateFrom(ir::ops::Cos(input.GetIrValue())); } @@ -1837,32 +1833,6 @@ XLATensor XLATensor::slice(const XLATensor& input, xla::int64 dim, return input.CreateViewTensor(std::move(view_info)); } -XLATensor XLATensor::slow_conv_transpose2d(const XLATensor& input, - const XLATensor& weight, - std::vector stride, - std::vector padding) { - ir::NodePtr node = ir::MakeNode( - input.GetIrValue(), weight.GetIrValue(), std::move(stride), - std::move(padding)); - return input.CreateFrom(node); -} - -std::tuple -XLATensor::slow_conv_transpose2d_backward(const XLATensor& out_backprop, - const XLATensor& input, - const XLATensor& weight, - std::vector stride, - std::vector padding) { - ir::NodePtr node = ir::MakeNode( - out_backprop.GetIrValue(), input.GetIrValue(), weight.GetIrValue(), - std::move(stride), std::move(padding)); - XLATensor grad_input = out_backprop.CreateFrom(ir::Value(node, 0)); - XLATensor grad_weight = out_backprop.CreateFrom(ir::Value(node, 1)); - XLATensor grad_bias = out_backprop.CreateFrom(ir::Value(node, 2)); - return std::make_tuple(std::move(grad_input), std::move(grad_weight), - std::move(grad_bias)); -} - XLATensor XLATensor::smooth_l1_loss(const XLATensor& input, const XLATensor& target, xla::int64 reduction) {