Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix observers for the interfaces. #4615

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -55,7 +55,7 @@ SGMatrix<float64_t> calculate_weights(
for (auto o : range(obs.get_num_observations()))
{
auto obs_storage = obs.get_observation(o);
for (auto i : range(obs_storage->get<index_t("num_folds")))
for (auto i : range(obs_storage->get<index_t>("num_folds")))
{
auto fold = obs_storage->get("folds", i);
CMKLClassification* machine =
Expand Down
29 changes: 16 additions & 13 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -15,14 +15,14 @@
#include <shogun/base/Parameter.h>
#include <shogun/base/SGObject.h>
#include <shogun/base/Version.h>
#include <shogun/base/class_list.h>
#include <shogun/io/SerializableFile.h>
#include <shogun/lib/DynamicObjectArray.h>
#include <shogun/lib/Map.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/SGStringList.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/observers/ParameterObserver.h>
#include <shogun/base/class_list.h>

#include <stdio.h>
#include <stdlib.h>
Expand Down Expand Up @@ -776,33 +776,36 @@ void CSGObject::subscribe(ParameterObserver* obs)

// Create an observable which emits values only if they are about
// parameters selected by the observable.
rxcpp::subscription subscription = m_observable_params
->filter([obs](Some<ObservedValue> v) {
return obs->filter(v->get<std::string>("name"));
})
.timestamp()
.subscribe(sub);
rxcpp::subscription subscription =
m_observable_params
->filter([obs](Some<ObservedValue> v) {
return obs->filter(v->get<std::string>("name"));
})
.timestamp()
.subscribe(sub);

// Insert the subscription in the list
m_subscriptions.insert(
std::make_pair<int64_t, rxcpp::subscription>(
std::move(m_next_subscription_index),
std::move(subscription)));
std::make_pair<int64_t, rxcpp::subscription>(
std::move(m_next_subscription_index), std::move(subscription)));

obs->put("subscription_id", m_next_subscription_index);

m_next_subscription_index++;
}

void CSGObject::unsubscribe(ParameterObserver *obs) {
void CSGObject::unsubscribe(ParameterObserver* obs)
{

int64_t index = obs->get<int64_t>("subscription_id");

// Check if we have such subscription
auto it = m_subscriptions.find(index);
if (it == m_subscriptions.end())
SG_ERROR("The object %s does not have any registered parameter observer with index %i",
this->get_name(), index);
SG_ERROR(
"The object %s does not have any registered parameter observer "
"with index %i",
this->get_name(), index);

it->second.unsubscribe();
m_subscriptions.erase(index);
Expand Down
111 changes: 56 additions & 55 deletions src/shogun/base/SGObject.h
Expand Up @@ -438,7 +438,8 @@ class CSGObject
}

#ifndef SWIG
/** Typed array getter for an object array class parameter of a Shogun base class
/** Typed array getter for an object array class parameter of a Shogun base
* class
* type, identified by a name and an index.
*
* Returns nullptr if parameter of desired type does not exist.
Expand All @@ -448,7 +449,7 @@ class CSGObject
* @return desired element
*/
template <class T,
class X = typename std::enable_if<is_sg_base<T>::value>::type>
class X = typename std::enable_if<is_sg_base<T>::value>::type>
T* get(const std::string& name, index_t index, std::nothrow_t) const
{
CSGObject* result = nullptr;
Expand All @@ -467,14 +468,15 @@ class CSGObject
}

template <class T,
class X = typename std::enable_if<is_sg_base<T>::value>::type>
class X = typename std::enable_if<is_sg_base<T>::value>::type>
T* get(const std::string& name, index_t index) const
{
auto result = this->get<T>(name, index, std::nothrow);
if (!result) {
SG_ERROR("Could not get array parameter %s::%s[%d] of type %s\n",
get_name(), name.c_str(), index, demangled_type<T>().c_str());

if (!result)
{
SG_ERROR(
"Could not get array parameter %s::%s[%d] of type %s\n",
get_name(), name.c_str(), index, demangled_type<T>().c_str());
}
return result;
};
Expand Down Expand Up @@ -673,7 +675,8 @@ class CSGObject

/**
* Detach an observer from the current SGObject.
* @param subscription_index the index obtained by calling the subscribe procedure
* @param subscription_index the index obtained by calling the subscribe
* procedure
*/
void unsubscribe(ParameterObserver* obs);

Expand Down Expand Up @@ -967,7 +970,6 @@ class CSGObject
Unique<ParameterObserverList> param_obs_list;

protected:

/**
* Return total subscriptions
* @return total number of subscriptions
Expand Down Expand Up @@ -1014,70 +1016,69 @@ class CSGObject
void register_observable(
const std::string& name, const std::string& description);

/**
* Get the current step for the observed values.
*/
/**
* Get the current step for the observed values.
*/
#ifndef SWIG
SG_FORCED_INLINE int64_t get_step() const
{
int64_t step = -1;
Tag<int64_t> tag("current_iteration");
if (has(tag))
SG_FORCED_INLINE int64_t get_step() const
{
step = get(tag);
int64_t step = -1;
Tag<int64_t> tag("current_iteration");
if (has(tag))
{
step = get(tag);
}
return step;
}
return step;
}
#endif

/** mapping from strings to enum for SWIG interface */
stringToEnumMapType m_string_to_enum_map;
/** mapping from strings to enum for SWIG interface */
stringToEnumMapType m_string_to_enum_map;

public:
/** io */
SGIO* io;
public:
/** io */
SGIO* io;

/** parallel */
Parallel* parallel;
/** parallel */
Parallel* parallel;

/** version */
Version* version;
/** version */
Version* version;

/** parameters */
Parameter* m_parameters;
/** parameters */
Parameter* m_parameters;

/** model selection parameters */
Parameter* m_model_selection_parameters;
/** model selection parameters */
Parameter* m_model_selection_parameters;

/** parameters wrt which we can compute gradients */
Parameter* m_gradient_parameters;
/** parameters wrt which we can compute gradients */
Parameter* m_gradient_parameters;

/** Hash of parameter values*/
uint32_t m_hash;
/** Hash of parameter values*/
uint32_t m_hash;

private:

EPrimitiveType m_generic;
bool m_load_pre_called;
bool m_load_post_called;
bool m_save_pre_called;
bool m_save_post_called;
private:
EPrimitiveType m_generic;
bool m_load_pre_called;
bool m_load_post_called;
bool m_save_pre_called;
bool m_save_post_called;

RefCount* m_refcount;
RefCount* m_refcount;

/** Subject used to create the params observer */
SGSubject* m_subject_params;
/** Subject used to create the params observer */
SGSubject* m_subject_params;

/** Parameter Observable */
SGObservable* m_observable_params;
/** Parameter Observable */
SGObservable* m_observable_params;

/** Subscriber used to call onNext, onComplete etc.*/
SGSubscriber* m_subscriber_params;
/** Subscriber used to call onNext, onComplete etc.*/
SGSubscriber* m_subscriber_params;

/** List of subscription for this SGObject */
std::map<int64_t, rxcpp::subscription> m_subscriptions;
int64_t m_next_subscription_index;
};
/** List of subscription for this SGObject */
std::map<int64_t, rxcpp::subscription> m_subscriptions;
int64_t m_next_subscription_index;
};

#ifndef SWIG
#ifndef DOXYGEN_SHOULD_SKIP_THIS
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/base/base_types.h
Expand Up @@ -54,7 +54,7 @@ namespace shogun
std::is_same<CMeanFunction, T>::value ||
std::is_same<CLossFunction, T>::value ||
std::is_same<CTokenizer, T>::value ||
std::is_same<CEvaluationResult, T>::value>
std::is_same<CEvaluationResult, T>::value>
{
};

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/classifier/Perceptron.cpp
Expand Up @@ -73,7 +73,7 @@ void CPerceptron::iteration()
{
converged = false;
const auto gradient = learn_rate * true_label;
put("bias", bias+gradient);
put("bias", bias + gradient);
v.add(gradient, w);
put("w", w);
}
Expand Down
9 changes: 6 additions & 3 deletions src/shogun/evaluation/CrossValidation.cpp
Expand Up @@ -99,7 +99,9 @@ CEvaluationResult* CCrossValidation::evaluate_impl()
CrossValidationStorage* storage = new CrossValidationStorage();
SG_REF(storage)
storage->put("num_runs", utils::safe_convert<index_t>(m_num_runs));
storage->put("num_folds", utils::safe_convert<index_t>(m_splitting_strategy->get_num_subsets()));
storage->put(
"num_folds", utils::safe_convert<index_t>(
m_splitting_strategy->get_num_subsets()));
storage->put("labels", m_labels);
storage->post_init();
SG_DEBUG("Ending CrossValidationStorage initilization.\n")
Expand All @@ -109,8 +111,9 @@ CEvaluationResult* CCrossValidation::evaluate_impl()
SG_DEBUG("result of cross-validation run %d is %f\n", i, results[i])

/* Emit the value */
observe(i, "cross_validation_run", "One run of CrossValidation",
storage->as<CEvaluationResult>());
observe(
i, "cross_validation_run", "One run of CrossValidation",
storage->as<CEvaluationResult>());

SG_UNREF(storage)
}
Expand Down
44 changes: 24 additions & 20 deletions src/shogun/evaluation/CrossValidationStorage.cpp
Expand Up @@ -50,27 +50,29 @@ CrossValidationFoldStorage::CrossValidationFoldStorage() : CEvaluationResult()
m_test_true_result = NULL;

