Skip to content

Commit

Permalink
Implement array reference with any (#4096)
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Jan 21, 2018
1 parent a66e13f commit 2b3a863
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 24 deletions.
45 changes: 45 additions & 0 deletions src/shogun/lib/any.h
Expand Up @@ -37,6 +37,7 @@

#include <shogun/base/init.h>

#include <algorithm>
#include <limits>
#include <stdexcept>
#include <string.h>
Expand Down Expand Up @@ -80,6 +81,28 @@ namespace shogun
template <class T>
class SGMatrix;

template <class T, class S>
class ArrayReference
{
public:
ArrayReference(T** ptr, S* length) : m_ptr(ptr), m_length(length)
{
}
ArrayReference(const ArrayReference<T, S>& other)
: m_ptr(other.m_ptr), m_length(other.m_length)
{
}
ArrayReference<T, S> operator=(const ArrayReference<T, S>& other)
{
throw std::logic_error("Assignment not supported");
}
bool equals(const ArrayReference<T, S>& other) const;

private:
T** m_ptr;
S* m_length;
};

/** Used to denote an empty Any object */
struct Empty
{
Expand Down Expand Up @@ -239,6 +262,22 @@ namespace shogun
using any_detail::mutable_value_of;
using any_detail::compare;

template <class T, class S>
bool ArrayReference<T, S>::equals(const ArrayReference<T, S>& other) const
{
if (*(m_length) != *(other.m_length))
{
return false;
}
if (*(m_ptr) == *(other.m_ptr))
{
return true;
}
return std::equal(
*(m_ptr), *(m_ptr) + *(m_length), *(other.m_ptr),
[](T lhs, T rhs) -> bool { return any_detail::compare(lhs, rhs); });
}

/** @brief An interface for a policy to store a value.
* Value can be any data like primitive data-types, shogun objects, etc.
* Policy defines how to handle this data. It works with a
Expand Down Expand Up @@ -759,6 +798,12 @@ namespace shogun
return Any(non_owning_policy<T>(), v);
}

template <typename T, typename S>
inline Any make_any_ref_array(T** ptr, S* length)
{
return make_any(ArrayReference<T, S>(ptr, length));
}

/** Tries to recall Any type, fails when type is wrong.
* Any stores type information of an object internally in a BaseAnyPolicy.
* This function returns type-casted value if the internal type information
Expand Down
72 changes: 48 additions & 24 deletions tests/unit/lib/Any_unittest.cc
Expand Up @@ -30,9 +30,8 @@
*/
#include <gtest/gtest.h>

#include <numeric>
#include <shogun/base/SGObject.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/any.h>
#include <shogun/lib/config.h>
#include <stdexcept>
Expand Down Expand Up @@ -60,6 +59,23 @@ struct Simple
bool cloned = false;
};

struct SimpleValue
{
public:
SimpleValue clone() const
{
SimpleValue copy;
copy.cloned = true;
return copy;
}
bool equals(const SimpleValue& other) const
{
return true;
}

bool cloned = false;
};

TEST(Any, as)
{
int32_t integer = 10;
Expand Down Expand Up @@ -356,34 +372,42 @@ TEST(Any, clone_into_owning_via_copy)
EXPECT_EQ(a_any.as<int>(), other);
}

TEST(Any, clone_sgvector)
TEST(Any, clone_value)
{
auto a = SGVector<float64_t>(3);
a.range_fill();
SGVector<float64_t> b;
ASSERT_FALSE(a.equals(b));
auto a_any = make_any(SimpleValue());
auto b_any = make_any(SimpleValue());
auto cloned_b = b_any.clone_from(a_any).as<SimpleValue>();
}

auto a_any = make_any(a);
auto b_any = make_any(b);
TEST(Any, array_ref)
{
int src_len = 5;
float* src = new float[src_len];
std::iota(src, src + src_len, 9);
int dst_len = 8;
float* dst = new float[dst_len];
std::iota(dst, dst + dst_len, 5);
int other_len = src_len;
float* other = new float[other_len];
std::iota(other, other + other_len, 9);

auto cloned_b = b_any.clone_from(a_any).as<SGVector<float64_t>>();
auto any_src = make_any_ref_array(&src, &src_len);
auto any_dst = make_any_ref_array(&dst, &dst_len);
auto any_other = make_any_ref_array(&other, &other_len);

EXPECT_NE(a.vector, cloned_b.vector);
EXPECT_TRUE(a.equals(cloned_b));
}
EXPECT_EQ(any_src, any_src);
EXPECT_EQ(any_dst, any_dst);

TEST(Any, clone_sgmatrix)
{
auto a = SGMatrix<float64_t>(3, 4);
SGVector<float64_t>(a.matrix, a.num_rows * a.num_cols, false).range_fill();
SGMatrix<float64_t> b;
ASSERT_FALSE(a.equals(b));
EXPECT_NE(any_src, any_dst);
EXPECT_NE(any_dst, any_src);

auto a_any = make_any(a);
auto b_any = make_any(b);
EXPECT_EQ(any_src, any_other);
EXPECT_EQ(any_other, any_src);

auto cloned_b = b_any.clone_from(a_any).as<SGMatrix<float64_t>>();
EXPECT_NE(any_dst, any_other);
EXPECT_NE(any_other, any_dst);

EXPECT_NE(a.matrix, cloned_b.matrix);
EXPECT_TRUE(a.equals(cloned_b));
delete[] src;
delete[] dst;
delete[] other;
}

0 comments on commit 2b3a863

Please sign in to comment.