From 3afb9870d66fece8494c3f54756288f66ecb01c5 Mon Sep 17 00:00:00 2001 From: Sergey Lisitsyn Date: Thu, 12 Apr 2018 10:28:07 +0300 Subject: [PATCH] Support clone of SGObjects' arrays (#4239) --- src/shogun/lib/any.cpp | 14 ++++++++ src/shogun/lib/any.h | 20 ++++++++++-- tests/unit/lib/Any_unittest.cc | 59 ++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/src/shogun/lib/any.cpp b/src/shogun/lib/any.cpp index 9d055b9c164..fffe492080b 100644 --- a/src/shogun/lib/any.cpp +++ b/src/shogun/lib/any.cpp @@ -40,6 +40,20 @@ namespace shogun namespace any_detail { + inline CSGObject* copy_object(CSGObject* object) + { + if (!object) + { + return nullptr; + } + return object->clone(); + } + + void copy_array(CSGObject** begin, CSGObject** end, CSGObject** dst) + { + std::transform(begin, end, dst, copy_object); + } + #ifndef REAL_COMPARE_IMPL #define REAL_COMPARE_IMPL(real_t) \ template <> \ diff --git a/src/shogun/lib/any.h b/src/shogun/lib/any.h index 164e7ba3b4e..cd3157310ec 100644 --- a/src/shogun/lib/any.h +++ b/src/shogun/lib/any.h @@ -327,18 +327,34 @@ namespace shogun template inline auto free_array(T* ptr, S size) { + if (!ptr) + { + return; + } SG_FREE(ptr); } template inline auto free_array(CSGObject** ptr, S size) { + if (!ptr) + { + return; + } for (S i = 0; i < size; ++i) { ptr[i]->unref(); } SG_FREE(ptr); } + + template + inline void copy_array(T* begin, T* end, T* dst) + { + std::copy(begin, end, dst); + } + + void copy_array(CSGObject** begin, CSGObject** end, CSGObject** dst); } using any_detail::typed_pointer; @@ -371,7 +387,7 @@ namespace shogun any_detail::free_array(dst, len); dst = new T[len]; *(this->m_length) = len; - std::copy(src, src + len, dst); + any_detail::copy_array(src, src + len, dst); } template @@ -403,7 +419,7 @@ namespace shogun dst = new T[rows * cols]; *(this->m_rows) = rows; *(this->m_cols) = cols; - std::copy(src, src + (rows * cols), dst); + any_detail::copy_array(src, src + (rows * cols), dst); } /** @brief An interface for a policy to store a value. diff --git a/tests/unit/lib/Any_unittest.cc b/tests/unit/lib/Any_unittest.cc index eadbd0150a2..314d547c999 100644 --- a/tests/unit/lib/Any_unittest.cc +++ b/tests/unit/lib/Any_unittest.cc @@ -84,6 +84,11 @@ class Object : public CSGObject { return "Object"; } + + CSGObject* create_empty() const override + { + return new Object(); + } }; TEST(Any, as) @@ -507,6 +512,60 @@ TEST(Any, clone_array2d) delete[] src; } +TEST(Any, clone_array_sgobject) +{ + int src_len = 3; + CSGObject** src = new CSGObject*[src_len]; + std::generate(src, src + src_len, []() { return new Object(); }); + auto any_src = make_any_ref(&src, &src_len); + + CSGObject** dst = nullptr; + int dst_len = 0; + auto any_dst = make_any_ref(&dst, &dst_len); + any_dst.clone_from(any_src); + + EXPECT_EQ(src_len, dst_len); + EXPECT_NE(src, dst); + for (int i = 0; i < dst_len; i++) + { + EXPECT_NE(src[i], dst[i]); + EXPECT_TRUE(src[i]->equals(dst[i])); + } + EXPECT_NE(dst, nullptr); + EXPECT_EQ(any_src, any_dst); + + delete[] src; +} + +TEST(Any, clone_array2d_sgobject) +{ + int src_rows = 5; + int src_cols = 4; + int src_size = src_rows * src_cols; + CSGObject** src = new CSGObject*[src_size]; + std::generate(src, src + src_size, []() { return new Object(); }); + auto any_src = make_any_ref(&src, &src_rows, &src_cols); + + int dst_rows = 0; + int dst_cols = 0; + CSGObject** dst = nullptr; + auto any_dst = make_any_ref(&dst, &dst_rows, &dst_cols); + any_dst.clone_from(any_src); + + EXPECT_EQ(src_rows, dst_rows); + EXPECT_EQ(src_cols, dst_cols); + EXPECT_NE(src, dst); + for (int i = 0; i < dst_rows * dst_cols; i++) + { + EXPECT_NE(src[i], dst[i]); + EXPECT_TRUE(src[i]->equals(dst[i])); + } + EXPECT_NE(dst, nullptr); + EXPECT_EQ(any_src, any_dst); + + delete[] src; +} + TEST(Any, free_array_simple) { auto size = 4;