Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…sions into elem
  • Loading branch information
xadupre committed Jun 18, 2024
2 parents 76a8857 + bef5f07 commit 8c9b408
Show file tree
Hide file tree
Showing 19 changed files with 272 additions and 30 deletions.
17 changes: 10 additions & 7 deletions .pipelines/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ stages:
inputs:
versionSpec: '3.x'
disableDownloadFromRegistry: true
addToPath: false
addToPath: true
architecture: 'x64'

- script: |
python -m pip install --upgrade setuptools pip
python -m pip install numpy
python -m pip install 'numpy < 2.0.0'
export OCOS_NO_OPENCV=1
export OCOS_SCB_DEBUG=1
CPU_NUMBER=8 python -m pip install -e .
Expand Down Expand Up @@ -322,6 +322,7 @@ stages:
python -m pip install --upgrade pip
python -m pip install --upgrade setuptools
python -m pip install --upgrade wheel
python -m pip install 'numpy < 2.0.0'
python -m pip install onnxruntime==$(ort.version)
displayName: Install requirements
Expand Down Expand Up @@ -507,13 +508,13 @@ stages:
inputs:
versionSpec: '3.x'
disableDownloadFromRegistry: true
addToPath: false
addToPath: true
architecture: 'x64'
displayName: Use ADO python task

- script: |
python -m pip install --upgrade setuptools pip
python -m pip install numpy
python -m pip install "numpy < 2.0.0"
set OCOS_NO_OPENCV=1
set OCOS_SCB_DEBUG=1
python -m pip install -v -e .
Expand Down Expand Up @@ -570,7 +571,9 @@ stages:

- script: |
set CUDA_PATH=$(Agent.TempDirectory)\v11.8
call .\build.bat -T cuda="%CUDA_PATH%" -DOCOS_ENABLE_CTEST=ON -DOCOS_USE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=70;86^
call .\build.bat -T cuda="%CUDA_PATH%" -DOCOS_ENABLE_CTEST=ON^
-DCMAKE_CUDA_FLAGS_INIT=-allow-unsupported-compiler^
-DOCOS_USE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=70;86^
-DOCOS_ONNXRUNTIME_VERSION="$(ORT_VERSION)" -DONNXRUNTIME_PKG_DIR=.\onnxruntime-win-x64-gpu-$(ORT_VERSION)
displayName: build the customop library with onnxruntime
Expand All @@ -583,14 +586,14 @@ stages:
inputs:
versionSpec: '3.x'
disableDownloadFromRegistry: true
addToPath: false
addToPath: true
architecture: 'x64'
displayName: Use ADO python task

- script: |
set CUDA_PATH=$(Agent.TempDirectory)\v11.8
python -m pip install --upgrade setuptools pip
python -m pip install numpy coloredlogs flatbuffers packaging protobuf sympy
python -m pip install "numpy < 2.0.0" coloredlogs flatbuffers packaging protobuf sympy
python -m pip install onnxruntime-gpu==$(ORT_VERSION)
python -m pip install -v --config-settings "ortx-user-option=use-cuda,cuda_archs=70;86" .
displayName: Build and install onnxruntime-extensions CUDA package.
Expand Down
4 changes: 2 additions & 2 deletions .pipelines/ci_optional.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ stages:
inputs:
versionSpec: '3.x'
disableDownloadFromRegistry: true
addToPath: false
addToPath: true
architecture: 'x64'
displayName: Use ADO python task

Expand Down Expand Up @@ -172,7 +172,7 @@ stages:
inputs:
versionSpec: '3.x'
disableDownloadFromRegistry: true
addToPath: false
addToPath: true
architecture: 'x64'
displayName: Use ADO python task

Expand Down
2 changes: 1 addition & 1 deletion .pipelines/wheels_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ parameters:
- name: ExtraEnv
displayName: 'Extra env variable set to CIBW_ENVIRONMENT, in form of "A=1 B=2 C=3"'
type: string
default: 'OCOS_ENABLE_AZURE=1'
default: 'OCOS_ENABLE_AZURE=0'

