Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port classifier meta examples #4361

Merged
merged 1 commit into from Jul 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/FindMetaExamples.cmake
Expand Up @@ -35,6 +35,8 @@ function(get_excluded_meta_examples)
statistical_testing/quadratic_time_maximum_mean_discrepancy.sg
gaussian_process/classifier.sg
gaussian_process/regression.sg
binary/svmlin.sg
binary/svmsgd.sg
)
ENDIF()

Expand Down
2 changes: 1 addition & 1 deletion data
5 changes: 2 additions & 3 deletions examples/meta/src/binary/mpdsvm.sg
Expand Up @@ -5,16 +5,15 @@ File f_labels_test = csv_file("../../data/classifier_binary_2d_linear_labels_tes

Features feats_train = features(f_feats_train)
Features feats_test = features(f_feats_test)
BinaryLabels labels_train(f_labels_train)
Labels labels_train = labels(f_labels_train)
Labels labels_test = labels(f_labels_test)

Kernel gaussian = kernel("GaussianKernel", log_width=0.01)
gaussian.init(feats_train, feats_train)

real C=1.0
Machine svm = machine("MPDSVM", C1=C, C2=C, kernel=gaussian, epsilon=0.00001)
Machine svm = machine("MPDSVM", C1=C, C2=C, kernel=gaussian, epsilon=0.00001, labels=labels_train)

svm.set_labels(labels_train)
svm.train()

Labels labels_predict = svm.apply(feats_test)
Expand Down
29 changes: 0 additions & 29 deletions examples/undocumented/python/classifier_svmlin.py

This file was deleted.

28 changes: 0 additions & 28 deletions examples/undocumented/python/classifier_svmsgd.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/gpl
17 changes: 9 additions & 8 deletions src/shogun/classifier/svm/MPDSVM.cpp
Expand Up @@ -30,13 +30,13 @@ CMPDSVM::~CMPDSVM()

bool CMPDSVM::train_machine(CFeatures* data)
{
ASSERT(m_labels)
ASSERT(m_labels->get_label_type() == LT_BINARY)
auto labels = binary_labels(m_labels);

ASSERT(kernel)

if (data)
{
if (m_labels->get_num_labels() != data->get_num_vectors())
if (labels->get_num_labels() != data->get_num_vectors())
SG_ERROR("Number of training vectors does not match number of labels\n")
kernel->init(data, data);
}
Expand All @@ -48,7 +48,7 @@ bool CMPDSVM::train_machine(CFeatures* data)
const int64_t maxiter = 1L<<30;
//const bool nustop=false;
//const int32_t k=2;
const int32_t n=m_labels->get_num_labels();
const int32_t n = labels->get_num_labels();
ASSERT(n>0)
//const float64_t d = 1.0/n/nu; //NUSVC
const float64_t d = get_C1(); //CSVC
Expand Down Expand Up @@ -84,9 +84,9 @@ bool CMPDSVM::train_machine(CFeatures* data)
for (int32_t i=0; i<n; i++)
{
alphas[i]=0;
F[i]=((CBinaryLabels*) m_labels)->get_label(i);
F[i] = labels->get_label(i);
//F[i+n]=-1;
hessres[i]=((CBinaryLabels*) m_labels)->get_label(i);
hessres[i] = labels->get_label(i);
//hessres[i+n]=-1;
//dalphas[i]=F[i+n]*etas[1]; //NUSVC
dalphas[i]=-1; //CSVC
Expand Down Expand Up @@ -149,7 +149,8 @@ bool CMPDSVM::train_machine(CFeatures* data)
{
obj-=alphas[i];
for (int32_t j=0; j<n; j++)
obj+=0.5*((CBinaryLabels*) m_labels)->get_label(i)*((CBinaryLabels*) m_labels)->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j);
obj += 0.5 * labels->get_label(i) * labels->get_label(j) *
alphas[i] * alphas[j] * kernel->kernel(i, j);
}

SG_DEBUG("obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter)
Expand Down Expand Up @@ -250,7 +251,7 @@ bool CMPDSVM::train_machine(CFeatures* data)
if (alphas[i]>0)
{
//set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
set_alpha(j, alphas[i]*((CBinaryLabels*) m_labels)->get_label(i));
set_alpha(j, alphas[i] * labels->get_label(i));
set_support_vector(j, i);
j++;
}
Expand Down