diff --git a/operators/cuda/add_mul.h b/operators/cuda/add_mul.h index 3f6cc1e6..43a1b97e 100644 --- a/operators/cuda/add_mul.h +++ b/operators/cuda/add_mul.h @@ -63,21 +63,91 @@ struct AddOrMulTwice { auto length_c = tensor_c.NumberOfElement(); T* output_data_ab = output_ab.Allocate( - length_a <= length_b - ? lenght_c <= length_b ? tensor_b.Shape() : tensor_c.Shape() - : lenght_a <= length_b ? tensor_b.Shape() : tensor_a.Shape()); + length_a <= length_b + ? lenght_c <= length_b ? tensor_b.Shape() : tensor_c.Shape() + : lenght_a <= length_b ? tensor_b.Shape() + : tensor_a.Shape()); if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { return {}; } LaunchAddOrMulTwiceKernel(reinterpret_cast(ctx->GetCudaStream()), - input_data_a, input_data_b, input_data_c, - output_data, - length_a, length_b, length_c, - addition); + input_data_a, input_data_b, input_data_c, + output_data, + length_a, length_b, length_c, + addition); return {}; } }; +template +struct AddAndMul { + template + OrtxStatus OnModelAttach(const TDict& dict) { + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& tensor_a, + const ortc::Tensor& tensor_b, + const ortc::Tensor& tensor_c, + ortc::Tensor& output) const { + const T* input_data_a = tensor_a.Data(); + const T* input_data_b = tensor_b.Data(); + const T* input_data_c = tensor_c.Data(); + + auto length_a = tensor_a.NumberOfElement(); + auto length_b = tensor_b.NumberOfElement(); + auto length_c = tensor_c.NumberOfElement(); + if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { + return {}; + } + + std::vector dimsA = tensor_a.Shape(); + std::vector dimsB = tensor_b.Shape(); + std::vector dimsC = tensor_c.Shape(); + + auto max_length = std::max(length_a, std::max(length_b, length_c)); + + auto max_rank = std::max(dimsA.size(), std::max(dimsB.size(), dimsC.size())); + while (dimsA.size() < max_rank) + dimsA.insert(dimsA.begin(), 1); + while (dimsB.size() < max_rank) + dimsB.insert(dimsB.begin(), 1); + while (dimsC.size() < max_rank) + dimsC.insert(dimsC.begin(), 1); + + std::vector output_dims(dimsA.size()); + for (size_t i = 0; i < dimsA.size(); ++i) { + output_dims[i] = std::max(std::max(dimsA[i], dimsB[i]), dimsC[i]); + } + + if (switchMiddelAxis_) { + if (output_dims.size() != 4) { + ORTX_CXX_API_THROW("switchMiddleAxes only works with 4D tensors", ORT_RUNTIME_EXCEPTION); + } + int64_t d4 = output_dims[output_dims.size() - 1]; + int64_t d3 = output_dims[output_dims.size() - 2]; + int64_t d2 = output_dims[output_dims.size() - 3]; + output_dims[1] = d3; + output_dims[2] = d2; + LaunchAddAndMulSwitchMiddleAxesKernel(reinterpret_cast(ctx->GetCudaStream()), + input_data_a, input_data_b, input_data_c, + output_data, + length_a, length_b, length_c, + addition_first, d2, d3, d4); + } else { + T* output_data_ab = output_ab.Allocate(output_dims); + LaunchAddAndMulKernel(reinterpret_cast(ctx->GetCudaStream()), + input_data_a, input_data_b, input_data_c, + output_data, + length_a, length_b, length_c, + addition_first, switchMiddelAxis_); + } + return {}; + } + + private: + bool switchMiddelAxis_; +}; } // namespace contrib \ No newline at end of file diff --git a/operators/cuda/add_mul_impl.cu b/operators/cuda/add_mul_impl.cu index 0373ddf2..ff52e00c 100644 --- a/operators/cuda/add_mul_impl.cu +++ b/operators/cuda/add_mul_impl.cu @@ -126,12 +126,12 @@ cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, length_a, length_b, length_c, addition); } -__device__ __forceinline__ void _add3_op(float *address, const float a, const float b, +__device__ __forceinline__ void _add3_op(float* address, const float a, const float b, const float c) { *address = a + b + c; } -__device__ __forceinline__ void _add3_op(half *address, const half a, const half b, +__device__ __forceinline__ void _add3_op(half* address, const half a, const half b, const half c) { #if __CUDA_ARCH__ < 700 *address = __float2half(__half2float(a) + __half2float(b) + __half2float(c)); @@ -140,12 +140,12 @@ __device__ __forceinline__ void _add3_op(half *address, const half a, const half #endif } -__device__ __forceinline__ void _mul3_op(float *address, const float a, const float b, +__device__ __forceinline__ void _mul3_op(float* address, const float a, const float b, const float c) { *address = a * b * c; } -__device__ __forceinline__ void _mul3_op(half *address, const half a, const half b, +__device__ __forceinline__ void _mul3_op(half* address, const half a, const half b, const half c) { #if __CUDA_ARCH__ < 700 *address = __float2half(__half2float(a) * __half2float(b) * __half2float(c)); @@ -154,14 +154,16 @@ __device__ __forceinline__ void _mul3_op(half *address, const half a, const half #endif } -template struct Mul3Op { - __device__ __inline__ void operator()(T *address, const T a, const T b, const T c) const { +template +struct Mul3Op { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { _mul3_op(address, a, b, c); } }; -template struct Add3Op { - __device__ __inline__ void operator()(T *address, const T a, const T b, const T c) const { +template +struct Add3Op { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { _add3_op(address, a, b, c); } }; @@ -201,7 +203,7 @@ cudaError_t _LaunchAddOrMulTwiceKernel(cudaStream_t stream, if (addition) { AddMulTwiceKernel, num_threads_per_block, num_elements_per_thread> <<>>( - reinterpret_cast(output), + reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), static_cast(countB), static_cast(countC), static_cast(max_count), Add3SharedOp()); @@ -236,3 +238,164 @@ cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, length_a, length_b, length_c, addition); } +__device__ __forceinline__ void _addmul_op(float* address, const float a, const float b, + const float c) { + *address = (a + b) * c; +} + +__device__ __forceinline__ void _addmul_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half((__half2float(a) + __half2float(b)) * __half2float(c)); +#else + *address = (a + b) * c; +#endif +} + +__device__ __forceinline__ void _muladd_op(float* address, const float a, const float b, + const float c) { + *address = a * b + c; +} + +__device__ __forceinline__ void _muladd_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(a) * __half2float(b) + __half2float(c)); +#else + *address = a * b + c; +#endif +} + +template +struct AddMul { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _addmul_op(address, a, b, c); + } +}; + +template +struct MulAdd { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _muladd_op(address, a, b, c); + } +}; + +template +__global__ void _AddAndMulKernel(T* output_data, const T* pA, const T* pB, const T* pC, + CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N, + const TFunc func) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + func(output_data + id, pA[id % nA], pB[id % nB], pC[id % nC]); + id += NumThreadsPerBlock; + } + } +} + +template +__global__ void _AddAndMulSwitchMiddleAxesKernel(T* output_data, const T* pA, const T* pB, + const T* pC, CUDA_LONG nA, CUDA_LONG nB, + CUDA_LONG nC, CUDA_LONG N, + const TFunc func, CUDA_LONG d2, + CUDA_LONG d3, CUDA_LONG d4) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + CUDA_LONG id = start; + CUDA_LONG k, j, ido; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + k = (id / d4) % d3; + j = (id / (d4 * d3)) % d2; + ido = id + d4 * ((k * d2 + j) - (j * d3 + k)); + func(output_data + ido, pA[id % nA], pB[id % nB], pC[id % nC]); + id += NumThreadsPerBlock; + } + } +} + +template +cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream, + const T* pA, const T* pB, const T* pC, + T* output, + int64_t countA, int64_t countB, int64_t countC, + bool addition_first, bool switchMiddleAxes) { + int64_t max_count = std::max(std::max(countA, countB), countC); + if (max_count == 0) // special case where there's a dim value of 0 in the output shape + return cudaGetLastError(); + + const int num_elements_per_thread = 4; + const int num_threads_per_block = 256; + const int num_el_th = num_threads_per_block * num_elements_per_thread; + + int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th; + + using TT = typename contrib::CudaT::MappedType; + + if (addition_first) { + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + cuda_stream, + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + countA, countB, countC, + max_size, AddMul()); + } else { + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + cuda_stream, + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + countA, countB, countC, + max_size, MulAdd()); + } + return cudaGetLastError(); +} + +template +cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, + const T* pA, const T* pB, const T* pC, + T* output, + int64_t countA, int64_t countB, int64_t countC, + bool addition_first, int64_t d2, int64_t d3, int64_t d4) { + int64_t max_count = std::max(std::max(countA, countB), countC); + if (max_count == 0) // special case where there's a dim value of 0 in the output shape + return cudaGetLastError(); + + const int num_elements_per_thread = 4; + const int num_threads_per_block = 256; + const int num_el_th = num_threads_per_block * num_elements_per_thread; + + int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th; + + using TT = typename contrib::CudaT::MappedType; + + if (addition_first) { + AddAndMulSwitchMiddleAxesKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + cuda_stream, + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + countA, countB, countC, + max_size, AddMul()); + } else { + AddAndMulSwitchMiddleAxesKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + cuda_stream, + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + countA, countB, countC, + max_size, MulAdd()); + } + return cudaGetLastError(); +} diff --git a/operators/cuda/add_mul_impl.cuh b/operators/cuda/add_mul_impl.cuh index 6fdf9369..98bb0616 100644 --- a/operators/cuda/add_mul_impl.cuh +++ b/operators/cuda/add_mul_impl.cuh @@ -12,4 +12,15 @@ cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const T* input_ template cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, - T* output, int64_t length_a, int64_t length_b, int64_t length_c, bool addition); \ No newline at end of file + T* output, int64_t length_a, int64_t length_b, int64_t length_c, bool addition); + +template +cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, + T* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition, bool switchMiddleAxis); + +template +cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, + T* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition, + int64_t d2, int64_t d3, int64_t d4); diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 8c4808a5..9ba3a732 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -10,38 +10,46 @@ #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { - using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; using AddTwiceFloat32Type = typename contrib::AddOrMulTwice; using MulTwiceFloat32Type = typename contrib::AddOrMulTwice; + using AddAndMulFloat32Type = typename contrib::AddOrMulTwice; + using MulAndAddFloat32Type = typename contrib::AddOrMulTwice; + #if ORT_API_VERSION >= 16 using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; using AddTwiceFloat16Type = typename contrib::AddOrMulTwice; using MulTwiceFloat16Type = typename contrib::AddOrMulTwice; -#endif + using AddAndMulFloat32Type = typename contrib::AddOrMulTwice; + using MulAndAddFloat32Type = typename contrib::AddOrMulTwice; +#endif static OrtOpLoader op_loader( []() { return nullptr; } #ifdef USE_CUDA , + CustomCudaStructV2("AddMul", AddAndMulFloat32Type), CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type), CustomCudaStructV2("AddTwice", AddTwiceFloat32Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), + CustomCudaStructV2("MulAdd", MulAndAddFloat32Type), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), CustomCudaStructV2("MulTwice", MulTwiceFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), #if ORT_API_VERSION >= 16 + CustomCudaStructV2("AddMul", AddAndMulFloat16Type), CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type), CustomCudaStructV2("AddTwice", AddTwiceFloat16Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), + CustomCudaStructV2("MulAdd", MulAndAddFloat16Type), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), CustomCudaStructV2("MulTwice", MulTwiceFloat16Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1) diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index 7955795d..db11b212 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -339,6 +339,90 @@ def test_add_shared_input_cuda_broadcast2(self): shapec=(3, 2, 3), ) + def _addmul_cuda(self, itype, op_type1, op_type2, broad=False, negative=False): + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node(op_type1, ["Y", "X"] if negative else ["X", "Y"], ["xy"]), + helper.make_node(op_type2, ["Z", "xy"] if negative else ["xy", "Z"], ["final"]), + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + helper.make_tensor_value_info("Z", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + kwargs = {"negative": 1} if negative else {} + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + f"{op_type1}{op_type2}", + ["X", "Y", "Z"], + ["final"], + domain="ai.onnx.contrib", + **kwargs, + ) + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + helper.make_tensor_value_info("Z", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + shapex = (1, 2, 3) if broad else (3, 2, 3) + shapey = (3, 2, 3) + shapez = (1, 2, 3) if broad else (3, 2, 3) + x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype) + y = (np.arange(np.prod(shapey)) + 1).reshape(shapey).astype(dtype) + z = (np.arange(np.prod(shapez)) + 1).reshape(shapez).astype(dtype) + + feeds1 = dict(X=x, Y=y, Z=z) + ref = ReferenceEvaluator(model1, verbose=0) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_almost_equal(expected, got) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_addmul_cuda(self): + self._addmul_cuda(TensorProto.FLOAT, "Add", "Mul") + self._addmul_cuda(TensorProto.FLOAT16, "Add", "Mul") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_addmul_cuda_broadcast(self): + self._addmul_cuda(TensorProto.FLOAT, "Add", "Mul", True) + self._addmul_cuda(TensorProto.FLOAT16, "Add", "Mul", True) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_muladd_cuda(self): + self._addmul_cuda(TensorProto.FLOAT, "Mul", "Add") + self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Add") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_submul_cuda(self): + self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul") + self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul") + if __name__ == "__main__": unittest.main()