Skip to content

Commit

Permalink
BUGFIX:
Browse files Browse the repository at this point in the history
Fixed the bug of the kernel not fully processing all the items when the batch * height * width > number of threads spawned by adding a layer of for-loop

**two spaces for minor performance increase:**
1. instead of taking `hue_delta` as a global memory, take in the value inside `hue_delta` in order to eliminate unnecessary global memory read
2. make the copying performed in this conditional: `if (!AdjustHue && !AdjustSaturation && !AdjustV)` access global memory with coalesced accesses
  • Loading branch information
ThisIsIsaac committed May 14, 2019
1 parent 1fb966e commit c73a146
Showing 1 changed file with 33 additions and 37 deletions.
70 changes: 33 additions & 37 deletions tensorflow/core/kernels/adjust_hsv_gpu.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,45 +99,41 @@ __global__ void adjust_hsv_nhwc(const int64 number_elements,
const float* const value_scale) {
// multiply by 3 since we're dealing with contiguous RGB bytes for each pixel
// (NHWC)
const int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3;
// bounds check
if (idx > number_elements - 1) {
return;
}
if (!AdjustHue && !AdjustSaturation && !AdjustV) {
output[idx] = input[idx];
output[idx + 1] = input[idx + 1];
output[idx + 2] = input[idx + 2];
return;
}
const HsvTuple hsv = rgb2hsv_cuda(static_cast<float>(input[idx]),
static_cast<float>(input[idx + 1]),
static_cast<float>(input[idx + 2]));
float new_h = hsv.h;
float new_s = hsv.s;
float new_v = hsv.v;
// hue adjustment
if (AdjustHue) {
const float delta = *hue_delta;
new_h = fmodf(hsv.h + delta, 1.0f);
if (new_h < 0.0f) {
new_h = fmodf(1.0f + new_h, 1.0f);
for (int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3; idx < number_elements; idx+=blockDim.x*gridDim.x*3){
if (!AdjustHue && !AdjustSaturation && !AdjustV) {
output[idx] = input[idx];
output[idx + 1] = input[idx + 1];
output[idx + 2] = input[idx + 2];
}
const HsvTuple hsv = rgb2hsv_cuda(static_cast<float>(input[idx]),
static_cast<float>(input[idx + 1]),
static_cast<float>(input[idx + 2]));
float new_h = hsv.h;
float new_s = hsv.s;
float new_v = hsv.v;
// hue adjustment
if (AdjustHue) {
const float delta = *hue_delta;
new_h = fmodf(hsv.h + delta, 1.0f);
if (new_h < 0.0f) {
new_h = fmodf(1.0f + new_h, 1.0f);
}
}
// saturation adjustment
if (AdjustSaturation && saturation_scale != nullptr) {
const float scale = *saturation_scale;
new_s = fminf(1.0f, fmaxf(0.0f, hsv.s * scale));
}
// value adjustment
if (AdjustV && value_scale != nullptr) {
const float scale = *value_scale;
new_v = hsv.v * scale;
}
const RgbTuple rgb = hsv2rgb_cuda(new_h, new_s, new_v);
output[idx] = static_cast<T>(rgb.r);
output[idx + 1] = static_cast<T>(rgb.g);
output[idx + 2] = static_cast<T>(rgb.b);
}
// saturation adjustment
if (AdjustSaturation && saturation_scale != nullptr) {
const float scale = *saturation_scale;
new_s = fminf(1.0f, fmaxf(0.0f, hsv.s * scale));
}
// value adjustment
if (AdjustV && value_scale != nullptr) {
const float scale = *value_scale;
new_v = hsv.v * scale;
}
const RgbTuple rgb = hsv2rgb_cuda(new_h, new_s, new_v);
output[idx] = static_cast<T>(rgb.r);
output[idx + 1] = static_cast<T>(rgb.g);
output[idx + 2] = static_cast<T>(rgb.b);
}

} // namespace internal
Expand Down

0 comments on commit c73a146

Please sign in to comment.