Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement array reference with any #4096

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/shogun/lib/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <string.h>
#include <string>
#include <typeinfo>
#include <algorithm>
#ifdef HAVE_CXA_DEMANGLE
#include <cxxabi.h>
#endif
Expand Down Expand Up @@ -80,6 +81,37 @@ namespace shogun
template <class T>
class SGMatrix;

template <class T, class S>
class ArrayReference
{
public:
ArrayReference(T** ptr, S* length) : m_ptr(ptr), m_length(length)
{
}
ArrayReference(const ArrayReference<T, S>& other) : m_ptr(other.m_ptr), m_length(other.m_length)
{
}
ArrayReference<T, S> operator=(const ArrayReference<T, S>& other)
{
throw std::logic_error("Assignment not supported");
}
bool equals(const ArrayReference<T, S>& other) const
{
if (*(m_length) != *(other.m_length))
{
return false;
}
if (*(m_ptr) == *(other.m_ptr))
{
return true;
}
return std::equal(*(m_ptr), *(m_ptr) + *(m_length), *(other.m_ptr));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to be specialized for floats I guess?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact need to check whether the elements support ".equals" or "->equals"

}
private:
T** m_ptr;
S* m_length;
};

/** Used to denote an empty Any object */
struct Empty
{
Expand Down Expand Up @@ -759,6 +791,12 @@ namespace shogun
return Any(non_owning_policy<T>(), v);
}

template <typename T, typename S>
inline Any make_any_ref_array(T** ptr, S* length)
{
return make_any(ArrayReference<T, S>(ptr, length));
}

/** Tries to recall Any type, fails when type is wrong.
* Any stores type information of an object internally in a BaseAnyPolicy.
* This function returns type-casted value if the internal type information
Expand Down
72 changes: 48 additions & 24 deletions tests/unit/lib/Any_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@
#include <gtest/gtest.h>

#include <shogun/base/SGObject.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/any.h>
#include <shogun/lib/config.h>
#include <stdexcept>
#include <numeric>

using namespace shogun;

Expand All @@ -60,6 +59,23 @@ struct Simple
bool cloned = false;
};

struct SimpleValue
{
public:
SimpleValue clone() const
{
SimpleValue copy;
copy.cloned = true;
return copy;
}
bool equals(const SimpleValue& other) const
{
return true;
}

bool cloned = false;
};

TEST(Any, as)
{
int32_t integer = 10;
Expand Down Expand Up @@ -356,34 +372,42 @@ TEST(Any, clone_into_owning_via_copy)
EXPECT_EQ(a_any.as<int>(), other);
}

TEST(Any, clone_sgvector)
TEST(Any, clone_value)
{
auto a = SGVector<float64_t>(3);
a.range_fill();
SGVector<float64_t> b;
ASSERT_FALSE(a.equals(b));
auto a_any = make_any(SimpleValue());
auto b_any = make_any(SimpleValue());
auto cloned_b = b_any.clone_from(a_any).as<SimpleValue>();
}

auto a_any = make_any(a);
auto b_any = make_any(b);
TEST(Any, array_ref)
{
int src_len = 5;
float* src = new float[src_len];
std::iota(src, src+src_len, 9);
int dst_len = 8;
float* dst = new float[dst_len];
std::iota(dst, dst+dst_len, 5);
int other_len = src_len;
float* other = new float[other_len];
std::iota(other, other+other_len, 9);

auto cloned_b = b_any.clone_from(a_any).as<SGVector<float64_t>>();
auto any_src = make_any_ref_array(&src, &src_len);
auto any_dst = make_any_ref_array(&dst, &dst_len);
auto any_other = make_any_ref_array(&other, &other_len);

EXPECT_NE(a.vector, cloned_b.vector);
EXPECT_TRUE(a.equals(cloned_b));
}
EXPECT_EQ(any_src, any_src);
EXPECT_EQ(any_dst, any_dst);

TEST(Any, clone_sgmatrix)
{
auto a = SGMatrix<float64_t>(3, 4);
SGVector<float64_t>(a.matrix, a.num_rows * a.num_cols, false).range_fill();
SGMatrix<float64_t> b;
ASSERT_FALSE(a.equals(b));
EXPECT_NE(any_src, any_dst);
EXPECT_NE(any_dst, any_src);

auto a_any = make_any(a);
auto b_any = make_any(b);
EXPECT_EQ(any_src, any_other);
EXPECT_EQ(any_other, any_src);

auto cloned_b = b_any.clone_from(a_any).as<SGMatrix<float64_t>>();
EXPECT_NE(any_dst, any_other);
EXPECT_NE(any_other, any_dst);

EXPECT_NE(a.matrix, cloned_b.matrix);
EXPECT_TRUE(a.equals(cloned_b));
delete[] src;
delete[] dst;
delete[] other;
}