Skip to content

Commit

Permalink
added metric evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
gf712 committed May 29, 2019
1 parent b045a8a commit 7842b61
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 25 deletions.
12 changes: 7 additions & 5 deletions src/shogun/io/openml/OpenMLRun.cpp
Expand Up @@ -4,7 +4,6 @@
* Authors: Gil Hoben
*/

#include <shogun/evaluation/CrossValidationStorage.h>
#include <shogun/io/openml/OpenMLFile.h>
#include <shogun/io/openml/OpenMLRun.h>
#include <shogun/io/openml/ShogunOpenML.h>
Expand Down Expand Up @@ -47,7 +46,7 @@ std::shared_ptr<OpenMLRun> OpenMLRun::run_flow_on_task(
SG_SERROR("INTERNAL ERROR: failed to cast model to machine!\n")
}

auto* xval_storage = new CrossValidationStorage();
auto xval_storage = std::make_shared<CrossValidationStorage>();

if (task->get_split()->contains_splits())
{
Expand Down Expand Up @@ -93,9 +92,7 @@ std::shared_ptr<OpenMLRun> OpenMLRun::run_flow_on_task(
std::string{}, // setup_id
std::string{}, // setup_string
std::string{}, // parameter_settings
std::vector<float64_t>{}, // evaluations
std::vector<float64_t>{}, // fold_evaluations
std::vector<float64_t>{}, // sample_evaluations
xval_storage, // xval_storage
std::string{}, // data_content
std::vector<std::string>{}, // output_files
task, // task
Expand Down Expand Up @@ -123,3 +120,8 @@ void OpenMLRun::publish() const
{
SG_SNOTIMPLEMENTED
}

std::unique_ptr<std::ostream> OpenMLRun::to_xml() const {

return std::unique_ptr<std::ostream>();
}
15 changes: 6 additions & 9 deletions src/shogun/io/openml/OpenMLRun.h
Expand Up @@ -8,6 +8,7 @@
#define SHOGUN_OPENMLRUN_H

#include <shogun/base/SGObject.h>
#include <shogun/evaluation/CrossValidationStorage.h>

#include <shogun/io/openml/OpenMLFlow.h>
#include <shogun/io/openml/OpenMLTask.h>
Expand All @@ -20,9 +21,7 @@ namespace shogun {
const std::string& uploader, const std::string& uploader_name,
const std::string& setup_id, const std::string& setup_string,
const std::string& parameter_settings,
std::vector<float64_t> evaluations,
std::vector<float64_t> fold_evaluations,
std::vector<float64_t> sample_evaluations,
std::shared_ptr<CrossValidationStorage> xval_storage,
const std::string& data_content,
std::vector<std::string> output_files,
std::shared_ptr<OpenMLTask> task, std::shared_ptr<OpenMLFlow> flow,
Expand All @@ -31,9 +30,7 @@ namespace shogun {
: m_uploader(uploader), m_uploader_name(uploader_name),
m_setup_id(setup_id), m_setup_string(setup_string),
m_parameter_settings(parameter_settings),
m_evaluations(std::move(evaluations)),
m_fold_evaluations(std::move(fold_evaluations)),
m_sample_evaluations(std::move(sample_evaluations)),
m_xval_storage(xval_storage),
m_data_content(data_content),
m_output_files(std::move(output_files)), m_task(std::move(task)),
m_flow(std::move(flow)), m_run_id(run_id),
Expand All @@ -55,6 +52,8 @@ namespace shogun {

void to_filesystem(const std::string& directory) const;

std::unique_ptr<std::ostream> to_xml() const;

void publish() const;

private:
Expand All @@ -63,9 +62,7 @@ namespace shogun {
std::string m_setup_id;
std::string m_setup_string;
std::string m_parameter_settings;
std::vector<float64_t> m_evaluations;
std::vector<float64_t> m_fold_evaluations;
std::vector<float64_t> m_sample_evaluations;
std::shared_ptr<CrossValidationStorage> m_xval_storage;
std::string m_data_content;
std::vector<std::string> m_output_files;
std::shared_ptr<OpenMLTask> m_task;
Expand Down
72 changes: 61 additions & 11 deletions src/shogun/io/openml/ShogunOpenML.cpp
Expand Up @@ -6,6 +6,8 @@
* Authors: Gil Hoben
*/

#include <shogun/evaluation/ContingencyTableEvaluation.h>
#include <shogun/evaluation/MeanAbsoluteError.h>
#include <shogun/util/factory.h>

#include <shogun/io/openml/ShogunOpenML.h>
Expand Down Expand Up @@ -310,6 +312,30 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
{
auto task_type = task->get_task_type();

CEvaluation* evaluation_criterion = nullptr;

switch (task_type)
{
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
evaluation_criterion = new CAccuracyMeasure();
break;
case OpenMLTask::TaskType::SUPERVISED_REGRESSION:
evaluation_criterion = new CMeanAbsoluteError();
break;
case OpenMLTask::TaskType::LEARNING_CURVE:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::CLUSTERING:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::MACHINE_LEARNING_CHALLENGE:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::SURVIVAL_ANALYSIS:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::SUBGROUP_DISCOVERY:
SG_SNOTIMPLEMENTED
}

switch (task_type)
{
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
Expand All @@ -324,8 +350,6 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
// shared
auto* features_clone = features->clone()->as<CFeatures>();
auto* labels_clone = labels->clone()->as<CLabels>();
// auto* evaluation_criterion =
// (CEvaluation*)m_evaluation_criterion->clone();

/* evtl. update xvalidation output class */
fold->set_run_index(repeat_idx);
Expand Down Expand Up @@ -371,8 +395,10 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
SG_REF(result_labels);

/* evaluate */
// results[i] = evaluation_criterion->evaluate(result_labels, labels);
// SG_DEBUG("result on fold %d is %f\n", i, results[i])
auto result =
evaluation_criterion->evaluate(result_labels, labels_clone);
SG_SINFO(
"result on repeat %d fold %d is %f\n", repeat_idx, fold_idx, result)

/* evtl. update xvalidation output class */
fold->set_test_indices(test_idx);
Expand All @@ -381,18 +407,17 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
fold->set_test_true_result(true_labels);
SG_UNREF(true_labels)
fold->post_update_results();
// fold->set_evaluation_result(results[i]);
fold->set_evaluation_result(result);

/* clean up, remove subsets */
labels->remove_subset();
SG_UNREF(cloned_machine);
SG_UNREF(features_clone);
SG_UNREF(labels_clone);
// SG_UNREF(evaluation_criterion);
SG_UNREF(result_labels);
delete evaluation_criterion;
return fold;
}
break;
case OpenMLTask::TaskType::LEARNING_CURVE:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
Expand All @@ -417,6 +442,30 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
{
auto task_type = task->get_task_type();

CEvaluation* evaluation_criterion = nullptr;

switch (task_type)
{
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
evaluation_criterion = new CAccuracyMeasure();
break;
case OpenMLTask::TaskType::SUPERVISED_REGRESSION:
evaluation_criterion = new CMeanAbsoluteError();
break;
case OpenMLTask::TaskType::LEARNING_CURVE:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::SUPERVISED_DATASTREAM_CLASSIFICATION:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::CLUSTERING:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::MACHINE_LEARNING_CHALLENGE:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::SURVIVAL_ANALYSIS:
SG_SNOTIMPLEMENTED
case OpenMLTask::TaskType::SUBGROUP_DISCOVERY:
SG_SNOTIMPLEMENTED
}

switch (task_type)
{
case OpenMLTask::TaskType::SUPERVISED_CLASSIFICATION:
Expand Down Expand Up @@ -446,23 +495,24 @@ std::unique_ptr<CrossValidationFoldStorage> ShogunOpenML::run_model_on_fold(
SG_SDEBUG("finished evaluation\n")

/* evaluate */
// results[i] = evaluation_criterion->evaluate(result_labels, labels);
// SG_DEBUG("result on fold %d is %f\n", i, results[i])
auto result =
evaluation_criterion->evaluate(result_labels, labels_clone);
SG_SINFO("result is %f\n", result)

/* evtl. update xvalidation output class */
fold->set_test_result(result_labels);
auto* true_labels = (CLabels*)labels->clone();
fold->set_test_true_result(true_labels);
SG_UNREF(true_labels)
fold->post_update_results();
// fold->set_evaluation_result(results[i]);
fold->set_evaluation_result(result);

// cleanup
SG_UNREF(cloned_machine);
SG_UNREF(features_clone);
SG_UNREF(labels_clone);
// SG_UNREF(evaluation_criterion);
SG_UNREF(result_labels);
delete evaluation_criterion;
return fold;
}
case OpenMLTask::TaskType::LEARNING_CURVE:
Expand Down

0 comments on commit 7842b61

Please sign in to comment.