Skip to content

Commit

Permalink
Implement LocalResponseNormalization{,Grad}Inst (#3814)
Browse files Browse the repository at this point in the history
Summary:
Implement LocalResponseNormalization{,Grad}Inst for OpenCL.

Documentation:

Trivial implementation: copied the interpreter version and converted the outer loops to work-items.

Fixes #3802
Pull Request resolved: #3814

Test Plan: Passes 'ninja check'.

Differential Revision: D19165421

Pulled By: jfix71

fbshipit-source-id: c3aa479c523a714152a28337f9eae6d50737ef44
  • Loading branch information
pjaaskel authored and facebook-github-bot committed Dec 19, 2019
1 parent ba494df commit de133ee
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/training/resnet50/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ using FloatIndexPair = std::pair<float, size_t>;
/// (topKCount) [float, index] pairs, i.e. the pairs with the highest floats.
template <typename ElemTy>
static std::vector<FloatIndexPair> getTopKPairs(Handle<ElemTy> H,
size_t topKCount) {
dim_t topKCount) {
DCHECK_LE(topKCount, H.size()) << "Function requires k < number of labels.";
DCHECK_EQ(H.dims().size(), 1) << "H must be a Handle of a 1d Tensor.";

Expand Down
44 changes: 43 additions & 1 deletion lib/Backends/OpenCL/OpenCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ cl_kernel OpenCLFunction::createKernel(const std::string &name,
DCHECK(program) << "program cannot be null.";
cl_int err = CL_SUCCESS;
cl_kernel kernel = clCreateKernel(program, name.c_str(), &err);
CHECK(kernel) << "clCreateKernel Failed.";
CHECK(kernel) << "clCreateKernel Failed for " << name;
CHECK_EQ(err, CL_SUCCESS) << "clCreateKernel Failed.";
return kernel;
}
Expand Down Expand Up @@ -963,6 +963,46 @@ Error OpenCLFunction::execute(ExecutionContext *context) {
continue;
}

if (auto *LRN = dyn_cast<LocalResponseNormalizationGradInst>(&I)) {
cl_kernel kernel = createKernel(kernelName, program);
setKernelArg(kernel, 0, deviceBuffer);

size_t numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);
ShapeNHWC dim(LRN->getDest()->getType()->dims());

uint32_t halfWindowSize = LRN->getHalfWindowSize();
uint32_t windowSize = 2 * halfWindowSize + 1;
setKernelArg(kernel, ++numArgs, dim);
setKernelArg(kernel, ++numArgs, halfWindowSize);
setKernelArg(kernel, ++numArgs, LRN->getK());
setKernelArg(kernel, ++numArgs, LRN->getBeta());
setKernelArg(kernel, ++numArgs, LRN->getAlpha() / windowSize);

enqueueKernel(I.getName(), commands, kernel, deviceId,
{dim.n, dim.h, dim.w}, kernelLaunches);
continue;
}

if (auto *LRN = dyn_cast<LocalResponseNormalizationInst>(&I)) {
cl_kernel kernel = createKernel(kernelName, program);
setKernelArg(kernel, 0, deviceBuffer);

size_t numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);
ShapeNHWC dim(LRN->getDest()->getType()->dims());

uint32_t halfWindowSize = LRN->getHalfWindowSize();
uint32_t windowSize = 2 * halfWindowSize + 1;
setKernelArg(kernel, ++numArgs, dim);
setKernelArg(kernel, ++numArgs, halfWindowSize);
setKernelArg(kernel, ++numArgs, LRN->getK());
setKernelArg(kernel, ++numArgs, LRN->getBeta());
setKernelArg(kernel, ++numArgs, LRN->getAlpha() / windowSize);

enqueueKernel(I.getName(), commands, kernel, deviceId,
{dim.n, dim.h, dim.w}, kernelLaunches);
continue;
}

