From ef071c989511fbbd084f25a1814be3213d98dc1a Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Fri, 20 Jul 2018 16:48:29 -0700 Subject: [PATCH] [quantization][opencl]: Support quantization for SplatInst --- lib/Backends/OpenCL/OpenCL.cpp | 11 ++++++++++- lib/Backends/OpenCL/OpenCL.h | 1 + lib/Backends/OpenCL/kernels.cl | 1 + 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/lib/Backends/OpenCL/OpenCL.cpp b/lib/Backends/OpenCL/OpenCL.cpp index 6b01f4d48b..47ae580390 100644 --- a/lib/Backends/OpenCL/OpenCL.cpp +++ b/lib/Backends/OpenCL/OpenCL.cpp @@ -566,7 +566,16 @@ void OpenCLFunction::execute() { if (auto *SI = dyn_cast(&I)) { // Pass the splat as a parameter. - setKernelArg(kernel, ++numArgs, SI->getValue()); + if (!isQuantized) { + setKernelArg(kernel, ++numArgs, SI->getValue()); + } else { + auto *destTy = SI->getDest()->getType(); + TensorQuantizationParams destQ{destTy->getScale(), + destTy->getOffset()}; + float val = SI->getValue(); + int8_t int8Val = quantization::quantize(val, destQ); + setKernelArg(kernel, ++numArgs, int8Val); + } } else if (auto *EPI = dyn_cast(&I)) { // Pass the exp as a parameter. setKernelArg(kernel, ++numArgs, EPI->getExp()); diff --git a/lib/Backends/OpenCL/OpenCL.h b/lib/Backends/OpenCL/OpenCL.h index 37fef2ad24..dd5b3a3f08 100644 --- a/lib/Backends/OpenCL/OpenCL.h +++ b/lib/Backends/OpenCL/OpenCL.h @@ -186,6 +186,7 @@ class OCLBackend final : public Backend { case Kinded::Kind::MinNodeKind: case Kinded::Kind::MulNodeKind: case Kinded::Kind::QuantizeNodeKind: + case Kinded::Kind::SplatNodeKind: case Kinded::Kind::SubNodeKind: case Kinded::Kind::TransposeNodeKind: return true; diff --git a/lib/Backends/OpenCL/kernels.cl b/lib/Backends/OpenCL/kernels.cl index 070e40e19e..8760c8e625 100644 --- a/lib/Backends/OpenCL/kernels.cl +++ b/lib/Backends/OpenCL/kernels.cl @@ -467,6 +467,7 @@ DEFINE_OPENCL_TERNARY_DATA_PARALLEL_KERNEL(elementselect, float, DEFINE_OPENCL_UNARY_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(splat, float, SRC) DEFINE_OPENCL_UNARY_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(splat_u, ulong, SRC) +DEFINE_OPENCL_UNARY_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(splat_i8, char, SRC) #undef DEFINE_OPENCL_BINARY_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND #undef DEFINE_OPENCL_UNARY_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND