Skip to content

Commit

Permalink
Merge pull request #224 from snehaa8/sn/color_temp_tensor
Browse files Browse the repository at this point in the history
Color Temperature fixes
  • Loading branch information
r-abishek committed Jan 29, 2024
2 parents 9944e02 + 88ab397 commit 45be1f8
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 83 deletions.
75 changes: 51 additions & 24 deletions src/modules/cpu/kernel/color_temperature.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,14 @@ RppStatus color_temperature_u8_u8_host_tensor(Rpp8u *srcPtr,
}
for (; vectorLoopCount < bufferLength; vectorLoopCount++)
{
dstPtrTemp[0] = (Rpp8u) RPPPIXELCHECK(*srcPtrTempR++ + adjustmentValue);
dstPtrTemp[1] = (Rpp8u) RPPPIXELCHECK(*srcPtrTempG++);
dstPtrTemp[2] = (Rpp8u) RPPPIXELCHECK(*srcPtrTempB++ - adjustmentValue);
dstPtrTemp[0] = (Rpp8u) RPPPIXELCHECK(*srcPtrTempR + adjustmentValue);
dstPtrTemp[1] = (Rpp8u) RPPPIXELCHECK(*srcPtrTempG);
dstPtrTemp[2] = (Rpp8u) RPPPIXELCHECK(*srcPtrTempB - adjustmentValue);

dstPtrTemp += 3;
srcPtrTempR++;
srcPtrTempG++;
srcPtrTempB++;
}

srcPtrRowR += srcDescPtr->strides.hStride;
Expand Down Expand Up @@ -216,10 +219,13 @@ RppStatus color_temperature_u8_u8_host_tensor(Rpp8u *srcPtr,
}
for (; vectorLoopCount < bufferLength; vectorLoopCount++)
{
*dstPtrTempR++ = (Rpp8u) RPPPIXELCHECK(*srcPtrTempR++ + adjustmentValue);
*dstPtrTempG++ = (Rpp8u) RPPPIXELCHECK(*srcPtrTempG++);
*dstPtrTempB++ = (Rpp8u) RPPPIXELCHECK(*srcPtrTempB++ - adjustmentValue);
*dstPtrTempR++ = (Rpp8u) RPPPIXELCHECK(*srcPtrTempR + adjustmentValue);
*dstPtrTempG++ = (Rpp8u) RPPPIXELCHECK(*srcPtrTempG);
*dstPtrTempB++ = (Rpp8u) RPPPIXELCHECK(*srcPtrTempB - adjustmentValue);

srcPtrTempR++;
srcPtrTempG++;
srcPtrTempB++;
}

srcPtrRowR += srcDescPtr->strides.hStride;
Expand Down Expand Up @@ -352,11 +358,14 @@ RppStatus color_temperature_f32_f32_host_tensor(Rpp32f *srcPtr,
}
for (; vectorLoopCount < bufferLength; vectorLoopCount++)
{
dstPtrTemp[0] = RPPPIXELCHECKF32(*srcPtrTempR++ + adjustmentValue);
dstPtrTemp[1] = RPPPIXELCHECKF32(*srcPtrTempG++);
dstPtrTemp[2] = RPPPIXELCHECKF32(*srcPtrTempB++ - adjustmentValue);
dstPtrTemp[0] = RPPPIXELCHECKF32(*srcPtrTempR + adjustmentValue);
dstPtrTemp[1] = RPPPIXELCHECKF32(*srcPtrTempG);
dstPtrTemp[2] = RPPPIXELCHECKF32(*srcPtrTempB - adjustmentValue);

dstPtrTemp += 3;
srcPtrTempR++;
srcPtrTempG++;
srcPtrTempB++;
}

srcPtrRowR += srcDescPtr->strides.hStride;
Expand Down Expand Up @@ -449,9 +458,13 @@ RppStatus color_temperature_f32_f32_host_tensor(Rpp32f *srcPtr,
}
for (; vectorLoopCount < bufferLength; vectorLoopCount++)
{
*dstPtrTempR++ = RPPPIXELCHECKF32(*srcPtrTempR++ + adjustmentValue);
*dstPtrTempG++ = RPPPIXELCHECKF32(*srcPtrTempG++);
*dstPtrTempB++ = RPPPIXELCHECKF32(*srcPtrTempB++ - adjustmentValue);
*dstPtrTempR++ = RPPPIXELCHECKF32(*srcPtrTempR + adjustmentValue);
*dstPtrTempG++ = RPPPIXELCHECKF32(*srcPtrTempG);
*dstPtrTempB++ = RPPPIXELCHECKF32(*srcPtrTempB - adjustmentValue);

srcPtrTempR++;
srcPtrTempG++;
srcPtrTempB++;
}

