Skip to content

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed May 6, 2024
1 parent 09601fc commit 06e0780
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 212 deletions.
133 changes: 10 additions & 123 deletions operators/cuda/scatter_nd_of_shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "exceptions.h"
#include "custom_op/onnxruntime_f16.h"

namespace ortc = Ort::Custom;

#include "cuda_type.h"

namespace contrib {

Expand Down Expand Up @@ -65,116 +71,6 @@ addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__
}
}

//////////////////
// ScatterNDOfShapeOp...
//////////////////

template <typename T>
void* ScatterNDOfShapeOp<T>::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return std::make_unique<ScatterNDOfShapeKernel<T>>(api, info).release();
}

template <typename T>
const char* ScatterNDOfShapeOp<T>::GetName() const {
return "ScatterNDOfShape";
}

template <typename T>
const char* ScatterNDOfShapeOp<T>::GetExecutionProviderType() const {
return "CUDAExecutionProvider";
}

template <typename T>
size_t ScatterNDOfShapeOp<T>::GetInputTypeCount() const { return 3; };

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<float>::GetInputType(std::size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
}
}

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<ortc::MFloat16>::GetInputType(std::size_t index) const {
switch (index) {
case 0:
case 1:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
case 2:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T>
OrtMemType ScatterNDOfShapeOp<T>::GetInputMemoryType(std::size_t index) const {
switch (index) {
case 0:
return OrtMemTypeCPUInput;
case 1:
case 2:
return OrtMemTypeDefault;
default:
ORTX_CXX_API_THROW("Wrong input index.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T>
OrtCustomOpInputOutputCharacteristic
ScatterNDOfShapeOp<T>::GetInputCharacteristic(std::size_t index) const {
switch (index) {
case 0:
case 1:
case 2:
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T>
size_t ScatterNDOfShapeOp<T>::GetOutputTypeCount() const { return 1; }

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<float>::GetOutputType(std::size_t index) const {
// D, scale D
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
}
}

template <>
ONNXTensorElementDataType ScatterNDOfShapeOp<ortc::MFloat16>::GetOutputType(std::size_t index) const {
// D, scale D
switch (index) {
case 0:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
default:
ORTX_CXX_API_THROW("Wrong output index.", ORT_RUNTIME_EXCEPTION);
}
}

template <typename T>
OrtCustomOpInputOutputCharacteristic
ScatterNDOfShapeOp<T>::GetOutputCharacteristic(std::size_t index) const {
switch (index) {
case 0:
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
default:
ORTX_CXX_API_THROW("Wrong output index", ORT_RUNTIME_EXCEPTION);
}
}

///////////////////
// ScatterNDOfShapeKernel
///////////////////
Expand All @@ -201,7 +97,7 @@ template <typename T>
void ScatterNDOfShapeKernel<T>::Compute(OrtKernelContext* context) {
Ort::KernelContext ctx(context);

int n_inputs = ctx.GetInputCount();
int n_inputs = ctx.GetInputCount(); // crashes here... core dumped.
_ENFORCE(n_inputs == 3, "Expecting 3 inputs.");
Ort::ConstValue shape = ctx.GetInput(0);
Ort::ConstValue indices = ctx.GetInput(1);
Expand Down Expand Up @@ -244,18 +140,9 @@ void _ComputeNoAtomic(cudaStream_t stream, const std::vector<int64_t>& input_sha
int threads_per_block, int blocks_per_grid, size_t indice_size, size_t nrows, size_t stride) {
dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
addition_inplace_kernel<T><<<blocks, threads, 0, stream>>>(output_data, indices_data, updates_data, indice_size, nrows, stride);
}

template <>
void _ComputeNoAtomic<ortc::MFloat16>(cudaStream_t stream, const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& indices_shape, ortc::MFloat16* output_data,
const int64_t* indices_data, const ortc::MFloat16* updates_data,
int threads_per_block, int blocks_per_grid, size_t indice_size, size_t nrows, size_t stride) {

dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
addition_inplace_kernel<half><<<blocks, threads, 0, stream>>>((half*)output_data, indices_data, (const half*)updates_data, indice_size, nrows, stride);
using TT = typename CudaT<T>::MappedType;
addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>((TT*)output_data, indices_data,
(TT*)updates_data, indice_size, nrows, stride);
}

template <typename T>
Expand Down
46 changes: 29 additions & 17 deletions operators/cuda/scatter_nd_of_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@

#pragma once

#include "ocos.h"
// #include "cublas_v2.h"
#include <cuda_runtime.h>

#define ORT_API_MANUAL_INIT
#ifdef ORT_SWIFT_PACKAGE_MANAGER_BUILD
#include "onnxruntime/onnxruntime_c_api.h"
#include "onnxruntime/onnxruntime_cxx_api.h"
#else
#include "onnxruntime_c_api.h"
#include "onnxruntime_cxx_api.h"
#endif
#undef ORT_API_MANUAL_INIT
#include "custom_op/onnxruntime_f16.h"

// #include "ocos.h"
// #include "cublas_v2.h"
#include <cuda_runtime.h>

namespace contrib {

Expand All @@ -23,10 +28,17 @@ enum class Reduction : int {
Max = 4,
};

template <typename T>
inline ONNXTensorElementDataType onnx_type();
template <>
inline ONNXTensorElementDataType onnx_type<float>() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }
template <>
inline ONNXTensorElementDataType onnx_type<Ort::Custom::MFloat16>() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }

/**
* This kernel implementation the fusion of ConstantOfShape and ScatterND.
* The implementation does not use OrtLiteCustom as the input shape (first input)
* is expected to be on CPU wheeras the other outputs are expected to be on CUDA.
* is expected to be on CPU whereas the other outputs are expected to be on CUDA.
*/
template <typename T>
struct ScatterNDOfShapeKernel {
Expand All @@ -46,18 +58,18 @@ template <typename T>
struct ScatterNDOfShapeOp : Ort::CustomOpBase<ScatterNDOfShapeOp<T>, ScatterNDOfShapeKernel<T>> {
typedef Ort::CustomOpBase<ScatterNDOfShapeOp<T>, ScatterNDOfShapeKernel<T>> parent_type;
ScatterNDOfShapeOp() : parent_type() {}
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
const char* GetName() const;
const char* GetExecutionProviderType() const;

std::size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(std::size_t index) const;
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(std::size_t index) const;
OrtMemType GetInputMemoryType(std::size_t index) const;

std::size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(std::size_t index) const;
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const;
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { return std::make_unique<ScatterNDOfShapeKernel<T>>(api, info).release(); }
const char* GetName() const { return "ScatterNDOfShape"; }
const char* GetExecutionProviderType() const { return "CUDAExecutionProvider"; }

std::size_t GetInputTypeCount() const { return 3; }
ONNXTensorElementDataType GetInputType(std::size_t index) const { return index == 2 ? onnx_type<T>() : ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; }
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(std::size_t index) const { return INPUT_OUTPUT_REQUIRED; }
OrtMemType GetInputMemoryType(std::size_t index) const { return index == 0 ? OrtMemTypeCPUInput : OrtMemTypeDefault; }

std::size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(std::size_t index) const { return onnx_type<T>(); }
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const { return INPUT_OUTPUT_REQUIRED; }
};

} // namespace contrib
Loading

0 comments on commit 06e0780

Please sign in to comment.