Skip to content

Commit

Permalink
cleanup and fixes in CARTree
Browse files Browse the repository at this point in the history
convert whatever is possible to const and use linalg wherever possible
optimize set_const for Eigen backend
CMath::argsort use lambda instead of compartor class
optimize CMath::pow(2,n) in case of integer
fix #4282: segfault in CARTree
  • Loading branch information
vigsterkr committed May 17, 2018
1 parent 271e673 commit bd69a3b
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 356 deletions.
12 changes: 3 additions & 9 deletions src/shogun/lib/SGVector.cpp
Expand Up @@ -871,15 +871,9 @@ float32_t SGVector<float32_t>::sum_abs(float32_t* vec, int32_t len)
template <class T>
int32_t SGVector<T>::unique(T* output, int32_t size)
{
CMath::qsort<T>(output, size);
int32_t j=0;

for (int32_t i=0; i<size; i++)
{
if (i==0 || output[i]!=output[i-1])
output[j++]=output[i];
}
return j;
std::sort(output, output+size);
auto last = std::unique(output, output+size);
return std::distance(output, last);
}

template <>
Expand Down
50 changes: 19 additions & 31 deletions src/shogun/mathematics/Math.h
Expand Up @@ -18,6 +18,7 @@
#include <shogun/mathematics/Random.h>
#include <shogun/lib/SGVector.h>
#include <algorithm>
#include <numeric>

#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
Expand Down Expand Up @@ -468,12 +469,17 @@ class CMath : public CSGObject
* @param x base (integer)
* @param n exponent (integer)
*/
static inline int32_t pow(int32_t x, int32_t n)
template<typename T, std::enable_if_t<std::numeric_limits<T>::is_integer, T>* = nullptr>
static inline T pow(T x, T n)
{
ASSERT(n>=0)
int32_t result=1;
// power of integer 2 is basically a bitshift...
if (x == 2)
return (1 << n);

T result = 1;
while (n--)
result*=x;
result *= x;

return result;
}
Expand Down Expand Up @@ -1229,42 +1235,25 @@ class CMath : public CSGObject
}
#endif

/** Helper functor for the function argsort */
template<class T>
struct IndexSorter
{
/** constructor */
IndexSorter(const SGVector<T> *vec) { data = vec->vector; }

/** access operator */
bool operator() (index_t i, index_t j) const
{
return abs(data[i]-data[j])>std::numeric_limits<T>::epsilon()
&& data[i]<data[j];
}

/** data */
const T* data;
};

#ifndef SWIG // SWIG should skip this part
/** Get sorted index.
*
* idx = v.argsort() is similar to Matlab [~, idx] = sort(v)
*
* @param vector vector to be sorted
* @param v vector to be sorted
* @return sorted index for this vector
*/
template<class T, class = typename std::enable_if<std::is_arithmetic<T>::value>::type>
static SGVector<index_t> argsort(SGVector<T> vector)
static SGVector<index_t> argsort(SGVector<T> v)
{
IndexSorter<T> cmp(&vector);
SGVector<index_t> idx(vector.size());
for (index_t i=0; i < vector.size(); ++i)
idx[i] = i;

std::sort(idx.vector, idx.vector+vector.size(), cmp);

SGVector<index_t> idx(v.vlen);
std::iota(idx.begin(), idx.end(), 0);
std::sort(idx.begin(), idx.end(),
[&v](index_t i1, index_t i2)
{
return std::abs(v[i1]-v[i2])>std::numeric_limits<T>::epsilon()
&& v[i1]<v[i2];
});
return idx;
}

Expand Down Expand Up @@ -1394,7 +1383,6 @@ class CMath : public CSGObject
template <class T1,class T2>
static void* parallel_qsort_index(void* p);


/** Finds the smallest element in output and puts that element as the
* first element
* @param output element array
Expand Down
8 changes: 6 additions & 2 deletions src/shogun/mathematics/linalg/LinalgBackendEigen.h
Expand Up @@ -615,8 +615,12 @@ namespace shogun
scale_impl(const SGMatrix<T>& a, T alpha, SGMatrix<T>& result) const;

/** Eigen3 set const method */
template <typename T, template <typename> class Container>
void set_const_impl(Container<T>& a, T value) const;
template <typename T>
void set_const_impl(SGVector<T>& a, T value) const;

/** Eigen3 set matrix to const */
template <typename T>
void set_const_impl(SGMatrix<T>& a, T value) const;

/** Eigen3 softmax method */
template <typename T, template <typename> class Container>
Expand Down
19 changes: 13 additions & 6 deletions src/shogun/mathematics/linalg/backend/eigen/Misc.cpp
Expand Up @@ -117,18 +117,25 @@ void LinalgBackendEigen::identity_impl(SGMatrix<T>& identity_matrix) const
I_eig.setIdentity();
}

template <typename T, template <typename> class Container>
void LinalgBackendEigen::range_fill_impl(Container<T>& a, const T start) const
template <typename T>
void LinalgBackendEigen::set_const_impl(SGVector<T>& a, T value) const
{
for (decltype(a.size()) i = 0; i < a.size(); ++i)
a[i] = start + T(i);
typename SGVector<T>::EigenVectorXtMap a_eig = a;
a_eig.setConstant(value);
}

template <typename T>
void LinalgBackendEigen::set_const_impl(SGMatrix<T>& a, T value) const
{
typename SGMatrix<T>::EigenMatrixXtMap a_eig = a;
a_eig.setConstant(value);
}

template <typename T, template <typename> class Container>
void LinalgBackendEigen::set_const_impl(Container<T>& a, T value) const
void LinalgBackendEigen::range_fill_impl(Container<T>& a, const T start) const
{
for (decltype(a.size()) i = 0; i < a.size(); ++i)
a[i] = value;
a[i] = start + T(i);
}

template <typename T>
Expand Down

0 comments on commit bd69a3b

Please sign in to comment.