Skip to content

Commit

Permalink
implement cross-validated calibration
Browse files Browse the repository at this point in the history
  • Loading branch information
durovo committed Dec 10, 2017
1 parent 6671ec0 commit 3b08714
Show file tree
Hide file tree
Showing 8 changed files with 996 additions and 0 deletions.
244 changes: 244 additions & 0 deletions src/shogun/evaluation/Calibration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2011-2012 Heiko Strathmann
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
*/

#include <shogun/evaluation/Calibration.h>
#include <shogun/evaluation/CalibrationMethod.h>
#include <shogun/lib/config.h>
#include <shogun/machine/Machine.h>
#include <shogun/mathematics/Math.h>
#include <shogun/mathematics/Statistics.h>

using namespace shogun;

CBinaryLabels* CCalibration::apply_binary(CFeatures* features)
{

CBinaryLabels* result_labels = (CBinaryLabels*)m_machine->apply(features);
CCalibrationMethod* method =
(CCalibrationMethod*)m_calibration_machines->get_element(0);
SGVector<float64_t> confidence_values =
method->apply_binary(result_labels->get_values());
result_labels->set_values(confidence_values);

return result_labels;
}

CMulticlassLabels* CCalibration::apply_multiclass(CFeatures* features)
{
index_t num_calibration_machines =
m_calibration_machines->get_num_elements();
CMulticlassLabels* result_labels =
(CMulticlassLabels*)m_machine->apply(features);
for (index_t i = 0; i < num_calibration_machines; ++i)
{
CCalibrationMethod* method =
(CCalibrationMethod*)m_calibration_machines->get_element(i);
SGVector<float64_t> confidence_values =
method->apply_binary(result_labels->get_multiclass_confidences(i));
result_labels->set_multiclass_confidences(i, confidence_values);
}

SGVector<float64_t> temp_confidences;
index_t num_classes = result_labels->get_num_classes();
index_t num_samples;

// normalize the probabilities
for (index_t i = 0; i < num_classes; ++i)
{
SGVector<float64_t> confidence_values =
result_labels->get_multiclass_confidences(i);
if (i == 0)
{
temp_confidences = confidence_values;
num_samples = temp_confidences.vlen;
}
else
{
for (index_t j = 0; j < num_samples; ++j)
{
temp_confidences[j] += confidence_values[j];
}
}
}
for (index_t i = 0; i < num_classes; ++i)
{
SGVector<float64_t> confidence_values =
result_labels->get_multiclass_confidences(i);
for (index_t j = 0; j < num_samples; ++j)
{
confidence_values[j] = confidence_values[j] / temp_confidences[j];
}
result_labels->set_multiclass_confidences(i, confidence_values);
}

return result_labels;
}

CMulticlassLabels*
CCalibration::apply_locked_multiclass(SGVector<index_t> subset_indices)
{
index_t num_calibration_machines =
m_calibration_machines->get_num_elements();
CMulticlassLabels* result_labels =
(CMulticlassLabels*)m_machine->apply_locked(subset_indices);
for (index_t i = 0; i < num_calibration_machines; ++i)
{
CCalibrationMethod* method =
(CCalibrationMethod*)m_calibration_machines->get_element(i);
SGVector<float64_t> confidence_values =
method->apply_binary(result_labels->get_multiclass_confidences(i));
result_labels->set_multiclass_confidences(i, confidence_values);
}
SGVector<float64_t> temp_confidences;
index_t num_classes = result_labels->get_num_classes();
index_t num_samples;

// normalize the probabilities
for (index_t i = 0; i < num_classes; ++i)
{
SGVector<float64_t> confidence_values =
result_labels->get_multiclass_confidences(i);
if (i == 0)
{
temp_confidences = confidence_values;
num_samples = temp_confidences.vlen;
continue;
}

for (index_t j = 0; j < num_samples; ++j)
{
temp_confidences[j] += confidence_values[j];
}
}
for (index_t i = 0; i < num_classes; ++i)
{
SGVector<float64_t> confidence_values =
result_labels->get_multiclass_confidences(i);
for (index_t j = 0; j < num_samples; ++j)
{
confidence_values[j] /= temp_confidences[j];
}
result_labels->set_multiclass_confidences(i, confidence_values);
}

return result_labels;
}

CBinaryLabels*
CCalibration::apply_locked_binary(SGVector<index_t> subset_indices)
{

CBinaryLabels* result_labels =
(CBinaryLabels*)m_machine->apply_locked(subset_indices);
CCalibrationMethod* method =
(CCalibrationMethod*)m_calibration_machines->get_element(0);
SGVector<float64_t> confidence_values =
method->apply_binary(result_labels->get_values());
result_labels->set_values(confidence_values);

return result_labels;
}

