Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 12, 2024
1 parent 6da49bf commit 3846169
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
5 changes: 0 additions & 5 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@
#include "cuda/fast_gelu.h"
#include "cuda/mul_sigmoid.h"
#include "cuda/negxplus1.h"
<<<<<<< HEAD
#include "cuda/replace_zero.h"
=======
#include "cuda/scatter_nd_of_shape.h"
>>>>>>> f5055466d5376059c2ea74e3cea46e16a537bc0d
#include "cuda/transpose_cast.h"
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {

using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, true>;
using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, false>;

Expand All @@ -28,7 +24,6 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast<ortc::MFloat16, float>;
#endif


static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef USE_CUDA
Expand Down
3 changes: 3 additions & 0 deletions operators/cuda/replace_zero.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ namespace contrib {
*
* Y = X.copy()
* X[X == 0] = c
*
* This operation usually appears when a tensor is updated with an operator Equal and Where.
* This kernel avoids the creation of one null tensor.
*/
template <typename T>
struct ReplaceZero {
Expand Down
20 changes: 12 additions & 8 deletions operators/cuda/replace_zero_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _replace_zero(const T x, const T by) {
template <typename T>
__device__ __inline__ T _replace_zero(const T x, const T by) {
return x == (T)0 ? by : x;
}

template <> __device__ __inline__ half _replace_zero(const half x, const half by) {
template <>
__device__ __inline__ half _replace_zero(const half x, const half by) {
#if __CUDA_ARCH__ < 700
return __half2float(x) == 0 ? by : x;
#else
Expand All @@ -25,31 +27,33 @@ template <> __device__ __inline__ half _replace_zero(const half x, const half by
}

template <typename T>
__global__ void ReplaceZeroKernel(T *output_data, const T *input_data, CUDA_LONG N, const T by) {
__global__ void ReplaceZeroKernel(T* output_data, const T* input_data, CUDA_LONG N, const T by) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = _replace_zero(input_data[id], by);
}

template <typename T> T _cvt(float value) { return (T)value; }
template <typename T>
T _cast(float value) { return (T)value; }

template <> half _cvt(float value) { return __float2half(value); }
template <>
half _cast(float value) { return __float2half(value); }

template <typename T>
cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) {
if (input_length == 0)
return cudaGetLastError();
return cudaGetLastError();
using TT = typename contrib::CudaT<T>::MappedType;

CUDA_LONG N = static_cast<CUDA_LONG>(input_length);

const int num_threads_per_block = 256;
const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block;

TT cby = _cvt<TT>(by);
TT cby = _cast<TT>(by);
ReplaceZeroKernel<TT><<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(
reinterpret_cast<TT*>(output_data), reinterpret_cast<const TT*>(input_data), N, cby);
reinterpret_cast<TT*>(output_data), reinterpret_cast<const TT*>(input_data), N, cby);
return cudaGetLastError();
}

Expand Down
2 changes: 1 addition & 1 deletion operators/cuda/replace_zero_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
#include <cuda_runtime.h>

template <typename T>
cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by);
cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by);

0 comments on commit 3846169

Please sign in to comment.