Skip to content

Commit

Permalink
Add correct type traits for SGVector<T> and SGMatrix<T>.
Browse files Browse the repository at this point in the history
  • Loading branch information
geektoni committed May 3, 2019
1 parent ddd4222 commit e540d7e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
21 changes: 10 additions & 11 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include <shogun/lib/type_case.h>

/** \namespace shogun
* @brief all of classes and functions are contained in the shogun namespace
Expand Down Expand Up @@ -1007,14 +1006,14 @@ class CSGObject
*/
template<class T,
typename std::enable_if_t<
!type_internal::is_sg_vector<T>::value &&
!type_internal::is_sg_matrix<T>::value>* = nullptr>
!is_sg_vector<T>::value &&
!is_sg_matrix<T>::value>* = nullptr>
void observe(const int64_t step, const std::string& name) const;

template<class T,
typename std::enable_if_t<
type_internal::is_sg_vector<T>::value ||
type_internal::is_sg_matrix<T>::value>* = nullptr>
is_sg_vector<T>::value ||
is_sg_matrix<T>::value>* = nullptr>
void observe(const int64_t step, const std::string& name) const;

/**
Expand Down Expand Up @@ -1316,10 +1315,10 @@ void CSGObject::observe(
this->observe(obs);
}

template<class T,
typename std::enable_if_t<
!type_internal::is_sg_vector<T>::value &&
!type_internal::is_sg_matrix<T>::value>* = nullptr>
template<class T,
typename std::enable_if_t<
!is_sg_vector<T>::value &&
!is_sg_matrix<T>::value>* = nullptr>
void CSGObject::observe(const int64_t step, const std::string& name) const
{
auto param = this->get_parameter(BaseTag(name));
Expand All @@ -1330,8 +1329,8 @@ void CSGObject::observe(const int64_t step, const std::string& name) const

template<class T,
typename std::enable_if_t<
type_internal::is_sg_vector<T>::value ||
type_internal::is_sg_matrix<T>::value>* = nullptr>
is_sg_vector<T>::value ||
is_sg_matrix<T>::value>* = nullptr>
void CSGObject::observe(const int64_t step, const std::string& name) const
{
auto param = this->get_parameter(BaseTag(name));
Expand Down
25 changes: 25 additions & 0 deletions src/shogun/base/base_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ namespace shogun
const char*, typename std::decay<T>::type>::value>
{
};

// General type traits to recognize SGMatrix and SGVectors.
template <typename T> class SGMatrix;
template <typename T> class SGVector;

template <typename>
struct is_sg_vector : std::false_type
{
};

template <typename T>
struct is_sg_vector<SGVector<T>> : std::true_type
{
};

template <typename>
struct is_sg_matrix : std::false_type
{
};

template <typename T>
struct is_sg_matrix<SGMatrix<T>> : std::true_type
{
};

} // namespace shogun

#endif // BASE_TYPES__H

0 comments on commit e540d7e

Please sign in to comment.