diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 692ef9eec64..8e7869bfa55 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -36,17 +36,13 @@ class TensorPtr { return static_cast(tensor_impl_); } - exec_aten::Tensor* get() const { - return tensor_impl_ ? &tensor_ : nullptr; - } - exec_aten::Tensor* operator->() const { - return get(); + return tensor_impl_ ? &tensor_ : nullptr; } exec_aten::Tensor& operator*() const { ET_DCHECK(*this != nullptr); - return *get(); + return *operator->(); } void reset() { diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 800955dcbfb..47b9c20dfc6 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -26,21 +26,17 @@ TEST_F(TensorPtrTest, BasicSmartPointerAccess) { TensorPtr p; EXPECT_FALSE(p); EXPECT_EQ(p, nullptr); - EXPECT_EQ(p.get(), nullptr); EXPECT_EQ(p.operator->(), nullptr); TensorPtr p2 = make_tensor_ptr({1}, nullptr, {}, {}); EXPECT_TRUE(p2); EXPECT_NE(p2, nullptr); - ASSERT_NE(p2.get(), nullptr); ASSERT_NE(p2.operator->(), nullptr); - EXPECT_EQ(p2.get(), p2.operator->()); EXPECT_EQ(p2->dim(), 1); EXPECT_EQ((*p2).dim(), 1); EXPECT_NE(p, p2); p2.reset(); EXPECT_FALSE(p2); EXPECT_EQ(p2, nullptr); - EXPECT_EQ(p2.get(), nullptr); EXPECT_EQ(p2.operator->(), nullptr); EXPECT_EQ(p, p2); } diff --git a/extension/training/examples/XOR/train.cpp b/extension/training/examples/XOR/train.cpp index 26ab3f9c67a..bca433fd889 100644 --- a/extension/training/examples/XOR/train.cpp +++ b/extension/training/examples/XOR/train.cpp @@ -86,8 +86,8 @@ int main(int argc, char** argv) { for (int i = 0; i < num_epochs; i++) { int index = dist(URBG); auto& data = data_set[index]; - const auto& results = mod.execute_forward_backward( - "forward", {*data.first.get(), *data.second.get()}); + const auto& results = + mod.execute_forward_backward("forward", {*data.first, *data.second}); if (results.error() != Error::Ok) { ET_LOG(Error, "Failed to execute forward_backward"); return 1;