srcPtrRowR += srcDescPtr->strides.hStride;
Expand Down Expand Up @@ -611,11 +624,14 @@ RppStatus color_temperature_f16_f16_host_tensor(Rpp16f *srcPtr,
}
for (; vectorLoopCount < bufferLength; vectorLoopCount++)
{
dstPtrTemp[0] = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempR++ + adjustmentValue);
dstPtrTemp[1] = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempG++);
dstPtrTemp[2] = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempB++ - adjustmentValue);
dstPtrTemp[0] = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempR + adjustmentValue);
dstPtrTemp[1] = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempG);
dstPtrTemp[2] = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempB - adjustmentValue);

dstPtrTemp += 3;
srcPtrTempR++;
srcPtrTempG++;
srcPtrTempB++;
}

srcPtrRowR += srcDescPtr->strides.hStride;
Expand Down Expand Up @@ -733,9 +749,13 @@ RppStatus color_temperature_f16_f16_host_tensor(Rpp16f *srcPtr,
}
for (; vectorLoopCount < bufferLength; vectorLoopCount++)
{
*dstPtrTempR++ = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempR++ + adjustmentValue);
*dstPtrTempG++ = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempG++);
*dstPtrTempB++ = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempB++ - adjustmentValue);
*dstPtrTempR++ = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempR + adjustmentValue);
*dstPtrTempG++ = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempG);
*dstPtrTempB++ = (Rpp16f) RPPPIXELCHECKF32(*srcPtrTempB - adjustmentValue);

srcPtrTempR++;
srcPtrTempG++;
srcPtrTempB++;
}

srcPtrRowR += srcDescPtr->strides.hStride;
Expand Down Expand Up @@ -868,11 +888,14 @@ RppStatus color_temperature_i8_i8_host_tensor(Rpp8s *srcPtr,
}
for (; vectorLoopCount < bufferLength; vectorLoopCount++)
{
dstPtrTemp[0] = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempR++ + adjustmentValue);
dstPtrTemp[1] = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempG++);
dstPtrTemp[2] = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempB++ - adjustmentValue);
dstPtrTemp[0] = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempR + adjustmentValue);
dstPtrTemp[1] = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempG);
dstPtrTemp[2] = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempB - adjustmentValue);

dstPtrTemp += 3;
srcPtrTempR++;
srcPtrTempG++;
srcPtrTempB++;
}

srcPtrRowR += srcDescPtr->strides.hStride;
Expand Down Expand Up @@ -965,9 +988,13 @@ RppStatus color_temperature_i8_i8_host_tensor(Rpp8s *srcPtr,
}
for (; vectorLoopCount < bufferLength; vectorLoopCount++)
{
*dstPtrTempR++ = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempR++ + adjustmentValue);
*dstPtrTempG++ = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempG++);
*dstPtrTempB++ = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempB++ - adjustmentValue);
*dstPtrTempR++ = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempR + adjustmentValue);
*dstPtrTempG++ = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempG);
*dstPtrTempB++ = (Rpp8s) RPPPIXELCHECKI8(*srcPtrTempB - adjustmentValue);

srcPtrTempR++;
srcPtrTempG++;
srcPtrTempB++;
}

srcPtrRowR += srcDescPtr->strides.hStride;
Expand Down
96 changes: 43 additions & 53 deletions src/modules/hip/kernel/color_temperature.hpp
Original file line number Diff line number Diff line change
@@ -1,47 +1,39 @@
#include <hip/hip_runtime.h>
#include "rpp_hip_common.hpp"

__device__ void color_temperature_hip_compute(uchar *srcPtr, d_float24 *pix_f24, float4 *adjustmentValue_f4)
{
pix_f24->f4[0] += *adjustmentValue_f4;
pix_f24->f4[1] += *adjustmentValue_f4;
pix_f24->f4[4] -= *adjustmentValue_f4;
pix_f24->f4[5] -= *adjustmentValue_f4;
}

__device__ void color_temperature_hip_compute(float *srcPtr, d_float24 *pix_f24, float4 *adjustmentValue_f4)
{
float4 adjustment_f4 = *adjustmentValue_f4 * (float4) ONE_OVER_255;
pix_f24->f4[0] += adjustment_f4;
pix_f24->f4[1] += adjustment_f4;
pix_f24->f4[4] -= adjustment_f4;
pix_f24->f4[5] -= adjustment_f4;
}

