Skip to content

Commit

Permalink
Support clone of SGObjects' arrays (#4239)
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn authored and karlnapf committed Apr 12, 2018
1 parent aac83db commit 3afb987
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 2 deletions.
14 changes: 14 additions & 0 deletions src/shogun/lib/any.cpp
Expand Up @@ -40,6 +40,20 @@ namespace shogun
namespace any_detail
{

inline CSGObject* copy_object(CSGObject* object)
{
if (!object)
{
return nullptr;
}
return object->clone();
}

void copy_array(CSGObject** begin, CSGObject** end, CSGObject** dst)
{
std::transform(begin, end, dst, copy_object);
}

#ifndef REAL_COMPARE_IMPL
#define REAL_COMPARE_IMPL(real_t) \
template <> \
Expand Down
20 changes: 18 additions & 2 deletions src/shogun/lib/any.h
Expand Up @@ -327,18 +327,34 @@ namespace shogun
template <class T, class S>
inline auto free_array(T* ptr, S size)
{
if (!ptr)
{
return;
}
SG_FREE(ptr);
}

template <class S>
inline auto free_array(CSGObject** ptr, S size)
{
if (!ptr)
{
return;
}
for (S i = 0; i < size; ++i)
{
ptr[i]->unref();
}
SG_FREE(ptr);
}

template <class T>
inline void copy_array(T* begin, T* end, T* dst)
{
std::copy(begin, end, dst);
}

void copy_array(CSGObject** begin, CSGObject** end, CSGObject** dst);
}

using any_detail::typed_pointer;
Expand Down Expand Up @@ -371,7 +387,7 @@ namespace shogun
any_detail::free_array(dst, len);
dst = new T[len];
*(this->m_length) = len;
std::copy(src, src + len, dst);
any_detail::copy_array(src, src + len, dst);
}

template <class T, class S>
Expand Down Expand Up @@ -403,7 +419,7 @@ namespace shogun
dst = new T[rows * cols];
*(this->m_rows) = rows;
*(this->m_cols) = cols;
std::copy(src, src + (rows * cols), dst);
any_detail::copy_array(src, src + (rows * cols), dst);
}

/** @brief An interface for a policy to store a value.
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/lib/Any_unittest.cc
Expand Up @@ -84,6 +84,11 @@ class Object : public CSGObject
{
return "Object";
}

CSGObject* create_empty() const override
{
return new Object();
}
};

TEST(Any, as)
Expand Down Expand Up @@ -507,6 +512,60 @@ TEST(Any, clone_array2d)
delete[] src;
}

TEST(Any, clone_array_sgobject)
{
int src_len = 3;
CSGObject** src = new CSGObject*[src_len];
std::generate(src, src + src_len, []() { return new Object(); });
auto any_src = make_any_ref(&src, &src_len);

CSGObject** dst = nullptr;
int dst_len = 0;
auto any_dst = make_any_ref(&dst, &dst_len);
any_dst.clone_from(any_src);

EXPECT_EQ(src_len, dst_len);
EXPECT_NE(src, dst);
for (int i = 0; i < dst_len; i++)
{
EXPECT_NE(src[i], dst[i]);
EXPECT_TRUE(src[i]->equals(dst[i]));
}
EXPECT_NE(dst, nullptr);
EXPECT_EQ(any_src, any_dst);

delete[] src;
}

TEST(Any, clone_array2d_sgobject)
{
int src_rows = 5;
int src_cols = 4;
int src_size = src_rows * src_cols;
CSGObject** src = new CSGObject*[src_size];
std::generate(src, src + src_size, []() { return new Object(); });
auto any_src = make_any_ref(&src, &src_rows, &src_cols);

int dst_rows = 0;
int dst_cols = 0;
CSGObject** dst = nullptr;
auto any_dst = make_any_ref(&dst, &dst_rows, &dst_cols);
any_dst.clone_from(any_src);

EXPECT_EQ(src_rows, dst_rows);
EXPECT_EQ(src_cols, dst_cols);
EXPECT_NE(src, dst);
for (int i = 0; i < dst_rows * dst_cols; i++)
{
EXPECT_NE(src[i], dst[i]);
EXPECT_TRUE(src[i]->equals(dst[i]));
}
EXPECT_NE(dst, nullptr);
EXPECT_EQ(any_src, any_dst);

delete[] src;
}

TEST(Any, free_array_simple)
{
auto size = 4;
Expand Down

0 comments on commit 3afb987

Please sign in to comment.