From 06e07802b079f09b94292f572a4e7d387a5a0f28 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 May 2024 11:16:21 +0000 Subject: [PATCH] fix build --- operators/cuda/scatter_nd_of_shape.cu | 133 ++------------------------ operators/cuda/scatter_nd_of_shape.h | 46 +++++---- test/cuda/test_cudaops.py | 111 ++++++++------------- 3 files changed, 78 insertions(+), 212 deletions(-) diff --git a/operators/cuda/scatter_nd_of_shape.cu b/operators/cuda/scatter_nd_of_shape.cu index 90c60b36..64f0d450 100644 --- a/operators/cuda/scatter_nd_of_shape.cu +++ b/operators/cuda/scatter_nd_of_shape.cu @@ -5,6 +5,12 @@ #include #include #include +#include "exceptions.h" +#include "custom_op/onnxruntime_f16.h" + +namespace ortc = Ort::Custom; + +#include "cuda_type.h" namespace contrib { @@ -65,116 +71,6 @@ addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__ } } -////////////////// -// ScatterNDOfShapeOp... -////////////////// - -template -void* ScatterNDOfShapeOp::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { - return std::make_unique>(api, info).release(); -} - -template -const char* ScatterNDOfShapeOp::GetName() const { - return "ScatterNDOfShape"; -} - -template -const char* ScatterNDOfShapeOp::GetExecutionProviderType() const { - return "CUDAExecutionProvider"; -} - -template -size_t ScatterNDOfShapeOp::GetInputTypeCount() const { return 3; }; - -template <> -ONNXTensorElementDataType ScatterNDOfShapeOp::GetInputType(std::size_t index) const { - switch (index) { - case 0: - case 1: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - case 2: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - default: - ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); - } -} - -template <> -ONNXTensorElementDataType ScatterNDOfShapeOp::GetInputType(std::size_t index) const { - switch (index) { - case 0: - case 1: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - case 2: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - default: - ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); - } -} - -template -OrtMemType ScatterNDOfShapeOp::GetInputMemoryType(std::size_t index) const { - switch (index) { - case 0: - return OrtMemTypeCPUInput; - case 1: - case 2: - return OrtMemTypeDefault; - default: - ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION); - } -} - -template -OrtCustomOpInputOutputCharacteristic -ScatterNDOfShapeOp::GetInputCharacteristic(std::size_t index) const { - switch (index) { - case 0: - case 1: - case 2: - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - default: - ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); - } -} - -template -size_t ScatterNDOfShapeOp::GetOutputTypeCount() const { return 1; } - -template <> -ONNXTensorElementDataType ScatterNDOfShapeOp::GetOutputType(std::size_t index) const { - // D, scale D - switch (index) { - case 0: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - default: - ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); - } -} - -template <> -ONNXTensorElementDataType ScatterNDOfShapeOp::GetOutputType(std::size_t index) const { - // D, scale D - switch (index) { - case 0: - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - default: - ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION); - } -} - -template -OrtCustomOpInputOutputCharacteristic -ScatterNDOfShapeOp::GetOutputCharacteristic(std::size_t index) const { - switch (index) { - case 0: - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - default: - ORTX_CXX_API_THROW("Wrong output index", ORT_RUNTIME_EXCEPTION); - } -} - /////////////////// // ScatterNDOfShapeKernel /////////////////// @@ -201,7 +97,7 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext* context) { Ort::KernelContext ctx(context); - int n_inputs = ctx.GetInputCount(); + int n_inputs = ctx.GetInputCount(); // crashes here... core dumped. _ENFORCE(n_inputs == 3, "Expecting 3 inputs."); Ort::ConstValue shape = ctx.GetInput(0); Ort::ConstValue indices = ctx.GetInput(1); @@ -244,18 +140,9 @@ void _ComputeNoAtomic(cudaStream_t stream, const std::vector& input_sha int threads_per_block, int blocks_per_grid, size_t indice_size, size_t nrows, size_t stride) { dim3 threads(threads_per_block); dim3 blocks(blocks_per_grid); - addition_inplace_kernel<<>>(output_data, indices_data, updates_data, indice_size, nrows, stride); -} - -template <> -void _ComputeNoAtomic(cudaStream_t stream, const std::vector& input_shape, - const std::vector& indices_shape, ortc::MFloat16* output_data, - const int64_t* indices_data, const ortc::MFloat16* updates_data, - int threads_per_block, int blocks_per_grid, size_t indice_size, size_t nrows, size_t stride) { - - dim3 threads(threads_per_block); - dim3 blocks(blocks_per_grid); - addition_inplace_kernel<<>>((half*)output_data, indices_data, (const half*)updates_data, indice_size, nrows, stride); + using TT = typename CudaT::MappedType; + addition_inplace_kernel<<>>((TT*)output_data, indices_data, + (TT*)updates_data, indice_size, nrows, stride); } template diff --git a/operators/cuda/scatter_nd_of_shape.h b/operators/cuda/scatter_nd_of_shape.h index 54bfe3f4..54c42be9 100644 --- a/operators/cuda/scatter_nd_of_shape.h +++ b/operators/cuda/scatter_nd_of_shape.h @@ -3,15 +3,20 @@ #pragma once -#include "ocos.h" -// #include "cublas_v2.h" -#include - +#define ORT_API_MANUAL_INIT #ifdef ORT_SWIFT_PACKAGE_MANAGER_BUILD +#include "onnxruntime/onnxruntime_c_api.h" #include "onnxruntime/onnxruntime_cxx_api.h" #else +#include "onnxruntime_c_api.h" #include "onnxruntime_cxx_api.h" #endif +#undef ORT_API_MANUAL_INIT +#include "custom_op/onnxruntime_f16.h" + +// #include "ocos.h" +// #include "cublas_v2.h" +#include namespace contrib { @@ -23,10 +28,17 @@ enum class Reduction : int { Max = 4, }; +template +inline ONNXTensorElementDataType onnx_type(); +template <> +inline ONNXTensorElementDataType onnx_type() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } +template <> +inline ONNXTensorElementDataType onnx_type() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; } + /** * This kernel implementation the fusion of ConstantOfShape and ScatterND. * The implementation does not use OrtLiteCustom as the input shape (first input) - * is expected to be on CPU wheeras the other outputs are expected to be on CUDA. + * is expected to be on CPU whereas the other outputs are expected to be on CUDA. */ template struct ScatterNDOfShapeKernel { @@ -46,18 +58,18 @@ template struct ScatterNDOfShapeOp : Ort::CustomOpBase, ScatterNDOfShapeKernel> { typedef Ort::CustomOpBase, ScatterNDOfShapeKernel> parent_type; ScatterNDOfShapeOp() : parent_type() {} - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; - const char* GetName() const; - const char* GetExecutionProviderType() const; - - std::size_t GetInputTypeCount() const; - ONNXTensorElementDataType GetInputType(std::size_t index) const; - OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(std::size_t index) const; - OrtMemType GetInputMemoryType(std::size_t index) const; - - std::size_t GetOutputTypeCount() const; - ONNXTensorElementDataType GetOutputType(std::size_t index) const; - OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const; + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { return std::make_unique>(api, info).release(); } + const char* GetName() const { return "ScatterNDOfShape"; } + const char* GetExecutionProviderType() const { return "CUDAExecutionProvider"; } + + std::size_t GetInputTypeCount() const { return 3; } + ONNXTensorElementDataType GetInputType(std::size_t index) const { return index == 2 ? onnx_type() : ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; } + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(std::size_t index) const { return INPUT_OUTPUT_REQUIRED; } + OrtMemType GetInputMemoryType(std::size_t index) const { return index == 0 ? OrtMemTypeCPUInput : OrtMemTypeDefault; } + + std::size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(std::size_t index) const { return onnx_type(); } + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const { return INPUT_OUTPUT_REQUIRED; } }; } // namespace contrib diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index f8867cd6..1c946077 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -20,25 +20,19 @@ def _run(self, shape, indices, updates, reduction=None, strategy=None): return (y,) - class TestCudaOps(unittest.TestCase): @staticmethod - def _create_negpos_test_model(domain='ai.onnx.contrib'): + def _create_negpos_test_model(domain="ai.onnx.contrib"): nodes = [ - helper.make_node('Identity', ['x'], ['identity1']), - helper.make_node( - 'NegPos', ['identity1'], ['neg', 'pos'], - domain=domain) + helper.make_node("Identity", ["x"], ["identity1"]), + helper.make_node("NegPos", ["identity1"], ["neg", "pos"], domain=domain), ] - input0 = helper.make_tensor_value_info( - 'x', onnx_proto.TensorProto.FLOAT, [None, None]) - output1 = helper.make_tensor_value_info( - 'neg', onnx_proto.TensorProto.FLOAT, [None, None]) - output2 = helper.make_tensor_value_info( - 'pos', onnx_proto.TensorProto.FLOAT, [None, None]) + input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT, [None, None]) + output1 = helper.make_tensor_value_info("neg", onnx_proto.TensorProto.FLOAT, [None, None]) + output2 = helper.make_tensor_value_info("pos", onnx_proto.TensorProto.FLOAT, [None, None]) - graph = helper.make_graph(nodes, 'test0', [input0], [output1, output2]) + graph = helper.make_graph(nodes, "test0", [input0], [output1, output2]) model = make_onnx_model(graph) return model @@ -47,88 +41,67 @@ def test_cuda_negpos(self): so.register_custom_ops_library(_get_library_path()) onnx_model = self._create_negpos_test_model() self.assertIn('op_type: "NegPos"', str(onnx_model)) - sess = _ort.InferenceSession(onnx_model.SerializeToString(), - so, - providers=['CUDAExecutionProvider']) - x = np.array([[0., 1., 1.5], [7., 8., -5.5]]).astype(np.float32) - neg, pos = sess.run(None, {'x': x}) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"]) + x = np.array([[0.0, 1.0, 1.5], [7.0, 8.0, -5.5]]).astype(np.float32) + neg, pos = sess.run(None, {"x": x}) diff = x - (neg + pos) assert_almost_equal(diff, np.zeros(diff.shape)) @staticmethod - def _create_fastgelu_test_model(domain='ai.onnx.contrib'): - nodes = [ - helper.make_node( - 'FastGelu', ['x', 'bias'], ['y'], - domain=domain) - ] + def _create_fastgelu_test_model(domain="ai.onnx.contrib"): + nodes = [helper.make_node("FastGelu", ["x", "bias"], ["y"], domain=domain)] - input0 = helper.make_tensor_value_info( - 'x', onnx_proto.TensorProto.FLOAT, []) - input1 = helper.make_tensor_value_info( - 'bias', onnx_proto.TensorProto.FLOAT, []) - output0 = helper.make_tensor_value_info( - 'y', onnx_proto.TensorProto.FLOAT, []) + input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT, []) + input1 = helper.make_tensor_value_info("bias", onnx_proto.TensorProto.FLOAT, []) + output0 = helper.make_tensor_value_info("y", onnx_proto.TensorProto.FLOAT, []) - graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0]) + graph = helper.make_graph(nodes, "test1", [input0, input1], [output0]) model = make_onnx_model(graph) return model @staticmethod - def _create_fastgelu_test_model_f16(domain='ai.onnx.contrib'): - nodes = [ - helper.make_node( - 'FastGelu', ['x', 'bias'], ['y'], - domain=domain) - ] + def _create_fastgelu_test_model_f16(domain="ai.onnx.contrib"): + nodes = [helper.make_node("FastGelu", ["x", "bias"], ["y"], domain=domain)] - input0 = helper.make_tensor_value_info( - 'x', onnx_proto.TensorProto.FLOAT16, []) - input1 = helper.make_tensor_value_info( - 'bias', onnx_proto.TensorProto.FLOAT16, []) - output0 = helper.make_tensor_value_info( - 'y', onnx_proto.TensorProto.FLOAT16, []) + input0 = helper.make_tensor_value_info("x", onnx_proto.TensorProto.FLOAT16, []) + input1 = helper.make_tensor_value_info("bias", onnx_proto.TensorProto.FLOAT16, []) + output0 = helper.make_tensor_value_info("y", onnx_proto.TensorProto.FLOAT16, []) - graph = helper.make_graph(nodes, 'test1', [input0, input1], [output0]) + graph = helper.make_graph(nodes, "test1", [input0, input1], [output0]) model = make_onnx_model(graph) return model def test_cuda_fastgelu(self): eps = _ort.get_available_providers() - if 'CUDAExecutionProvider' in eps: + if "CUDAExecutionProvider" in eps: so = _ort.SessionOptions() so.register_custom_ops_library(_get_library_path()) onnx_model = self._create_fastgelu_test_model() self.assertIn('op_type: "FastGelu"', str(onnx_model)) - sess = _ort.InferenceSession(onnx_model.SerializeToString(), - so, - providers=['CUDAExecutionProvider']) - x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float32) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"]) + x = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).astype(np.float32) bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32) - expected_y = np.array([0., 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32) - y = sess.run(None, {'x': x, 'bias':bias})[0] + expected_y = np.array([0.0, 0.9505811, 2.1696784, 3.298689, 4.399991, 5.5]).astype(np.float32) + y = sess.run(None, {"x": x, "bias": bias})[0] assert_almost_equal(y, expected_y) else: - print ('CUDAExecutionProvider not available, test_cuda_fastgelu skipped.') + print("CUDAExecutionProvider not available, test_cuda_fastgelu skipped.") def test_cuda_fastgelu_f16(self): eps = _ort.get_available_providers() - if 'CUDAExecutionProvider' in eps: + if "CUDAExecutionProvider" in eps: so = _ort.SessionOptions() so.register_custom_ops_library(_get_library_path()) onnx_model = self._create_fastgelu_test_model_f16() self.assertIn('op_type: "FastGelu"', str(onnx_model)) - sess = _ort.InferenceSession(onnx_model.SerializeToString(), - so, - providers=['CUDAExecutionProvider']) - x = np.array([0., 1., 2., 3., 4., 5.]).astype(np.float16) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=["CUDAExecutionProvider"]) + x = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).astype(np.float16) bias = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float16) - expected_y = np.array([0., 0.95, 2.17, 3.299, 4.4, 5.5]).astype(np.float16) - y = sess.run(None, {'x': x, 'bias':bias})[0] + expected_y = np.array([0.0, 0.95, 2.17, 3.299, 4.4, 5.5]).astype(np.float16) + y = sess.run(None, {"x": x, "bias": bias})[0] assert_almost_equal(y, expected_y) else: - print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.') - + print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.") def _scatternd_of_shape_cuda(self, reduction, line, itype): import onnxruntime @@ -146,9 +119,7 @@ def _scatternd_of_shape_cuda(self, reduction, line, itype): "nd", [ helper.make_tensor_value_info("data", itype, [None, None, None]), - helper.make_tensor_value_info( - "indices", TensorProto.INT64, [None, None] - ), + helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None]), helper.make_tensor_value_info("updates", itype, [None, None, None]), ], [helper.make_tensor_value_info("y", itype, [None, None, None])], @@ -168,22 +139,20 @@ def _scatternd_of_shape_cuda(self, reduction, line, itype): inputs=["shape", "indices", "updates"], outputs=["y"], reduction=reduction, - domain="onnx_extended.ortops.optim.cuda", + domain="ai.onnx.contrib", ) ], "nd", [ helper.make_tensor_value_info("shape", TensorProto.INT64, [None]), - helper.make_tensor_value_info( - "indices", TensorProto.INT64, [None, None] - ), + helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None]), helper.make_tensor_value_info("updates", itype, [None, None, None]), ], [helper.make_tensor_value_info("y", itype, [None, None, None])], ), opset_imports=[ helper.make_opsetid("", 18), - helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1), + helper.make_opsetid("ai.onnx.contrib", 1), ], ir_version=9, ) @@ -195,9 +164,7 @@ def _scatternd_of_shape_cuda(self, reduction, line, itype): updates = np.arange(18).reshape((3, 2, 3)).astype(dtype) feeds1 = dict(data=data, indices=indices, updates=updates) - feeds2 = dict( - shape=np.array([2, 2, 3], dtype=np.int64), indices=indices, updates=updates - ) + feeds2 = dict(shape=np.array([2, 2, 3], dtype=np.int64), indices=indices, updates=updates) ref = ReferenceEvaluator(model1, new_ops=[ScatterNDOfShape]) expected = ref.run(None, feeds1)[0]