__device__ void color_temperature_hip_compute(signed char *srcPtr, d_float24 *pix_f24, float4 *adjustmentValue_f4)
{
pix_f24->f4[0] += *adjustmentValue_f4;
pix_f24->f4[1] += *adjustmentValue_f4;
pix_f24->f4[4] -= *adjustmentValue_f4;
pix_f24->f4[5] -= *adjustmentValue_f4;
}

__device__ void color_temperature_hip_compute(half *srcPtr, d_float24 *pix_f24, float4 *adjustmentValue_f4)
template <typename T>
__device__ void color_temperature_hip_compute(T *srcPtr, d_float24 *pix_f24, float4 *adjustmentValue_f4)
{
float4 adjustment_f4 = *adjustmentValue_f4 * (float4) ONE_OVER_255;
pix_f24->f4[0] += adjustment_f4;
pix_f24->f4[1] += adjustment_f4;
pix_f24->f4[4] -= adjustment_f4;
pix_f24->f4[5] -= adjustment_f4;
float4 adjustment_f4;
if constexpr ((std::is_same<T, float>::value) || (std::is_same<T, half>::value))
{
adjustment_f4 = *adjustmentValue_f4 * (float4) ONE_OVER_255;
rpp_hip_math_add8_const(&pix_f24->f8[0], &pix_f24->f8[0], adjustment_f4);
rpp_hip_math_subtract8_const(&pix_f24->f8[2], &pix_f24->f8[2], adjustment_f4);
}
else if constexpr (std::is_same<T, schar>::value)
{
adjustment_f4 = *adjustmentValue_f4;
rpp_hip_math_add24_const(pix_f24, pix_f24, (float4)128);
rpp_hip_math_add8_const(&pix_f24->f8[0], &pix_f24->f8[0], adjustment_f4);
rpp_hip_math_subtract8_const(&pix_f24->f8[2], &pix_f24->f8[2], adjustment_f4);
rpp_hip_pixel_check_0to255(pix_f24);
rpp_hip_math_subtract24_const(pix_f24, pix_f24, (float4)128);
}
else
{
rpp_hip_math_add8_const(&pix_f24->f8[0], &pix_f24->f8[0], *adjustmentValue_f4);
rpp_hip_math_subtract8_const(&pix_f24->f8[2], &pix_f24->f8[2], *adjustmentValue_f4);
}
}

