From d82576fc032028744b16ff6cbb5f3b44cf3bea5b Mon Sep 17 00:00:00 2001 From: Dave Bort Date: Tue, 2 Jan 2024 16:15:53 -0800 Subject: [PATCH 1/3] Non-fatal error for unknown KernelTypes type Summary: Don't fail fatally if the .pte file contains an unknown type enum value. Discovered by lionhead fuzzing. Differential Revision: D52493415 --- runtime/executor/method.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 70a6f7c1991..1bb7ea34710 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -388,12 +388,11 @@ Error Method::parse_values() { // subtract one to keep the output in 0 based indexing for a // disgruntled debugger seeing this error message and checking // schema.fbs - ET_CHECK_MSG( - false, - "Enum KernelTypes type: %" PRIu32 - " not supported. Please look in executorch/schema/program.fbs " - "to see which type this is.", + ET_LOG( + Error, + "Unknown KernelTypes value %" PRIu32, static_cast(serialization_value->val_type()) - 1); + return Error::InvalidProgram; } // ~Method() will try to clean up n_value_ entries in the values_ array. From eb40380600073d2140fa9543576930849b0cafec Mon Sep 17 00:00:00 2001 From: Dave Bort Date: Tue, 2 Jan 2024 16:15:53 -0800 Subject: [PATCH 2/3] Check for missing arrays in Program Summary: Flatbuffer array fields can be missing, so we need to check for `nullptr` before calling `size()` on them. Discovered by lionhead fuzzing. Differential Revision: D52493423 --- runtime/executor/program.cpp | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/runtime/executor/program.cpp b/runtime/executor/program.cpp index 37a3b8b2f2e..acdce187ed5 100644 --- a/runtime/executor/program.cpp +++ b/runtime/executor/program.cpp @@ -146,30 +146,33 @@ Result get_execution_plan( // Constant data may live inside the flatbuffer data (constant_buffer) or in a // separate segment (constant_segment). It should not be in both. - const auto& constant_buffer = flatbuffer_program->constant_buffer(); - const auto& constant_segment = flatbuffer_program->constant_segment(); - - // Check if the constant data is inside a separate segment. - if (constant_segment != nullptr && constant_segment->offsets()->size() > 0) { + const auto* constant_segment = flatbuffer_program->constant_segment(); + if (constant_segment != nullptr && constant_segment->offsets() != nullptr && + constant_segment->offsets()->size() > 0) { + // The constant data is inside a separate segment. + const auto* constant_buffer = flatbuffer_program->constant_buffer(); ET_CHECK_OR_RETURN_ERROR( - constant_buffer->size() == 0, - InvalidState, - "constant_buffer contains %u items, constant_segment.offsets contains %u items. Only one should be used.", + constant_buffer == nullptr || constant_buffer->size() == 0, + InvalidProgram, + "constant_buffer contains %u items, " + "constant_segment.offsets contains %u items. Only one should be used.", constant_buffer->size(), constant_segment->offsets()->size()); + const auto* segments = flatbuffer_program->segments(); + ET_CHECK_OR_RETURN_ERROR( + segments != nullptr, InvalidProgram, "No segments in program"); // Load constant segment. // TODO(T171839323): Add test for segment_index > num available segments. ET_CHECK_OR_RETURN_ERROR( - constant_segment->segment_index() < - flatbuffer_program->segments()->size(), - InvalidArgument, + constant_segment->segment_index() < segments->size(), + InvalidProgram, "Constant segment index %d invalid for program segments range %d", constant_segment->segment_index(), - flatbuffer_program->segments()->size()); + segments->size()); const executorch_flatbuffer::DataSegment* data_segment = - flatbuffer_program->segments()->Get(constant_segment->segment_index()); + segments->Get(constant_segment->segment_index()); Result constant_segment_data = loader->Load( segment_base_offset + data_segment->offset(), data_segment->size()); if (!constant_segment_data.ok()) { @@ -199,7 +202,12 @@ Result get_execution_plan( size_t Program::num_methods() const { auto internal_program = static_cast(internal_program_); - return internal_program->execution_plan()->size(); + const auto execution_plan = internal_program->execution_plan(); + if (execution_plan != nullptr) { + return execution_plan->size(); + } else { + return 0; + } } Result Program::get_method_name(size_t plan_index) const { From 94735dd0b6cb3ea4ad3a81f46052f616f1235780 Mon Sep 17 00:00:00 2001 From: Dave Bort Date: Tue, 2 Jan 2024 16:15:53 -0800 Subject: [PATCH 3/3] Catch invalid scalar type when parsing tensors Summary: Fail non-fatally when encountering an unknown/unhandled `ScalarType` in a `.pte` file. As part of this: - Move the "types not supported yet" logic out of `scalar_type_util` and into `tensor_parser`, since that decision is an aspect of the runtime and not a fundamental aspect of `ScalarType`. - Remove the now-duplicate `sizeof_scalar_type` function, which is the same as the exsting `elementSize` function. Before this diff, `sizeof_scalar_type` did the "unsupported" checks that have now moved. - Add an `isValid()` function to let users of `ScalarType` know whether a given enum value is legit. This makes it possible to avoid the fatal error when calling `elementSize` on a bad value. - Add unit tests for the new `isValid()`. Differential Revision: D52451738 --- .../core/exec_aten/util/scalar_type_util.h | 42 +++++-------------- .../util/test/scalar_type_util_test.cpp | 18 ++++++++ runtime/core/portable_type/tensor_impl.cpp | 8 ++-- runtime/executor/method_meta.cpp | 2 +- runtime/executor/tensor_parser_aten.cpp | 6 +++ runtime/executor/tensor_parser_portable.cpp | 16 ++++++- 6 files changed, 54 insertions(+), 38 deletions(-) diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 3a480b83326..b66bde5f534 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -253,6 +253,16 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) // Utility functions to retrieve metadata for a given ScalarType // +/** + * Returns true if the parameter is one of the values covered by + * ET_FORALL_SCALAR_TYPES. + */ +inline bool isValid(exec_aten::ScalarType type) { + return static_cast(type) >= 0 && + type < exec_aten::ScalarType::NumOptions && + type != exec_aten::ScalarType::Undefined; +} + /** * Returns the name of a ScalarType as a C string. * @@ -541,38 +551,6 @@ inline exec_aten::ScalarType promoteTypes( return _promoteTypesLookup[static_cast(a)][static_cast(b)]; } -/** - * Return the size of corresponding ctype given ScalarType. - */ -inline size_t sizeof_scalar_type(exec_aten::ScalarType type) { - // Reject types that are not yet supported or are out of bounds. - ET_CHECK_MSG( - type != exec_aten::ScalarType::Half && - type != exec_aten::ScalarType::ComplexHalf && - type != exec_aten::ScalarType::ComplexFloat && - type != exec_aten::ScalarType::ComplexDouble && - type != exec_aten::ScalarType::BFloat16 && - type != exec_aten::ScalarType::Undefined, - "Invalid or unsupported ScalarType %" PRId8, - static_cast(type)); - - size_t type_size = 0; -#define SCALAR_TYPE_SIZE(ctype, dtype) \ - case exec_aten::ScalarType::dtype: \ - type_size = sizeof(ctype); \ - break; - - switch (type) { - ET_FORALL_SCALAR_TYPES(SCALAR_TYPE_SIZE) - default: - ET_CHECK_MSG( - false, "Invalid input ScalarType %" PRId8, static_cast(type)); - } -#undef SCALAR_TYPE_SIZE - - return type_size; -} - // // Helper macros for switch case macros (see below) // diff --git a/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp b/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp index 333327a578c..5ee11d6ace8 100644 --- a/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp +++ b/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp @@ -70,6 +70,24 @@ TEST(ScalarTypeUtilTest, ElementSize) { } } +TEST(ScalarTypeUtilTest, IsValid) { + // Some valid types. + EXPECT_TRUE(torch::executor::isValid(ScalarType::Byte)); + EXPECT_TRUE(torch::executor::isValid(ScalarType::Float)); + EXPECT_TRUE(torch::executor::isValid(ScalarType::ComplexFloat)); + EXPECT_TRUE(torch::executor::isValid(ScalarType::Bits16)); + + // Undefined, which is sort of a special case since it's not part of the + // iteration macros but is still a part of the enum. + EXPECT_FALSE(torch::executor::isValid(ScalarType::Undefined)); + + // Some out-of-range types, also demonstrating that NumOptions is not really a + // scalar type. + EXPECT_FALSE(torch::executor::isValid(ScalarType::NumOptions)); + EXPECT_FALSE(torch::executor::isValid(static_cast(127))); + EXPECT_FALSE(torch::executor::isValid(static_cast(-1))); +} + TEST(ScalarTypeUtilTest, UnknownTypeElementSizeDies) { // Undefined, which is sort of a special case since it's not part of the // iteration macros but is still a part of the enum. diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index 21c2061129b..f0e51929367 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -47,12 +47,12 @@ TensorImpl::TensorImpl( data_(data), dim_(dim), numel_(compute_numel(sizes, dim)), - capacity_(numel_ * sizeof_scalar_type(type)), + capacity_(numel_ * elementSize(type)), type_(type), shape_dynamism_(dynamism) {} size_t TensorImpl::nbytes() const { - return numel_ * sizeof_scalar_type(type_); + return numel_ * elementSize(type_); } ssize_t TensorImpl::size(ssize_t dim) const { @@ -78,7 +78,7 @@ ScalarType TensorImpl::scalar_type() const { // Return the size of one element of the tensor ssize_t TensorImpl::element_size() const { - return sizeof_scalar_type(type_); + return elementSize(type_); } const ArrayRef TensorImpl::sizes() const { @@ -145,7 +145,7 @@ Error TensorImpl::internal_resize_contiguous(ArrayRef new_sizes) { // Upper bounded tensors can be reshaped but not beyond upper bound if (shape_dynamism_ == TensorShapeDynamism::DYNAMIC_BOUND) { - auto new_nbytes = new_numel * sizeof_scalar_type(type_); + auto new_nbytes = new_numel * elementSize(type_); ET_CHECK_OR_RETURN_ERROR( new_nbytes <= capacity_, NotSupported, diff --git a/runtime/executor/method_meta.cpp b/runtime/executor/method_meta.cpp index 0b8bd239910..1820cd0b8b6 100644 --- a/runtime/executor/method_meta.cpp +++ b/runtime/executor/method_meta.cpp @@ -59,7 +59,7 @@ size_t calculate_nbytes( for (ssize_t i = 0; i < sizes.size(); i++) { n *= sizes[i]; } - return n * sizeof_scalar_type(scalar_type); + return n * torch::executor::elementSize(scalar_type); } } // namespace diff --git a/runtime/executor/tensor_parser_aten.cpp b/runtime/executor/tensor_parser_aten.cpp index 8200d5db2fa..a2a72a9f57e 100644 --- a/runtime/executor/tensor_parser_aten.cpp +++ b/runtime/executor/tensor_parser_aten.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -43,6 +44,11 @@ Result parseTensor( // get metadata at::ScalarType type = static_cast(s_tensor->scalar_type()); + ET_CHECK_OR_RETURN_ERROR( + isValid(type), + InvalidProgram, + "Invalid ScalarType %" PRId8, + static_cast(type)); auto options = at::CPU(type).options(); // convert int32 in serialization to int64 for aten diff --git a/runtime/executor/tensor_parser_portable.cpp b/runtime/executor/tensor_parser_portable.cpp index 5c01fd5f6d4..ac04f5e54aa 100644 --- a/runtime/executor/tensor_parser_portable.cpp +++ b/runtime/executor/tensor_parser_portable.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,19 @@ Result parseTensor( "Non-zero storage offset %" PRId32 " not supported", s_tensor->storage_offset()); + ScalarType scalar_type = static_cast(s_tensor->scalar_type()); + ET_CHECK_OR_RETURN_ERROR( + isValid(scalar_type) && + // Types not yet supported by ExecuTorch. + scalar_type != exec_aten::ScalarType::Half && + scalar_type != exec_aten::ScalarType::ComplexHalf && + scalar_type != exec_aten::ScalarType::ComplexFloat && + scalar_type != exec_aten::ScalarType::ComplexDouble && + scalar_type != exec_aten::ScalarType::BFloat16, + InvalidProgram, + "Invalid or unsupported ScalarType %" PRId8, + static_cast(scalar_type)); + TensorShapeDynamism dynamism = static_cast(s_tensor->shape_dynamism()); // TODO(T133200526): Remove this check once fully dynamic shapes are @@ -90,7 +104,7 @@ Result parseTensor( // Placement new on the allocated memory space. Note that we create this first // with null data so we can find its expected size before getting its memory. new (tensor_impl) torch::executor::TensorImpl( - static_cast(s_tensor->scalar_type()), + scalar_type, dim, sizes, /*data=*/nullptr,