Skip to content

Commit

Permalink
Add SafeInt bounds checking to memory allocation size calculations. (…
Browse files Browse the repository at this point in the history
…#3022)

* Add SafeInt bounds checking to memory allocation size calculations.

* Fix TensorRT library includes
  • Loading branch information
skottmckay committed Feb 20, 2020
1 parent 21f9a8b commit a1db87b
Show file tree
Hide file tree
Showing 38 changed files with 273 additions and 115 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@
[submodule "cmake/external/onnx-tensorrt"]
path = cmake/external/onnx-tensorrt
url = https://github.com/stevenlix/onnx-tensorrt.git
[submodule "cmake/external/SafeInt/safeint"]
path = cmake/external/SafeInt/safeint
url = https://github.com/dcleblanc/SafeInt.git
9 changes: 9 additions & 0 deletions cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,15 @@
},
"type": "git"
}
},
{
"component": {
"git": {
"commitHash": "39de59d5bca3a226a241b29571abe1493d15d07a",
"repositoryUrl": "https://github.com/dcleblanc/SafeInt.git"
},
"type": "git"
}
}
],
"Version": 1
Expand Down
4 changes: 4 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ if(onnxruntime_PREFER_SYSTEM_LIB)
find_package(re2)
endif()

set(SAFEINT_INCLUDE_DIR ${REPO_ROOT}/cmake/external/SafeInt)
add_library(safeint_interface INTERFACE)
target_include_directories(safeint_interface INTERFACE ${SAFEINT_INCLUDE_DIR})

if(NOT TARGET re2::re2)
add_subdirectory(external/re2 EXCLUDE_FROM_ALL)
set_target_properties(re2 PROPERTIES FOLDER "External/re2")
Expand Down
1 change: 1 addition & 0 deletions cmake/external/SafeInt/safeint
Submodule safeint added at 39de59
2 changes: 1 addition & 1 deletion cmake/onnxruntime_codegen.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_codegen_common_sr
add_library(onnxruntime_codegen_tvm ${onnxruntime_codegen_common_srcs} ${onnxruntime_codegen_tvm_srcs})
set_target_properties(onnxruntime_codegen_tvm PROPERTIES FOLDER "ONNXRuntime")
target_include_directories(onnxruntime_codegen_tvm PRIVATE ${ONNXRUNTIME_ROOT} ${TVM_INCLUDES} ${MKLML_INCLUDE_DIR} ${eigen_INCLUDE_DIRS})
onnxruntime_add_include_to_target(onnxruntime_codegen_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
onnxruntime_add_include_to_target(onnxruntime_codegen_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf safeint_interface)
target_compile_options(onnxruntime_codegen_tvm PRIVATE ${DISABLED_WARNINGS_FOR_TVM})
# need onnx to build to create headers that this project includes
add_dependencies(onnxruntime_codegen_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES})
2 changes: 1 addition & 1 deletion cmake/onnxruntime_common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ if (onnxruntime_USE_MIMALLOC)
endif()
endif()

onnxruntime_add_include_to_target(onnxruntime_common date_interface)
onnxruntime_add_include_to_target(onnxruntime_common date_interface safeint_interface)
target_include_directories(onnxruntime_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}
PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/external/nsync/public")
if(NOT WIN32)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_framework.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_framework PRIVATE ONNXRUNTIME_ENABLE_INSTRUMENT)
endif()
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
onnxruntime_add_include_to_target(onnxruntime_framework onnxruntime_common onnx onnx_proto protobuf::libprotobuf)
onnxruntime_add_include_to_target(onnxruntime_framework onnxruntime_common onnx onnx_proto protobuf::libprotobuf safeint_interface)
set_target_properties(onnxruntime_framework PROPERTIES FOLDER "ONNXRuntime")
# need onnx to build to create headers that this project includes
add_dependencies(onnxruntime_framework ${onnxruntime_EXTERNAL_DEPENDENCIES})
Expand Down
6 changes: 3 additions & 3 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ add_library(onnxruntime_providers ${onnxruntime_providers_src})
if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
target_compile_options(onnxruntime_providers PRIVATE "/wd4244")
endif()
onnxruntime_add_include_to_target(onnxruntime_providers onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
onnxruntime_add_include_to_target(onnxruntime_providers onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf safeint_interface)

if (onnxruntime_USE_FEATURIZERS)
add_dependencies(onnxruntime_providers onnxruntime_featurizers onnxruntime_featurizers_comp)
Expand Down Expand Up @@ -250,7 +250,7 @@ if (onnxruntime_USE_TENSORRT)
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_tensorrt_cc_srcs})
add_library(onnxruntime_providers_tensorrt ${onnxruntime_providers_tensorrt_cc_srcs})
target_link_libraries(onnxruntime_providers_tensorrt ${onnxparser_link_libs} ${trt_link_libs})
onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf safeint_interface)
add_dependencies(onnxruntime_providers_tensorrt ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_CUDNN_HOME}/include ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/tensorrt DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
Expand Down Expand Up @@ -382,7 +382,7 @@ if (onnxruntime_USE_NUPHAR)

