Skip to content

Commit

Permalink
Fix for 3-class training of multiclass libsvm
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Nov 1, 2012
1 parent 276a2ff commit a55b25e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
4 changes: 3 additions & 1 deletion src/NEWS
Expand Up @@ -13,11 +13,13 @@
- New concept of artificial data generator classes: Based on streaming
features. First implemented instance is CMeanShiftDataGenerator.
Use above new concepts to get non-streaming data if desired.
- Accelerated projected gradient multiclass logistic regression classifier by Sergey Lisitsyn
- Accelerated projected gradient multiclass logistic regression classifier by Sergey Lisitsyn
* Bugfixes:
- Fix for shallow copy of gaussian kernel by Matt Aasted
- Fixed a bug when using StringFeatures along with kernel machines in
cross-validation which cause an assertion error. Thanks to Eric (yoo)!
- Fix for 3-class case training of MulticlassLibSVM reported by Arya Iranmehr
suggested by Oksana Bayda
* Cleanup and API Changes:
- SGString and SGStringList are now based on SGReferencedData
- "confidences" in context of CLabel and subclasses are now "values"
Expand Down
35 changes: 21 additions & 14 deletions src/shogun/multiclass/MulticlassLibSVM.cpp
Expand Up @@ -103,6 +103,7 @@ bool CMulticlassLibSVM::train_machine(CFeatures* data)
model->nr_class, num_classes);
}
ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef));
SG_PRINT("Num classes = %d\n",num_classes);
create_multiclass_svm(num_classes);

int32_t* offsets=SG_MALLOC(int32_t, num_classes);
Expand Down Expand Up @@ -153,24 +154,30 @@ bool CMulticlassLibSVM::train_machine(CFeatures* data)

int32_t idx=0;

if (sgn>0)
if (num_classes > 3)
{
for (k=0; k<model->label[i]; k++)
idx+=num_classes-k-1;

for (l=model->label[i]+1; l<model->label[j]; l++)
idx++;
if (sgn>0)
{
for (k=0; k<model->label[i]; k++)
idx+=num_classes-k-1;

for (l=model->label[i]+1; l<model->label[j]; l++)
idx++;
}
else
{
for (k=0; k<model->label[j]; k++)
idx+=num_classes-k-1;

for (l=model->label[j]+1; l<model->label[i]; l++)
idx++;
}
}
else
{
for (k=0; k<model->label[j]; k++)
idx+=num_classes-k-1;

for (l=model->label[j]+1; l<model->label[i]; l++)
idx++;
idx = model->label[j]+model->label[i] - 3;
}


//
// if (sgn>0)
// idx=((num_classes-1)*model->label[i]+model->label[j])/2;
// else
Expand All @@ -181,7 +188,7 @@ bool CMulticlassLibSVM::train_machine(CFeatures* data)
s, num_sv, model->l, bias, model->label[i],
model->label[j], idx);

set_svm(idx, svm);
REQUIRE(set_svm(idx, svm),"SVM set failed");
s++;
}
}
Expand Down

0 comments on commit a55b25e

Please sign in to comment.