EProblemType CCalibration::get_machine_problem_type() const
{
return m_machine->get_machine_problem_type();
}

bool CCalibration::train(CFeatures* features)
{
CCalibrationMethod* calibration_machine = NULL;
if (get_machine_problem_type() == PT_MULTICLASS)
{
SGVector<float64_t> confidences;
index_t num_calibration_machines =
((CMulticlassLabels*)get_labels())->get_num_classes();
m_calibration_machines =
new CDynamicObjectArray(num_calibration_machines);
m_machine->train(features);
CMulticlassLabels* result_labels =
(CMulticlassLabels*)m_machine->apply(features);
for (index_t i = 0; i < num_calibration_machines; ++i)
{
confidences = result_labels->get_multiclass_confidences(i);

calibration_machine = (CCalibrationMethod*)m_method->clone();
if (!calibration_machine->train(confidences))
{
return false;
}
m_calibration_machines->set_element(calibration_machine, i);
}
}
else
{
SGVector<float64_t> confidences;
m_calibration_machines = new CDynamicObjectArray(1);
m_machine->train(features);
CBinaryLabels* result_labels =
(CBinaryLabels*)m_machine->apply_binary(features);

confidences = result_labels->get_values();

calibration_machine = (CCalibrationMethod*)m_method->clone();
if (!calibration_machine->train(confidences))
{
return false;
}
m_calibration_machines->set_element(calibration_machine, 0);
}

return true;
}

bool CCalibration::train_locked(SGVector<index_t> subset_indices)
{
CBinaryLabels* m_result_labels =
(CBinaryLabels*)m_machine->apply_locked(subset_indices);
CStatistics::SigmoidParamters params =
CStatistics::fit_sigmoid(m_result_labels->get_values());
a = params.a;
b = params.b;

return true;
}

CCalibration::CCalibration() : CMachine()
{
init();
}

void CCalibration::set_calibration_method(CCalibrationMethod* method)
{
m_method = method;
}

void CCalibration::set_machine(CMachine* machine)
{
m_machine = machine;
}

void CCalibration::init()
{
m_machine = NULL;
m_labels = NULL;
}

CCalibration::~CCalibration()
{

SG_UNREF(m_calibration_machines);
SG_UNREF(m_machine);
SG_UNREF(m_labels);
}

CMachine* CCalibration::get_machine()
{
return m_machine;
}
66 changes: 66 additions & 0 deletions src/shogun/evaluation/Calibration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2011-2012 Heiko Strathmann
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
*/

#include <shogun/evaluation/CalibrationMethod.h>
#include <shogun/lib/config.h>
#include <shogun/machine/Machine.h>

#ifndef _CALIBRATION_H__
#define _CALIBRATION_H__

namespace shogun
{

class CCalibration : public CMachine
{
public:
virtual const char* get_name() const
{
return "Calibration";
}

virtual EProblemType get_machine_problem_type() const;

virtual bool train(CFeatures* data = NULL);

virtual bool train_locked(SGVector<index_t> subset_indices);

virtual CBinaryLabels* apply_binary(CFeatures* features);

virtual CMulticlassLabels* apply_multiclass(CFeatures* features);

virtual CMulticlassLabels*
apply_locked_multiclass(SGVector<index_t> subset_indices);

virtual CBinaryLabels*
apply_locked_binary(SGVector<index_t> subset_indices);

/** constructor
*/
CCalibration();

~CCalibration();

void init();

void set_machine(CMachine* machine);

void set_calibration_method(CCalibrationMethod* calibration_method);

CMachine* get_machine();

private:
CMachine* m_machine;
float64_t a, b;
CDynamicObjectArray* m_calibration_machines;
CCalibrationMethod* m_method;
};
}
#endif
45 changes: 45 additions & 0 deletions src/shogun/evaluation/CalibrationMethod.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2011-2012 Heiko Strathmann
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
*/

#include <shogun/evaluation/CalibrationMethod.h>
#include <shogun/machine/Machine.h>

using namespace shogun;

SGVector<float64_t> CCalibrationMethod::apply_binary(SGVector<float64_t> values)
{
SG_NOTIMPLEMENTED
return NULL;
}

bool CCalibrationMethod::train(SGVector<float64_t> values)
{
SG_NOTIMPLEMENTED

return true;
}

CCalibrationMethod::CCalibrationMethod() : CMachine()
{
}

void CCalibrationMethod::set_target_values(SGVector<float64_t> target_values)
{
m_target_values = target_values;
}

CCalibrationMethod::CCalibrationMethod(SGVector<float64_t> target_values)
{
m_target_values = target_values;
}

CCalibrationMethod::~CCalibrationMethod()
{
}

0 comments on commit 3b08714

Please sign in to comment.