Skip to content

Commit

Permalink
Add matmul_v2 kernel in pten (PaddlePaddle#36844)
Browse files Browse the repository at this point in the history
* initial tensor design & sign kernel demo

* add move constructor for meta & add lodtensor

* add dirs & sign xpu kernel

* add mean cpu&cuda kernel impl

* move sign & mean xpu & npu kernel

* add selected_rows basic impl

* refactor design, BaseTensor to DenseTensor, etc.

* add scale mkldnn kernel

* polish xpu & npu impl details

* fix mkldnn reuse compile failed

* change tensor operation lib name

* rename util filename

* add more comments

* change TensorImplInterface to TensorInterface

* add kernel key and factory

* remove MKLDNNTensorMeta, add MKLDNNDenseTensor

* change XXDeviceContext to XXContext

* add base kernel registrar utils & test on sign

* replace boost::any by paddle::any

* fix several ci failed

* fix npu compile error

* add ordered map util

* fix multiple ordered_map compile errors

* move dev into include dir

* support sign op in static op run

* fix static op run error

* fix new executor compile failed

* add dygraph branch & remove sign_op.h

* fix test_infer_no_need_buffer_slots

* fix rocm compile link error

* fix unitybuild error & clear glog

* fix npu compile failed

* skip quant trans test

* fix part windows compile problem

* fix xpu enforce error

* fix inference test failed

* remove ordered_map to solve quant failed

* fix part of rcom compile faild

* add more register kernels

* revert scale kernel temporarily

* fix code format error

* add new kernel registrar marco

* rename top to tcmpt

* revert xpu, npu, mkldnn impl & remove op def

* add kernel args parse functor to auto parse args

* revert some change & add scale kernels

* add op proto in dygraph kernelcontext building

* polish kernel dispatch logic & nameing rule

* fix scale kernel match error

* fix scale test failed

* add mean API and unittest

* test mean api success

* add branch to solve compiled error

* skip clang format error

* add mean skip rule in op_library

* add dot kernel, api and unittest (PaddlePaddle#6)

* remove old kernel and add symbol link

* fix dot compiled failed

* add merco for module declare

* fix npu and xpu compile error

* revert sign, mean, scale, dot kernel removing

* add comment for keeping old kernel impl

* fix mutable_data error

* fix bfloat16 conflit

* fix inference undef error

* adapt to msvc compile rules

* polish comment for template inst

* add cmake template instantiation for win

* fix backend to place device id bug

* fix ifdef error

* Op2functor (PaddlePaddle#7)

* add kernel args maker class

* make args maker non-const

* remove debug log

* modify codes by review options

* split constructPrKernelContext function

* fix output name bug

* fix test_mean_op test_sign_op failed

* fill_any_like kernel refactor (PaddlePaddle#10)

* fill_any_like kernel refactor

* remove useless code of full_like c++ api

* skip dtype for fill_any_like

* add attrs for kernel key constrcut

* add use_pt_kernel Flags to control whether to use pt kernel (PaddlePaddle#13)

* add use_pt_kernel Flags to control whether to use pt kernel

* change the default value to true for cheking pt kernels

* fix mutable_data cuda place error

* move high level apis into hapi

* remove selectedrows adapting temporarily

* Support Scalar in Tensor Compute Library (PaddlePaddle#14)

* fill_any_like kernel refactor

* remove useless code of full_like c++ api

* Support Scalar in Tensor Compute Library

* add scalar in dygraph and static graph mode

* keep the basic type for attr, instead of using scalar for all

* merge the code

* remove mkldnn tensor & polish details

* use flat_hash_map and small_vector in kernel factory

* Refactor flatten kernel (PaddlePaddle#12)

* refactor flatten kernel

* update infershape function

* fix compile bugs

* fix bugs when merge

* fix compiler bugs

* fix bugs when run test_flatten_api

* fix bugs when run test

* Revert "use flat_hash_map and small_vector in kernel factory"

This reverts commit 2309149.

* Move cpu, cuda and other device code into kernels (PaddlePaddle#15)

* fill_any_like kernel refactor

* remove useless code of full_like c++ api

* Support Scalar in Tensor Compute Library

* add scalar in dygraph and static graph mode

* keep the basic type for attr, instead of using scalar for all

* merge the code

* start refactor matmul

* move cpu, cuda and other device modules into kernels

* merge code

* polish code in operator.cc

* Perfect unitests (PaddlePaddle#16)

* perfect unittest

* update license

* replace with flat_hash_map, small_vector (PaddlePaddle#19)

* fix small_vector build error on windows platform

* replace with flat_hash_map, small_vector

* remove todo

* Perfect unitests (PaddlePaddle#20)

* perfect unittest

* update license

* fix bug when run tcmpt_utils_test

* refactor execution adapting impl

* fix insert conflit

* Fix CI bug of test_yolov3 (PaddlePaddle#21)

* fill_any_like kernel refactor

* remove useless code of full_like c++ api

* Support Scalar in Tensor Compute Library

* add scalar in dygraph and static graph mode

* keep the basic type for attr, instead of using scalar for all

* merge the code

* start refactor matmul

* move cpu, cuda and other device modules into kernels

* merge code

* polish code in operator.cc

* Fix CI bug of test_yolov3

* add the tensor base class, test=develop (PaddlePaddle#17)

* update the tensor base class, test=develop

* remove two funcs, test=develop

* update the error msg, test=develop

Co-authored-by: Chen Weihang <chenweihang@baidu.com>

* [no-verify] commit backend and tensor signature changes

* Rename tcmpt to pten (PaddlePaddle#23)

* rename tcmpt to pten

* update omitted files for rename to pten

* update omitted file for rename to pten

* remove k of all enum var

* remove kernel_instantiate (PaddlePaddle#26)

* remove symbols and spatial_tensor

* change common to functions

* readd share tensor impl methods

* add a candidate dense tensor class, test=develop (PaddlePaddle#28)

* change all Pt to Pten

* resolve conflit with xiaowei

* Op2functor opt1 (PaddlePaddle#27)

* replace to small vector and change to const &

* add std::move

Co-authored-by: Chen Weihang <chenweihang@baidu.com>

* polish kernel factory and kernel registry

* fix operator test error msg mismatch

* remove tensor signature and backend set member

* move scalar and polish enforce

* revert dtype layout change to fix error

* fix enum operator override error

* add several base unittests

* add pten utils tests

* polish some details

* Dev/op2func refactor 3 (PaddlePaddle#30)

* add a candidate dense tensor class, test=develop

* remove TensorBase::backend(), test=develop

* remove some ops, test=develop

* cherry-pick the pr of tensor meta, test=develop

* moves the dense tensor and some ops, test=develop

* update the linalg operator, test=develop

* update other operators, test=develop

* fix errors, test=develop

* fix bugs, test=develop

* try to resolve the problem of windows ci, test=develop

* updates codes, test=develop

* fix the tensor_utils.cc, test=develop

* modify the dense tensor, test=develop

* fix the data type, test=develop

Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>

* polish some details

* polish kernel signature details

* fix a bug about offsets of the tensor, test=develop (PaddlePaddle#31)

Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>

* add matmul kernel in pten

* add unittest for new matmul_v2 kernel

* fix bug of CI compile

* fix bug of CI compile

* merge conflict

* remove useless file

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
Co-authored-by: chentianyu03 <ctychentianyu@gmail.com>
Co-authored-by: YuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: 石晓伟 <39303645+Shixiaowei02@users.noreply.github.com>
  • Loading branch information
5 people authored and piotrekobi committed Nov 3, 2021
1 parent 5d86ed9 commit 356a64e
Show file tree
Hide file tree
Showing 16 changed files with 866 additions and 12 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/op_call_stack.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h"
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/utils/small_vector.h"
Expand Down
25 changes: 16 additions & 9 deletions paddle/fluid/operators/matmul_v2_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ limitations under the License. */
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"

// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/include/core.h"
#include "paddle/pten/api/include/linalg.h"
#include "paddle/pten/hapi/lib/utils/tensor_utils.h"

#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif
Expand Down Expand Up @@ -380,15 +385,17 @@ class MatMulV2Kernel : public framework::OpKernel<T> {
auto* Out = ctx.Output<Tensor>("Out");
bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y");
PADDLE_ENFORCE_NE(framework::product(X->dims()), 0,
platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(framework::product(Y->dims()), 0,
platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
MatMulFunction<DeviceContext, T>(X, Y, Out, trans_x, trans_y, ctx);

auto& dev_ctx = ctx.device_context<DeviceContext>();
Out->mutable_data<T>(X->place());

auto pt_x = paddle::experimental::MakePtenDenseTensor(*X);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*Y);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out);

// call new kernel
pten::Matmul<T>(dev_ctx, *pt_x.get(), *pt_y.get(), trans_x, trans_y,
pt_out.get());
}
};

Expand Down
5 changes: 5 additions & 0 deletions paddle/pten/hapi/include/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,10 @@ namespace experimental {

Tensor dot(const Tensor& x, const Tensor& y);

Tensor matmul(const Tensor& x,
const Tensor& y,
bool transpose_x,
bool transpose_y);

} // namespace experimental
} // namespace paddle
43 changes: 42 additions & 1 deletion paddle/pten/hapi/lib/linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/hapi/lib/kernel_dispatch.h"
#include "paddle/pten/hapi/lib/utils/allocator.h"
#include "paddle/pten/infershape/binary.h"

namespace paddle {
namespace experimental {
Expand Down Expand Up @@ -65,5 +64,47 @@ Tensor dot(const Tensor& x, const Tensor& y) {
return out;
}

Tensor matmul(const Tensor& x,
const Tensor& y,
bool transpose_x,
bool transpose_y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"matmul_v2", kernel_key);

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(*dev_ctx);

// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
auto dense_y = std::dynamic_pointer_cast<pten::DenseTensor>(y.impl());
kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(transpose_x);
kernel_context.EmplaceBackAttr(transpose_y);
// TODO(chenweihang): add transform impl

// 4. InferShape
auto out_meta = MatmulInferShape(
dense_x->meta(), dense_y->meta(), transpose_x, transpose_y);

// 5. Prepare outputs
const auto allocator = std::make_shared<DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);

Tensor out;
out.set_impl(dense_out);

// 6. Call kernel
kernel(&kernel_context);

return out;
}

} // namespace experimental
} // namespace paddle
70 changes: 70 additions & 0 deletions paddle/pten/infershape/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,74 @@ DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta,
return return_meta;
}

DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta,
bool trans_x,
bool trans_y) {
std::vector<int64_t> dims_x = paddle::framework::vectorize(x_meta.dims);
std::vector<int64_t> dims_y = paddle::framework::vectorize(y_meta.dims);
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0,
paddle::platform::errors::InvalidArgument(
"The Input(x) dims size must be greater than 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0,
paddle::platform::errors::InvalidArgument(
"The Input(y) dims size must be greater than 0,"
" but reviced dims size is 0. "));

bool x_broadcasted = false, y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}

if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}

size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}

std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}

auto ddim_out = paddle::framework::make_ddim(new_dims);

return {x_meta.type, ddim_out, x_meta.layout};
}

} // namespace pten
5 changes: 5 additions & 0 deletions paddle/pten/infershape/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ namespace pten {
DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta);

DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta,
bool trans_x,
bool trans_y);

} // namespace pten
27 changes: 27 additions & 0 deletions paddle/pten/kernels/cpu/linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/complex.h"

#include "paddle/pten/kernels/functions/math/matmul_func.h"

namespace pten {

template <typename T>
Expand All @@ -45,6 +47,27 @@ void Dot(const CPUContext& dev_ctx,
}
}

template <typename T>
void Matmul(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
math::MatMulFunction<CPUContext, T>(
dev_ctx, x, y, out, transpose_x, transpose_y);
}

} // namespace pten

PT_REGISTER_MODULE(LinalgCPU);
Expand All @@ -62,3 +85,7 @@ PT_REGISTER_KERNEL("dot",
int64_t,
complex64,
complex128) {}

PT_REGISTER_KERNEL(
"matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) {
}
2 changes: 1 addition & 1 deletion paddle/pten/kernels/cpu/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void Dot(const CPUContext& dev_ctx,
DenseTensor* out);

template <typename T>
void matmul(const CPUContext& dev_ctx,
void Matmul(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
Expand Down
33 changes: 33 additions & 0 deletions paddle/pten/kernels/cuda/linalg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/functions/eigen/dot.h"
#include "paddle/pten/kernels/functions/math/matmul_func.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
Expand All @@ -30,10 +31,32 @@ void Dot(const CUDAContext& dev_ctx,
eigen::Dot<CUDAContext, T>(dev_ctx, x, y, out);
}

template <typename T>
void Matmul(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
math::MatMulFunction<CUDAContext, T>(
dev_ctx, x, y, out, transpose_x, transpose_y);
}

} // namespace pten

PT_REGISTER_MODULE(LinalgCUDA);

using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;

Expand All @@ -47,3 +70,13 @@ PT_REGISTER_KERNEL("dot",
int64_t,
complex64,
complex128) {}

PT_REGISTER_KERNEL("matmul_v2",
CUDA,
ANY,
pten::Matmul,
float,
double,
float16,
complex64,
complex128) {}
8 changes: 8 additions & 0 deletions paddle/pten/kernels/cuda/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ void Dot(const CUDAContext& dev_ctx,
const DenseTensor& y,
DenseTensor* out);

template <typename T>
void Matmul(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out);

} // namespace pten

#endif
5 changes: 5 additions & 0 deletions paddle/pten/kernels/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ limitations under the License. */

#pragma once

// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)

#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"

Expand All @@ -26,3 +29,5 @@ using CUDAContext = paddle::platform::CUDADeviceContext;
void Copy(const CUDAContext& dev_ctx, const DenseTensor& src, DenseTensor* dst);

} // namespace pten

#endif
Loading

0 comments on commit 356a64e

Please sign in to comment.