Skip to content

Commit

Permalink
remove sigmoid params array, make methods pure virtual in Calibration
Browse files Browse the repository at this point in the history
  • Loading branch information
durovo committed Feb 12, 2018
1 parent 72abcba commit 6f3b546
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 42 deletions.
26 changes: 4 additions & 22 deletions src/shogun/evaluation/Calibration.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,44 +70,26 @@ namespace shogun
* @param targets The true labels corresponding to the predictions
* @return boolean indicating whether the calibration was succesful
**/
virtual bool fit_binary(CBinaryLabels* predictions, CBinaryLabels* targets)
{
SG_NOTIMPLEMENTED

return true;
}
virtual bool fit_binary(CBinaryLabels* predictions, CBinaryLabels* targets) = 0;

/** Calibrate binary predictions based on parameters learned by calling fit.
* @param predictions The predictions outputted by the machine
* @return Calibrated binary labels
**/
virtual CBinaryLabels* calibrate_binary(CBinaryLabels* predictions)
{
SG_NOTIMPLEMENTED
return NULL;
}
virtual CBinaryLabels* calibrate_binary(CBinaryLabels* predictions) = 0;

/** Fit calibration parameters for multiclass labels.
* @param predictions The predictions outputted by the machine
* @param targets The true labels corresponding to the predictions
* @return boolean indicating whether the calibration was succesful
**/
virtual bool fit_multiclass(CMulticlassLabels* predictions, CMulticlassLabels* targets)
{
SG_NOTIMPLEMENTED

return true;
}
virtual bool fit_multiclass(CMulticlassLabels* predictions, CMulticlassLabels* targets) = 0;

/** Calibrate multiclass predictions based on parameters learned by calling fit.
* @param predictions The predictions outputted by the machine
* @return Calibrated binary labels
**/
virtual CMulticlassLabels* calibrate_multiclass(CMulticlassLabels* predictions)
{
SG_NOTIMPLEMENTED
return NULL;
}
virtual CMulticlassLabels* calibrate_multiclass(CMulticlassLabels* predictions) = 0;
};
}
#endif
35 changes: 25 additions & 10 deletions src/shogun/evaluation/SigmoidCalibration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

using namespace shogun;

CSigmoidCalibration::CSigmoidCalibration()
CSigmoidCalibration::CSigmoidCalibration() : CCalibration()
{
init();
}
Expand All @@ -53,12 +53,15 @@ CSigmoidCalibration::~CSigmoidCalibration()

