Skip to content

Commit

Permalink
Add distance interface factory. Port Euclidean meta example to new API.
Browse files Browse the repository at this point in the history
  • Loading branch information
iglesias committed Apr 14, 2018
1 parent 8ca2ee4 commit 3130714
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 15 deletions.
18 changes: 9 additions & 9 deletions examples/meta/src/distance/euclidean.sg
@@ -1,24 +1,24 @@
CSVFile f_feats_a("../../data/fm_train_real.dat")
CSVFile f_feats_b("../../data/fm_test_real.dat")
File f_features_a = csv_file("../../data/fm_train_real.dat")
File f_features_b = csv_file("../../data/fm_test_real.dat")

#![create_features]
RealFeatures features_a(f_feats_a)
RealFeatures features_b(f_feats_b)
Features features_a = features(f_features_a)
Features features_b = features(f_features_b)
#![create_features]

#![create_instance]
EuclideanDistance distance(features_a, features_a)
Distance d = distance("EuclideanDistance", lhs=features_a, rhs=features_a)
#![create_instance]

#![extract_distance]
RealMatrix distance_matrix_aa = distance.get_distance_matrix()
RealMatrix distance_matrix_aa = d.get_distance_matrix()
#![extract_distance]

#![refresh_distance]
distance.init(features_a, features_b)
d.init(features_a, features_b)
#![refresh_distance]

#![extract_sq_distance]
distance.set_disable_sqrt(True)
RealMatrix distance_matrix_ab = distance.get_distance_matrix()
d.put("disable_sqrt", True)
RealMatrix sq_distance_matrix_ab = d.get_distance_matrix()
#![extract_sq_distance]
1 change: 1 addition & 0 deletions src/interfaces/swig/shogun.i
Expand Up @@ -178,6 +178,7 @@ namespace shogun
%template(put) CSGObject::put_scalar_dispatcher<int64_t, int64_t>;
#endif // SWIGJAVA
%template(put) CSGObject::put_scalar_dispatcher<float64_t, float64_t>;
%template(put) CSGObject::put_scalar_dispatcher<bool, bool>;


#ifndef SWIGJAVA
Expand Down
6 changes: 2 additions & 4 deletions src/shogun/distance/Distance.cpp
Expand Up @@ -259,10 +259,8 @@ void CDistance::init()
num_lhs=0;
num_rhs=0;

m_parameters->add((CSGObject**) &lhs, "lhs",
"Feature vectors to occur on left hand side.");
m_parameters->add((CSGObject**) &rhs, "rhs",
"Feature vectors to occur on right hand side.");
SG_ADD(&lhs, "lhs", "Left hand side features.", MS_NOT_AVAILABLE);
SG_ADD(&rhs, "rhs", "Right hand side features.", MS_NOT_AVAILABLE);
}

template <class T>
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/distance/EuclideanDistance.cpp
@@ -1,7 +1,7 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Saurabh Mahindre, Soumyajit De, Chiyuan Zhang, Viktor Gal,
* Authors: Saurabh Mahindre, Soumyajit De, Chiyuan Zhang, Viktor Gal,
* Björn Esser, Soeren Sonnenburg
*/

Expand Down
6 changes: 5 additions & 1 deletion src/shogun/util/factory.h
@@ -1,21 +1,24 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Heiko Strathmann
* Authors: Heiko Strathmann, Fernando Iglesias
*/
#ifndef FACTORY_H_
#define FACTORY_H_

#include <shogun/base/class_list.h>
#include <shogun/distance/Distance.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/io/CSVFile.h>
#include <shogun/io/SGIO.h>
#include <shogun/kernel/Kernel.h>
#include <shogun/labels/DenseLabels.h>
#include <shogun/machine/Machine.h>

namespace shogun
{

CDistance* distance(const std::string& name);
CKernel* kernel(const std::string& name);
CMachine* machine(const std::string& name);

Expand All @@ -25,6 +28,7 @@ namespace shogun
return create_object<T>(name.c_str()); \
}

BASE_CLASS_FACTORY(CDistance, distance)
BASE_CLASS_FACTORY(CKernel, kernel)
BASE_CLASS_FACTORY(CMachine, machine)

Expand Down

0 comments on commit 3130714

Please sign in to comment.