diff --git a/examples/undocumented/libshogun/so_multiclass.cpp b/examples/undocumented/libshogun/so_multiclass.cpp index 1e067c97658..79daf834668 100644 --- a/examples/undocumented/libshogun/so_multiclass.cpp +++ b/examples/undocumented/libshogun/so_multiclass.cpp @@ -111,7 +111,7 @@ int main(int argc, char ** argv) // Create SO-SVM CPrimalMosekSOSVM* sosvm = new CPrimalMosekSOSVM(model, loss, labels); - CDualLibQPBMSOSVM* bundle = new CDualLibQPBMSOSVM(model, loss, labels, 1000); + CDualLibQPBMSOSVM* bundle = new CDualLibQPBMSOSVM(model, labels, 1000); bundle->set_verbose(false); SG_REF(sosvm); SG_REF(bundle); diff --git a/examples/undocumented/libshogun/so_multiclass_BMRM.cpp b/examples/undocumented/libshogun/so_multiclass_BMRM.cpp index 0dd63e01663..51960fa1ede 100644 --- a/examples/undocumented/libshogun/so_multiclass_BMRM.cpp +++ b/examples/undocumented/libshogun/so_multiclass_BMRM.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include @@ -200,14 +199,10 @@ int main(int argc, char * argv[]) // Create structured model CMulticlassModel* model = new CMulticlassModel(features, labels); - // Create loss function - CHingeLoss* loss = new CHingeLoss(); - // Create SO-SVM CDualLibQPBMSOSVM* sosvm = new CDualLibQPBMSOSVM( model, - loss, labels, lambda); SG_REF(sosvm); diff --git a/examples/undocumented/python_modular/structure_discrete_hmsvm_bmrm.py b/examples/undocumented/python_modular/structure_discrete_hmsvm_bmrm.py index eafb3a9f0e0..03022b09be1 100644 --- a/examples/undocumented/python_modular/structure_discrete_hmsvm_bmrm.py +++ b/examples/undocumented/python_modular/structure_discrete_hmsvm_bmrm.py @@ -10,7 +10,6 @@ def structure_discrete_hmsvm_bmrm (m_data_dict=data_dict): from shogun.Features import RealMatrixFeatures - from shogun.Loss import HingeLoss from shogun.Structure import SequenceLabels, HMSVMModel, Sequence, TwoStateModel, SMT_TWO_STATE from shogun.Evaluation import StructuredAccuracy from shogun.Structure import DualLibQPBMSOSVM @@ -23,11 +22,10 @@ def structure_discrete_hmsvm_bmrm (m_data_dict=data_dict): labels = SequenceLabels(labels_array, 250, 500, 2) features = RealMatrixFeatures(m_data_dict['signal'].astype(float), 250, 500) - loss = HingeLoss() num_obs = 4 # given by the data file used model = HMSVMModel(features, labels, SMT_TWO_STATE, num_obs) - sosvm = DualLibQPBMSOSVM(model, loss, labels, 5000.0) + sosvm = DualLibQPBMSOSVM(model, labels, 5000.0) sosvm.train() #print sosvm.get_w() diff --git a/examples/undocumented/python_modular/structure_plif_hmsvm_bmrm.py b/examples/undocumented/python_modular/structure_plif_hmsvm_bmrm.py index e87c42ba836..a3013a52add 100644 --- a/examples/undocumented/python_modular/structure_plif_hmsvm_bmrm.py +++ b/examples/undocumented/python_modular/structure_plif_hmsvm_bmrm.py @@ -4,13 +4,11 @@ def structure_plif_hmsvm_bmrm (num_examples, example_length, num_features, num_noise_features): from shogun.Features import RealMatrixFeatures - from shogun.Loss import HingeLoss from shogun.Structure import TwoStateModel, DualLibQPBMSOSVM from shogun.Evaluation import StructuredAccuracy model = TwoStateModel.simulate_data(num_examples, example_length, num_features, num_noise_features) - loss = HingeLoss() - sosvm = DualLibQPBMSOSVM(model, loss, model.get_labels(), 5000.0) + sosvm = DualLibQPBMSOSVM(model, model.get_labels(), 5000.0) sosvm.train() #print sosvm.get_w() diff --git a/src/shogun/machine/KernelStructuredOutputMachine.cpp b/src/shogun/machine/KernelStructuredOutputMachine.cpp index 9f19ea217bc..09040321894 100644 --- a/src/shogun/machine/KernelStructuredOutputMachine.cpp +++ b/src/shogun/machine/KernelStructuredOutputMachine.cpp @@ -20,10 +20,9 @@ CKernelStructuredOutputMachine::CKernelStructuredOutputMachine() CKernelStructuredOutputMachine::CKernelStructuredOutputMachine( CStructuredModel* model, - CLossFunction* loss, CStructuredLabels* labs, CKernel* kernel) -: CStructuredOutputMachine(model, loss, labs), m_kernel(NULL) +: CStructuredOutputMachine(model, labs), m_kernel(NULL) { set_kernel(kernel); register_parameters(); diff --git a/src/shogun/machine/KernelStructuredOutputMachine.h b/src/shogun/machine/KernelStructuredOutputMachine.h index 12270c277c4..f4ac1c57dfe 100644 --- a/src/shogun/machine/KernelStructuredOutputMachine.h +++ b/src/shogun/machine/KernelStructuredOutputMachine.h @@ -27,11 +27,10 @@ class CKernelStructuredOutputMachine : public CStructuredOutputMachine /** standard constructor * * @param model structured model with application specific functions - * @param loss loss function * @param labs structured labels * @param kernel kernel */ - CKernelStructuredOutputMachine(CStructuredModel* model, CLossFunction* loss, CStructuredLabels* labs, CKernel* kernel); + CKernelStructuredOutputMachine(CStructuredModel* model, CStructuredLabels* labs, CKernel* kernel); /** destructor */ virtual ~CKernelStructuredOutputMachine(); diff --git a/src/shogun/machine/LinearStructuredOutputMachine.cpp b/src/shogun/machine/LinearStructuredOutputMachine.cpp index e46f975c464..eb2e17cbbc3 100644 --- a/src/shogun/machine/LinearStructuredOutputMachine.cpp +++ b/src/shogun/machine/LinearStructuredOutputMachine.cpp @@ -21,9 +21,8 @@ CLinearStructuredOutputMachine::CLinearStructuredOutputMachine() CLinearStructuredOutputMachine::CLinearStructuredOutputMachine( CStructuredModel* model, - CLossFunction* loss, CStructuredLabels* labs) -: CStructuredOutputMachine(model, loss, labs) +: CStructuredOutputMachine(model, labs) { register_parameters(); } diff --git a/src/shogun/machine/LinearStructuredOutputMachine.h b/src/shogun/machine/LinearStructuredOutputMachine.h index 12348a2384f..60b2b219051 100644 --- a/src/shogun/machine/LinearStructuredOutputMachine.h +++ b/src/shogun/machine/LinearStructuredOutputMachine.h @@ -27,10 +27,9 @@ class CLinearStructuredOutputMachine : public CStructuredOutputMachine /** standard constructor * * @param model structured model with application specific functions - * @param loss loss function * @param labs structured labels */ - CLinearStructuredOutputMachine(CStructuredModel* model, CLossFunction* loss, CStructuredLabels* labs); + CLinearStructuredOutputMachine(CStructuredModel* model, CStructuredLabels* labs); /** destructor */ virtual ~CLinearStructuredOutputMachine(); diff --git a/src/shogun/machine/StructuredOutputMachine.cpp b/src/shogun/machine/StructuredOutputMachine.cpp index a9533ad350b..9f145ba3326 100644 --- a/src/shogun/machine/StructuredOutputMachine.cpp +++ b/src/shogun/machine/StructuredOutputMachine.cpp @@ -13,19 +13,17 @@ using namespace shogun; CStructuredOutputMachine::CStructuredOutputMachine() -: CMachine(), m_model(NULL), m_loss(NULL) +: CMachine(), m_model(NULL) { register_parameters(); } CStructuredOutputMachine::CStructuredOutputMachine( CStructuredModel* model, - CLossFunction* loss, CStructuredLabels* labs) -: CMachine(), m_model(model), m_loss(loss) +: CMachine(), m_model(model) { SG_REF(m_model); - SG_REF(m_loss); set_labels(labs); register_parameters(); } @@ -33,7 +31,6 @@ CStructuredOutputMachine::CStructuredOutputMachine( CStructuredOutputMachine::~CStructuredOutputMachine() { SG_UNREF(m_model); - SG_UNREF(m_loss); } void CStructuredOutputMachine::set_model(CStructuredModel* model) @@ -49,23 +46,9 @@ CStructuredModel* CStructuredOutputMachine::get_model() const return m_model; } -void CStructuredOutputMachine::set_loss(CLossFunction* loss) -{ - SG_UNREF(m_loss); - SG_REF(loss); - m_loss = loss; -} - -CLossFunction* CStructuredOutputMachine::get_loss() const -{ - SG_REF(m_loss); - return m_loss; -} - void CStructuredOutputMachine::register_parameters() { SG_ADD((CSGObject**)&m_model, "m_model", "Structured model", MS_NOT_AVAILABLE); - SG_ADD((CSGObject**)&m_loss, "m_loss", "Structured loss", MS_NOT_AVAILABLE); } void CStructuredOutputMachine::set_labels(CLabels* lab) diff --git a/src/shogun/machine/StructuredOutputMachine.h b/src/shogun/machine/StructuredOutputMachine.h index f4a9a52d1c0..c0a3211638a 100644 --- a/src/shogun/machine/StructuredOutputMachine.h +++ b/src/shogun/machine/StructuredOutputMachine.h @@ -14,13 +14,11 @@ #include #include #include -#include #include namespace shogun { class CStructuredModel; -class CLossFunction; /** TODO doc */ class CStructuredOutputMachine : public CMachine @@ -35,10 +33,9 @@ class CStructuredOutputMachine : public CMachine /** standard constructor * * @param model structured model with application specific functions - * @param loss loss function * @param labs structured labels */ - CStructuredOutputMachine(CStructuredModel* model, CLossFunction* loss, CStructuredLabels* labs); + CStructuredOutputMachine(CStructuredModel* model, CStructuredLabels* labs); /** destructor */ virtual ~CStructuredOutputMachine(); @@ -55,18 +52,6 @@ class CStructuredOutputMachine : public CMachine */ CStructuredModel* get_model() const; - /** set loss function - * - * @param loss loss function to set - */ - void set_loss(CLossFunction* loss); - - /** get loss function - * - * @return loss function - */ - CLossFunction* get_loss() const; - /** @return object name */ virtual const char* get_name() const { @@ -87,9 +72,6 @@ class CStructuredOutputMachine : public CMachine /** the model that contains the application dependent modules */ CStructuredModel* m_model; - /** the general loss function */ - CLossFunction* m_loss; - }; /* class CStructuredOutputMachine */ diff --git a/src/shogun/structure/CCSOSVM.cpp b/src/shogun/structure/CCSOSVM.cpp index fe4317c0bff..f524b2b9cfc 100644 --- a/src/shogun/structure/CCSOSVM.cpp +++ b/src/shogun/structure/CCSOSVM.cpp @@ -24,7 +24,7 @@ CCCSOSVM::CCCSOSVM() } CCCSOSVM::CCCSOSVM(CStructuredModel* model, SGVector w) - : CLinearStructuredOutputMachine(model, NULL, model->get_labels()) + : CLinearStructuredOutputMachine(model, model->get_labels()) { init(); diff --git a/src/shogun/structure/DualLibQPBMSOSVM.cpp b/src/shogun/structure/DualLibQPBMSOSVM.cpp index 3909bce0fda..51d2544fa40 100644 --- a/src/shogun/structure/DualLibQPBMSOSVM.cpp +++ b/src/shogun/structure/DualLibQPBMSOSVM.cpp @@ -23,11 +23,10 @@ CDualLibQPBMSOSVM::CDualLibQPBMSOSVM() CDualLibQPBMSOSVM::CDualLibQPBMSOSVM( CStructuredModel* model, - CLossFunction* loss, CStructuredLabels* labs, float64_t _lambda, SGVector< float64_t > W) - : CLinearStructuredOutputMachine(model, loss, labs) + : CLinearStructuredOutputMachine(model, labs) { set_TolRel(0.001); set_TolAbs(0.0); diff --git a/src/shogun/structure/DualLibQPBMSOSVM.h b/src/shogun/structure/DualLibQPBMSOSVM.h index b9b1babf511..63cb6bcf5e0 100644 --- a/src/shogun/structure/DualLibQPBMSOSVM.h +++ b/src/shogun/structure/DualLibQPBMSOSVM.h @@ -53,14 +53,12 @@ class CDualLibQPBMSOSVM : public CLinearStructuredOutputMachine /** constructor * * @param model Structured Model - * @param loss Loss function * @param labs Structured labels * @param _lambda Regularization constant * @param W initial solution of weight vector */ CDualLibQPBMSOSVM( CStructuredModel* model, - CLossFunction* loss, CStructuredLabels* labs, float64_t _lambda, SGVector< float64_t > W=0); diff --git a/src/shogun/structure/PrimalMosekSOSVM.cpp b/src/shogun/structure/PrimalMosekSOSVM.cpp index 361a3c71a1a..fa68f141a76 100644 --- a/src/shogun/structure/PrimalMosekSOSVM.cpp +++ b/src/shogun/structure/PrimalMosekSOSVM.cpp @@ -19,7 +19,7 @@ using namespace shogun; CPrimalMosekSOSVM::CPrimalMosekSOSVM() : CLinearStructuredOutputMachine(), - po_value(0.0) + m_surrogate_loss(NULL), po_value(0.0) { init(); } @@ -28,9 +28,10 @@ CPrimalMosekSOSVM::CPrimalMosekSOSVM( CStructuredModel* model, CLossFunction* loss, CStructuredLabels* labs) -: CLinearStructuredOutputMachine(model, loss, labs), - po_value(0.0) +: CLinearStructuredOutputMachine(model, labs), + m_surrogate_loss(loss), po_value(0.0) { + SG_REF(m_surrogate_loss); init(); } @@ -39,12 +40,14 @@ void CPrimalMosekSOSVM::init() SG_ADD(&m_slacks, "m_slacks", "Slacks vector", MS_NOT_AVAILABLE); //FIXME model selection available for SO machines SG_ADD(&m_regularization, "m_regularization", "Regularization constant", MS_NOT_AVAILABLE); + SG_ADD((CSGObject**)&m_surrogate_loss, "m_surrogate_loss", "Surrogate loss", MS_NOT_AVAILABLE); m_regularization = 1.0; } CPrimalMosekSOSVM::~CPrimalMosekSOSVM() { + SG_UNREF(m_surrogate_loss); } bool CPrimalMosekSOSVM::train_machine(CFeatures* data) @@ -143,7 +146,7 @@ bool CPrimalMosekSOSVM::train_machine(CFeatures* data) result = m_model->argmax(m_w, i); // Compute the loss associated with the prediction - slack = m_loss->loss( compute_loss_arg(result) ); + slack = m_surrogate_loss->loss( compute_loss_arg(result) ); cur_list = (CList*) results->get_element(i); // Update the list of constraints @@ -157,7 +160,7 @@ bool CPrimalMosekSOSVM::train_machine(CFeatures* data) while ( cur_res != NULL ) { max_slack = CMath::max(max_slack, - m_loss->loss( compute_loss_arg(cur_res) )); + m_surrogate_loss->loss( compute_loss_arg(cur_res) )); SG_UNREF(cur_res); cur_res = (CResultSet*) cur_list->get_next_element(); @@ -277,4 +280,17 @@ void CPrimalMosekSOSVM::set_regularization(float64_t C) m_regularization = C; } +void CPrimalMosekSOSVM::set_surrogate_loss(CLossFunction* loss) +{ + SG_UNREF(m_surrogate_loss); + SG_REF(loss); + m_surrogate_loss = loss; +} + +CLossFunction* CPrimalMosekSOSVM::get_surrogate_loss() const +{ + SG_REF(m_surrogate_loss); + return m_surrogate_loss; +} + #endif /* USE_MOSEK */ diff --git a/src/shogun/structure/PrimalMosekSOSVM.h b/src/shogun/structure/PrimalMosekSOSVM.h index 37c55b3bd58..9c378114d14 100644 --- a/src/shogun/structure/PrimalMosekSOSVM.h +++ b/src/shogun/structure/PrimalMosekSOSVM.h @@ -74,6 +74,18 @@ class CPrimalMosekSOSVM : public CLinearStructuredOutputMachine */ void set_regularization(float64_t C); + /** set loss function + * + * @param loss loss function to set + */ + void set_surrogate_loss(CLossFunction* loss); + + /** get loss function + * + * @return loss function + */ + CLossFunction* get_surrogate_loss() const; + protected: /** train primal SO-SVM * @@ -125,6 +137,9 @@ class CPrimalMosekSOSVM : public CLinearStructuredOutputMachine bool add_constraint(CMosek* mosek, CResultSet* result, index_t con_idx, index_t train_idx) const; private: + /** the surrogate loss */ + CLossFunction* m_surrogate_loss; + /** slack variables associated to each training example */ SGVector< float64_t > m_slacks;