Skip to content

Commit

Permalink
[quantization][opencl]: Support quantization for TransposeInst
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi Zhang authored and ZchiPitt committed Jul 20, 2018
1 parent e81c1eb commit 2f1c086
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 45 deletions.
1 change: 1 addition & 0 deletions lib/Backends/OpenCL/OpenCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class OCLBackend final : public Backend {
case Kinded::Kind::MulNodeKind:
case Kinded::Kind::QuantizeNodeKind:
case Kinded::Kind::SubNodeKind:
case Kinded::Kind::TransposeNodeKind:
return true;
default:
return false;
Expand Down
76 changes: 31 additions & 45 deletions lib/Backends/OpenCL/kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -1056,54 +1056,40 @@ __kernel void oclpoolavgW(__global void *mem, cl_uint32_t dest, cl_uint32_t src,
oclpoolavgK(&mem[dest], &mem[src], filterSize, stride, pads, odim, idim);
}

__kernel void transposeK(__global float *dest, __global float *src,
ShapeNHWC odim, ShapeNHWC idim, ShapeNHWC shuffle) {
size_t res[4];
size_t d0 = get_global_id(0);
size_t d1 = get_global_id(1);
res[0] = d0;
res[1] = d1;
for (size_t d2 = 0; d2 < idim.w; d2++) {
res[2] = d2;
for (size_t d3 = 0; d3 < idim.c; d3++) {
res[3] = d3;
size_t dstIdx = getNHWC(odim, res[shuffle.n], res[shuffle.h],
res[shuffle.w], res[shuffle.c]);
size_t srcIdx = getNHWC(idim, d0, d1, d2, d3);
dest[dstIdx] = src[srcIdx];
}
/// Macro to define a kernel for transpose operations. The body of
/// the kernel is auto-generated by the macro.
/// \p type the type of the tensor elements and of the return value
#define DEFINE_OPENCL_TRANSPOSE_KERNEL(name, type) \
__kernel void name##K(__global type *dest, __global type *src, \
ShapeNHWC odim, ShapeNHWC idim, \
ShapeNHWC shuffle) { \
size_t res[4]; \
size_t d0 = get_global_id(0); \
size_t d1 = get_global_id(1); \
res[0] = d0; \
res[1] = d1; \
for (size_t d2 = 0; d2 < idim.w; d2++) { \
res[2] = d2; \
for (size_t d3 = 0; d3 < idim.c; d3++) { \
res[3] = d3; \
size_t dstIdx = getNHWC(odim, res[shuffle.n], res[shuffle.h], \
res[shuffle.w], res[shuffle.c]); \
size_t srcIdx = getNHWC(idim, d0, d1, d2, d3); \
dest[dstIdx] = src[srcIdx]; \
} \
} \
} \
__kernel void name##W(__global void *mem, cl_uint32_t dest, \
cl_uint32_t src, ShapeNHWC odim, \
ShapeNHWC idim, ShapeNHWC shuffle) { \
name##K(&mem[dest], &mem[src], odim, idim, shuffle); \
}
}

__kernel void transposeW(__global void *mem, cl_uint32_t dest, cl_uint32_t src,
ShapeNHWC odim, ShapeNHWC idim, ShapeNHWC shuffle) {
transposeK(&mem[dest], &mem[src], odim, idim, shuffle);
}

__kernel void transposeK_u(__global cl_uint64_t *dest,
__global cl_uint64_t *src, ShapeNHWC odim,
ShapeNHWC idim, ShapeNHWC shuffle) {
size_t res[4];
size_t d0 = get_global_id(0);
size_t d1 = get_global_id(1);
res[0] = d0;
res[1] = d1;
for (size_t d2 = 0; d2 < idim.w; d2++) {
res[2] = d2;
for (size_t d3 = 0; d3 < idim.c; d3++) {
res[3] = d3;
size_t dstIdx = getNHWC(odim, res[shuffle.n], res[shuffle.h],
res[shuffle.w], res[shuffle.c]);
size_t srcIdx = getNHWC(idim, d0, d1, d2, d3);
dest[dstIdx] = src[srcIdx];
}
}
}
DEFINE_OPENCL_TRANSPOSE_KERNEL(transpose_i8, cl_int8_t)
DEFINE_OPENCL_TRANSPOSE_KERNEL(transpose_u, cl_uint64_t)
DEFINE_OPENCL_TRANSPOSE_KERNEL(transpose, float)

__kernel void transpose_uW(__global void *mem, cl_uint32_t dest, cl_uint32_t src,
ShapeNHWC odim, ShapeNHWC idim, ShapeNHWC shuffle) {
transposeK_u(&mem[dest], &mem[src], odim, idim, shuffle);
}
#undef DEFINE_OPENCL_TRANSPOSE_KERNEL

__kernel void inserttensorK(__global float *dest, __global float *src,
ShapeNHWC odim, ShapeNHWC idim, ShapeNHWC offset,
Expand Down
20 changes: 20 additions & 0 deletions tests/unittests/OperatorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,26 @@ TEST_P(InterpAndCPU, QuantizedArithmeticRescaled) {
}
}

TEST_P(Operator, QuantizedTranspose) {
auto *A = mod_.createVariable(ElemKind::FloatTy, {2, 3}, "A",
VisibilityKind::Public);
auto *B = mod_.createVariable(ElemKind::FloatTy, {3, 2}, "B",
VisibilityKind::Public);
A->getPayload().getHandle() = {1, 1.2f, 0.5f, 1.3f, 2.7f, 5.8f};
A->getPayload().transpose(&B->getPayload(), {1, 0});
auto qType = mod_.uniqueType(ElemKind::Int8QTy, {2, 3}, 0.05, -138);
auto *quantizeA = F_->createQuantize("quantize", A, qType);
auto *tr = F_->createTranspose("tr", quantizeA, {1, 0});
auto *dequantize = F_->createDequantize("dequantize", tr);
auto *result = F_->createSave("ret", dequantize);
auto *fpTr = F_->createTranspose("fpTr", A, {1, 0});
auto *fpResult = F_->createSave("fpRet", fpTr);
EE_.compile(CompilationMode::Infer, F_);
EE_.run({}, {});
EXPECT_TRUE(result->getVariable()->getPayload().isEqual(B->getPayload()));
EXPECT_TRUE(fpResult->getVariable()->getPayload().isEqual(B->getPayload()));
}

TEST_P(Operator, QuantizedArithmeticUnrescaled) {
const size_t len = 100;

Expand Down

0 comments on commit 2f1c086

Please sign in to comment.