Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…sions into misc
  • Loading branch information
xadupre committed Jun 12, 2024
2 parents d3b6f5d + 690bed7 commit c1ecc67
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 1 deletion.
5 changes: 5 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifdef USE_CUDA
#include "cuda/add_mul.h"
#include "cuda/fast_gelu.h"
#include "cuda/mul_sigmoid.h"
#include "cuda/negxplus1.h"
#include "cuda/scatter_nd_of_shape.h"
#include "cuda/transpose_cast.h"
Expand All @@ -32,6 +33,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<float>),
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
#if ORT_API_VERSION >= 16
Expand All @@ -41,6 +44,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<ortc::MFloat16>),
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
Expand Down
72 changes: 72 additions & 0 deletions operators/cuda/mul_sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "mul_sigmoid_impl.cuh"
#include "ortx_common.h"

namespace contrib {

/**
* MulSigmoid(X) = X * Sigmoid(X)
No shape broadcasting supported.
*/
template <typename T>
struct MulSigmoid {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
T* output_data = output.Allocate(input.Shape());
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return {};
}
LaunchMulSigmoidKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
input_data,
output_data);
return {};
}
};

/**
* MulSigmoid(X, Y) = X * Y * Sigmoid(Y)
No shape broadcasting supported.
*/
template <typename T>
struct MulMulSigmoid {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input_x,
const ortc::Tensor<T>& input_y,
ortc::Tensor<T>& output) const {
const T* input_data_x = input_x.Data();
const T* input_data_y = input_y.Data();
auto input_length_x = input_x.NumberOfElement();
auto input_length_y = input_y.NumberOfElement();
if (0 == input_length_x || 0 == input_data_y) {
return {};
}
T* output_data = output.Allocate(input_length_x > input_length_y ? input_x.Shape() : input_y.Shape());
LaunchMulMulSigmoidKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length_x,
input_length_y,
input_data_x,
input_data_y,
output_data);
return {};
}
};

} // namespace contrib
119 changes: 119 additions & 0 deletions operators/cuda/mul_sigmoid_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "mul_sigmoid_impl.cuh"
#include "cuda_type.h"

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _exp_typed(const T x);

template <> __device__ __inline__ float _exp_typed(const float x) { return expf(x); }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _exp_typed(const half x) {
return __float2half(expf(__half2float(x)));
}
#else
template <> __device__ __inline__ half _exp_typed(const half x) { return hexp(x); }
#endif

template <typename T> __device__ __inline__ T sigmoid(const T a) {
return a > T(0) ? (T)1 / ((T)1. + _exp_typed<T>(-a))
: (T)1 - (T)1 / ((T)1 + _exp_typed<T>(a));
}

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half sigmoid(const half a) {
return __float2half(sigmoid(__half2float(a)));
}
#endif

template <typename T> __device__ __inline__ T mul_sigmoid(const T a) { return a * sigmoid(a); }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half mul_sigmoid(const half a) {
float x = __half2float(a);
return __float2half(x * sigmoid(x));
}
#endif

template <typename T> __device__ __inline__ T mul_mul_sigmoid(const T x, const T y) {
return x * y * sigmoid(y);
}

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half mul_mul_sigmoid(const half x, const half y) {
float hy = __half2float(y);
return __float2half(__half2float(x) * hy * sigmoid(hy));
}
#endif

template <typename T>
__global__ void MulSigmoidKernel(T *output_data, const T *input_data, CUDA_LONG N) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = mul_sigmoid(input_data[id]);
}

template <typename T>
__global__ void MulMulSigmoidKernel(T *output_data, const T *px, const T *py, CUDA_LONG N,
CUDA_LONG Nx, CUDA_LONG Ny) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = mul_mul_sigmoid(px[id % Nx], py[id % Ny]);
}

template <typename T>
cudaError_t _LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
using TT = typename contrib::CudaT<T>::MappedType;
MulSigmoidKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
return cudaGetLastError();
}

template <>
cudaError_t LaunchMulSigmoidKernel<float>(cudaStream_t stream, int input_length, const float* input, float* output) {
return _LaunchMulSigmoidKernel(stream, input_length, input, output);
}

template <>
cudaError_t LaunchMulSigmoidKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input, ortc::MFloat16* output) {
return _LaunchMulSigmoidKernel(stream, input_length, input, output);
}