jobs:
- job: linux_x86_64
Expand Down
26 changes: 16 additions & 10 deletions .pyproject/cmdclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def build_cmake(self, extension):
if sys.platform == "win32":
cuda_path = os.environ.get("CUDA_PATH")
cmake_args += [f'-T cuda={cuda_path}']
# TODO: temporarily add a flag for MSVC 19.40
cmake_args += ['-DCMAKE_CUDA_FLAGS_INIT=-allow-unsupported-compiler']
f_ver = ext_fullpath.parent / "_version.py"
with f_ver.open('a') as _f:
_f.writelines(["\n", f"cuda = \"{cuda_ver}\"", "\n"])
Expand All @@ -235,7 +237,8 @@ def build_cmake(self, extension):
else:
smi = _load_nvidia_smi()
if not smi:
raise RuntimeError(f"Cannot detect the CUDA archs from your machine, please specify it by yourself.")
raise RuntimeError(
"Cannot detect the CUDA archs from your machine, please specify it manually.")
cmake_args += ['-DCMAKE_CUDA_ARCHITECTURES=' + smi]

# CMake lets you override the generator - we need to check this.
Expand Down Expand Up @@ -274,7 +277,6 @@ def build_cmake(self, extension):
cmake_args += [
"-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))]


# overwrite the Python module info if the auto-detection doesn't work.
# export Python3_INCLUDE_DIRS=/opt/python/cp38-cp38
# export Python3_LIBRARIES=/opt/python/cp38-cp38
Expand All @@ -292,14 +294,18 @@ def build_cmake(self, extension):
'--parallel' + ('' if cpu_number is None else ' ' + cpu_number)
]
cmake_exe = 'cmake'
# unlike Linux/macOS, cmake pip package on Windows fails to build some 3rd party dependencies.
# so we have to use the cmake installed from Visual Studio.
if os.environ.get(VSINSTALLDIR_NAME):
cmake_exe = os.environ[VSINSTALLDIR_NAME] + \
'Common7\\IDE\\CommonExtensions\\Microsoft\\CMake\\CMake\\bin\\cmake.exe'
# Add this cmake directory into PATH to make sure the child-process still find it.
os.environ['PATH'] = os.path.dirname(
cmake_exe) + os.pathsep + os.environ['PATH']
# if sys.platform == "win32":
# # unlike Linux/macOS, cmake pip package on Windows fails to build some 3rd party dependencies.
# # so we have to use the cmake from a standalone installation or the one from Visual Studio.
# standalone_cmake = os.path.join(os.environ.get("ProgramFiles"), "\\CMake\\bin\\cmake.exe")
# if os.path.exists(standalone_cmake):
# cmake_exe = standalone_cmake
# elif os.environ.get(VSINSTALLDIR_NAME):
# cmake_exe = os.environ[VSINSTALLDIR_NAME] + \
# 'Common7\\IDE\\CommonExtensions\\Microsoft\\CMake\\CMake\\bin\\cmake.exe'
# # Add this cmake directory into PATH to make sure the child-process still find it.
# os.environ['PATH'] = os.path.dirname(
# cmake_exe) + os.pathsep + os.environ['PATH']

self.spawn([cmake_exe, '-S', str(project_dir),
'-B', str(build_temp)] + cmake_args)
Expand Down
14 changes: 9 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0")
cmake_policy(SET CMP0077 NEW)
endif()

# Avoid warning of Calling FetchContent_Populate(GSL) is deprecated
if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.30.0")
cmake_policy(CMP0169 OLD)
endif()

# Needed for Java
set(CMAKE_C_STANDARD 99)

