diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.cpp b/aten/src/ATen/native/vulkan/ops/Convolution.cpp index 5bec92abb53d..7cf3b4fe5137 100644 --- a/aten/src/ATen/native/vulkan/ops/Convolution.cpp +++ b/aten/src/ATen/native/vulkan/ops/Convolution.cpp @@ -1,8 +1,7 @@ -#include +#include #include #include #include -#include namespace at { namespace native { @@ -10,74 +9,6 @@ namespace vulkan { namespace ops { namespace { -class Context final : public torch::jit::CustomClassHolder { - public: - static Context create( - api::Resource::Pool& pool, - const Tensor& weight, - const c10::optional& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool transposed, - IntArrayRef output_padding, - int64_t groups, - c10::optional output_min = c10::nullopt, - c10::optional output_max = c10::nullopt); - - using State = std::tuple< - Tensor, - c10::optional, - std::vector, - std::vector, - std::vector, - int64_t, - c10::optional, - c10::optional>; - - Tensor run(const Tensor& input) const; - State unpack() const; - - private: - Context( - api::Resource::Pool& pool, - const Tensor& weight, - const c10::optional& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool transposed, - IntArrayRef output_padding, - int64_t groups, - c10::optional output_min = c10::nullopt, - c10::optional output_max = c10::nullopt); - - private: - struct { - vTensor v_weight; - vTensor v_bias; - std::array filter; - std::array stride; - std::array padding; - std::array dilation; - int32_t groups; - float output_min; - float output_max; - } packed_; - - struct { - Tensor weight; - c10::optional bias; - std::vector filter; - std::vector stride; - std::vector padding; - std::vector dilation; - int64_t groups; - c10::optional output_min; - c10::optional output_max; - } unpacked_; -}; - inline bool is_depthwise( const IntArrayRef filter, const int64_t groups) { @@ -263,42 +194,6 @@ std::array pack_params(const std::vector& vector) { }; } -Context::Context( - api::Resource::Pool& pool, - const Tensor& weight, - const c10::optional& bias, - const IntArrayRef stride, - const IntArrayRef padding, - const IntArrayRef dilation, - const bool /* transposed */, - const IntArrayRef /* output_padding */, - const int64_t groups, - const c10::optional output_min, - const c10::optional output_max) - : packed_{ - pack_weights(pool, weight, groups), - pack_biases(pool, bias, weight), - pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2)), - pack_params(expand_param_if_needed(stride, "stride", 2)), - pack_params(expand_param_if_needed(padding, "padding", 2)), - pack_params(expand_param_if_needed(dilation, "dilation", 2)), - groups, - output_min ? output_min->template to() : -std::numeric_limits::infinity(), - output_max ? output_max->template to() : +std::numeric_limits::infinity(), - }, - unpacked_{ - weight, - bias, - weight.sizes().vec(), - stride.vec(), - padding.vec(), - dilation.vec(), - groups, - output_min, - output_max, - } { -} - bool available( const Tensor& weight, const c10::optional& bias, @@ -349,56 +244,6 @@ bool available( true; } -Context Context::create( - api::Resource::Pool& pool, - const Tensor& weight, - const c10::optional& bias, - const IntArrayRef stride_arg, - const IntArrayRef padding_arg, - const IntArrayRef dilation_arg, - const bool transposed, - const IntArrayRef output_padding_arg, - const int64_t groups, - const c10::optional output_min, - const c10::optional output_max) { - const auto stride = expand_param_if_needed(stride_arg, "stride", 2); - const auto padding = expand_param_if_needed(padding_arg, "padding", 2); - const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2); - const auto output_padding = output_padding_arg; // TODO: Deconvolutions - - TORCH_CHECK( - available( - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - output_min, - output_max), - "Vulkan::convolution not available! " - "Reason: The provided (weight, bias, stride, padding, dilation, groups, " - "transposed, output_padding, output_min, output_max) parameters are either " - "invalid individually or their combination is not supported by Vulkan impl."); - - // Pass in the originals - return Context{ - pool, - weight, - bias, - stride_arg, - padding_arg, - dilation_arg, - transposed, - output_padding_arg, - groups, - output_min, - output_max, - }; -} - bool usable(const Tensor& input) { // Input return (4 == input.ndimension()) && @@ -632,7 +477,126 @@ void conv2d( } } -Tensor Context::run(const Tensor& input_arg) const { +Tensor convolution( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool transposed, + const IntArrayRef output_padding, + const int64_t groups) { + return Conv2dOpContext::create( + api::context()->resource().pool, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups + ).run(input); +} + +#ifdef USE_VULKAN_API + +TORCH_LIBRARY_IMPL(aten, Vulkan, m) { + m.impl_UNBOXED("convolution_overrideable", convolution); +} + +#endif /* USE_VULKAN_API */ + +} // namespace + +Conv2dOpContext::Conv2dOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride, + const IntArrayRef padding, + const IntArrayRef dilation, + const bool /* transposed */, + const IntArrayRef /* output_padding */, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) + : packed_{ + pack_weights(pool, weight, groups), + pack_biases(pool, bias, weight), + pack_filter(weight, expand_param_if_needed(dilation, "dilation", 2)), + pack_params(expand_param_if_needed(stride, "stride", 2)), + pack_params(expand_param_if_needed(padding, "padding", 2)), + pack_params(expand_param_if_needed(dilation, "dilation", 2)), + groups, + output_min ? output_min->template to() : -std::numeric_limits::infinity(), + output_max ? output_max->template to() : +std::numeric_limits::infinity(), + }, + unpacked_{ + weight, + bias, + weight.sizes().vec(), + stride.vec(), + padding.vec(), + dilation.vec(), + groups, + output_min, + output_max, + } { +} + +Conv2dOpContext Conv2dOpContext::create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef stride_arg, + const IntArrayRef padding_arg, + const IntArrayRef dilation_arg, + const bool transposed, + const IntArrayRef output_padding_arg, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max) { + const auto stride = expand_param_if_needed(stride_arg, "stride", 2); + const auto padding = expand_param_if_needed(padding_arg, "padding", 2); + const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2); + const auto output_padding = output_padding_arg; // TODO: Deconvolutions + + TORCH_CHECK( + available( + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_min, + output_max), + "Vulkan::convolution not available! " + "Reason: The provided (weight, bias, stride, padding, dilation, groups, " + "transposed, output_padding, output_min, output_max) parameters are either " + "invalid individually or their combination is not supported by Vulkan impl."); + + // Pass in the originals + return Conv2dOpContext{ + pool, + weight, + bias, + stride_arg, + padding_arg, + dilation_arg, + transposed, + output_padding_arg, + groups, + output_min, + output_max, + }; +} + +Tensor Conv2dOpContext::run(const Tensor& input_arg) const { api::Context* const context = api::context(); const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); @@ -708,8 +672,8 @@ Tensor Context::run(const Tensor& input_arg) const { return convert(v_output); } -Context::State Context::unpack() const { - return Context::State{ +Conv2dOpContext::State Conv2dOpContext::unpack() const { + return Conv2dOpContext::State{ unpacked_.weight, unpacked_.bias, unpacked_.stride, @@ -721,7 +685,7 @@ Context::State Context::unpack() const { }; } -c10::intrusive_ptr conv2_clamp_prepack( +c10::intrusive_ptr conv2d_clamp_prepack( Tensor&& weight, c10::optional&& bias, std::vector&& stride, @@ -730,8 +694,8 @@ c10::intrusive_ptr conv2_clamp_prepack( const int64_t groups, const c10::optional output_min, const c10::optional output_max) { - return c10::make_intrusive( - Context::create( + return c10::make_intrusive( + Conv2dOpContext::create( persistent()->pool, std::move(weight), std::move(bias), @@ -747,78 +711,10 @@ c10::intrusive_ptr conv2_clamp_prepack( Tensor conv2d_clamp_run( const Tensor& input, - const c10::intrusive_ptr& context) { + const c10::intrusive_ptr& context) { return context->run(input); } -Tensor convolution( - const Tensor& input, - const Tensor& weight, - const c10::optional& bias, - const IntArrayRef stride, - const IntArrayRef padding, - const IntArrayRef dilation, - const bool transposed, - const IntArrayRef output_padding, - const int64_t groups) { - return Context::create( - api::context()->resource().pool, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups - ).run(input); -} - -TORCH_LIBRARY(vulkan, m) { - m.class_("Conv2dOpContext") - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr& context) { - return context->unpack(); - }, - // __setstate__ - [](Context::State state) { - return conv2_clamp_prepack( - std::move(std::get<0>(state)), - std::move(std::get<1>(state)), - std::move(std::get<2>(state)), - std::move(std::get<3>(state)), - std::move(std::get<4>(state)), - std::move(std::get<5>(state)), - std::move(std::get<6>(state)), - std::move(std::get<7>(state))); - }); -} - -TORCH_LIBRARY(vulkan_prepack, m) { - m.def( - "conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, " - "int[2] padding, int[2] dilation, int groups, " - "Scalar? output_min=None, Scalar? output_max=None) " - "-> __torch__.torch.classes.vulkan.Conv2dOpContext"); - m.def( - "conv2d_clamp_run(Tensor X, " - "__torch__.torch.classes.vulkan.Conv2dOpContext W_prepack) -> Tensor Y"); -} - -TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { - m.impl("conv2d_clamp_prepack", TORCH_FN(conv2_clamp_prepack)); -} - -TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { - m.impl("conv2d_clamp_run", conv2d_clamp_run); -} - -TORCH_LIBRARY_IMPL(aten, Vulkan, m) { - m.impl_UNBOXED("convolution_overrideable", convolution); -} - -} // namespace } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Convolution.h b/aten/src/ATen/native/vulkan/ops/Convolution.h new file mode 100644 index 000000000000..2bab7091d4ab --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Convolution.h @@ -0,0 +1,99 @@ +#pragma once +#ifdef USE_VULKAN + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +class Conv2dOpContext final : public torch::jit::CustomClassHolder { + public: + static Conv2dOpContext create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups, + c10::optional output_min = c10::nullopt, + c10::optional output_max = c10::nullopt); + + using State = std::tuple< + Tensor, + c10::optional, + std::vector, + std::vector, + std::vector, + int64_t, + c10::optional, + c10::optional>; + + Tensor run(const Tensor& input) const; + State unpack() const; + + private: + Conv2dOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool transposed, + IntArrayRef output_padding, + int64_t groups, + c10::optional output_min = c10::nullopt, + c10::optional output_max = c10::nullopt); + + private: + struct { + vTensor v_weight; + vTensor v_bias; + std::array filter; + std::array stride; + std::array padding; + std::array dilation; + int32_t groups; + float output_min; + float output_max; + } packed_; + + struct { + Tensor weight; + c10::optional bias; + std::vector filter; + std::vector stride; + std::vector padding; + std::vector dilation; + int64_t groups; + c10::optional output_min; + c10::optional output_max; + } unpacked_; +}; + +Tensor conv2d_clamp_run( + const Tensor& input, + const c10::intrusive_ptr& context); + +c10::intrusive_ptr conv2d_clamp_prepack( + Tensor&& weight, + c10::optional&& bias, + std::vector&& stride, + std::vector&& padding, + std::vector&& dilation, + const int64_t groups, + const c10::optional output_min, + const c10::optional output_max); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN */ diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index 185f66226e15..ca342e70a7b8 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -1,5 +1,5 @@ -#include -#include +#include +#include namespace at { namespace native { @@ -7,100 +7,113 @@ namespace vulkan { namespace ops { namespace { -Tensor addmm( - const Tensor& self_arg, - const Tensor& mat1_arg, - const Tensor& mat2_arg, - const Scalar beta, - const Scalar alpha) { - api::Context* const context = api::context(); +vTensor pack_weights(api::Resource::Pool& pool, const Tensor& weight_arg) { + return convert(weight_arg.vulkan()); +} - const Tensor self = self_arg.is_vulkan() ? self_arg : self_arg.vulkan(); - const vTensor& v_self = convert(self); +vTensor pack_biases( + api::Resource::Pool& pool, + const c10::optional& bias_arg, + const Tensor& weight_arg) { + if (bias_arg) { + return convert(bias_arg->vulkan()); + } else { + vTensor v_bias{ + api::context(), + &pool, + {weight_arg.size(Layout::Parameter::width)}, + weight_arg.options(), + }; - const Tensor mat1 = mat1_arg.is_vulkan() ? mat1_arg : mat1_arg.vulkan(); - const vTensor& v_mat1 = convert(mat1); + using Future = vTensor::Future; + Future v_bias_future = v_bias.host(); + Future::Payload v_bias_payload = v_bias_future.wait(); - const Tensor mat2 = mat2_arg.is_vulkan() ? mat2_arg : mat2_arg.vulkan(); - const vTensor& v_mat2 = convert(mat2); - - const auto self_sizes = self.sizes(); - const auto mat1_sizes = mat1.sizes(); - const auto mat2_sizes = mat2.sizes(); + memset( + v_bias_payload.get(), + // 2's complement integers and IEEE-754 floating point numbers both + // have identical bit representations for 0, so can use memset which + // only accepts uint8_t parameter. + 0, + v_bias.nbytes()); - if (self_sizes.size() >= 2) { - TORCH_CHECK( - (mat1_sizes[Layout::Parameter::width] == - mat2_sizes[Layout::Parameter::height]) && - (self_sizes[Layout::Parameter::height] == - mat1_sizes[Layout::Parameter::height]) && - (self_sizes[Layout::Parameter::width] == - mat2_sizes[Layout::Parameter::width]), - "Incompatible matrix dimensions!"); + return v_bias; } - else { - TORCH_CHECK( - (mat1_sizes[Layout::Parameter::width] == - mat2_sizes[Layout::Parameter::height]) && - ((self_sizes[Layout::Parameter::height] == - mat1_sizes[Layout::Parameter::height]) || - (self_sizes[Layout::Parameter::height] == - mat2_sizes[Layout::Parameter::width])), - "Incompatible matrix dimensions!"); +} + +bool available(const Tensor& weight, const c10::optional& bias) { + bool valid = true; + if (bias && bias->ndimension() > 1) { + valid = + (bias->sizes()[Layout::Parameter::width] == + weight.sizes()[Layout::Parameter::width]); } + return api::available() && valid; +} - vTensor v_output{ - context, - { - mat1_sizes[Layout::Parameter::height], - mat2_sizes[Layout::Parameter::width], - }, - self.options(), - }; +bool usable( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias) { + return (input.sizes()[Layout::Parameter::width] == + weight.sizes()[Layout::Parameter::height]); +} - api::Command::Buffer command_buffer = context->command().pool.allocate(); - command_buffer.begin(); - { - if (v_self.has_image()) { - const struct { - float beta, alpha; - } block { - alpha.to(), - beta.to(), - }; +void addmm_impl( + api::Context* const context, + api::Command::Buffer& command_buffer, + vTensor& v_output, + const vTensor& v_self, + const vTensor& v_mat1, + const vTensor& v_mat2, + const float beta, + const float alpha) { + if (v_output.has_image() && v_self.has_image() && v_mat1.has_image() && + v_mat2.has_image()) { + const struct { + float alpha, beta; + } block{ + alpha, + beta, + }; - context->dispatch( - command_buffer, - { + context->dispatch( + command_buffer, + { VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - }, - VK_KERNEL(addmm), - v_output.extents(), - // Write-only access bypasses synchronization but inserts appropriate - // barriers if necessary. - v_output.image(command_buffer, vTensor::Access::Write), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_mat1.image(command_buffer), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_mat2.image(command_buffer), - // Read-only access is implied on const tensors and triggers an async - // synchronization if necessary. - v_self.image(command_buffer), - context->resource().pool.uniform(block).object); - } else { - TORCH_CHECK(false, "Not implemented!"); - } + }, + VK_KERNEL(addmm), + v_output.extents(), + // Write-only access bypasses synchronization but inserts appropriate + // barriers if necessary. + v_output.image(command_buffer, vTensor::Access::Write), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_mat1.image(command_buffer), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_mat2.image(command_buffer), + // Read-only access is implied on const tensors and triggers an async + // synchronization if necessary. + v_self.image(command_buffer), + context->resource().pool.uniform(block).object); + } else { + TORCH_CHECK(false, "Not implemented!"); } - command_buffer.end(); - command_buffer.submit(context->gpu().queue); +} - return convert(v_output); +Tensor addmm( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const Scalar beta, + const Scalar alpha) { + return LinearOpContext::create(api::context()->resource().pool, mat2, self) + .run(mat1, beta.to(), alpha.to()); } Tensor mm(const Tensor& self_arg, const Tensor& mat2_arg) { @@ -121,12 +134,12 @@ Tensor mm(const Tensor& self_arg, const Tensor& mat2_arg) { "Incompatible matrix dimensions!"); vTensor v_output{ - context, - { - mat1_sizes[Layout::Parameter::height], - mat2_sizes[Layout::Parameter::width], - }, - mat1.options(), + context, + { + mat1_sizes[Layout::Parameter::height], + mat2_sizes[Layout::Parameter::width], + }, + mat1.options(), }; api::Command::Buffer command_buffer = context->command().pool.allocate(); @@ -136,9 +149,9 @@ Tensor mm(const Tensor& self_arg, const Tensor& mat2_arg) { context->dispatch( command_buffer, { - VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, }, VK_KERNEL(mm), v_output.extents(), @@ -171,6 +184,100 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) { #endif /* USE_VULKAN_API */ } // namespace + +LinearOpContext::LinearOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias) + : packed_{ + pack_weights(pool, weight), + pack_biases(pool, bias, weight), + }, + unpacked_{ + weight, + bias, + } { +} + +LinearOpContext LinearOpContext::create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias) { + TORCH_CHECK(available(weight, bias)) + // Pass in the originals + return LinearOpContext{ + pool, + weight, + bias, + }; +} + +Tensor LinearOpContext::run(const Tensor& input_arg, float beta, float alpha) + const { + api::Context* const context = api::context(); + + const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan(); + const vTensor& v_input = convert(input); + + TORCH_CHECK( + usable(input, unpacked_.weight, unpacked_.bias), + "Vulkan Linear not usable! " + "Reason: The provided input tensor is either invalid or unsupported by Vulkan impl."); + + vTensor v_output{ + context, + { + input_arg.sizes()[Layout::Parameter::height], + packed_.v_weight.sizes()[Layout::Parameter::width], + }, + input.options(), + }; + + api::Command::Buffer command_buffer = context->command().pool.allocate(); + command_buffer.begin(); + { + if (input_arg.ndimension() == 2) { + addmm_impl( + context, + command_buffer, + v_output, + packed_.v_bias, + v_input, + packed_.v_weight, + beta, + alpha); + } else { + TORCH_CHECK( + false, "linear_run does not yet support inputs with ndim > 2!") + } + } + command_buffer.end(); + command_buffer.submit(context->gpu().queue); + + return convert(v_output); +} + +LinearOpContext::State LinearOpContext::unpack() const { + return LinearOpContext::State{ + unpacked_.weight, + unpacked_.bias, + }; +} + + +c10::intrusive_ptr linear_prepack( + Tensor&& weight, + c10::optional&& bias) { + return c10::make_intrusive(LinearOpContext::create( + persistent()->pool, std::move(weight), std::move(bias))); +} + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& context) { + return context->run(input, 1.0, 1.0); +} + } // namespace ops } // namespace vulkan } // namespace native diff --git a/aten/src/ATen/native/vulkan/ops/Mm.h b/aten/src/ATen/native/vulkan/ops/Mm.h new file mode 100644 index 000000000000..08c84967d00f --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/Mm.h @@ -0,0 +1,55 @@ +#pragma once +#ifdef USE_VULKAN + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { + +class LinearOpContext final : public torch::jit::CustomClassHolder { + public: + static LinearOpContext create( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias); + + using State = std::tuple>; + + Tensor run(const Tensor& input, float beta, float alpha) const; + State unpack() const; + + private: + LinearOpContext( + api::Resource::Pool& pool, + const Tensor& weight, + const c10::optional& bias); + + private: + struct { + vTensor v_weight; + vTensor v_bias; + } packed_; + + struct { + Tensor weight; + c10::optional bias; + } unpacked_; +}; + +c10::intrusive_ptr linear_prepack( + Tensor&& weight, + c10::optional&& bias); + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& context); + +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN */ diff --git a/aten/src/ATen/native/vulkan/ops/RegisterOpContextClass.cpp b/aten/src/ATen/native/vulkan/ops/RegisterOpContextClass.cpp new file mode 100644 index 000000000000..699944b7c48e --- /dev/null +++ b/aten/src/ATen/native/vulkan/ops/RegisterOpContextClass.cpp @@ -0,0 +1,80 @@ +#ifdef USE_VULKAN + +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace ops { +namespace { + +TORCH_LIBRARY(vulkan, m) { + m.class_("Conv2dOpContext") + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& context) { + return context->unpack(); + }, + // __setstate__ + [](Conv2dOpContext::State state) { + return conv2d_clamp_prepack( + std::move(std::get<0>(state)), + std::move(std::get<1>(state)), + std::move(std::get<2>(state)), + std::move(std::get<3>(state)), + std::move(std::get<4>(state)), + std::move(std::get<5>(state)), + std::move(std::get<6>(state)), + std::move(std::get<7>(state))); + }); + m.class_("LinearOpContext") + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& context) { + return context->unpack(); + }, + // __setstate__ + [](LinearOpContext::State state) { + return linear_prepack( + std::move(std::get<0>(state)), std::move(std::get<1>(state))); + }); +} + +TORCH_LIBRARY(vulkan_prepack, m) { + m.def( + "conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, " + "int[2] padding, int[2] dilation, int groups, " + "Scalar? output_min=None, Scalar? output_max=None) " + "-> __torch__.torch.classes.vulkan.Conv2dOpContext"); + m.def( + "conv2d_clamp_run(Tensor X, " + "__torch__.torch.classes.vulkan.Conv2dOpContext W_prepack) -> Tensor Y"); + m.def( + "linear_prepack(Tensor W, Tensor? B) " + "-> __torch__.torch.classes.vulkan.LinearOpContext"); + m.def( + "linear_run(Tensor X, " + "__torch__.torch.classes.vulkan.LinearOpContext BW_prepack) -> Tensor Y"); +} + +TORCH_LIBRARY_IMPL(vulkan_prepack, CPU, m) { + m.impl("conv2d_clamp_prepack", TORCH_FN(conv2d_clamp_prepack)); + m.impl("linear_prepack", TORCH_FN(linear_prepack)); +} + +TORCH_LIBRARY_IMPL(vulkan_prepack, Vulkan, m) { + m.impl("conv2d_clamp_run", TORCH_FN(conv2d_clamp_run)); + m.impl("linear_run", TORCH_FN(linear_run)); +} + +} // namespace +} // namespace ops +} // namespace vulkan +} // namespace native +} // namespace at + +#endif /* USE_VULKAN */ diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp index 0b4e90f3e1aa..f8f8794eccc1 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.cpp +++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp @@ -22,6 +22,51 @@ namespace jit { namespace { +void insertPrePackedLinearOp(std::shared_ptr& graph) { + // fuse decomposed linear into aten::linear + FuseLinear(graph); + + std::string linear_before_inline = R"( + graph(%linear, %input, %weight, %bias): + %r = prim::CallFunction(%linear, %input, %weight, %bias) + return (%r))"; + std::string prepacked_ops_pattern_before_inline = R"( + graph(%linear, %input, %weight, %bias): + %weight_t = aten::t(%weight) + %packed_weight_bias = vulkan_prepack::linear_prepack( + %weight_t, %bias) + %res = vulkan_prepack::linear_run(%input, %packed_weight_bias) + return (%res))"; + std::string linear_pattern = R"( + graph(%input, %weight, %bias): + %r = aten::linear(%input, %weight, %bias) + return (%r))"; + std::string prepacked_ops_pattern = R"( + graph(%input, %weight, %bias): + %weight_t = aten::t(%weight) + %packed_weight_bias = vulkan_prepack::linear_prepack( + %weight_t, %bias) + %res = vulkan_prepack::linear_run(%input, %packed_weight_bias) + return (%res))"; + + const auto filter = [](const Match& match, + const std::unordered_map& vmap) { + const auto& match_vmap = match.values_map; + const auto linear_value = match_vmap.at(vmap.at("linear")); + const auto func_name = graph_rewrite_helper::getFuncName(linear_value); + return (func_name == "linear"); + }; + + SubgraphRewriter linear_call_fn_rewriter; + linear_call_fn_rewriter.RegisterRewritePattern( + linear_before_inline, prepacked_ops_pattern_before_inline); + linear_call_fn_rewriter.runOnGraph(graph, filter); + + SubgraphRewriter linear_rewriter; + linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern); + linear_rewriter.runOnGraph(graph); +} + void insertPrePackedConv2dOp(std::shared_ptr& graph) { graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); @@ -131,6 +176,7 @@ void fuseReluWithPackedOps(std::shared_ptr& graph) { } // namespace void vulkanInsertPrePackedOps(std::shared_ptr& graph) { + insertPrePackedLinearOp(graph); insertPrePackedConv2dOp(graph); } @@ -153,8 +199,10 @@ void vulkanFusePrePackedConvWithClamp(script::Module& module) { void vulkanFoldPrePackingOps(script::Module& m) { PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool { return ( - n->kind() == - Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack")); + (n->kind() == + Symbol::fromQualString("vulkan_prepack::conv2d_clamp_prepack")) || + (n->kind() == + Symbol::fromQualString("vulkan_prepack::linear_prepack"))); }; PrePackingOpsFolder(m, filter_fn, "prepack_folding"); }