diff --git a/.pipelines/ci.yml b/.pipelines/ci.yml index b1aff406f..8ad782daa 100644 --- a/.pipelines/ci.yml +++ b/.pipelines/ci.yml @@ -121,12 +121,12 @@ stages: inputs: versionSpec: '3.x' disableDownloadFromRegistry: true - addToPath: false + addToPath: true architecture: 'x64' - script: | python -m pip install --upgrade setuptools pip - python -m pip install numpy + python -m pip install 'numpy < 2.0.0' export OCOS_NO_OPENCV=1 export OCOS_SCB_DEBUG=1 CPU_NUMBER=8 python -m pip install -e . @@ -322,6 +322,7 @@ stages: python -m pip install --upgrade pip python -m pip install --upgrade setuptools python -m pip install --upgrade wheel + python -m pip install 'numpy < 2.0.0' python -m pip install onnxruntime==$(ort.version) displayName: Install requirements @@ -507,13 +508,13 @@ stages: inputs: versionSpec: '3.x' disableDownloadFromRegistry: true - addToPath: false + addToPath: true architecture: 'x64' displayName: Use ADO python task - script: | python -m pip install --upgrade setuptools pip - python -m pip install numpy + python -m pip install "numpy < 2.0.0" set OCOS_NO_OPENCV=1 set OCOS_SCB_DEBUG=1 python -m pip install -v -e . @@ -570,7 +571,9 @@ stages: - script: | set CUDA_PATH=$(Agent.TempDirectory)\v11.8 - call .\build.bat -T cuda="%CUDA_PATH%" -DOCOS_ENABLE_CTEST=ON -DOCOS_USE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=70;86^ + call .\build.bat -T cuda="%CUDA_PATH%" -DOCOS_ENABLE_CTEST=ON^ + -DCMAKE_CUDA_FLAGS_INIT=-allow-unsupported-compiler^ + -DOCOS_USE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=70;86^ -DOCOS_ONNXRUNTIME_VERSION="$(ORT_VERSION)" -DONNXRUNTIME_PKG_DIR=.\onnxruntime-win-x64-gpu-$(ORT_VERSION) displayName: build the customop library with onnxruntime @@ -583,14 +586,14 @@ stages: inputs: versionSpec: '3.x' disableDownloadFromRegistry: true - addToPath: false + addToPath: true architecture: 'x64' displayName: Use ADO python task - script: | set CUDA_PATH=$(Agent.TempDirectory)\v11.8 python -m pip install --upgrade setuptools pip - python -m pip install numpy coloredlogs flatbuffers packaging protobuf sympy + python -m pip install "numpy < 2.0.0" coloredlogs flatbuffers packaging protobuf sympy python -m pip install onnxruntime-gpu==$(ORT_VERSION) python -m pip install -v --config-settings "ortx-user-option=use-cuda,cuda_archs=70;86" . displayName: Build and install onnxruntime-extensions CUDA package. diff --git a/.pipelines/ci_optional.yml b/.pipelines/ci_optional.yml index 3afcf6a4d..7eaf689a7 100644 --- a/.pipelines/ci_optional.yml +++ b/.pipelines/ci_optional.yml @@ -132,7 +132,7 @@ stages: inputs: versionSpec: '3.x' disableDownloadFromRegistry: true - addToPath: false + addToPath: true architecture: 'x64' displayName: Use ADO python task @@ -172,7 +172,7 @@ stages: inputs: versionSpec: '3.x' disableDownloadFromRegistry: true - addToPath: false + addToPath: true architecture: 'x64' displayName: Use ADO python task diff --git a/.pipelines/wheels_linux.yml b/.pipelines/wheels_linux.yml index a68849093..6464c8645 100644 --- a/.pipelines/wheels_linux.yml +++ b/.pipelines/wheels_linux.yml @@ -2,7 +2,7 @@ parameters: - name: ExtraEnv displayName: 'Extra env variable set to CIBW_ENVIRONMENT, in form of "A=1 B=2 C=3"' type: string - default: 'OCOS_ENABLE_AZURE=1' + default: 'OCOS_ENABLE_AZURE=0' jobs: - job: linux_x86_64 diff --git a/.pyproject/cmdclass.py b/.pyproject/cmdclass.py index 5b38b02c0..f56ad8c1b 100644 --- a/.pyproject/cmdclass.py +++ b/.pyproject/cmdclass.py @@ -226,6 +226,8 @@ def build_cmake(self, extension): if sys.platform == "win32": cuda_path = os.environ.get("CUDA_PATH") cmake_args += [f'-T cuda={cuda_path}'] + # TODO: temporarily add a flag for MSVC 19.40 + cmake_args += ['-DCMAKE_CUDA_FLAGS_INIT=-allow-unsupported-compiler'] f_ver = ext_fullpath.parent / "_version.py" with f_ver.open('a') as _f: _f.writelines(["\n", f"cuda = \"{cuda_ver}\"", "\n"]) @@ -235,7 +237,8 @@ def build_cmake(self, extension): else: smi = _load_nvidia_smi() if not smi: - raise RuntimeError(f"Cannot detect the CUDA archs from your machine, please specify it by yourself.") + raise RuntimeError( + "Cannot detect the CUDA archs from your machine, please specify it manually.") cmake_args += ['-DCMAKE_CUDA_ARCHITECTURES=' + smi] # CMake lets you override the generator - we need to check this. @@ -274,7 +277,6 @@ def build_cmake(self, extension): cmake_args += [ "-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] - # overwrite the Python module info if the auto-detection doesn't work. # export Python3_INCLUDE_DIRS=/opt/python/cp38-cp38 # export Python3_LIBRARIES=/opt/python/cp38-cp38 @@ -292,14 +294,18 @@ def build_cmake(self, extension): '--parallel' + ('' if cpu_number is None else ' ' + cpu_number) ] cmake_exe = 'cmake' - # unlike Linux/macOS, cmake pip package on Windows fails to build some 3rd party dependencies. - # so we have to use the cmake installed from Visual Studio. - if os.environ.get(VSINSTALLDIR_NAME): - cmake_exe = os.environ[VSINSTALLDIR_NAME] + \ - 'Common7\\IDE\\CommonExtensions\\Microsoft\\CMake\\CMake\\bin\\cmake.exe' - # Add this cmake directory into PATH to make sure the child-process still find it. - os.environ['PATH'] = os.path.dirname( - cmake_exe) + os.pathsep + os.environ['PATH'] + # if sys.platform == "win32": + # # unlike Linux/macOS, cmake pip package on Windows fails to build some 3rd party dependencies. + # # so we have to use the cmake from a standalone installation or the one from Visual Studio. + # standalone_cmake = os.path.join(os.environ.get("ProgramFiles"), "\\CMake\\bin\\cmake.exe") + # if os.path.exists(standalone_cmake): + # cmake_exe = standalone_cmake + # elif os.environ.get(VSINSTALLDIR_NAME): + # cmake_exe = os.environ[VSINSTALLDIR_NAME] + \ + # 'Common7\\IDE\\CommonExtensions\\Microsoft\\CMake\\CMake\\bin\\cmake.exe' + # # Add this cmake directory into PATH to make sure the child-process still find it. + # os.environ['PATH'] = os.path.dirname( + # cmake_exe) + os.pathsep + os.environ['PATH'] self.spawn([cmake_exe, '-S', str(project_dir), '-B', str(build_temp)] + cmake_args) diff --git a/CMakeLists.txt b/CMakeLists.txt index 066e204fb..3b9cb0ead 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,11 @@ if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") cmake_policy(SET CMP0077 NEW) endif() +# Avoid warning of Calling FetchContent_Populate(GSL) is deprecated +if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.30.0") + cmake_policy(CMP0169 OLD) +endif() + # Needed for Java set(CMAKE_C_STANDARD 99) @@ -90,12 +95,11 @@ set(OCOS_ONNXRUNTIME_PKG_URI "" CACHE STRING "Specify the onnxruntime C++ shared library zip package path, like ./onnxruntime-win-x64-1.16.0.zip") set(OCOS_BUILD_PRESET "" CACHE STRING "Specify the build preset cmake settings file path, like 'token_api_only' which includes ./cmake/presets/token_api_only.cmake") -# TODO: Remove the following statements if AzureOp build is enabled by default. -# If build_buildid environment varaible is set, which means this is a CI build, then always enable AzureOp. -# or it is enabled when OCOS_ENABLE_AZURE is set, which means the user explicitly enables it. -if ((DEFINED ENV{BUILD_BUILDID}) OR (DEFINED ENV{OCOS_ENABLE_AZURE})) + +# AzureOp can be enabled by environment varaible OCOS_ENABLE_AZURE == 1 +if (DEFINED ENV{OCOS_ENABLE_AZURE}) set(OCOS_ENABLE_AZURE ON CACHE INTERNAL "" FORCE) - message(STATUS "=> AzureOp is enabled by default.") + message(STATUS "=> AzureOp is enabled env variable.") endif() function(disable_all_operators) diff --git a/build.bat b/build.bat index bdd39a1d7..1a5367a7a 100644 --- a/build.bat +++ b/build.bat @@ -1,5 +1,10 @@ @ECHO OFF SETLOCAL ENABLEDELAYEDEXPANSION + +IF NOT EXIST "%ProgramFiles%\CMake\bin\cmake.exe" GOTO :FIND_VS +set cmake_exe="%ProgramFiles%\CMake\bin\cmake.exe" + +:FIND_VS IF DEFINED VSINSTALLDIR GOTO :VSDEV_CMD set _VSFINDER=%~dp0tools\get_vsdevcmd.ps1 for /f "tokens=* USEBACKQ" %%i in ( @@ -7,6 +12,10 @@ for /f "tokens=* USEBACKQ" %%i in ( IF NOT DEFINED VSINSTALLDIR GOTO :NOT_FOUND +IF DEFINED cmake_exe GOTO :CMAKE_DEF +set cmake_exe="%VSINSTALLDIR%Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe" + +:CMAKE_DEF IF "%1" == "-A" GOTO :VSDEV_CMD set GEN_PLATFORM=-A x64 @@ -16,8 +25,8 @@ IF "%VisualStudioVersion:~0,2%" == "16" GOTO :START_BUILD set GENERATOR="Visual Studio 17 2022" :START_BUILD -set cmake_exe="%VSINSTALLDIR%Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe" mkdir .\out\Windows\ 2>NUL +ECHO %cmake_exe% -G %GENERATOR% %GEN_PLATFORM% %* -B out\Windows -S . %cmake_exe% -G %GENERATOR% %GEN_PLATFORM% %* -B out\Windows -S . IF %ERRORLEVEL% NEQ 0 EXIT /B %ERRORLEVEL% %cmake_exe% --build out\Windows --config RelWithDebInfo diff --git a/cmake/externals/opencv-no-rtti.patch b/cmake/externals/opencv-no-rtti.patch index 487f9296e..c94f89d5d 100644 --- a/cmake/externals/opencv-no-rtti.patch +++ b/cmake/externals/opencv-no-rtti.patch @@ -39,6 +39,19 @@ index d95e5db163..db185453df 100644 include(cmake/OpenCVCompilerOptions.cmake) ocv_cmake_hook(POST_COMPILER_OPTIONS) +diff --git a/cmake/OpenCVDetectCXXCompiler.cmake b/cmake/OpenCVDetectCXXCompiler.cmake +index 7f229cde96..92e204a5b9 100644 +--- a/cmake/OpenCVDetectCXXCompiler.cmake ++++ b/cmake/OpenCVDetectCXXCompiler.cmake +@@ -171,7 +171,7 @@ elseif(MSVC) + set(OpenCV_RUNTIME vc15) + elseif(MSVC_VERSION MATCHES "^192[0-9]$") + set(OpenCV_RUNTIME vc16) +- elseif(MSVC_VERSION MATCHES "^193[0-9]$") ++ elseif(MSVC_VERSION MATCHES "^19[34][0-9]$") + set(OpenCV_RUNTIME vc17) + else() + message(WARNING "OpenCV does not recognize MSVC_VERSION \"${MSVC_VERSION}\". Cannot set OpenCV_RUNTIME") diff --git a/modules/core/include/opencv2/core/ocl.hpp b/modules/core/include/opencv2/core/ocl.hpp index 4503fa00dd..642b0508d0 100644 --- a/modules/core/include/opencv2/core/ocl.hpp diff --git a/docs/custom_ops.md b/docs/custom_ops.md index 1634531a8..d53ec5653 100644 --- a/docs/custom_ops.md +++ b/docs/custom_ops.md @@ -1312,7 +1312,7 @@ expect(node, inputs=[text, pattern, rewrite], outputs=[y], ## Azure operators - +Starting from onnxruntime-extensions 0.12, these Azure operators will be removed from the official onnxruntime-extensions packages. However, they can still be built from source using `cmake -DOCOS_ENABLE_AZURE=ON ...`. ### OpenAIAudioToText
diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 1488caa14..fd4bb9f90 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -8,6 +8,7 @@ #include "cuda/fast_gelu.h" #include "cuda/mul_sigmoid.h" #include "cuda/negxplus1.h" +#include "cuda/replace_zero.h" #include "cuda/scatter_nd_of_shape.h" #include "cuda/transpose_cast.h" #endif @@ -58,6 +59,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("MulSub", MulAndSubFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), CustomCudaStructV2("SubMul", SubAndMulFloat32Type), #if ORT_API_VERSION >= 16 @@ -75,6 +77,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("MulSub", MulAndSubFloat16Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), CustomCudaStructV2("SubMul", SubAndMulFloat16Type), CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type), diff --git a/operators/cuda/negxplus1.h b/operators/cuda/negxplus1.h index 5460c37a2..5ff53d357 100644 --- a/operators/cuda/negxplus1.h +++ b/operators/cuda/negxplus1.h @@ -8,6 +8,9 @@ namespace contrib { +/** +* NegXPlus1(X) = 1 - X +*/ template struct NegXPlus1 { template diff --git a/operators/cuda/replace_zero.h b/operators/cuda/replace_zero.h new file mode 100644 index 000000000..e7974739d --- /dev/null +++ b/operators/cuda/replace_zero.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "replace_zero_impl.cuh" +#include "ortx_common.h" + +namespace contrib { + +/** +* Y = ReplaceZero(X, by=c) is equivalent to: +* +* Y = X.copy() +* X[X == 0] = c +* +* This operation usually appears when a tensor is updated with an operator Equal and Where. +* This kernel avoids the creation of one null tensor. +*/ +template +struct ReplaceZero { + template + OrtxStatus OnModelAttach(const TDict& dict) { + float default_value=0; + by_ = dict.TryToGetAttributeWithDefault("by", default_value); + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input, + ortc::Tensor& output) const { + const T* input_data = input.Data(); + auto input_shape = input.Shape(); + T* output_data = output.Allocate(input_shape); + auto input_length = input.NumberOfElement(); + if (0 == input_length) { + return {}; + } + + LaunchReplaceZeroKernel(reinterpret_cast(ctx->GetCudaStream()), + input_length, + input_data, + output_data, + by_); + return {}; + } + + private: + float by_; +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/replace_zero_impl.cu b/operators/cuda/replace_zero_impl.cu new file mode 100644 index 000000000..43952c303 --- /dev/null +++ b/operators/cuda/replace_zero_impl.cu @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "replace_zero_impl.cuh" +#include "cuda_type.h" + +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +using namespace Ort::Custom; + +template +__device__ __inline__ T _replace_zero(const T x, const T by) { + return x == (T)0 ? by : x; +} + +template <> +__device__ __inline__ half _replace_zero(const half x, const half by) { +#if __CUDA_ARCH__ < 700 + return __half2float(x) == 0 ? by : x; +#else + return x == (half)0 ? by : x; +#endif +} + +template +__global__ void ReplaceZeroKernel(T* output_data, const T* input_data, CUDA_LONG N, const T by) { + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= N) + return; + output_data[id] = _replace_zero(input_data[id], by); +} + +template +T _cast(float value) { return (T)value; } + +template <> +half _cast(float value) { return __float2half(value); } + +template +cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) { + if (input_length == 0) + return cudaGetLastError(); + using TT = typename contrib::CudaT::MappedType; + + CUDA_LONG N = static_cast(input_length); + + const int num_threads_per_block = 256; + const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block; + + TT cby = _cast(by); + ReplaceZeroKernel<<>>( + reinterpret_cast(output_data), reinterpret_cast(input_data), N, cby); + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const float* input_data, float* output_data, float by) { + return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by); +} + +template <> +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const ortc::MFloat16* input_data, ortc::MFloat16* output_data, float by) { + return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by); +} diff --git a/operators/cuda/replace_zero_impl.cuh b/operators/cuda/replace_zero_impl.cuh new file mode 100644 index 000000000..7d975d4d5 --- /dev/null +++ b/operators/cuda/replace_zero_impl.cuh @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +template +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by); diff --git a/operators/cuda/scatter_nd_of_shape.h b/operators/cuda/scatter_nd_of_shape.h index 239c2b5e6..610454d42 100644 --- a/operators/cuda/scatter_nd_of_shape.h +++ b/operators/cuda/scatter_nd_of_shape.h @@ -8,6 +8,9 @@ namespace contrib { +/** +* ScatterNDOfShape(shape, indices, updates) = ScatterND(ConstantOfShape(shape, value=0), indices, updates) +*/ template struct ScatterNDOfShape { OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { @@ -71,6 +74,12 @@ struct ScatterNDOfShape { }; +/** +* MaskedScatterNDOfShape(shape, indices, updates) = ScatterND(ConstantOfShape(shape, value=0), +* indices[indices != maskedValue], +* updates[indices != maskedValue]) +* +*/ template struct MaskedScatterNDOfShape { OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { diff --git a/operators/cuda/transpose_cast.h b/operators/cuda/transpose_cast.h index 6ffae51c2..92e1f8a23 100644 --- a/operators/cuda/transpose_cast.h +++ b/operators/cuda/transpose_cast.h @@ -8,6 +8,9 @@ namespace contrib { +/** +* Transpose2DCast(X, to=to) = Cast(Transpose(X, perm=[1, 0]), to=to) +*/ template struct Transpose2DCast { template diff --git a/requirements-dev.txt b/requirements-dev.txt index 47114f13f..6d5867be1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ pytest -onnx >= 1.9.0 +numpy < 2.0.0 +onnx >=1.9.0 protobuf < 4.0.0 # multiple versions of onnxruntime are supported, but only one can be installed at a time onnxruntime >=1.12.0 diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index 800df6a4d..fabc4e349 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -833,6 +833,66 @@ def test_transpose_cast_cuda(self): self._transpose_cast_cuda(TensorProto.FLOAT) self._transpose_cast_cuda(TensorProto.FLOAT16) + def _replace_zero_cuda(self, itype): + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node("Equal", ["X", "zero"], ["cond"]), + helper.make_node("Where", ["cond", "cst", "X"], ["Y"]), + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + [ + numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"), + numpy_helper.from_array(np.array([1.67], dtype=dtype), name="cst"), + ], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "ReplaceZero", + ["X"], + ["Y"], + by=1.67, + domain="ai.onnx.contrib", + ) + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype) + + feeds1 = dict(X=x) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_allclose(expected, got, atol=1e-5) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_replace_zero_cuda(self): + self._replace_zero_cuda(TensorProto.FLOAT) + self._replace_zero_cuda(TensorProto.FLOAT16) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tools/test_cibuildwheel.bat b/tools/test_cibuildwheel.bat index ca101d96e..ca38b983a 100644 --- a/tools/test_cibuildwheel.bat +++ b/tools/test_cibuildwheel.bat @@ -2,7 +2,7 @@ if "%OCOS_ENABLE_AZURE%"=="1" ( pushd %1\test python -m pip install coloredlogs flatbuffers numpy packaging protobuf sympy - python -m pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ ort-nightly==1.16.0.dev20230820005 + python -m pip install onnxruntime==1.18 python test_azure_ops.py popd ) diff --git a/tools/test_cibuildwheel.sh b/tools/test_cibuildwheel.sh index 8197364b1..07d748879 100755 --- a/tools/test_cibuildwheel.sh +++ b/tools/test_cibuildwheel.sh @@ -4,7 +4,7 @@ if [[ "$OCOS_ENABLE_AZURE" == "1" ]] then pushd $1/test python -m pip install coloredlogs flatbuffers numpy packaging protobuf sympy - python -m pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ ort-nightly==1.16.0.dev20230820005 + python -m pip install onnxruntime==1.18 python ./test_azure_ops.py popd fi