source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_nuphar_cc_srcs})
add_library(onnxruntime_providers_nuphar ${onnxruntime_providers_nuphar_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_nuphar onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
onnxruntime_add_include_to_target(onnxruntime_providers_nuphar onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf safeint_interface)
set_target_properties(onnxruntime_providers_nuphar PROPERTIES FOLDER "ONNXRuntime")
target_include_directories(onnxruntime_providers_nuphar PRIVATE ${ONNXRUNTIME_ROOT} ${TVM_INCLUDES} ${eigen_INCLUDE_DIRS})
set_target_properties(onnxruntime_providers_nuphar PROPERTIES LINKER_LANGUAGE CXX)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_session.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_session_srcs})

add_library(onnxruntime_session ${onnxruntime_session_srcs})
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/session DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core)
onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf safeint_interface)
if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
endif()
Expand Down
4 changes: 2 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function(AddTest)
else()
target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES})
endif()
onnxruntime_add_include_to_target(${_UT_TARGET} date_interface)
onnxruntime_add_include_to_target(${_UT_TARGET} date_interface safeint_interface)
target_include_directories(${_UT_TARGET} PRIVATE ${TEST_INC_DIR})
if (onnxruntime_USE_CUDA)
target_include_directories(${_UT_TARGET} PRIVATE ${CUDA_INCLUDE_DIRS} ${onnxruntime_CUDNN_HOME}/include)
Expand Down Expand Up @@ -511,7 +511,7 @@ if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
#TODO: fix the warnings, they are dangerous
target_compile_options(onnx_test_runner_common PRIVATE "/wd4244")
endif()
onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework onnxruntime_test_utils onnx onnx_proto re2::re2)
onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework onnxruntime_test_utils onnx onnx_proto re2::re2 safeint_interface)