template <typename T>
cudaError_t _LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y,
const T* input_data_x, const T* input_data_y, T* output) {
int input_length = std::max(input_length_x, input_length_y);
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
using TT = typename contrib::CudaT<T>::MappedType;
MulMulSigmoidKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(input_data_x),
reinterpret_cast<const TT*>(input_data_y),
input_length, input_length_x, input_length_y);
return cudaGetLastError();
}

template <>
cudaError_t LaunchMulMulSigmoidKernel<float>(cudaStream_t stream, int input_length_x, int input_length_y,
const float* input_data_x, const float* input_data_y, float* output) {
return _LaunchMulMulSigmoidKernel(stream, input_length_x, input_length_y, input_data_x, input_data_y, output);
}

template <>
cudaError_t LaunchMulMulSigmoidKernel<ortc::MFloat16>(cudaStream_t stream, int input_length_x, int input_length_y,
const ortc::MFloat16* input_data_x, const ortc::MFloat16* input_data_y,
ortc::MFloat16* output) {
return _LaunchMulMulSigmoidKernel(stream, input_length_x, input_length_y, input_data_x, input_data_y, output);
}
13 changes: 13 additions & 0 deletions operators/cuda/mul_sigmoid_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

template <typename T>
cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output);

template <typename T>
cudaError_t LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y,
const T* input_data_x, const T* input_data_y, T* output);
127 changes: 126 additions & 1 deletion test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import numpy as np
from numpy.testing import assert_almost_equal
from numpy.testing import assert_almost_equal, assert_allclose
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
Expand Down Expand Up @@ -128,6 +128,131 @@ def test_cuda_fastgelu_f16(self):
else:
print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.")

def _mulmulsigmoid_cuda(self, itype, broad=False, atol=1e-5, rtol=1e-3):
model1 = helper.make_model(
helper.make_graph(
[
helper.make_node("Mul", ["X", "Y"], ["xy"]),
helper.make_node("Sigmoid", ["Y"], ["sy"]),
helper.make_node("Mul", ["xy", "sy"], ["final"]),
],
"nd",
[
helper.make_tensor_value_info("X", itype, [None, None, None]),
helper.make_tensor_value_info("Y", itype, [None, None, None]),
],
[helper.make_tensor_value_info("final", itype, [None, None, None])],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)

model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"MulMulSigmoid",
["X", "Y"],
["final"],
domain="ai.onnx.contrib",
)
],
"nd",
[
helper.make_tensor_value_info("X", itype, [None, None, None]),
helper.make_tensor_value_info("Y", itype, [None, None, None]),
],
[helper.make_tensor_value_info("final", itype, [None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
shapex = (1, 2, 3) if broad else (3, 2, 3)
shapey = (3, 2, 3)
x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype)
y = (np.arange(np.prod(shapey)) + 2).reshape(shapey).astype(dtype)
x /= x.size
y /= y.size

feeds1 = dict(X=x, Y=y)
ref = ReferenceEvaluator(model1)
expected = ref.run(None, feeds1)[0]

opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)[0]
assert_allclose(expected, got, atol=atol, rtol=rtol)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_mulmulsigmoid_cuda(self):
self._mulmulsigmoid_cuda(TensorProto.FLOAT)
self._mulmulsigmoid_cuda(TensorProto.FLOAT16)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_mulmulsigmoid_cuda_broadcast(self):
self._mulmulsigmoid_cuda(TensorProto.FLOAT, True)
self._mulmulsigmoid_cuda(TensorProto.FLOAT16, True)

def _mul_sigmoid_cuda(self, itype):
model1 = helper.make_model(
helper.make_graph(
[
helper.make_node("Sigmoid", ["X"], ["sx"]),
helper.make_node("Mul", ["X", "sx"], ["Y"]),
],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)

model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"MulSigmoid",
["X"],
["Y"],
domain="ai.onnx.contrib",
)
],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(18) + 1).reshape((3, 2, 3)).astype(dtype)

feeds1 = dict(X=x)
ref = ReferenceEvaluator(model1)
expected = ref.run(None, feeds1)[0]

opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)[0]
assert_allclose(expected, got, atol=1e-5 if itype == TensorProto.FLOAT else 1e-2)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_mul_sigmoid_cuda(self):
self._mul_sigmoid_cuda(TensorProto.FLOAT)
self._mul_sigmoid_cuda(TensorProto.FLOAT16)

def _negxplus1_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
model1 = helper.make_model(
Expand Down

0 comments on commit c1ecc67

Please sign in to comment.