Skip to content

Commit

Permalink
[OPENCL] Fix reshape arguments register. test=develop (PaddlePaddle#5060
Browse files Browse the repository at this point in the history
)
  • Loading branch information
zhaoyang-star committed Dec 26, 2020
1 parent 6dcbd4a commit fdf8c72
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions lite/kernels/opencl/reshape_image_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,10 @@ REGISTER_LITE_KERNEL(reshape,
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
Expand All @@ -217,9 +219,12 @@ REGISTER_LITE_KERNEL(reshape2,
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ShapeTensor",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("XShape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
Expand All @@ -236,7 +241,8 @@ REGISTER_LITE_KERNEL(flatten,
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
Expand All @@ -253,8 +259,10 @@ REGISTER_LITE_KERNEL(flatten2,
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kOpenCL))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Shape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("XShape",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
Expand Down

0 comments on commit fdf8c72

Please sign in to comment.