Skip to content

Commit

Permalink
add submul
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 8, 2024
1 parent 31d8feb commit cd66c02
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 8 deletions.
56 changes: 55 additions & 1 deletion operators/cuda/add_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ struct AddAndMul {
input_data_a, input_data_b, input_data_c,
output_data,
length_a, length_b, length_c,
addition_first, switchMiddelAxis_);
addition_first);
}
return {};
}
Expand All @@ -150,4 +150,58 @@ struct AddAndMul {
bool switchMiddelAxis_;
};

template <typename T, bool subtract_first>
struct SubAndMul {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& dict) {
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& tensor_a,
const ortc::Tensor<T>& tensor_b,
const ortc::Tensor<T>& tensor_c,
ortc::Tensor<T>& output) const {
const T* input_data_a = tensor_a.Data();
const T* input_data_b = tensor_b.Data();
const T* input_data_c = tensor_c.Data();

auto length_a = tensor_a.NumberOfElement();
auto length_b = tensor_b.NumberOfElement();
auto length_c = tensor_c.NumberOfElement();
if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
return {};
}

std::vector<int64_t> dimsA = tensor_a.Shape();
std::vector<int64_t> dimsB = tensor_b.Shape();
std::vector<int64_t> dimsC = tensor_c.Shape();

auto max_length = std::max(length_a, std::max(length_b, length_c));

auto max_rank = std::max(dimsA.size(), std::max(dimsB.size(), dimsC.size()));
while (dimsA.size() < max_rank)
dimsA.insert(dimsA.begin(), 1);
while (dimsB.size() < max_rank)
dimsB.insert(dimsB.begin(), 1);
while (dimsC.size() < max_rank)
dimsC.insert(dimsC.begin(), 1);

std::vector<int64_t> output_dims(dimsA.size());
for (size_t i = 0; i < dimsA.size(); ++i) {
output_dims[i] = std::max(std::max(dimsA[i], dimsB[i]), dimsC[i]);
}

T* output_data_ab = output_ab.Allocate(output_dims);
LaunchSubAndMulKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_data_a, input_data_b, input_data_c,
output_data,
length_a, length_b, length_c,
subtract_first, negative_);
return {};
}

private:
bool negative_;
};

} // namespace contrib
218 changes: 217 additions & 1 deletion operators/cuda/add_mul_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream,
const T* pA, const T* pB, const T* pC,
T* output,
int64_t countA, int64_t countB, int64_t countC,
bool addition_first, bool switchMiddleAxes) {
bool addition_first) {
int64_t max_count = std::max(std::max(countA, countB), countC);
if (max_count == 0) // special case where there's a dim value of 0 in the output shape
return cudaGetLastError();
Expand Down Expand Up @@ -358,6 +358,22 @@ cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream,
return cudaGetLastError();
}

template <>
cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c,
float* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition) {
return _LaunchAddAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, addition_first);
}

template <>
cudaError_t LaunchAddAndMulKernel(cudaStream_t stream,
const ortc::MFloat16* input_a, const ortc::MFloat16* input_b,
const ortc::MFloat16* input_c,
ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition) {
return _LaunchAddAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, addition_first);
}

template <typename T>
cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream,
const T* pA, const T* pB, const T* pC,
Expand Down Expand Up @@ -399,3 +415,203 @@ cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream,
}
return cudaGetLastError();
}

template <>
cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c,
float* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition,
int64_t d2, int64_t d3, int64_t d4) {
return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, pA, pB, pC, output, countA, countB, countC,
addition_first, d2, d3, d4);
}

template <>
cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const ortc::MFloat16* input_a,
const ortc::MFloat16* input_b, const ortc::MFloat16* input_c,
ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition,
int64_t d2, int64_t d3, int64_t d4) {
return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, pA, pB, pC, output, countA, countB, countC,
addition_first, d2, d3, d4);
}

__device__ __forceinline__ void _submul_op(float* address, const float a, const float b,
const float c) {
*address = (a - b) * c;
}

__device__ __forceinline__ void _submul_op(half* address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half((__half2float(a) - __half2float(b)) * __half2float(c));
#else
*address = (a - b) * c;
#endif
}

__device__ __forceinline__ void _submul_neg_op(float* address, const float a, const float b,
const float c) {
*address = (b - a) * c;
}

__device__ __forceinline__ void _submul_neg_op(half* address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half((__half2float(b) - __half2float(a)) * __half2float(c));
#else
*address = (b - a) * c;
#endif
}

__device__ __forceinline__ void _mulsub_op(float* address, const float a, const float b,
const float c) {
*address = a * b - c;
}

__device__ __forceinline__ void _mulsub_op(half* address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half(__half2float(a) * __half2float(b) - __half2float(c));
#else
*address = a * b - c;
#endif
}

__device__ __forceinline__ void _mulsub_neg_op(float* address, const float a, const float b,
const float c) {
*address = c - a * b;
}

__device__ __forceinline__ void _mulsub_neg_op(half* address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half(__half2float(c) - __half2float(a) * __half2float(b));
#else
*address = c - a * b;
#endif
}

template <typename T>
struct SubMul {
__device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const {
_submul_op(address, a, b, c);
}
};

template <typename T>
struct MulSub {
__device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const {
_mulsub_op(address, a, b, c);
}
};

template <typename T>
struct SubMulNeg {
__device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const {
_submul_neg_op(address, a, b, c);
}
};

template <typename T>
struct MulSubNeg {
__device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const {
_mulsub_neg_op(address, a, b, c);
}
};

