Skip to content

Commit

Permalink
Add custom kernel ScatterNDOfShape (microsoft#705)
Browse files Browse the repository at this point in the history
* first draft

* clang

* Draft for ScatterNFOfShape

* fix build

* disable test when cuda is missing

* fix implementation

* update test

* add MaskedScatterNdOfShape

* fix merge conflicts
  • Loading branch information
xadupre committed Jun 11, 2024
1 parent 79f3b04 commit f505546
Show file tree
Hide file tree
Showing 5 changed files with 646 additions and 0 deletions.
5 changes: 5 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "cuda/add_mul.h"
#include "cuda/fast_gelu.h"
#include "cuda/negxplus1.h"
#include "cuda/scatter_nd_of_shape.h"
#include "cuda/transpose_cast.h"
#endif

Expand All @@ -29,15 +30,19 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
,
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
#endif
Expand Down
143 changes: 143 additions & 0 deletions operators/cuda/scatter_nd_of_shape.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "string_utils.h"
#include "scatter_nd_of_shape_impl.cuh"

namespace contrib {

template <typename T>
struct ScatterNDOfShape {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string value;
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
if (status != nullptr)
return status;

if (value == "add")
reduction_ = ScatterReduction::Add;
else if (value == "mul")
reduction_ = ScatterReduction::Mul;
else if (value == "min")
reduction_ = ScatterReduction::Min;
else if (value == "max")
reduction_ = ScatterReduction::Max;
else
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);

return nullptr;
}

OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<int64_t>& output_shape,
const ortc::Tensor<int64_t>& indices,
const ortc::Tensor<T>& updates,
ortc::Tensor<T>& output) const {
auto& output_shape_shape = output_shape.Shape();
auto& indices_shape = indices.Shape();
auto& updates_shape = updates.Shape();

if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
}
if (indices_shape[indices_shape.size() - 1] != 1) {
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
}

const int64_t* shape_data = output_shape.Data(); // CPU pointer
const int64_t* indices_data = indices.Data(); // GPU pointer
const T* updates_data = updates.Data(); // GPU pointer
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
T* output_data = output.Allocate(voutput_shape); // GPU pointer
LaunchScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
voutput_shape,
indices_shape,
indices_data,
updates_data,
output_data,
reduction_);
return nullptr;
}

static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 0) // shape
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}

ScatterReduction reduction_;
};


template <typename T>
struct MaskedScatterNDOfShape {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string value;
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
if (status != nullptr)
return status;

if (value == "add")
reduction_ = ScatterReduction::Add;
else if (value == "mul")
reduction_ = ScatterReduction::Mul;
else if (value == "min")
reduction_ = ScatterReduction::Min;
else if (value == "max")
reduction_ = ScatterReduction::Max;
else
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);

status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_);
if (status != nullptr)
return status;

return nullptr;
}

OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<int64_t>& output_shape,
const ortc::Tensor<int64_t>& indices,
const ortc::Tensor<T>& updates,
ortc::Tensor<T>& output) const {
auto& output_shape_shape = output_shape.Shape();
auto& indices_shape = indices.Shape();
auto& updates_shape = updates.Shape();

if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
}
if (indices_shape[indices_shape.size() - 1] != 1) {
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
}

const int64_t* shape_data = output_shape.Data(); // CPU pointer
const int64_t* indices_data = indices.Data(); // GPU pointer
const T* updates_data = updates.Data(); // GPU pointer
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
T* output_data = output.Allocate(voutput_shape); // GPU pointer
LaunchMaskedScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
voutput_shape,
indices_shape,
indices_data,
updates_data,
output_data,
reduction_,
masked_value_);
return nullptr;
}

static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 0) // shape
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}

private:
ScatterReduction reduction_;
int64_t masked_value_;
};

} // namespace contrib
Loading

0 comments on commit f505546

Please sign in to comment.