Skip to content

Commit

Permalink
Update for tapkee library
Browse files Browse the repository at this point in the history
- Fix for compilation issue with clang -
  struct template specialization instead of
  boolean value specialization
- Lazy parameter checks to avoid runtime errors with
  parameters unrelated to the used method
- Proper license for vptree (missed Laurens whose code it is based on)
  • Loading branch information
lisitsyn committed May 9, 2013
1 parent 024eca5 commit 5fa05de
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 40 deletions.
14 changes: 9 additions & 5 deletions src/shogun/lib/tapkee/neighbors/covertree_point.hpp
Expand Up @@ -124,8 +124,8 @@ struct CoverTreePoint
ScalarType norm_;
}; /* struct JLCoverTreePoint */

template <bool, class RandomAccessIterator, class Callback>
struct distance_impl_if_kernel;
template <class Type, class RandomAccessIterator, class Callback>
struct distance_impl;

/** Functions declared out of the class definition to respect CoverTree
* structure */
Expand All @@ -138,11 +138,13 @@ inline ScalarType distance(Callback& cb, const CoverTreePoint<RandomAccessIterat
if (l.iter_==r.iter_)
return 0.0;

return distance_impl_if_kernel<Callback::is_kernel,RandomAccessIterator,Callback>()(cb,l,r,upper_bound);
return distance_impl<typename Callback::type,RandomAccessIterator,Callback>()(cb,l,r,upper_bound);
}

struct KernelType;

template <class RandomAccessIterator, class Callback>
struct distance_impl_if_kernel<true,RandomAccessIterator,Callback>
struct distance_impl<KernelType,RandomAccessIterator,Callback>
{
inline ScalarType operator()(Callback& cb, const CoverTreePoint<RandomAccessIterator>& l,
const CoverTreePoint<RandomAccessIterator>& r, ScalarType /*upper_bound*/)
Expand All @@ -151,8 +153,10 @@ struct distance_impl_if_kernel<true,RandomAccessIterator,Callback>
}
};

struct DistanceType;

template <class RandomAccessIterator, class Callback>
struct distance_impl_if_kernel<false,RandomAccessIterator,Callback>
struct distance_impl<DistanceType,RandomAccessIterator,Callback>
{
inline ScalarType operator()(Callback& cb, const CoverTreePoint<RandomAccessIterator>& l,
const CoverTreePoint<RandomAccessIterator>& r, ScalarType /*upper_bound*/)
Expand Down
16 changes: 10 additions & 6 deletions src/shogun/lib/tapkee/neighbors/neighbors.hpp
Expand Up @@ -36,6 +36,10 @@ struct distances_comparator
}
};

struct KernelType
{
};

template <class RandomAccessIterator, class Callback>
struct KernelDistance
{
Expand All @@ -48,11 +52,13 @@ struct KernelDistance
{
return sqrt(callback.kernel(*l,*l) - 2*callback.kernel(*l,*r) + callback.kernel(*r,*r));
}
static const bool is_kernel;
typedef KernelType type;
Callback callback;
};
template <class RandomAccessIterator, class Callback>
const bool KernelDistance<RandomAccessIterator, Callback>::is_kernel = true;

struct DistanceType
{
};

template <class RandomAccessIterator, class Callback>
struct PlainDistance
Expand All @@ -66,11 +72,9 @@ struct PlainDistance
{
return callback.distance(*l,*r);
}
static const bool is_kernel;
typedef DistanceType type;
Callback callback;
};
template <class RandomAccessIterator, class Callback>
const bool PlainDistance<RandomAccessIterator, Callback>::is_kernel = false;

#ifdef TAPKEE_USE_LGPL_COVERTREE
template <class RandomAccessIterator, class Callback>
Expand Down
18 changes: 11 additions & 7 deletions src/shogun/lib/tapkee/neighbors/vptree.hpp
@@ -1,12 +1,12 @@
/* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Copyright (c) 2012-2013 Sergey Lisitsyn
* Copyright (c) 2012-2013 Laurens van der Maaten, Sergey Lisitsyn
*/

#ifndef TAPKEE_VPTREE_H_
#define TAPKEE_VPTREE_H_

/* Tapkee include */
/* Tapkee includes */
#include <shogun/lib/tapkee/tapkee_defines.hpp>
/* End of Tapkee includes */

Expand All @@ -20,8 +20,8 @@ namespace tapkee
namespace tapkee_internal
{

template<bool, class RandomAccessIterator, class DistanceCallback>
struct compare_if_kernel;
template<class Type, class RandomAccessIterator, class DistanceCallback>
struct compare_impl;

template<class RandomAccessIterator, class DistanceCallback>
struct DistanceComparator
Expand All @@ -32,13 +32,15 @@ struct DistanceComparator
callback(c), item(i) {}
inline bool operator()(const RandomAccessIterator& a, const RandomAccessIterator& b)
{
return compare_if_kernel<DistanceCallback::is_kernel,RandomAccessIterator,DistanceCallback>()
return compare_impl<typename DistanceCallback::type,RandomAccessIterator,DistanceCallback>()
(callback,item,a,b);
}
};

