Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 6, 2024
2 parents d58c35c + 79f3b04 commit e946e6a
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 6 deletions.
2 changes: 1 addition & 1 deletion operators/cuda/add_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct AddOrMulSharedInput {
auto length_c = tensor_c.NumberOfElement();

T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape());
T* output_data_ac = output_ab.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());
T* output_data_ac = output_ac.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape());

if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
return {};
Expand Down
7 changes: 6 additions & 1 deletion operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "cuda/fast_gelu.h"
#include "cuda/negxplus1.h"
#include "cuda/replace_zero.h"
#include "cuda/transpose_cast.h"
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
Expand All @@ -18,6 +19,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
#if ORT_API_VERSION >= 16
using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, true>;
using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, false>;
using Transpose2DCastFloat32ToFloat16Type = typename contrib::Transpose2DCast<float, ortc::MFloat16>;
using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast<ortc::MFloat16, float>;
#endif


Expand All @@ -37,7 +40,9 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<ortc::MFloat16>)
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<ortc::MFloat16>),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
#endif
#endif
);
Expand Down
39 changes: 39 additions & 0 deletions operators/cuda/transpose_cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "transpose_cast_impl.cuh"
#include "ortx_common.h"

namespace contrib {

template <typename TIN, typename TOUT>
struct Transpose2DCast {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<TIN>& input,
ortc::Tensor<TOUT>& output) const {
const TIN* input_data = input.Data();
auto shape = input.Shape();
if (shape.size() != 2) {
ORTX_CXX_API_THROW("Input must be a 2D tensor", ORT_RUNTIME_EXCEPTION);
}
int n_rows = static_cast<int>(shape[0]);
int n_cols = static_cast<int>(shape[1]);

std::vector<int64_t> new_shape{static_cast<int64_t>(n_cols), static_cast<int64_t>(n_rows)};
TOUT* output_data = output.Allocate(new_shape);
if (0 == n_rows || 0 == n_cols) {
return {};
}
LaunchTranspose2DCastKernel<TIN, TOUT>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
n_rows, n_cols, input_data, output_data);
return {};
}
};

} // namespace contrib
56 changes: 56 additions & 0 deletions operators/cuda/transpose_cast_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "transpose_cast_impl.cuh"
#include "cuda_type.h"

using namespace Ort::Custom;

#define TILE_DIM 32
#define BLOCK_ROWS 8

template <typename TOUT, typename TIN>
__global__ void Transpose2DCastKernel(TOUT *output_data, const TIN *input_data, int n_rows, int n_cols) {
__shared__ TIN tile[TILE_DIM][TILE_DIM + 1];

int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
// int width = gridDim.x * TILE_DIM;

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
tile[threadIdx.y + j][threadIdx.x] = input_data[(y + j) * n_cols + x];

__syncthreads();

x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
y = blockIdx.x * TILE_DIM + threadIdx.y;

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
output_data[(y + j) * n_rows + x] = (TOUT)(tile[threadIdx.x][threadIdx.y + j]);
}

template <typename TIN, typename TOUT>
cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols,
const TIN* input, TOUT* output) {
dim3 dimGrid((n_cols + TILE_DIM - 1) / TILE_DIM, (n_rows + TILE_DIM - 1) / TILE_DIM, 1);
dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1);
using TTIN = typename contrib::CudaT<TIN>::MappedType;
using TTOUT = typename contrib::CudaT<TOUT>::MappedType;
Transpose2DCastKernel<TTOUT, TTIN><<<dimGrid, dimBlock, TILE_DIM * TILE_DIM + TILE_DIM, stream>>>(
reinterpret_cast<TTOUT*>(output), reinterpret_cast<const TTIN*>(input), n_rows, n_cols);
return cudaGetLastError();
}

template <>
cudaError_t LaunchTranspose2DCastKernel<float, ortc::MFloat16>(cudaStream_t stream, int n_rows, int n_cols,
const float* input, ortc::MFloat16* output) {
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
}

template <>
cudaError_t LaunchTranspose2DCastKernel<ortc::MFloat16, float>(cudaStream_t stream, int n_rows, int n_cols,
const ortc::MFloat16* input, float* output) {
return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output);
}
9 changes: 9 additions & 0 deletions operators/cuda/transpose_cast_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

template <typename TIN, typename TOUT>
cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const TIN* input, TOUT* output);
76 changes: 72 additions & 4 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ def _run(self, X):
return (1 - X,)


class Transpose2DCastFP16(OpRun):
op_domain = "ai.onnx.contrib"

def _run(self, X):
return (X.T.to(np.float16),)


class Transpose2DCastFP32(OpRun):
op_domain = "ai.onnx.contrib"

def _run(self, X):
return (X.T.to(np.float32),)


class TestCudaOps(unittest.TestCase):
@staticmethod
def _create_negpos_test_model(domain="ai.onnx.contrib"):
Expand Down Expand Up @@ -151,8 +165,6 @@ def test_cuda_negxplus1(self):
self._negxplus1_cuda(TensorProto.FLOAT16)

def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)):
from ai.onnx.contrib import get_ort_ext_libs

model1 = helper.make_model(
helper.make_graph(
[
Expand Down Expand Up @@ -212,7 +224,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
expected = ref.run(None, feeds1)

opts = _ort.SessionOptions()
opts.register_custom_ops_library(get_ort_ext_libs()[0])
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)
for i in range(2):
Expand Down Expand Up @@ -262,6 +274,62 @@ def test_add_shared_input_cuda_broadcast2(self):
shapec=(3, 2, 3),
)

def _transpose_cast_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16
model1 = helper.make_model(
helper.make_graph(
[
helper.make_node("Transpose", ["X"], ["t"], perm=[1, 0]),
helper.make_node("Cast", ["t"], ["Y"], to=itype2),
],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None])],
[helper.make_tensor_value_info("Y", itype2, [None, None])],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)

model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
("Transpose2DCastFP16" if itype2 == TensorProto.FLOAT16 else "Transpose2DCastFP32"),
["X"],
["Y"],
domain="ai.onnx.contrib",
)
],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None])],
[helper.make_tensor_value_info("Y", itype2, [None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(32 * 32 * 3) + 1).reshape((32, 32 * 3)).astype(dtype)

feeds1 = dict(X=x)
ref = ReferenceEvaluator(model1, new_ops=[Transpose2DCastFP16, Transpose2DCastFP32])
expected = ref.run(None, feeds1)[0]

opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)[0]
assert_almost_equal(expected, got, decimal=5)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_transpose_cast_cuda(self):
self._transpose_cast_cuda(TensorProto.FLOAT)
self._transpose_cast_cuda(TensorProto.FLOAT16)

def _replace_zero_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
model1 = helper.make_model(
Expand Down Expand Up @@ -324,4 +392,4 @@ def test_replace_zero_cuda(self):


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=2)

0 comments on commit e946e6a

Please sign in to comment.