Skip to content

Commit

Permalink
first refactor draft [SKIP CI]
Browse files Browse the repository at this point in the history
  • Loading branch information
gf712 committed Apr 2, 2019
1 parent 7f1f6b8 commit 2d49dfb
Show file tree
Hide file tree
Showing 9 changed files with 962 additions and 155 deletions.
4 changes: 4 additions & 0 deletions src/interfaces/swig/ModelSelection.i
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
%newobject CParameterCombination::leaf_sets_multiplication();
%newobject CModelSelectionParameters::get_combinations();
%newobject CModelSelectionParameters::get_single_combination();
%newobject ParameterNode::attach();
%newobject ParameterNode::next_combination();

/* what about parameter_set_multiplication returns new DynArray<Parameter*>? */

Expand All @@ -22,12 +24,14 @@
%rename(ModelSelectionBase) CModelSelection;
%rename(ModelSelectionParameters) CModelSelectionParameters;
%rename(ParameterCombination) CParameterCombination;
%rename(ParameterNode) CParameterNode;

%include <shogun/modelselection/ModelSelection.h>
%include <shogun/modelselection/GridSearchModelSelection.h>
%include <shogun/modelselection/RandomSearchModelSelection.h>
%include <shogun/modelselection/ParameterCombination.h>
%include <shogun/modelselection/ModelSelectionParameters.h>
%include <shogun/modelselection/NewGridSearch.h>
#ifdef USE_GPL_SHOGUN
%include <shogun/modelselection/GradientModelSelection.h>
#endif //USE_GPL_SHOGUN
1 change: 1 addition & 0 deletions src/interfaces/swig/ModelSelection_includes.i
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
#include <shogun/modelselection/GradientModelSelection.h>
#endif //USE_GPL_SHOGUN
#include <shogun/modelselection/ParameterCombination.h>
#include <shogun/modelselection/NewGridSearch.h>
%}
12 changes: 12 additions & 0 deletions src/interfaces/swig/shogun.i
Original file line number Diff line number Diff line change
Expand Up @@ -300,5 +300,17 @@ PUT_ADD(CTokenizer)
%template(kernel) kernel<float64_t, float64_t>;
%template(features) features<float64_t>;

%template(attach) ParameterNode::attach<int32_t>;
#ifndef SWIGJAVA
%template(attach) ParameterNode::attach<int64_t>;
#endif // SWIGJAVA
%template(attach) ParameterNode::attach<float64_t>;
%template(attach) ParameterNode::attach<bool>;
%template(attach) GridParameters::attach<bool>;
%template(attach) GridParameters::attach<int32_t>;
%template(attach) GridParameters::attach<float64_t>;
#ifndef SWIGJAVA
%template(attach) GridParameters::attach<int64_t>;
#endif // SWIGJAVA

} // namespace shogun
127 changes: 9 additions & 118 deletions src/shogun/base/SGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,124 +822,6 @@ bool CSGObject::has(const std::string& name) const
return has_parameter(BaseTag(name));
}

class ToStringVisitor : public AnyVisitor
{
public:
ToStringVisitor(std::stringstream* ss) : AnyVisitor(), m_stream(ss)
{
}

virtual void on(bool* v)
{
stream() << (*v ? "true" : "false");
}
virtual void on(int32_t* v)
{
stream() << *v;
}
virtual void on(int64_t* v)
{
stream() << *v;
}
virtual void on(float* v)
{
stream() << *v;
}
virtual void on(double* v)
{
stream() << *v;
}
virtual void on(long double* v)
{
stream() << *v;
}
virtual void on(CSGObject** v)
{
if (*v)
{
stream() << (*v)->get_name() << "(...)";
}
else
{
stream() << "null";
}
}
virtual void on(SGVector<int>* v)
{
to_string(v);
}
virtual void on(SGVector<float>* v)
{
to_string(v);
}
virtual void on(SGVector<double>* v)
{
to_string(v);
}
virtual void on(SGMatrix<int>* mat)
{
to_string(mat);
}
virtual void on(SGMatrix<float>* mat)
{
to_string(mat);
}
virtual void on(SGMatrix<double>* mat)
{
to_string(mat);
}

private:
std::stringstream& stream()
{
return *m_stream;
}

template <class T>
void to_string(SGMatrix<T>* m)
{
if (m)
{
stream() << "Matrix<" << demangled_type<T>() << ">(" << m->num_rows
<< "," << m->num_cols << "): [";
for (auto col : range(m->num_cols))
{
stream() << "[";
for (auto row : range(m->num_rows))
{
stream() << (*m)(row, col);
if (row < m->num_rows - 1)
stream() << ",";
}
stream() << "]";
if (col < m->num_cols)
stream() << ",";
}
stream() << "]";
}
}

template <class T>
void to_string(SGVector<T>* v)
{
if (v)
{
stream() << "Vector<" << demangled_type<T>() << ">(" << v->vlen
<< "): [";
for (auto i : range(v->vlen))
{
stream() << (*v)[i];
if (i < v->vlen - 1)
stream() << ",";
}
stream() << "]";
}
}

private:
std::stringstream* m_stream;
};