SG_ADD(
&m_current_run_index, "run_index",
"The current run index of this fold", ParameterProperties::HYPER);
&m_current_run_index, "run_index", "The current run index of this fold",
ParameterProperties::HYPER);
SG_ADD(
&m_current_fold_index, "fold_index", "The current fold index",
ParameterProperties::HYPER);
SG_ADD(
&m_trained_machine, "trained_machine",
"The machine trained by this fold", ParameterProperties::HYPER);
SG_ADD(
&m_test_result, "predicted_labels",
"The test result of this fold", ParameterProperties::HYPER);
&m_test_result, "predicted_labels", "The test result of this fold",
ParameterProperties::HYPER);
SG_ADD(
&m_test_true_result, "ground_truth_labels",
"The true test result for this fold", ParameterProperties::HYPER);
SG_ADD(&m_train_indices, "train_indices",
"Indices used for training", ParameterProperties::HYPER);
SG_ADD(&m_test_indices, "test_indices",
"Indices used for testing", ParameterProperties::HYPER);
SG_ADD(&m_evaluation_result, "evaluation_result",
"Result of the evaluation", ParameterProperties::HYPER);

SG_ADD(
&m_train_indices, "train_indices", "Indices used for training",
ParameterProperties::HYPER);
SG_ADD(
&m_test_indices, "test_indices", "Indices used for testing",
ParameterProperties::HYPER);
SG_ADD(
&m_evaluation_result, "evaluation_result", "Result of the evaluation",
ParameterProperties::HYPER);
}

