Skip to content

Commit

Permalink
Merge pull request #2913 from yorkerlin/develop
Browse files Browse the repository at this point in the history
support sparse penalty (eg, L1)
  • Loading branch information
yorkerlin committed Sep 4, 2015
2 parents 264b0ae + 8b5ccd0 commit 01a4511
Show file tree
Hide file tree
Showing 14 changed files with 850 additions and 324 deletions.
6 changes: 0 additions & 6 deletions src/shogun/optimization/ConstLearningRate.h
Expand Up @@ -76,9 +76,6 @@ class ConstLearningRate: public LearningRate
}

/** Update a context object to store mutable variables
*
* This method will be called by
* DescendUpdaterWithCorrection::update_context()
*
* @param context, a context object
*/
Expand All @@ -90,9 +87,6 @@ class ConstLearningRate: public LearningRate
/** Return a context object which stores mutable variables
* Usually it is used in serialization.
*
* This method will be called by
* DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)
*
* @return a context object
*/
virtual void load_from_context(CMinimizerContext* context)
Expand Down
146 changes: 146 additions & 0 deletions src/shogun/optimization/ElasticNetPenalty.h
@@ -0,0 +1,146 @@
/*
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (w) 2015 Wu Lin
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*
*/

#ifndef ELASTICNETPENALTY_H
#define ELASTICNETPENALTY_H
#include <shogun/optimization/SparsePenalty.h>
#include <shogun/optimization/L1Penalty.h>
#include <shogun/optimization/L2Penalty.h>
#include <shogun/lib/config.h>
#include <shogun/mathematics/Math.h>
namespace shogun
{
/** @brief The is the base class for ElasticNet penalty/regularization within the FirstOrderMinimizer framework.
*
* For ElasticNet penalty, \f$ElasticNet(w)\f$
* \f[
* ElasticNet(w)= \lambda \| w \|_1 + (1.0-\lambda) \| w \|_2
* \f]
* where \f$\lambda\f$ is the l1_ratio.
*/

class ElasticNetPenalty: public SparsePenalty
{
public:
ElasticNetPenalty()
:SparsePenalty() {init();}

virtual ~ElasticNetPenalty()
{
delete m_l1_penalty;
delete m_l2_penalty;
}

virtual void set_l1_ratio(float64_t ratio)
{
REQUIRE(ratio>0.0 && ratio<1.0, "");
m_l1_ratio=ratio;
}

/** Given the value of a target variable,
* this method returns the penalty of the variable
*
* @param variable value of the variable
* @return penalty of the variable
*/
virtual float64_t get_penalty(float64_t variable)
{
check_ratio();
float64_t penalty=m_l1_ratio*m_l1_penalty->get_penalty(variable);
penalty+=(1.0-m_l1_ratio)*m_l2_penalty->get_penalty(variable);
return penalty;
}

virtual float64_t get_penalty_gradient(float64_t variable,
float64_t gradient_of_variable)
{
check_ratio();
float64_t grad=m_l1_ratio*m_l1_penalty->get_penalty_gradient(variable, gradient_of_variable);
grad+=(1.0-m_l1_ratio)*m_l2_penalty->get_penalty_gradient(variable, gradient_of_variable);
return grad;
}

virtual void set_rounding_eplison(float64_t eplison)
{
m_l1_penalty->set_rounding_eplison(eplison);
}

virtual void update_sparse_variable(SGVector<float64_t> variable,
float64_t penalty_delta)
{
check_ratio();
m_l1_penalty->update_sparse_variable(variable, penalty_delta*m_l1_ratio);
}

/** Update a context object to store mutable variables
* used in learning rate
*
* @param context, a context object
*/
virtual void update_context(CMinimizerContext* context)
{
REQUIRE(context, "Context must set\n");
m_l1_penalty->update_context(context);
m_l2_penalty->update_context(context);
}

/** Load the given context object to restore mutable variables
*
* @param context, a context object
*/
virtual void load_from_context(CMinimizerContext* context)
{
REQUIRE(context, "Context must set\n");
m_l1_penalty->load_from_context(context);
m_l2_penalty->load_from_context(context);
}
protected:
virtual void check_ratio()
{
REQUIRE(m_l1_ratio>0, "l1_ratio must set\n");
}

float64_t m_l1_ratio;
L1Penalty* m_l1_penalty;
L2Penalty* m_l2_penalty;

private:
void init()
{
m_l1_ratio=0;
m_l1_penalty=new L1Penalty();
m_l2_penalty=new L2Penalty();
}
};

}

#endif
25 changes: 23 additions & 2 deletions src/shogun/optimization/FirstOrderMinimizer.h
Expand Up @@ -100,14 +100,24 @@ class FirstOrderMinimizer
*
* @return a context object
*/
virtual CMinimizerContext* save_to_context()=0;
virtual CMinimizerContext* save_to_context()
{
CMinimizerContext* result=new CMinimizerContext();
update_context(result);
return result;
}

/** Load the given context object to restores mutable variables
* Usually it is used in deserialization.
*
* @param context, a context object
*/
virtual void load_from_context(CMinimizerContext* context)=0;
virtual void load_from_context(CMinimizerContext* context)
{
REQUIRE(context,"Context must set\n");
if(m_penalty_type)
m_penalty_type->load_from_context(context);
}

