Skip to content

Commit

Permalink
Update on "Use object identity for deepcopy memo"
Browse files Browse the repository at this point in the history
Copy of #126089, with some fixes & tests (TODO)

[TODO description]


[ghstack-poisoned]
  • Loading branch information
davidberard98 committed May 14, 2024
1 parent 2a7cfe5 commit 2326b7d
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/cpp/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(TORCH_API_TEST_SOURCES
${TORCH_API_TEST_DIR}/functional.cpp
${TORCH_API_TEST_DIR}/init.cpp
${TORCH_API_TEST_DIR}/integration.cpp
${TORCH_API_TEST_DIR}/ivalue.cpp
${TORCH_API_TEST_DIR}/jit.cpp
${TORCH_API_TEST_DIR}/memory.cpp
${TORCH_API_TEST_DIR}/meta_tensor.cpp
Expand Down
58 changes: 58 additions & 0 deletions test/cpp/api/ivalue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <gtest/gtest.h>

#include <ATen/core/ivalue.h>

#include <c10/util/flat_hash_map.h>
#include <c10/util/irange.h>
#include <c10/util/tempfile.h>

#include <torch/torch.h>

#include <test/cpp/api/support.h>

#include <cstdio>
#include <memory>
#include <sstream>
#include <string>
#include <vector>

using namespace torch::test;
using namespace torch::nn;
using namespace torch::optim;

TEST(IValueTest, DeepcopyTensors) {
torch::Tensor t0 = torch::randn({2, 3});
torch::Tensor t1 = torch::randn({3, 4});
torch::Tensor t2 = t0.detach();
torch::Tensor t3 = t0;
torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2);
std::vector<torch::Tensor> tensor_vector = {t0, t1, t2, t3, t4};
c10::List<torch::Tensor> tensor_list(tensor_vector);
torch::IValue tensor_list_ivalue(tensor_list);

c10::IValue::CompIdentityIValues ivalue_compare;

// Make sure our setup configuration is correct
ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get()));
ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get()));
ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get()));
ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get()));
ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get()));

c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy();
c10::List<torch::IValue> copied_list = copied_ivalue.toList();

// Make sure our setup configuration is correct
ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get()));
ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get()));
ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get()));
ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get()));
// NOTE: this is actually incorrect. Ideally, these _should_ be aliases.
ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get()));

ASSERT_TRUE(copied_list[0].get().toTensor().allclose(tensor_list[0].get().toTensor()));
ASSERT_TRUE(copied_list[1].get().toTensor().allclose(tensor_list[1].get().toTensor()));
ASSERT_TRUE(copied_list[2].get().toTensor().allclose(tensor_list[2].get().toTensor()));
ASSERT_TRUE(copied_list[3].get().toTensor().allclose(tensor_list[3].get().toTensor()));
ASSERT_TRUE(copied_list[4].get().toTensor().allclose(tensor_list[4].get().toTensor()));
}

0 comments on commit 2326b7d

Please sign in to comment.