std::string CSGObject::to_string() const
{
std::stringstream ss;
Expand Down Expand Up @@ -981,6 +863,15 @@ std::map<std::string, std::shared_ptr<const AnyParameter>> CSGObject::get_params
}
return result;
}

std::map<std::string, std::shared_ptr<const AnyParameter>> CSGObject::get_params(ParameterProperties props) const
{
std::map<std::string, std::shared_ptr<const AnyParameter>> result;
for (auto const& each: self->filter(props)) {
result.emplace(each.first.name(), std::make_shared<const AnyParameter>(each.second));
}
return result;
}
#endif

bool CSGObject::equals(const CSGObject* other) const
Expand Down
125 changes: 125 additions & 0 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include <shogun/lib/config.h>
#include <shogun/lib/exception/ShogunException.h>
#include <shogun/lib/tag.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/base/range.h>


#include <map>
#include <unordered_map>
Expand Down Expand Up @@ -613,6 +617,9 @@ class CSGObject
*/
#ifndef SWIG // SWIG should skip this part
std::map<std::string, std::shared_ptr<const AnyParameter>> get_params() const;

std::map<std::string, std::shared_ptr<const AnyParameter>> get_params(ParameterProperties) const;

#endif
/** Specializes a provided object to the specified type.
* Throws exception if the object cannot be specialized.
Expand Down Expand Up @@ -1022,5 +1029,123 @@ class CSGObject
/** Subscriber used to call onNext, onComplete etc.*/
SGSubscriber* m_subscriber_params;
};

class ToStringVisitor : public AnyVisitor
{
public:
ToStringVisitor(std::stringstream* ss) : AnyVisitor(), m_stream(ss)
{
}

virtual void on(bool* v)
{
stream() << (*v ? "true" : "false");
}
virtual void on(int32_t* v)
{
stream() << *v;
}
virtual void on(int64_t* v)
{
stream() << *v;
}
virtual void on(float* v)
{
stream() << *v;
}
virtual void on(double* v)
{
stream() << *v;
}
virtual void on(long double* v)
{
stream() << *v;
}
virtual void on(CSGObject** v)
{
if (*v)
{
stream() << (*v)->get_name() << "(...)";
}
else
{
stream() << "null";
}
}
virtual void on(SGVector<int>* v)
{
to_string(v);
}
virtual void on(SGVector<float>* v)
{
to_string(v);
}
virtual void on(SGVector<double>* v)
{
to_string(v);
}
virtual void on(SGMatrix<int>* mat)
{
to_string(mat);
}
virtual void on(SGMatrix<float>* mat)
{
to_string(mat);
}
virtual void on(SGMatrix<double>* mat)
{
to_string(mat);
}

private:
std::stringstream& stream()
{
return *m_stream;
}

template <class T>
void to_string(SGMatrix<T>* m)
{
if (m)
{
stream() << "Matrix<" << demangled_type<T>() << ">(" << m->num_rows
<< "," << m->num_cols << "): [";
for (auto col : range(m->num_cols))
{
stream() << "[";
for (auto row : range(m->num_rows))
{
stream() << (*m)(row, col);
if (row < m->num_rows - 1)
stream() << ",";
}
stream() << "]";
if (col < m->num_cols)
stream() << ",";
}
stream() << "]";
}
}

template <class T>
void to_string(SGVector<T>* v)
{
if (v)
{
stream() << "Vector<" << demangled_type<T>() << ">(" << v->vlen
<< "): [";
for (auto i : range(v->vlen))
{
stream() << (*v)[i];
if (i < v->vlen - 1)
stream() << ",";
}
stream() << "]";
}
}

private:
std::stringstream* m_stream;
};
}
#endif // __SGOBJECT_H__

0 comments on commit 2d49dfb

Please sign in to comment.