template <typename T>
__global__ void color_temperature_pkd_hip_tensor(T *srcPtr,
uint2 srcStridesNH,
T *dstPtr,
uint2 dstStridesNH,
int *adjustmentValueTensor,
RpptROIPtr roiTensorPtrSrc)
uint2 srcStridesNH,
T *dstPtr,
uint2 dstStridesNH,
int *adjustmentValueTensor,
RpptROIPtr roiTensorPtrSrc)
{
int id_x = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x) * 8;
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y;
Expand All @@ -66,11 +58,11 @@ __global__ void color_temperature_pkd_hip_tensor(T *srcPtr,

template <typename T>
__global__ void color_temperature_pln_hip_tensor(T *srcPtr,
uint3 srcStridesNCH,
T *dstPtr,
uint3 dstStridesNCH,
int *adjustmentValueTensor,
RpptROIPtr roiTensorPtrSrc)
uint3 srcStridesNCH,
T *dstPtr,
uint3 dstStridesNCH,
int *adjustmentValueTensor,
RpptROIPtr roiTensorPtrSrc)
{
int id_x = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x) * 8;
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y;
Expand All @@ -95,11 +87,11 @@ __global__ void color_temperature_pln_hip_tensor(T *srcPtr,

template <typename T>
__global__ void color_temperature_pkd3_pln3_hip_tensor(T *srcPtr,
uint2 srcStridesNH,
T *dstPtr,
uint3 dstStridesNCH,
int *adjustmentValueTensor,
RpptROIPtr roiTensorPtrSrc)
uint2 srcStridesNH,
T *dstPtr,
uint3 dstStridesNCH,
int *adjustmentValueTensor,
RpptROIPtr roiTensorPtrSrc)
{
int id_x = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x) * 8;
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y;
Expand All @@ -124,11 +116,11 @@ __global__ void color_temperature_pkd3_pln3_hip_tensor(T *srcPtr,

template <typename T>
__global__ void color_temperature_pln3_pkd3_hip_tensor(T *srcPtr,
uint3 srcStridesNCH,
T *dstPtr,
uint2 dstStridesNH,
int *adjustmentValueTensor,
RpptROIPtr roiTensorPtrSrc)
uint3 srcStridesNCH,
T *dstPtr,
uint2 dstStridesNH,
int *adjustmentValueTensor,
RpptROIPtr roiTensorPtrSrc)
{
int id_x = (hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x) * 8;
int id_y = hipBlockIdx_y * hipBlockDim_y + hipThreadIdx_y;
Expand Down Expand Up @@ -171,7 +163,6 @@ RppStatus hip_exec_color_temperature_tensor(T *srcPtr,

if ((srcDescPtr->layout == RpptLayout::NHWC) && (dstDescPtr->layout == RpptLayout::NHWC))
{
globalThreads_x = (dstDescPtr->strides.hStride / 3 + 7) >> 3;
hipLaunchKernelGGL(color_temperature_pkd_hip_tensor,
dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)),
dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z),
Expand Down Expand Up @@ -214,7 +205,6 @@ RppStatus hip_exec_color_temperature_tensor(T *srcPtr,
}
else if ((srcDescPtr->layout == RpptLayout::NCHW) && (dstDescPtr->layout == RpptLayout::NHWC))
{
globalThreads_x = (srcDescPtr->strides.hStride + 7) >> 3;
hipLaunchKernelGGL(color_temperature_pln3_pkd3_hip_tensor,
dim3(ceil((float)globalThreads_x/LOCAL_THREADS_X), ceil((float)globalThreads_y/LOCAL_THREADS_Y), ceil((float)globalThreads_z/LOCAL_THREADS_Z)),
dim3(LOCAL_THREADS_X, LOCAL_THREADS_Y, LOCAL_THREADS_Z),
Expand Down
12 changes: 6 additions & 6 deletions src/modules/rppt_tensor_color_augmentations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ RppStatus rppt_color_temperature_host(RppPtr_t srcPtr,
{
color_temperature_f16_f16_host_tensor(reinterpret_cast<Rpp16f*>(static_cast<Rpp8u*>(srcPtr) + srcDescPtr->offsetInBytes),
srcDescPtr,
(Rpp16f*) (static_cast<Rpp8u*>(dstPtr) + dstDescPtr->offsetInBytes),
reinterpret_cast<Rpp16f*>(static_cast<Rpp8u*>(dstPtr) + dstDescPtr->offsetInBytes),
dstDescPtr,
adjustmentValueTensor,
roiTensorPtrSrc,
Expand All @@ -713,7 +713,7 @@ RppStatus rppt_color_temperature_host(RppPtr_t srcPtr,
{
color_temperature_f32_f32_host_tensor(reinterpret_cast<Rpp32f*>(static_cast<Rpp8u*>(srcPtr) + srcDescPtr->offsetInBytes),
srcDescPtr,
(Rpp32f*) (static_cast<Rpp8u*>(dstPtr) + dstDescPtr->offsetInBytes),
reinterpret_cast<Rpp32f*>(static_cast<Rpp8u*>(dstPtr) + dstDescPtr->offsetInBytes),
dstDescPtr,
adjustmentValueTensor,
roiTensorPtrSrc,
Expand Down Expand Up @@ -1300,19 +1300,19 @@ RppStatus rppt_color_temperature_gpu(RppPtr_t srcPtr,
}
else if ((srcDescPtr->dataType == RpptDataType::F16) && (dstDescPtr->dataType == RpptDataType::F16))
{
hip_exec_color_temperature_tensor((half*) (static_cast<Rpp8u*>(srcPtr) + srcDescPtr->offsetInBytes),
hip_exec_color_temperature_tensor(reinterpret_cast<half*>(static_cast<Rpp8u*>(srcPtr) + srcDescPtr->offsetInBytes),
srcDescPtr,
(half*) (static_cast<Rpp8u*>(dstPtr) + dstDescPtr->offsetInBytes),
reinterpret_cast<half*>(static_cast<Rpp8u*>(dstPtr) + dstDescPtr->offsetInBytes),
dstDescPtr,
roiTensorPtrSrc,
roiType,
rpp::deref(rppHandle));
}
else if ((srcDescPtr->dataType == RpptDataType::F32) && (dstDescPtr->dataType == RpptDataType::F32))
{
hip_exec_color_temperature_tensor((Rpp32f*) (static_cast<Rpp8u*>(srcPtr) + srcDescPtr->offsetInBytes),
hip_exec_color_temperature_tensor(reinterpret_cast<Rpp32f*>(static_cast<Rpp8u*>(srcPtr) + srcDescPtr->offsetInBytes),
srcDescPtr,
(Rpp32f*) (static_cast<Rpp8u*>(dstPtr) + dstDescPtr->offsetInBytes),
reinterpret_cast<Rpp32f*>(static_cast<Rpp8u*>(dstPtr) + dstDescPtr->offsetInBytes),
dstDescPtr,
roiTensorPtrSrc,
roiType,
Expand Down
Binary file not shown.

0 comments on commit 45be1f8

Please sign in to comment.