CrossValidationFoldStorage::~CrossValidationFoldStorage()
Expand All @@ -84,8 +86,8 @@ void CrossValidationFoldStorage::post_update_results()
{
}

void CrossValidationFoldStorage::print_result() {

void CrossValidationFoldStorage::print_result()
{
}

/** CrossValidationStorage **/
Expand All @@ -100,12 +102,14 @@ CrossValidationStorage::CrossValidationStorage() : CEvaluationResult()
&m_num_runs, "num_runs", "The total number of cross-validation runs",
ParameterProperties::HYPER);
SG_ADD(
&m_num_folds, "num_folds",
"The total number of cross-validation folds", ParameterProperties::HYPER);
&m_num_folds, "num_folds", "The total number of cross-validation folds",
ParameterProperties::HYPER);
SG_ADD(
&m_original_labels, "labels",
"The labels used for this cross-validation", ParameterProperties::HYPER);
this->watch_param("folds", &m_folds_results, AnyParameterProperties("Fold results"));
&m_original_labels, "labels",
"The labels used for this cross-validation",
ParameterProperties::HYPER);
this->watch_param(
"folds", &m_folds_results, AnyParameterProperties("Fold results"));
}

CrossValidationStorage::~CrossValidationStorage()
Expand All @@ -126,6 +130,6 @@ void CrossValidationStorage::append_fold_result(
m_folds_results.push_back(result);
}

void CrossValidationStorage::print_result() {

void CrossValidationStorage::print_result()
{
}