From 4e2b3e753dbb360cb56c94ba3038b61220602d67 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 14 Jan 2025 11:08:18 -0800 Subject: [PATCH] add a bunch of bounds checking to pytree It's possible to pass arbitrary string input to pytree from Python; let's not have a bunch of low-hanging memory safety issues. Differential Revision: [D68166303](https://our.internmc.facebook.com/intern/diff/D68166303/) [ghstack-poisoned] --- extension/pytree/pytree.h | 58 ++++++++++++++++++++------- extension/pytree/test/test_pytree.cpp | 19 +++++++++ 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/extension/pytree/pytree.h b/extension/pytree/pytree.h index b0f2ca1fab6..5b88fca3e41 100644 --- a/extension/pytree/pytree.h +++ b/extension/pytree/pytree.h @@ -60,7 +60,7 @@ struct Key { std::variant repr_; public: - Key() {} + Key() = default; /*implicit*/ Key(KeyInt key) : repr_(key) {} /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {} @@ -131,7 +131,7 @@ struct ContainerHandle { using leaf_type = T; std::unique_ptr handle; - ContainerHandle() {} + ContainerHandle() = default; template ContainerHandle(Args... args) @@ -427,6 +427,20 @@ struct arr { return data_[idx]; } + T& at(size_t idx) { + if (idx >= size()) { + throw std::out_of_range("bounds check failed in pytree arr"); + } + return data_[idx]; + } + + const T& at(size_t idx) const { + if (idx >= size()) { + throw std::out_of_range("bounds check failed in pytree arr"); + } + return data_[idx]; + } + inline T* data() { return data_.get(); } @@ -458,7 +472,7 @@ struct arr { inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) { size_t num = 0; - while (isdigit(spec[read_idx])) { + while (isdigit(spec.at(read_idx))) { num = 10 * num + (spec[read_idx] - '0'); read_idx++; } @@ -470,19 +484,22 @@ inline arr read_node_layout(const StrTreeSpec& spec, size_t& read_idx) { arr ret(child_num); size_t child_idx = 0; - while (spec[read_idx] == Config::kChildrenDataSep) { + while (spec.at(read_idx) == Config::kChildrenDataSep) { ++read_idx; - ret[child_idx++] = read_number(spec, read_idx); + ret.at(child_idx++) = read_number(spec, read_idx); } return ret; } +// spec_data comes from pre_parse, which guarantees 1) +// spec_data.size() == spec.size() and 2) contents of spec_data are +// in-bounds indices for spec, so we omit bounds checks for spec_data. template TreeSpec from_str_internal( const StrTreeSpec& spec, size_t read_idx, const arr& spec_data) { - const auto kind_char = spec[read_idx]; + const auto kind_char = spec.at(read_idx); switch (kind_char) { case Config::kTuple: case Config::kNamedTuple: @@ -496,7 +513,7 @@ TreeSpec from_str_internal( } else if (Config::kCustom == kind_char) { kind = Kind::Custom; read_idx++; - assert(spec[read_idx] == '('); + assert(spec.at(read_idx) == '('); auto type_str_end = spec_data[read_idx]; read_idx++; custom_type = spec.substr(read_idx, type_str_end - read_idx); @@ -515,10 +532,14 @@ TreeSpec from_str_internal( size_t leaves_offset = 0; if (size > 0) { - while (spec[read_idx] != Config::kNodeDataEnd) { + while (spec.at(read_idx) != Config::kNodeDataEnd) { // NOLINTNEXTLINE auto next_delim_idx = spec_data[read_idx]; read_idx++; + if (child_idx >= size) { + throw std::out_of_range( + "bounds check failed writing to pytree item"); + } c->items[child_idx] = from_str_internal(spec, read_idx, spec_data); read_idx = next_delim_idx; @@ -541,11 +562,14 @@ TreeSpec from_str_internal( size_t leaves_offset = 0; if (size > 0) { - while (spec[read_idx] != Config::kNodeDataEnd) { + while (spec.at(read_idx) != Config::kNodeDataEnd) { // NOLINTNEXTLINE auto next_delim_idx = spec_data[read_idx]; read_idx++; - if (spec[read_idx] == Config::kDictStrKeyQuote) { + if (child_idx >= size) { + throw std::out_of_range("bounds check failed decoding pytree dict"); + } + if (spec.at(read_idx) == Config::kDictStrKeyQuote) { auto key_delim_idx = spec_data[read_idx]; read_idx++; const size_t key_len = key_delim_idx - read_idx; @@ -562,7 +586,7 @@ TreeSpec from_str_internal( c->items[child_idx] = from_str_internal(spec, read_idx, spec_data); read_idx = next_delim_idx; - leaves_offset += layout[child_idx++]; + leaves_offset += layout.at(child_idx++); } } else { read_idx++; @@ -605,7 +629,9 @@ struct stack final { } }; +// We guarantee indicies in the result are in bounds. inline arr pre_parse(const StrTreeSpec& spec) { + // Invariant: indices in stack are in bounds. stack> stack; size_t i = 0; const size_t size = spec.size(); @@ -627,11 +653,15 @@ inline arr pre_parse(const StrTreeSpec& spec) { case Config::kDictStrKeyQuote: { size_t idx = i; i++; - while (spec[i] != Config::kDictStrKeyQuote) { + while (spec.at(i) != Config::kDictStrKeyQuote) { i++; } - ret[idx] = i; - ret[i] = idx; + if (i >= size) { + throw std::out_of_range( + "bounds check failed while parsing dictionary key"); + } + ret.at(idx) = i; + ret.at(i) = idx; break; } case Config::kChildrenSep: { diff --git a/extension/pytree/test/test_pytree.cpp b/extension/pytree/test/test_pytree.cpp index 0101bca3f55..39b58d8ce2d 100644 --- a/extension/pytree/test/test_pytree.cpp +++ b/extension/pytree/test/test_pytree.cpp @@ -183,3 +183,22 @@ TEST(pytree, FlattenNestedDict) { ASSERT_EQ(*leaves[i], items[i]); } } + +TEST(pytree, EmptySpec) { + Leaf items[1] = {9}; + EXPECT_THROW(unflatten("", items), std::out_of_range); +} + +TEST(pytree, BoundsCheckListLayout) { + // Malformed: layout one child, have two + std::string spec = "L1#1($,$)"; + Leaf items[2] = {11, 12}; + EXPECT_THROW(unflatten(spec, items), std::out_of_range); +} + +TEST(pytree, BoundsCheckDictLayout) { + // Malformed: layout one child, have two. + std::string spec = "D1#1('key0':$,'key1':$)"; + Leaf items[2] = {11, 12}; + EXPECT_THROW(unflatten(spec, items), std::out_of_range); +}