diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 6c505f8b656c..3086fa18add6 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -887,12 +887,12 @@ c10::intrusive_ptr ivalue::Object::create( } IValue IValue::deepcopy(std::optional device) const { - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return deepcopy(memo, device); } IValue IValue::deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device) const { if (memo.count(*this)) { return memo.at(*this); @@ -1028,12 +1028,12 @@ c10::intrusive_ptr ivalue::Object::copy_to_weak_compilation_ref( c10::intrusive_ptr ivalue::Object::deepcopy( std::optional device) const { - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return deepcopy(memo, device); } c10::intrusive_ptr ivalue::Object::deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device) const { auto cu = type_.cu_; auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes()); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 7715ffbe3c31..922b10b8efeb 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -1117,6 +1117,23 @@ struct TORCH_API IValue final { using HashAliasedIValueMap = std::unordered_map; + struct HashIdentityIValue { + size_t operator()(const IValue& val) const { + return val.payload.u.as_int; + } + }; + + struct CompIdentityIValues { + bool operator()(const IValue& lhs, const IValue& rhs) const { + return lhs.is(rhs); + } + }; + + using HashIdentityIValues = + std::unordered_set; + using HashIdentityIValueMap = + std::unordered_map; + // Chechs if this and rhs has a subvalues in common. // [t1,t2] and [t2, t3] returns true. bool overlaps(const IValue& rhs) const; @@ -1130,7 +1147,7 @@ struct TORCH_API IValue final { void visit(const std::function& visitor) const; IValue deepcopy(std::optional device = c10::nullopt) const; IValue deepcopy( - HashAliasedIValueMap& memo, + HashIdentityIValueMap& memo, std::optional device = c10::nullopt) const; private: diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index b1124c12cfb3..b99229f2759c 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -1589,7 +1589,7 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { std::optional device = c10::nullopt) const; c10::intrusive_ptr deepcopy( - IValue::HashAliasedIValueMap& memo, + IValue::HashIdentityIValueMap& memo, std::optional device = c10::nullopt) const; bool is_weak_compilation_ref() const { diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 42b67d8cb25c..b0e296ad2309 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -10,6 +10,7 @@ set(TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/functional.cpp ${TORCH_API_TEST_DIR}/init.cpp ${TORCH_API_TEST_DIR}/integration.cpp + ${TORCH_API_TEST_DIR}/ivalue.cpp ${TORCH_API_TEST_DIR}/jit.cpp ${TORCH_API_TEST_DIR}/memory.cpp ${TORCH_API_TEST_DIR}/meta_tensor.cpp diff --git a/test/cpp/api/ivalue.cpp b/test/cpp/api/ivalue.cpp new file mode 100644 index 000000000000..fa8dcc25cd4d --- /dev/null +++ b/test/cpp/api/ivalue.cpp @@ -0,0 +1,63 @@ +#include + +#include + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include + +using namespace torch::test; +using namespace torch::nn; +using namespace torch::optim; + +TEST(IValueTest, DeepcopyTensors) { + torch::Tensor t0 = torch::randn({2, 3}); + torch::Tensor t1 = torch::randn({3, 4}); + torch::Tensor t2 = t0.detach(); + torch::Tensor t3 = t0; + torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2); + std::vector tensor_vector = {t0, t1, t2, t3, t4}; + c10::List tensor_list(tensor_vector); + torch::IValue tensor_list_ivalue(tensor_list); + + c10::IValue::CompIdentityIValues ivalue_compare; + + // Make sure our setup configuration is correct + ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get())); + ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get())); + ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get())); + + c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy(); + c10::List copied_list = copied_ivalue.toList(); + + // Make sure our setup configuration is correct + ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get())); + ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get())); + ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get())); + ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get())); + // NOTE: this is actually incorrect. Ideally, these _should_ be aliases. + ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get())); + + ASSERT_TRUE(copied_list[0].get().toTensor().allclose( + tensor_list[0].get().toTensor())); + ASSERT_TRUE(copied_list[1].get().toTensor().allclose( + tensor_list[1].get().toTensor())); + ASSERT_TRUE(copied_list[2].get().toTensor().allclose( + tensor_list[2].get().toTensor())); + ASSERT_TRUE(copied_list[3].get().toTensor().allclose( + tensor_list[3].get().toTensor())); + ASSERT_TRUE(copied_list[4].get().toTensor().allclose( + tensor_list[4].get().toTensor())); +} diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index 1b9932ed34d4..45b99eb8e47a 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -323,7 +323,7 @@ Module Module::deepcopy(std::optional device) const { Module Module::clone(bool inplace) const { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; const std::unordered_set ignored_methods; const std::unordered_set ignored_attributes; return clone_impl( @@ -335,7 +335,7 @@ Module Module::clone( const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return clone_impl( type_remap, inplace, memo, ignored_methods, ignored_attributes); } @@ -343,7 +343,7 @@ Module Module::clone( Module Module::clone_impl( std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo, + IValue::HashIdentityIValueMap memo, const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const { // Create a new _ivalue in the same compilation unit. diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 0787210a4aef..e779542e315f 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -301,7 +301,7 @@ struct TORCH_API Module : public Object { Module clone_impl( std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo, + IValue::HashIdentityIValueMap memo, const std::unordered_set& ignored_methods, const std::unordered_set& ignored_attributes) const; diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp index e5df64f1929c..de1cff1ba9d1 100644 --- a/torch/csrc/jit/passes/quantization/insert_observers.cpp +++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp @@ -92,7 +92,7 @@ class ModuleCloneHelper { const ModuleQConfigMap& module_qconfig_map, bool inplace = false) { std::unordered_map type_remap; - IValue::HashAliasedIValueMap memo; + IValue::HashIdentityIValueMap memo; return clone_impl( module, module_qconfig_map, type_remap, inplace, std::move(memo)); } @@ -103,7 +103,7 @@ class ModuleCloneHelper { const ModuleQConfigMap& module_qconfig_map, std::unordered_map& type_remap, bool inplace, - IValue::HashAliasedIValueMap memo) { + IValue::HashIdentityIValueMap memo) { auto qconfig = module_qconfig_map.at(module._ivalue()); auto type = module.type(); // Create a new _ivalue in the same compilation unit. diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 971b6c76ca47..c46762a88615 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -668,13 +668,13 @@ static constexpr std::array magic_method_names = { }; struct DeepCopyMemoTable { - std::shared_ptr map; + std::shared_ptr map; }; IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) { if (!memo.contains(py::str("__torch_script_memo_table"))) { memo["__torch_script_memo_table"] = - DeepCopyMemoTable{std::make_shared()}; + DeepCopyMemoTable{std::make_shared()}; } auto& ivalue_memo = *py::cast(memo["__torch_script_memo_table"]).map;