if (auto *CC = dyn_cast<ConvolutionInst>(&I)) {
if (CC->getLayout() == NCHW) {
executeNCHWConvolution(CC, context, clBindings);
Expand Down Expand Up @@ -1540,6 +1580,8 @@ bool OCLBackend::isOpSupported(const NodeInfo &NI) const {
{ElemKind::FloatTy, ElemKind::Int8QTy, IndexElemKind});

case Kinded::Kind::PowNodeKind:
case Kinded::Kind::LocalResponseNormalizationNodeKind:
case Kinded::Kind::LocalResponseNormalizationGradNodeKind:
case Kinded::Kind::BatchedReduceAddNodeKind:
case Kinded::Kind::TanhNodeKind:
case Kinded::Kind::SigmoidNodeKind:
Expand Down
92 changes: 92 additions & 0 deletions lib/Backends/OpenCL/kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,98 @@ __kernel void matmul_i8W(__global void *mem, cl_uint32_t dest, cl_uint32_t lhs,
destScaleParams.scale);
}

//__attribute__((reqd_work_group_size(1, 1, 1)))
__kernel void localresponsenormalizationW(__global void *mem, unsigned dest,
unsigned src, unsigned scaleC,
ShapeNHWC dim,
unsigned halfWindowSize, float k,
float beta, float normedAlpha) {

global float *outW = (global float *)&mem[dest];
global float *inW = (global float *)&mem[src];
global float *scaleCache = (global float *)&mem[scaleC];
unsigned n = get_global_id(0);
unsigned h = get_global_id(1);
unsigned w = get_global_id(2);
// For every channel:
for (unsigned c = 0; c < dim.c; c++) {
float squareSum = 0.0;
for (unsigned i = (c >= halfWindowSize ? c - halfWindowSize : 0);
i <= min(c + halfWindowSize, (unsigned)dim.c - 1); i++) {
float val = inW[getNHWC(dim, n, h, w, i)];
squareSum += val * val;
}

float scale = k + normedAlpha * squareSum;

// This will be used to accelerate the backward pass.
scaleCache[getNHWC(dim, n, h, w, c)] = scale;

float normFactor = pow(scale, -beta);
outW[getNHWC(dim, n, h, w, c)] = inW[getNHWC(dim, n, h, w, c)] * normFactor;
}
}

__kernel void localresponsenormalizationgradW(__global void *mem, unsigned dest,
unsigned src, unsigned scaleC,
unsigned destGrad,
unsigned srcGrad, ShapeNHWC dim,
unsigned halfWindowSize, float k,
float beta, float normedAlpha) {

global float *outW = (global float *)&mem[dest];
global float *outG = (global float *)&mem[destGrad];
global float *inW = (global float *)&mem[src];
global float *inG = (global float *)&mem[srcGrad];
global float *scaleCache = (global float *)&mem[scaleC];

unsigned n = get_global_id(0);
unsigned h = get_global_id(1);
unsigned w = get_global_id(2);

float sum = 0.0;

// Compute sum for first channel.
for (unsigned c = 0; c <= halfWindowSize && c < dim.c; c++) {
float outw = outW[getNHWC(dim, n, h, w, c)];
float scale = scaleCache[getNHWC(dim, n, h, w, c)];
float outg = outG[getNHWC(dim, n, h, w, c)];
sum += (outg * (outw / scale));
}

// For every channel:
for (unsigned c = 0; c < dim.c; c++) {
float outg = outG[getNHWC(dim, n, h, w, c)];
float scale = scaleCache[getNHWC(dim, n, h, w, c)];
float inw = inW[getNHWC(dim, n, h, w, c)];

inG[getNHWC(dim, n, h, w, c)] =
outg * pow(scale, -beta) - 2 * normedAlpha * beta * inw * sum;

// Modify sum for next channel.
unsigned subIndex = c - halfWindowSize;
unsigned addIndex = c + halfWindowSize + 1;

if (c >= halfWindowSize) {
float outw = outW[getNHWC(dim, n, h, w, subIndex)];
float scale = scaleCache[getNHWC(dim, n, h, w, subIndex)];
float outg = outG[getNHWC(dim, n, h, w, subIndex)];

// Subtract "rear" end of this window.
sum -= (outg * (outw / scale));
}

if (addIndex < dim.c) {
float outw = outW[getNHWC(dim, n, h, w, addIndex)];
float scale = scaleCache[getNHWC(dim, n, h, w, addIndex)];
float outg = outG[getNHWC(dim, n, h, w, addIndex)];

// Add "front" end of next window.
sum += (outg * (outw / scale));
}
}
}

__kernel void softmaxK(__global float *dest, __global float *src,
__global float *e_cache, cl_uint32_t sliceSize) {
size_t i = get_global_id(0);
Expand Down
2 changes: 0 additions & 2 deletions lib/Backends/OpenCL/tests/OpenCLBackendCorrectnessTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ using namespace glow;
std::set<std::string> glow::backendTestBlacklist = {
// Requires the CPU target due to the use of MockCPUBackend.
"dataParallelStackingTest/0",
"localResponseNormalizationTest/0",
"localResponseNormalizationGradTest/0",
"AvgPoolGradTest/0",
"intLookupTable/0",
};

0 comments on commit de133ee

Please sign in to comment.