diff --git a/src/shogun/lib/tapkee/neighbors/covertree_point.hpp b/src/shogun/lib/tapkee/neighbors/covertree_point.hpp index e42ab3e0916..bedb5350bb8 100644 --- a/src/shogun/lib/tapkee/neighbors/covertree_point.hpp +++ b/src/shogun/lib/tapkee/neighbors/covertree_point.hpp @@ -124,8 +124,8 @@ struct CoverTreePoint ScalarType norm_; }; /* struct JLCoverTreePoint */ -template -struct distance_impl_if_kernel; +template +struct distance_impl; /** Functions declared out of the class definition to respect CoverTree * structure */ @@ -138,11 +138,13 @@ inline ScalarType distance(Callback& cb, const CoverTreePoint()(cb,l,r,upper_bound); + return distance_impl()(cb,l,r,upper_bound); } +struct KernelType; + template -struct distance_impl_if_kernel +struct distance_impl { inline ScalarType operator()(Callback& cb, const CoverTreePoint& l, const CoverTreePoint& r, ScalarType /*upper_bound*/) @@ -151,8 +153,10 @@ struct distance_impl_if_kernel } }; +struct DistanceType; + template -struct distance_impl_if_kernel +struct distance_impl { inline ScalarType operator()(Callback& cb, const CoverTreePoint& l, const CoverTreePoint& r, ScalarType /*upper_bound*/) diff --git a/src/shogun/lib/tapkee/neighbors/neighbors.hpp b/src/shogun/lib/tapkee/neighbors/neighbors.hpp index 67d9164d430..b783d76264f 100644 --- a/src/shogun/lib/tapkee/neighbors/neighbors.hpp +++ b/src/shogun/lib/tapkee/neighbors/neighbors.hpp @@ -36,6 +36,10 @@ struct distances_comparator } }; +struct KernelType +{ +}; + template struct KernelDistance { @@ -48,11 +52,13 @@ struct KernelDistance { return sqrt(callback.kernel(*l,*l) - 2*callback.kernel(*l,*r) + callback.kernel(*r,*r)); } - static const bool is_kernel; + typedef KernelType type; Callback callback; }; -template -const bool KernelDistance::is_kernel = true; + +struct DistanceType +{ +}; template struct PlainDistance @@ -66,11 +72,9 @@ struct PlainDistance { return callback.distance(*l,*r); } - static const bool is_kernel; + typedef DistanceType type; Callback callback; }; -template -const bool PlainDistance::is_kernel = false; #ifdef TAPKEE_USE_LGPL_COVERTREE template diff --git a/src/shogun/lib/tapkee/neighbors/vptree.hpp b/src/shogun/lib/tapkee/neighbors/vptree.hpp index e501491f280..f39c3ca3108 100644 --- a/src/shogun/lib/tapkee/neighbors/vptree.hpp +++ b/src/shogun/lib/tapkee/neighbors/vptree.hpp @@ -1,12 +1,12 @@ /* This software is distributed under BSD 3-clause license (see LICENSE file). * - * Copyright (c) 2012-2013 Sergey Lisitsyn + * Copyright (c) 2012-2013 Laurens van der Maaten, Sergey Lisitsyn */ #ifndef TAPKEE_VPTREE_H_ #define TAPKEE_VPTREE_H_ -/* Tapkee include */ +/* Tapkee includes */ #include /* End of Tapkee includes */ @@ -20,8 +20,8 @@ namespace tapkee namespace tapkee_internal { -template -struct compare_if_kernel; +template +struct compare_impl; template struct DistanceComparator @@ -32,13 +32,15 @@ struct DistanceComparator callback(c), item(i) {} inline bool operator()(const RandomAccessIterator& a, const RandomAccessIterator& b) { - return compare_if_kernel() + return compare_impl() (callback,item,a,b); } }; +struct KernelType; + template -struct compare_if_kernel +struct compare_impl { inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item, const RandomAccessIterator& a, const RandomAccessIterator& b) @@ -47,8 +49,10 @@ struct compare_if_kernel } }; +struct DistanceType; + template -struct compare_if_kernel +struct compare_impl { inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item, const RandomAccessIterator& a, const RandomAccessIterator& b) diff --git a/src/shogun/lib/tapkee/parameters/parameters.hpp b/src/shogun/lib/tapkee/parameters/parameters.hpp index 7566f5633cf..b1bbd3a7755 100644 --- a/src/shogun/lib/tapkee/parameters/parameters.hpp +++ b/src/shogun/lib/tapkee/parameters/parameters.hpp @@ -14,10 +14,6 @@ #include #include -using std::vector; -using std::string; -using std::stringstream; - namespace tapkee { @@ -37,7 +33,7 @@ struct Message return ss.str(); } - stringstream ss; + std::stringstream ss; }; class ParametersSet; @@ -53,6 +49,7 @@ class Parameter template Parameter(const ParameterName& pname, const T& value) : + valid(true), invalidity_reason(), parameter_name(pname), keeper(tapkee_internal::ValueKeeper(value)) { } @@ -65,11 +62,15 @@ class Parameter return Parameter(name, value); } - Parameter() : parameter_name("unknown"), keeper(tapkee_internal::ValueKeeper()) + Parameter() : + valid(false), invalidity_reason(), + parameter_name("unknown"), keeper(tapkee_internal::ValueKeeper()) { } - Parameter(const Parameter& p) : parameter_name(p.name()), keeper(p.keeper) + Parameter(const Parameter& p) : + valid(p.valid), invalidity_reason(p.invalidity_reason), + parameter_name(p.name()), keeper(p.keeper) { } @@ -90,6 +91,10 @@ class Parameter template inline operator T() { + if (!valid) + { + throw wrong_parameter_error(invalidity_reason); + } try { return getValue(); @@ -192,8 +197,17 @@ class Parameter return keeper.isTypeCorrect(); } + inline void invalidate(const std::string& reason) + { + valid = false; + invalidity_reason = reason; + } + private: + bool valid; + std::string invalidity_reason; + ParameterName parameter_name; tapkee_internal::ValueKeeper keeper; @@ -205,16 +219,10 @@ class CheckedParameter public: - explicit CheckedParameter(const Parameter& p) : parameter(p) + explicit CheckedParameter(Parameter& p) : parameter(p) { } - template - inline operator T() const - { - return parameter.getValue(); - } - inline operator const Parameter&() { return parameter; @@ -237,11 +245,11 @@ class CheckedParameter { if (!parameter.isInRange(lower, upper)) { - std::string error_message = - (Message() << "Value " << parameter.name() << " " + std::string reason = + (Message() << "Value of " << parameter.name() << " " << parameter.getValue() << " doesn't fit the range [" << lower << ", " << upper << ")"); - throw tapkee::wrong_parameter_error(error_message); + parameter.invalidate(reason); } return *this; } @@ -250,9 +258,9 @@ class CheckedParameter { if (!parameter.isPositive()) { - std::string error_message = + std::string reason = (Message() << "Value of " << parameter.name() << " is not positive"); - throw tapkee::wrong_parameter_error(error_message); + parameter.invalidate(reason); } return *this; } @@ -261,9 +269,9 @@ class CheckedParameter { if (!parameter.isNonNegative()) { - std::string error_message = + std::string reason = (Message() << "Value of " << parameter.name() << " is negative"); - throw tapkee::wrong_parameter_error(error_message); + parameter.invalidate(reason); } return *this; } @@ -271,7 +279,7 @@ class CheckedParameter private: - Parameter parameter; + Parameter& parameter; };