From 043c7af3cf483953a0ffd77be01f6418cad723e4 Mon Sep 17 00:00:00 2001 From: Github Executorch Date: Tue, 14 Jan 2025 11:08:21 -0800 Subject: [PATCH] Replace pytree_assert with production pytree_check. Remove pytree_unreachable When handling untrusted input, it's not appropriate to use debug-only checks; we should be checking in prod as these are not programmer errors. pytree_unreachable was similarly being used for input validation. Differential Revision: [D68166301](https://our.internmc.facebook.com/intern/diff/D68166301/) [ghstack-poisoned] --- extension/pytree/pytree.h | 58 +++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/extension/pytree/pytree.h b/extension/pytree/pytree.h index 5b88fca3e41..5e261b4b90a 100644 --- a/extension/pytree/pytree.h +++ b/extension/pytree/pytree.h @@ -24,8 +24,10 @@ namespace executorch { namespace extension { namespace pytree { -inline void pytree_assert(bool must_be_true) { - assert(must_be_true); +void pytree_check(bool must_be_true) { + if (!must_be_true) { + throw std::runtime_error("pytree assertion failed"); + } } #ifdef _MSC_VER @@ -36,18 +38,6 @@ inline void pytree_assert(bool must_be_true) { #define EXECUTORCH_ALWAYS_INLINE inline #endif -[[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() { - assert(false); -#if defined(__GNUC__) - __builtin_unreachable(); -#elif defined(_MSC_VER) - __assume(0); -#else - while (!0) - ; -#endif -} - enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None }; using KeyStr = std::string; @@ -143,45 +133,45 @@ struct ContainerHandle { : handle(std::move(c)) {} void set_leaf(leaf_type* leaf) { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); handle->leaf = leaf; } operator leaf_type() const { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return *handle->leaf; } const leaf_type& leaf() const { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return *handle->leaf; } leaf_type& leaf() { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return *handle->leaf; } const leaf_type* leaf_ptr() const { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return handle->leaf; } leaf_type* leaf_ptr() { - pytree_assert(handle->kind == Kind::Leaf); + pytree_check(handle->kind == Kind::Leaf); return handle->leaf; } const ContainerHandle& operator[](size_t idx) const { - pytree_assert(idx < handle->size); + pytree_check(idx < handle->size); return handle->items[idx]; } ContainerHandle& operator[](size_t idx) { - pytree_assert(idx < handle->size); + pytree_check(idx < handle->size); return handle->items[idx]; } bool contains(const KeyStr& lookup_key) const { - pytree_assert(isDict()); + pytree_check(isDict()); for (size_t i = 0; i < handle->size; ++i) { if (handle->keys[i] == lookup_key) { return true; @@ -191,13 +181,13 @@ struct ContainerHandle { } const ContainerHandle& at(const Key& lookup_key) const { - pytree_assert(isDict()); + pytree_check(isDict()); for (size_t i = 0; i < handle->size; ++i) { if (handle->keys[i] == lookup_key) { return handle->items[i]; } } - pytree_unreachable(); + throw std::runtime_error("Dict::at lookup failed"); } const ContainerHandle& at(const KeyInt& lookup_key) const { @@ -209,11 +199,11 @@ struct ContainerHandle { } const Key& key(size_t idx) const { - pytree_assert(isDict()); + pytree_check(isDict()); return handle->keys[idx]; } Key& key(size_t idx) { - pytree_assert(isDict()); + pytree_check(isDict()); return handle->keys[idx]; } @@ -398,7 +388,8 @@ StrTreeSpec to_str_internal(const TreeSpec& spec) { s.append(key.as_str()); s.push_back(Config::kDictStrKeyQuote); } else { - pytree_unreachable(); + throw std::runtime_error( + "invalid key in pytree dict; must be int or string"); } s.push_back(Config::kDictKeyValueSep); s.append(to_str_internal(spec[i])); @@ -472,6 +463,9 @@ struct arr { inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) { size_t num = 0; + if (!isdigit(spec.at(read_idx))) { + throw std::runtime_error("expected a number while decoding pytree"); + } while (isdigit(spec.at(read_idx))) { num = 10 * num + (spec[read_idx] - '0'); read_idx++; @@ -577,7 +571,6 @@ TreeSpec from_str_internal( c->keys[child_idx] = spec.substr(read_idx, key_len); read_idx = key_delim_idx + 2; } else { - pytree_assert(isdigit(spec[read_idx])); size_t key = read_number(spec, read_idx); c->keys[child_idx] = KeyInt(key); read_idx += 1; @@ -598,7 +591,6 @@ TreeSpec from_str_internal( case Config::kLeaf: return new TreeSpecContainer(nullptr); } - pytree_unreachable(); return new TreeSpecContainer(Kind::None); } @@ -610,17 +602,17 @@ struct stack final { T data[SIZE]; void push(T&& item) { - pytree_assert(size_ < SIZE); + pytree_check(size_ < SIZE); data[size_++] = std::move(item); } T pop() { - pytree_assert(size_ > 0); + pytree_check(size_ > 0); return data[--size_]; } T& top() { - pytree_assert(size_ > 0); + pytree_check(size_ > 0); return data[size_ - 1]; }