struct KernelType;

template<class RandomAccessIterator, class DistanceCallback>
struct compare_if_kernel<true,RandomAccessIterator,DistanceCallback>
struct compare_impl<KernelType,RandomAccessIterator,DistanceCallback>
{
inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item,
const RandomAccessIterator& a, const RandomAccessIterator& b)
Expand All @@ -47,8 +49,10 @@ struct compare_if_kernel<true,RandomAccessIterator,DistanceCallback>
}
};

struct DistanceType;

template<class RandomAccessIterator, class DistanceCallback>
struct compare_if_kernel<false,RandomAccessIterator,DistanceCallback>
struct compare_impl<DistanceType,RandomAccessIterator,DistanceCallback>
{
inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item,
const RandomAccessIterator& a, const RandomAccessIterator& b)
Expand Down
52 changes: 30 additions & 22 deletions src/shogun/lib/tapkee/parameters/parameters.hpp
Expand Up @@ -14,10 +14,6 @@
#include <vector>
#include <map>

using std::vector;
using std::string;
using std::stringstream;

namespace tapkee
{

Expand All @@ -37,7 +33,7 @@ struct Message
return ss.str();
}

stringstream ss;
std::stringstream ss;
};

class ParametersSet;
Expand All @@ -53,6 +49,7 @@ class Parameter

template <typename T>
Parameter(const ParameterName& pname, const T& value) :
valid(true), invalidity_reason(),
parameter_name(pname), keeper(tapkee_internal::ValueKeeper(value))
{
}
Expand All @@ -65,11 +62,15 @@ class Parameter
return Parameter(name, value);
}

Parameter() : parameter_name("unknown"), keeper(tapkee_internal::ValueKeeper())
Parameter() :
valid(false), invalidity_reason(),
parameter_name("unknown"), keeper(tapkee_internal::ValueKeeper())
{
}

Parameter(const Parameter& p) : parameter_name(p.name()), keeper(p.keeper)
Parameter(const Parameter& p) :
valid(p.valid), invalidity_reason(p.invalidity_reason),
parameter_name(p.name()), keeper(p.keeper)
{
}

Expand All @@ -90,6 +91,10 @@ class Parameter
template <typename T>
inline operator T()
{
if (!valid)
{
throw wrong_parameter_error(invalidity_reason);
}
try
{
return getValue<T>();
Expand Down Expand Up @@ -192,8 +197,17 @@ class Parameter
return keeper.isTypeCorrect<T>();
}

inline void invalidate(const std::string& reason)
{
valid = false;
invalidity_reason = reason;
}

private:

bool valid;
std::string invalidity_reason;

ParameterName parameter_name;

tapkee_internal::ValueKeeper keeper;
Expand All @@ -205,16 +219,10 @@ class CheckedParameter

public:

explicit CheckedParameter(const Parameter& p) : parameter(p)
explicit CheckedParameter(Parameter& p) : parameter(p)
{
}

template <typename T>
inline operator T() const
{
return parameter.getValue<T>();
}

inline operator const Parameter&()
{
return parameter;
Expand All @@ -237,11 +245,11 @@ class CheckedParameter
{
if (!parameter.isInRange(lower, upper))
{
std::string error_message =
(Message() << "Value " << parameter.name() << " "
std::string reason =
(Message() << "Value of " << parameter.name() << " "
<< parameter.getValue<T>() << " doesn't fit the range [" <<
lower << ", " << upper << ")");
throw tapkee::wrong_parameter_error(error_message);
parameter.invalidate(reason);
}
return *this;
}
Expand All @@ -250,9 +258,9 @@ class CheckedParameter
{
if (!parameter.isPositive())
{
std::string error_message =
std::string reason =
(Message() << "Value of " << parameter.name() << " is not positive");
throw tapkee::wrong_parameter_error(error_message);
parameter.invalidate(reason);
}
return *this;
}
Expand All @@ -261,17 +269,17 @@ class CheckedParameter
{
if (!parameter.isNonNegative())
{
std::string error_message =
std::string reason =
(Message() << "Value of " << parameter.name() << " is negative");
throw tapkee::wrong_parameter_error(error_message);
parameter.invalidate(reason);
}
return *this;
}


private:

Parameter parameter;
Parameter& parameter;

};

Expand Down

0 comments on commit 5fa05de

Please sign in to comment.