add_dependencies(onnx_test_runner_common onnx_test_data_proto ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_include_directories(onnx_test_runner_common PRIVATE ${eigen_INCLUDE_DIRS} ${RE2_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/onnx ${ONNXRUNTIME_ROOT})
Expand Down
45 changes: 19 additions & 26 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,32 @@ class IAllocator {
virtual FencePtr CreateFence(const SessionState* /*unused*/) { return nullptr; }

static bool CalcMemSizeForArray(size_t nmemb, size_t size, size_t* out) noexcept {
return CalcMemSizeForArrayWithAlignment<0>(nmemb, size, out);
return CalcMemSizeForArrayWithAlignment(nmemb, size, 0, out);
}

/**
* Calculate the memory size for an array. The size is bounds checked using SafeInt.
* \tparam alignment must be power of 2
* \param nmemb Number of members or elements in the array
* \param size Size of each element
* \param out Total size required after any alignment is applied
* \return true, successful. false, overflow
*/
static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept ORT_MUST_USE_RESULT;

/**
* https://cwe.mitre.org/data/definitions/190.html
* \tparam alignment must be power of 2
* \param nmemb
* \param size
* \param out
* \param nmemb Number of members or elements in the array
* \param size Size of each element
* \param out Total size required after any alignment is applied
* \return true, successful. false, overflow
* \remarks This was the original API and was implemented in the header. Replaced with the above version
* implemented in the .cc file so that the SafeInt dependency is internal.
*/
template <size_t alignment>
static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept ORT_MUST_USE_RESULT;

/**
* allocate memory for an array which has nmemb items of data, each size bytes long
*/
Expand All @@ -192,7 +205,7 @@ class IAllocator {
template <size_t alignment>
void* AllocArrayWithAlignment(size_t nmemb, size_t size) {
size_t len;
if (!CalcMemSizeForArrayWithAlignment<alignment>(nmemb, size, &len))
if (!CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, &len))
return nullptr;
return Alloc(len);
}
Expand Down Expand Up @@ -229,27 +242,7 @@ class IAllocator {

template <size_t alignment>
bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept {
static constexpr size_t max_allowed = (size_t(1) << (size_t(std::numeric_limits<size_t>::digits >> 1))) - alignment;
static constexpr size_t max_size = std::numeric_limits<size_t>::max() - alignment;
static constexpr size_t alignment_mask = alignment - 1;

//Indeed, we only need to check if max_size / nmemb < size
//max_allowed is for avoiding unnecessary DIV.
if (nmemb >= max_allowed && max_size / nmemb < size) {
return false;
}

if (size >= max_allowed &&
nmemb > 0 && max_size / nmemb < size) {
return false;
}

if (alignment == 0)
*out = size * nmemb;
else
*out = (size * nmemb + alignment_mask) & ~static_cast<size_t>(alignment_mask);

return true;
return CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, out);
}

/**
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/providers/cpu/math/gemm_helper.h"
#include "core/providers/cpu/math/softmax.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "core/common/safeint.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -121,7 +122,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

// STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH)
auto gemm_data = allocator->Alloc(batch_size * sequence_length * 3 * hidden_size * element_size);
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
auto Q = reinterpret_cast<T*>(gemm_data);
auto K = Q + batch_size * sequence_length * hidden_size;
Expand Down Expand Up @@ -177,11 +178,12 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
}

// STEP.2: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, 1, 1)
auto scratch_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * sequence_length * element_size);
auto scratch_data = allocator->Alloc(
SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * sequence_length * element_size);
BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator));

{
auto scratch_broadcast_data = allocator->Alloc(batch_size * sequence_length * element_size);
auto scratch_broadcast_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * element_size);
BufferUniquePtr scratch_broadcast_buffer(scratch_broadcast_data, BufferDeleter(allocator));
memset(scratch_broadcast_data, 0, batch_size * sequence_length * element_size);
T* p_scratch_broadcast_current_data = reinterpret_cast<T*>(scratch_broadcast_data);
Expand Down Expand Up @@ -269,7 +271,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
}

// STEP.4: out_tmp(B, N, S, H) = P(B, N, S, S) x V(B, N, S, H)
auto out_tmp_data = allocator->Alloc(batch_size * num_heads_ * sequence_length * head_size * element_size);
auto out_tmp_data = allocator->Alloc(
SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * head_size * element_size);
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(allocator));

concurrency::ThreadPool::TryParallelFor(context->GetOperatorThreadPool(), batch_size * num_heads_, [&](int i) {
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/bias_gelu_fusion.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "bias_gelu_fusion.h"

#include "core/util/math_cpuonly.h"
#include "core/platform/threadpool.h"
#include "core/common/safeint.h"
#include "core/mlas/inc/mlas.h"
#include "core/platform/threadpool.h"
#include "core/util/math_cpuonly.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -42,7 +43,8 @@ Status BiasGelu<T>::Compute(OpKernelContext* ctx) const {
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));

BufferUniquePtr temp_data_buf_ptr = BufferUniquePtr(alloc->Alloc(sizeof(T) * X->Shape().Size()), BufferDeleter(alloc));
BufferUniquePtr temp_data_buf_ptr = BufferUniquePtr(alloc->Alloc(SafeInt<size_t>(sizeof(T)) * X->Shape().Size()),
BufferDeleter(alloc));
T* tmp_data = static_cast<T*>(temp_data_buf_ptr.get());

const T* X_data = X->template Data<T>();
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cpu/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "layer_norm.h"

#include "core/common/safeint.h"
#include "core/framework/tensor.h"
#include "core/platform/threadpool.h"
#include "core/providers/common.h"
Expand Down Expand Up @@ -70,7 +71,7 @@ Status LayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
if (mean != nullptr) {
mean_data = mean->template MutableData<T>();
} else {
auto mean_data_buf = alloc->Alloc(sizeof(T) * norm_count);
auto mean_data_buf = alloc->Alloc(SafeInt<size_t>(sizeof(T)) * norm_count);
mean_data_buf_ptr = BufferUniquePtr(mean_data_buf, BufferDeleter(alloc));
mean_data = static_cast<T*>(mean_data_buf_ptr.get());
}
Expand All @@ -82,7 +83,7 @@ Status LayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
if (inv_std_var != nullptr) {
inv_std_var_data = inv_std_var->template MutableData<T>();
} else {
auto inv_std_var_data_buf = alloc->Alloc(sizeof(T) * norm_count);
auto inv_std_var_data_buf = alloc->Alloc(SafeInt<size_t>(sizeof(T)) * norm_count);
inv_std_var_data_buf_ptr = BufferUniquePtr(inv_std_var_data_buf, BufferDeleter(alloc));
inv_std_var_data = static_cast<T*>(inv_std_var_data_buf_ptr.get());
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/codegen/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/codegen/common/utils.h"
#include "core/common/cpuid_info.h"
#include "core/common/make_unique.h"
#include "core/common/safeint.h"

#include <stdlib.h>
#include <string.h>
Expand Down Expand Up @@ -46,7 +47,7 @@ bool IsEnvVarDefined(const char* var) {
}

int64_t TotalSize(const std::vector<int64_t>& shape) {
int64_t total = 1;
SafeInt<int64_t> total = 1;
for (auto s : shape) {
total *= s;
}
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/core/common/safeint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/exceptions.h"

// Setup the infrastructure to throw ORT exceptions so they're caught by existing handlers.

template <typename E>
class SafeIntExceptionHandler;

template <>
class SafeIntExceptionHandler<onnxruntime::OnnxRuntimeException> {
public:
static void SafeIntOnOverflow() {
ORT_THROW("Integer overflow");
}

static void SafeIntOnDivZero() {
ORT_THROW("Divide by zero");
}
};

#define SAFEINT_EXCEPTION_HANDLER_CPP 1
#define SafeIntDefaultExceptionHandler SafeIntExceptionHandler<onnxruntime::OnnxRuntimeException>
#include "safeint/SafeInt.hpp"
22 changes: 22 additions & 0 deletions onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/safeint.h"
#include "core/framework/allocator.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/utils.h"
Expand All @@ -10,6 +11,27 @@

namespace onnxruntime {

// private helper for calculation so SafeInt usage doesn't bleed into the public allocator.h header
bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept {
bool ok = true;

try {
SafeInt<size_t> alloc_size(size);
if (alignment == 0) {
*out = alloc_size * nmemb;
} else {
size_t alignment_mask = alignment - 1;
*out = (alloc_size * nmemb + alignment_mask) & ~static_cast<size_t>(alignment_mask);
}
} catch (const OnnxRuntimeException& ex) {
// overflow in calculating the size thrown by SafeInt.
LOGS_DEFAULT(ERROR) << ex.what();
ok = false;
}

return ok;
}

#ifdef USE_MIMALLOC
void* MiMallocAllocator::Alloc(size_t size) {
return mi_malloc(size);
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/framework/error_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/session/onnxruntime_c_api.h"
#include "core/session/ort_apis.h"
#include "core/common/status.h"
#include "core/common/safeint.h"
#include "core/framework/error_code_helper.h"
#include <cassert>
using onnxruntime::common::Status;
Expand All @@ -16,7 +17,7 @@ struct OrtStatus {
//Even we say it may not return NULL, indeed it may.
ORT_EXPORT _Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtApis::CreateStatus(OrtErrorCode code, _In_ const char* msg) NO_EXCEPTION {
assert(!(code == 0 && msg != nullptr));
size_t clen = strlen(msg);
SafeInt<size_t> clen(strlen(msg));
OrtStatus* p = reinterpret_cast<OrtStatus*>(::malloc(sizeof(OrtStatus) + clen));
if (p == nullptr) return nullptr; // OOM. What we can do here? abort()?
p->code = code;
Expand All @@ -29,7 +30,7 @@ namespace onnxruntime {
OrtStatus* ToOrtStatus(const Status& st) {
if (st.IsOK())
return nullptr;
size_t clen = st.ErrorMessage().length();
SafeInt<size_t> clen(st.ErrorMessage().length());
OrtStatus* p = reinterpret_cast<OrtStatus*>(::malloc(sizeof(OrtStatus) + clen));
p->code = static_cast<OrtErrorCode>(st.Code());
memcpy(p->msg, st.ErrorMessage().c_str(), clen);
Expand Down
Loading

0 comments on commit a1db87b

Please sign in to comment.