void CSigmoidCalibration::init()
{
m_sigmoid_parameters.resize(1);
m_sigmoid_as.resize_vector(1);
m_sigmoid_bs.resize_vector(1);
m_maxiter = 100;
m_minstep = 1E-10;
m_sigma = 1E-12;
m_epsilon = 1E-5;

SG_ADD(&m_sigmoid_as, "m_sigmoid_as", "Vector of paramter A of sigmoid for each class.", MS_NOT_AVAILABLE);
SG_ADD(&m_sigmoid_bs, "m_sigmoid_bs", "Vector of paramter B of sigmoid for each class.", MS_NOT_AVAILABLE);
SG_ADD(&m_maxiter, "m_maxiter", "Maximum number of iteration for search.", MS_NOT_AVAILABLE);
SG_ADD(&m_minstep, "m_minstep", "Minimum step taken in line search.", MS_NOT_AVAILABLE);
SG_ADD(&m_sigma, "m_sigma", "Positive parameter to ensure positive semi-definite Hessian.", MS_NOT_AVAILABLE);
Expand Down Expand Up @@ -107,19 +110,24 @@ float64_t CSigmoidCalibration::get_epsilon()

bool CSigmoidCalibration::fit_binary(CBinaryLabels* predictions, CBinaryLabels* targets)
{
m_sigmoid_parameters.resize(1);

m_sigmoid_parameters[0] = CStatistics::fit_sigmoid(predictions->get_values(), targets->get_labels(),
m_sigmoid_as.resize_vector(1);
m_sigmoid_bs.resize_vector(1);
auto sigmoid_params = CStatistics::fit_sigmoid(predictions->get_values(), targets->get_labels(),
m_maxiter, m_minstep, m_sigma, m_epsilon);

m_sigmoid_as[0] = sigmoid_params.a;
m_sigmoid_bs[0] = sigmoid_params.b;

return true;
}

CBinaryLabels*
CSigmoidCalibration::calibrate_binary(CBinaryLabels* predictions)
{
auto params = m_sigmoid_parameters[0];

CStatistics::SigmoidParamters params;
params.a = m_sigmoid_as[0];
params.b = m_sigmoid_bs[0];

/** Convert predictions to probabilties. */
auto values = calibrate_values(predictions->get_values(), params);

Expand All @@ -133,7 +141,8 @@ bool CSigmoidCalibration::fit_multiclass(CMulticlassLabels* predictions, CMultic
index_t num_classes =
predictions->get_num_classes();

m_sigmoid_parameters.resize(num_classes);
m_sigmoid_bs.resize_vector(num_classes);
m_sigmoid_bs.resize_vector(num_classes);

/** Fit and store parameters for for each class seperately. */

Expand All @@ -144,8 +153,10 @@ bool CSigmoidCalibration::fit_multiclass(CMulticlassLabels* predictions, CMultic
auto pred_values = class_predictions->get_values();
auto target_labels = class_targets->get_labels();

m_sigmoid_parameters[i] = CStatistics::fit_sigmoid(pred_values, target_labels,
auto sigmoid_params = CStatistics::fit_sigmoid(pred_values, target_labels,
m_maxiter, m_minstep, m_sigma, m_epsilon);
m_sigmoid_as[i] = sigmoid_params.a;
m_sigmoid_bs[i] = sigmoid_params.b;
}

return true;
Expand All @@ -165,8 +176,12 @@ CSigmoidCalibration::calibrate_multiclass(CMulticlassLabels* predictions)
auto binary_predictions = predictions->get_binary_for_class(i);
auto class_values = binary_predictions->get_values();

CStatistics::SigmoidParamters sigmoid_params;
sigmoid_params.a = m_sigmoid_as[i];
sigmoid_params.b = m_sigmoid_bs[i];

SGVector<float64_t> calibrated_values =
calibrate_values(class_values, m_sigmoid_parameters[i]);
calibrate_values(class_values, sigmoid_params);

confidence_values.set_column(i, calibrated_values);
}
Expand Down
22 changes: 12 additions & 10 deletions src/shogun/evaluation/SigmoidCalibration.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,42 +97,42 @@ namespace shogun
/** Set maximum number of iterations
* @param maxiter maximum number of iterations
*/
virtual void set_maxiter(index_t maxiter);
void set_maxiter(index_t maxiter);

/** Get max iterations
* @return maximum number of iterations
*/
virtual index_t get_maxiter();
index_t get_maxiter();

/** Set min step
* @param minstep min step taken in line search
*/
virtual void set_minstep(float64_t minstep);
void set_minstep(float64_t minstep);

/** Get min step
* @return minimum steps taken in line search
*/
virtual float64_t get_minstep();
float64_t get_minstep();

/** Set sigma
* @param sigma Set to a value greater than 0 to ensure that the Hessian matrix is positive semi-definite
*/
virtual void set_sigma(float64_t sigma);
void set_sigma(float64_t sigma);

/** Get sigma
* @return sigma
*/
virtual float64_t get_sigma();
float64_t get_sigma();

/** Get epsilon
* @param epsilon stopping criteria
*/
virtual void set_epsilon(float64_t epsilon);
void set_epsilon(float64_t epsilon);

/** Get epsilon
* @return stopping critera
*/
virtual float64_t get_epsilon();
float64_t get_epsilon();

private:
/** Initialize parameters */
Expand All @@ -145,8 +145,10 @@ namespace shogun
SGVector<float64_t> calibrate_values(SGVector<float64_t> values, CStatistics::SigmoidParamters params);

private:
/** Array to store sigmoid parameters for each class. In case of binary labels, only one pair of parameters are stored. */
std::vector<CStatistics::SigmoidParamters> m_sigmoid_parameters;
/** Vector to store parameter A of sigmoid for each class. In case of binary labels, only one pair of parameters are stored. */
SGVector<float64_t> m_sigmoid_as;
/** Vector to store parameter B of sigmoid for each class. */
SGVector<float64_t> m_sigmoid_bs;
/** Maximum number of iterations. */
index_t m_maxiter;
/** Minimum step taken in line search. */
Expand Down

0 comments on commit 6f3b546

Please sign in to comment.