template <typename T, typename TFunc, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _MulSubKernel(T* output_data, const T* pA, const T* pB, const T* pC,
CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N,
const TFunc func) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
func(output_data + id, pA[id % nA], pB[id % nB], pC[id % nC]);
id += NumThreadsPerBlock;
}
}
}

template <typename T>
cudaError_t _LaunchSubAndMulKernel(cudaStream_t stream,
const T* pA, const T* pB, const T* pC,
T* output,
int64_t countA, int64_t countB, int64_t countC,
bool addition_first) {
int64_t max_count = std::max(std::max(countA, countB), countC);
if (max_count == 0) // special case where there's a dim value of 0 in the output shape
return cudaGetLastError();

const int num_elements_per_thread = 4;
const int num_threads_per_block = 256;
const int num_el_th = num_threads_per_block * num_elements_per_thread;

int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th;

using TT = typename contrib::CudaT<T>::MappedType;

if (addition_first) {
if (negative) {
SubAndMulKernel<TT, SubMul<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
cuda_stream,
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA),
reinterpret_cast<const TT*>(pB),
reinterpret_cast<const TT*>(pC),
countA, countB, countC,
max_size, SubMulNEg<TT>());
} else {
SubAndMulKernel<TT, SubMul<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
cuda_stream,
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA),
reinterpret_cast<const TT*>(pB),
reinterpret_cast<const TT*>(pC),
countA, countB, countC,
max_size, SubMul<TT>());
}
} else {
if (negative) {
SubAndMulKernel<TT, MulSub<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
cuda_stream,
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA),
reinterpret_cast<const TT*>(pB),
reinterpret_cast<const TT*>(pC),
countA, countB, countC,
max_size, MulSubNeg<TT>());
} else {
SubAndMulKernel<TT, MulSub<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
cuda_stream,
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA),
reinterpret_cast<const TT*>(pB),
reinterpret_cast<const TT*>(pC),
countA, countB, countC,
max_size, MulSub<TT>());
}
}
return cudaGetLastError();
}

template <>
cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c,
float* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool subtract_first, bool negative) {
return _LaunchSubAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, subtract_first, negative);
}

template <>
cudaError_t LaunchSubAndMulKernel(cudaStream_t stream,
const ortc::MFloat16* input_a, const ortc::MFloat16* input_b,
const ortc::MFloat16* input_c,
ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool subtract_first, negative) {
return _LaunchSubAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, subtract_first, negative);
}
7 changes: 6 additions & 1 deletion operators/cuda/add_mul_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, const T* input_a, con
template <typename T>
cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
T* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition, bool switchMiddleAxis);
bool addition);

template <typename T>
cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
T* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition,
int64_t d2, int64_t d3, int64_t d4);

template <typename T>
cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
T* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition, bool negative);
20 changes: 15 additions & 5 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
using AddTwiceFloat32Type = typename contrib::AddOrMulTwice<float, true>;
using MulTwiceFloat32Type = typename contrib::AddOrMulTwice<float, false>;

using AddAndMulFloat32Type = typename contrib::AddOrMulTwice<float, true>;
using MulAndAddFloat32Type = typename contrib::AddOrMulTwice<float, false>;
using AddAndMulFloat32Type = typename contrib::AddAndMul<float, true>;
using MulAndAddFloat32Type = typename contrib::AddAndMul<float, false>;

using SubAndMulFloat32Type = typename contrib::SubAndMul<float, true>;
using MulAndSubFloat32Type = typename contrib::SubAndMul<float, false>;

#if ORT_API_VERSION >= 16
using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, true>;
Expand All @@ -26,8 +29,11 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
using AddTwiceFloat16Type = typename contrib::AddOrMulTwice<ortc::MFloat16, true>;
using MulTwiceFloat16Type = typename contrib::AddOrMulTwice<ortc::MFloat16, false>;

using AddAndMulFloat32Type = typename contrib::AddOrMulTwice<ortc::MFloat16, true>;
using MulAndAddFloat32Type = typename contrib::AddOrMulTwice<ortc::MFloat16, false>;
using AddAndMulFloat32Type = typename contrib::AddAndMul<ortc::MFloat16, true>;
using MulAndAddFloat32Type = typename contrib::AddAndMul<ortc::MFloat16, false>;

using SubAndMulFloat32Type = typename contrib::SubAndMul<ortc::MFloat16, true>;
using MulAndSubFloat32Type = typename contrib::SubAndMul<ortc::MFloat16, false>;
#endif

static OrtOpLoader op_loader(
Expand All @@ -40,8 +46,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MulAdd", MulAndAddFloat32Type),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("MulSub", MulAndSubFloat32Type),
CustomCudaStructV2("MulTwice", MulTwiceFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("SubMul", SubAndMulFloat32Type),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("AddMul", AddAndMulFloat16Type),
Expand All @@ -51,8 +59,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MulAdd", MulAndAddFloat16Type),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("MulSub", MulAndSubFloat16Type),
CustomCudaStructV2("MulTwice", MulTwiceFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("SubMul", SubAndMulFloat16Type)
#endif
#endif
);
Expand Down
20 changes: 20 additions & 0 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,26 @@ def test_submul_cuda(self):
self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul")
self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul")

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_submul_cuda_negative(self):
self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul", negative=True)
self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul", negative=True)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_submul_cuda_broadcast(self):
self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul", True)
self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul", True)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_mulsub_cuda(self):
self._addmul_cuda(TensorProto.FLOAT, "Mul", "Sub")
self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Sub")

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_mulsub_cuda_negative(self):
self._addmul_cuda(TensorProto.FLOAT, "Mul", "Sub", negative=True)
self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Sub", negative=True)


if __name__ == "__main__":
unittest.main()

0 comments on commit cd66c02

Please sign in to comment.