Expand Down Expand Up @@ -90,12 +95,11 @@ set(OCOS_ONNXRUNTIME_PKG_URI "" CACHE STRING
"Specify the onnxruntime C++ shared library zip package path, like ./onnxruntime-win-x64-1.16.0.zip")
set(OCOS_BUILD_PRESET "" CACHE STRING
"Specify the build preset cmake settings file path, like 'token_api_only' which includes ./cmake/presets/token_api_only.cmake")
# TODO: Remove the following statements if AzureOp build is enabled by default.
# If build_buildid environment varaible is set, which means this is a CI build, then always enable AzureOp.
# or it is enabled when OCOS_ENABLE_AZURE is set, which means the user explicitly enables it.
if ((DEFINED ENV{BUILD_BUILDID}) OR (DEFINED ENV{OCOS_ENABLE_AZURE}))

# AzureOp can be enabled by environment varaible OCOS_ENABLE_AZURE == 1
if (DEFINED ENV{OCOS_ENABLE_AZURE})
set(OCOS_ENABLE_AZURE ON CACHE INTERNAL "" FORCE)
message(STATUS "=> AzureOp is enabled by default.")
message(STATUS "=> AzureOp is enabled env variable.")
endif()

function(disable_all_operators)
Expand Down
11 changes: 10 additions & 1 deletion build.bat
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
@ECHO OFF
SETLOCAL ENABLEDELAYEDEXPANSION

IF NOT EXIST "%ProgramFiles%\CMake\bin\cmake.exe" GOTO :FIND_VS
set cmake_exe="%ProgramFiles%\CMake\bin\cmake.exe"

:FIND_VS
IF DEFINED VSINSTALLDIR GOTO :VSDEV_CMD
set _VSFINDER=%~dp0tools\get_vsdevcmd.ps1
for /f "tokens=* USEBACKQ" %%i in (
`powershell -NoProfile -ExecutionPolicy Bypass -File "%_VSFINDER%"`) do call "%%i"

IF NOT DEFINED VSINSTALLDIR GOTO :NOT_FOUND

IF DEFINED cmake_exe GOTO :CMAKE_DEF
set cmake_exe="%VSINSTALLDIR%Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe"

:CMAKE_DEF
IF "%1" == "-A" GOTO :VSDEV_CMD
set GEN_PLATFORM=-A x64

Expand All @@ -16,8 +25,8 @@ IF "%VisualStudioVersion:~0,2%" == "16" GOTO :START_BUILD
set GENERATOR="Visual Studio 17 2022"

:START_BUILD
set cmake_exe="%VSINSTALLDIR%Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe"
mkdir .\out\Windows\ 2>NUL
ECHO %cmake_exe% -G %GENERATOR% %GEN_PLATFORM% %* -B out\Windows -S .
%cmake_exe% -G %GENERATOR% %GEN_PLATFORM% %* -B out\Windows -S .
IF %ERRORLEVEL% NEQ 0 EXIT /B %ERRORLEVEL%
%cmake_exe% --build out\Windows --config RelWithDebInfo
Expand Down
13 changes: 13 additions & 0 deletions cmake/externals/opencv-no-rtti.patch
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ index d95e5db163..db185453df 100644
include(cmake/OpenCVCompilerOptions.cmake)

