diff --git a/lite/kernels/opencl/reshape_image_compute.cc b/lite/kernels/opencl/reshape_image_compute.cc index 7812631e97c..fc1786f193f 100644 --- a/lite/kernels/opencl/reshape_image_compute.cc +++ b/lite/kernels/opencl/reshape_image_compute.cc @@ -199,8 +199,10 @@ REGISTER_LITE_KERNEL(reshape, {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault))}) - .BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFP16), @@ -217,9 +219,12 @@ REGISTER_LITE_KERNEL(reshape2, {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault))}) - .BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("ShapeTensor", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("XShape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFP16), @@ -236,7 +241,8 @@ REGISTER_LITE_KERNEL(flatten, {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault))}) - .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFP16), @@ -253,8 +259,10 @@ REGISTER_LITE_KERNEL(flatten2, {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault))}) - .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Shape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .BindOutput("XShape", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFP16),