diff --git a/extension/pytree/pytree.h b/extension/pytree/pytree.h index 78e2305fe3e..515ba8b7402 100644 --- a/extension/pytree/pytree.h +++ b/extension/pytree/pytree.h @@ -15,6 +15,7 @@ #include #include #include +#include // NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime. #include @@ -55,29 +56,28 @@ using KeyInt = int32_t; struct Key { enum class Kind : uint8_t { None, Int, Str } kind_; - KeyInt as_int_ = {}; - KeyStr as_str_ = {}; + private: + std::variant repr_; - Key() : kind_(Kind::None) {} - /*implicit*/ Key(KeyInt key) : kind_(Kind::Int), as_int_(std::move(key)) {} - /*implicit*/ Key(KeyStr key) : kind_(Kind::Str), as_str_(std::move(key)) {} + public: + Key() {} + /*implicit*/ Key(KeyInt key) : repr_(key) {} + /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {} - const Kind& kind() const { - return kind_; + Kind kind() const { + return static_cast(repr_.index()); } - const KeyInt& as_int() const { - pytree_assert(kind_ == Key::Kind::Int); - return as_int_; + KeyInt as_int() const { + return std::get(repr_); } - operator const KeyInt&() const { + operator KeyInt() const { return as_int(); } const KeyStr& as_str() const { - pytree_assert(kind_ == Key::Kind::Str); - return as_str_; + return std::get(repr_); } operator const KeyStr&() const { @@ -85,21 +85,7 @@ struct Key { } bool operator==(const Key& rhs) const { - if (kind_ != rhs.kind_) { - return false; - } - switch (kind_) { - case Kind::Str: { - return as_str_ == rhs.as_str_; - } - case Kind::Int: { - return as_int_ == rhs.as_int_; - } - case Kind::None: { - return true; - } - } - pytree_unreachable(); + return repr_ == rhs.repr_; } bool operator!=(const Key& rhs) const { @@ -153,6 +139,9 @@ struct ContainerHandle { /*implicit*/ ContainerHandle(container_type* c) : handle(c) {} + /*implicit*/ ContainerHandle(std::unique_ptr c) + : handle(std::move(c)) {} + void set_leaf(leaf_type* leaf) { pytree_assert(handle->kind == Kind::Leaf); handle->leaf = leaf; @@ -500,10 +489,10 @@ TreeSpec from_str_internal( read_idx++; auto layout = read_node_layout(spec, read_idx); const auto size = layout.size(); - auto c = new TreeSpecContainer(kind, size); + auto c = std::make_unique>(kind, size); if (Kind::Custom == kind) { - c->custom_type = custom_type; + c->custom_type = std::move(custom_type); } size_t child_idx = 0; @@ -523,14 +512,14 @@ TreeSpec from_str_internal( read_idx++; } c->leaves_num = leaves_offset; - return c; + return TreeSpec(std::move(c)); } case Config::kDict: { read_idx++; auto layout = read_node_layout(spec, read_idx); const auto size = layout.size(); - auto c = new TreeSpecContainer(Kind::Dict, size); + auto c = std::make_unique>(Kind::Dict, size); size_t child_idx = 0; size_t leaves_offset = 0; @@ -563,7 +552,7 @@ TreeSpec from_str_internal( read_idx++; } c->leaves_num = leaves_offset; - return c; + return TreeSpec(std::move(c)); } case Config::kLeaf: