-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
NearestCentroid.cpp
118 lines (96 loc) · 2.95 KB
/
NearestCentroid.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2012 Philippe Tillet
*/
#include <shogun/classifier/NearestCentroid.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/features/Features.h>
#include <shogun/features/FeatureTypes.h>
namespace shogun{
CNearestCentroid::CNearestCentroid() : CDistanceMachine()
{
init();
}
CNearestCentroid::CNearestCentroid(CDistance* d, CLabels* trainlab) : CDistanceMachine()
{
init();
ASSERT(d)
ASSERT(trainlab)
set_distance(d);
set_labels(trainlab);
}
CNearestCentroid::~CNearestCentroid()
{
if(m_is_trained)
distance->remove_lhs();
else
delete m_centroids;
}
void CNearestCentroid::init()
{
m_shrinking=0;
m_is_trained=false;
m_centroids = new CDenseFeatures<float64_t>();
}
bool CNearestCentroid::train_machine(CFeatures* data)
{
ASSERT(m_labels)
ASSERT(m_labels->get_label_type() == LT_MULTICLASS)
ASSERT(distance)
ASSERT( data->get_feature_class() == C_DENSE)
if (data)
{
if (m_labels->get_num_labels() != data->get_num_vectors())
SG_ERROR("Number of training vectors does not match number of labels\n")
distance->init(data, data);
}
else
{
data = distance->get_lhs();
}
index_t num_vectors = data->get_num_vectors();
index_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
index_t num_feats = ((CDenseFeatures<float64_t>*) data)->get_num_features();
SGMatrix<float64_t> centroids(num_feats,num_classes);
centroids.zero();
m_centroids->set_num_features(num_feats);
m_centroids->set_num_vectors(num_classes);
int64_t* num_per_class = new int64_t[num_classes];
for (index_t i=0 ; i<num_classes ; i++)
{
num_per_class[i]=0;
}
for (index_t idx=0 ; idx<num_vectors ; idx++)
{
index_t current_len;
bool current_free;
index_t current_class = ((CMulticlassLabels*) m_labels)->get_label(idx);
float64_t* target = centroids.matrix + num_feats*current_class;
float64_t* current = ((CDenseFeatures<float64_t>*)data)->get_feature_vector(idx,current_len,current_free);
SGVector<float64_t>::add(target,1.0,target,1.0,current,current_len);
num_per_class[current_class]++;
((CDenseFeatures<float64_t>*)data)->free_feature_vector(current, current_len, current_free);
}
for (index_t i=0 ; i<num_classes ; i++)
{
float64_t* target = centroids.matrix + num_feats*i;
index_t total = num_per_class[i];
float64_t scale = 0;
if(total>1)
scale = 1.0/((float64_t)(total-1));
else
scale = 1.0/(float64_t)total;
SGVector<float64_t>::scale_vector(scale,target,num_feats);
}
m_centroids->free_feature_matrix();
m_centroids->set_feature_matrix(centroids);
m_is_trained=true;
distance->init(m_centroids,distance->get_rhs());
SG_FREE(num_per_class);
return true;
}
}