ocv_cmake_hook(POST_COMPILER_OPTIONS)
diff --git a/cmake/OpenCVDetectCXXCompiler.cmake b/cmake/OpenCVDetectCXXCompiler.cmake
index 7f229cde96..92e204a5b9 100644
--- a/cmake/OpenCVDetectCXXCompiler.cmake
+++ b/cmake/OpenCVDetectCXXCompiler.cmake
@@ -171,7 +171,7 @@ elseif(MSVC)
set(OpenCV_RUNTIME vc15)
elseif(MSVC_VERSION MATCHES "^192[0-9]$")
set(OpenCV_RUNTIME vc16)
- elseif(MSVC_VERSION MATCHES "^193[0-9]$")
+ elseif(MSVC_VERSION MATCHES "^19[34][0-9]$")
set(OpenCV_RUNTIME vc17)
else()
message(WARNING "OpenCV does not recognize MSVC_VERSION \"${MSVC_VERSION}\". Cannot set OpenCV_RUNTIME")
diff --git a/modules/core/include/opencv2/core/ocl.hpp b/modules/core/include/opencv2/core/ocl.hpp
index 4503fa00dd..642b0508d0 100644
--- a/modules/core/include/opencv2/core/ocl.hpp
Expand Down
2 changes: 1 addition & 1 deletion docs/custom_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ expect(node, inputs=[text, pattern, rewrite], outputs=[y],
## Azure operators
Starting from onnxruntime-extensions 0.12, these Azure operators will be removed from the official onnxruntime-extensions packages. However, they can still be built from source using `cmake -DOCOS_ENABLE_AZURE=ON ...`.
### OpenAIAudioToText
<details>
Expand Down
3 changes: 3 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "cuda/fast_gelu.h"
#include "cuda/mul_sigmoid.h"
#include "cuda/negxplus1.h"
#include "cuda/replace_zero.h"
#include "cuda/scatter_nd_of_shape.h"
#include "cuda/transpose_cast.h"
#endif
Expand Down Expand Up @@ -58,6 +59,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
CustomCudaStructV2("MulSub", MulAndSubFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<float>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
CustomCudaStructV2("SubMul", SubAndMulFloat32Type),
#if ORT_API_VERSION >= 16
Expand All @@ -75,6 +77,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
CustomCudaStructV2("MulSub", MulAndSubFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<ortc::MFloat16>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("SubMul", SubAndMulFloat16Type),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
Expand Down
3 changes: 3 additions & 0 deletions operators/cuda/negxplus1.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

namespace contrib {

/**
* NegXPlus1(X) = 1 - X
*/
template <typename T>
struct NegXPlus1 {
template <typename TDict>
Expand Down
51 changes: 51 additions & 0 deletions operators/cuda/replace_zero.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "replace_zero_impl.cuh"
#include "ortx_common.h"

namespace contrib {

/**
* Y = ReplaceZero(X, by=c) is equivalent to:
*
* Y = X.copy()
* X[X == 0] = c
*
* This operation usually appears when a tensor is updated with an operator Equal and Where.
* This kernel avoids the creation of one null tensor.
*/
template <typename T>
struct ReplaceZero {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& dict) {
float default_value=0;
by_ = dict.TryToGetAttributeWithDefault("by", default_value);
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
auto input_shape = input.Shape();
T* output_data = output.Allocate(input_shape);
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return {};
}

LaunchReplaceZeroKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
input_data,
output_data,
by_);
return {};
}

private:
float by_;
};

} // namespace contrib
68 changes: 68 additions & 0 deletions operators/cuda/replace_zero_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "replace_zero_impl.cuh"
#include "cuda_type.h"

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

using namespace Ort::Custom;

template <typename T>
__device__ __inline__ T _replace_zero(const T x, const T by) {
return x == (T)0 ? by : x;
}

template <>
__device__ __inline__ half _replace_zero(const half x, const half by) {
#if __CUDA_ARCH__ < 700
return __half2float(x) == 0 ? by : x;
#else
return x == (half)0 ? by : x;
#endif
}

template <typename T>
__global__ void ReplaceZeroKernel(T* output_data, const T* input_data, CUDA_LONG N, const T by) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = _replace_zero(input_data[id], by);
}

template <typename T>
T _cast(float value) { return (T)value; }

template <>
half _cast(float value) { return __float2half(value); }

template <typename T>
cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) {
if (input_length == 0)
return cudaGetLastError();
using TT = typename contrib::CudaT<T>::MappedType;

CUDA_LONG N = static_cast<CUDA_LONG>(input_length);

const int num_threads_per_block = 256;
const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block;

TT cby = _cast<TT>(by);
ReplaceZeroKernel<TT><<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(
reinterpret_cast<TT*>(output_data), reinterpret_cast<const TT*>(input_data), N, cby);
return cudaGetLastError();
}

template <>
cudaError_t LaunchReplaceZeroKernel<float>(cudaStream_t stream, int input_length, const float* input_data, float* output_data, float by) {
return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by);
}

template <>
cudaError_t LaunchReplaceZeroKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input_data, ortc::MFloat16* output_data, float by) {
return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by);
}
Loading

0 comments on commit 8c9b408

Please sign in to comment.