/** Set the weight of penalty
*
Expand All @@ -130,6 +140,17 @@ class FirstOrderMinimizer
}
protected:

/** Update a context object to store mutable variables
*
* @param context, a context object
*/
virtual void update_context(CMinimizerContext* context)
{
REQUIRE(context,"Context must set\n");
if(m_penalty_type)
m_penalty_type->update_context(context);
}

/** Get the penalty given target variables
* For L2 penalty,
* the target variable is \f$w\f$
Expand Down
14 changes: 2 additions & 12 deletions src/shogun/optimization/FirstOrderStochasticMinimizer.h
Expand Up @@ -135,25 +135,14 @@ class FirstOrderStochasticMinimizer: public FirstOrderMinimizer
{
REQUIRE(context,"Context must set\n");
REQUIRE(m_gradient_updater,"Descend updater must set\n");
FirstOrderMinimizer::load_from_context(context);
m_gradient_updater->load_from_context(context);
if(m_learning_rate)
m_learning_rate->load_from_context(context);
std::string key="FirstOrderStochasticMinimizer::m_iter_counter";
m_iter_counter=context->get_data_int32(key);
}

/** Return a context object which stores mutable variables
* Usually it is used in serialization.
*
* @return a context object
*/
virtual CMinimizerContext* save_to_context()
{
CMinimizerContext* result=new CMinimizerContext();
update_context(result);
return result;
}

virtual void set_learning_rate(LearningRate *learning_rate)
{
m_learning_rate=learning_rate;
Expand All @@ -170,6 +159,7 @@ class FirstOrderStochasticMinimizer: public FirstOrderMinimizer
{
REQUIRE(context,"Context must set\n");
REQUIRE(m_gradient_updater,"Descend updater must set\n");
FirstOrderMinimizer::update_context(context);
m_gradient_updater->update_context(context);
if(m_learning_rate)
m_learning_rate->update_context(context);
Expand Down
6 changes: 0 additions & 6 deletions src/shogun/optimization/InverseScalingLearningRate.h
Expand Up @@ -113,9 +113,6 @@ class InverseScalingLearningRate: public LearningRate
}

/** Update a context object to store mutable variables
*
* This method will be called by
* DescendUpdaterWithCorrection::update_context()
*
* @param context, a context object
*/
Expand All @@ -127,9 +124,6 @@ class InverseScalingLearningRate: public LearningRate
/** Return a context object which stores mutable variables
* Usually it is used in serialization.
*
* This method will be called by
* DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)
*
* @return a context object
*/
virtual void load_from_context(CMinimizerContext* context)
Expand Down
129 changes: 129 additions & 0 deletions src/shogun/optimization/L1Penalty.h
@@ -0,0 +1,129 @@
/*
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (w) 2015 Wu Lin
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*
*/

#ifndef L1PENALTY_H
#define L1PENALTY_H
#include <shogun/optimization/SparsePenalty.h>
#include <shogun/lib/config.h>
#include <shogun/mathematics/Math.h>
namespace shogun
{
/** @brief The is the base class for L1 penalty/regularization within the FirstOrderMinimizer framework.
*
* For L1 penalty, \f$L1(w)\f$
* \f[
* L1(w)=\| w \|_1 = \sum_i \| w_i \|
* \f]
*/

class L1Penalty: public SparsePenalty
{
public:
L1Penalty()
:SparsePenalty() {init();}

virtual ~L1Penalty() {}

/** Given the value of a target variable,
* this method returns the penalty of the variable
*
* @param variable value of the variable
* @return penalty of the variable
*/
virtual float64_t get_penalty(float64_t variable) {return CMath::abs(variable);}

virtual float64_t get_penalty_gradient(float64_t variable,
float64_t gradient_of_variable) {return 0.0;}

virtual void set_rounding_eplison(float64_t eplison)
{
REQUIRE(eplison>=0,"Rounding eplison (%f) should be non-negative\n", eplison);
m_rounding_eplison=eplison;
}

virtual void update_sparse_variable(SGVector<float64_t> variable,
float64_t penalty_delta)
{
for(index_t idx=0; idx<variable.vlen; idx++)
variable[idx]=get_sparse_variable(variable[idx], penalty_delta);
}

/** Update a context object to store mutable variables
* used in learning rate
*
* @param context, a context object
*/
virtual void update_context(CMinimizerContext* context)
{
REQUIRE(context, "Context must set\n");
}

/** Load the given context object to restore mutable variables
*
* @param context, a context object
*/
virtual void load_from_context(CMinimizerContext* context)
{
REQUIRE(context, "Context must set\n");
}
protected:
float64_t m_rounding_eplison;

virtual float64_t get_sparse_variable(float64_t variable, float64_t penalty_delta)
{
if (variable>0.0)
{
variable-=penalty_delta;
if (variable<0.0)
variable=0.0;
}
else
{
variable+=penalty_delta;
if (variable>0.0)
variable=0.0;
}
if (CMath::abs(variable)<m_rounding_eplison)
variable=0.0;
return variable;
return 0;
}

private:
void init()
{
m_rounding_eplison=1e-8;
}
};

}

#endif

0 comments on commit 01a4511

Please sign in to comment.