-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
ContingencyTableEvaluation.cpp
121 lines (110 loc) · 2.76 KB
/
ContingencyTableEvaluation.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Written (W) 2011 Sergey Lisitsyn
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
*/
#include <shogun/evaluation/ContingencyTableEvaluation.h>
#include <shogun/labels/BinaryLabels.h>
using namespace shogun;
float64_t CContingencyTableEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
{
ASSERT(predicted->get_label_type()==LT_BINARY)
ASSERT(ground_truth->get_label_type()==LT_BINARY)
/* commented out: what if a machine only returns +1 in apply() ??
* Heiko Strathamn */
// predicted->ensure_valid();
ground_truth->ensure_valid();
compute_scores((CBinaryLabels*)predicted,(CBinaryLabels*)ground_truth);
switch (m_type)
{
case ACCURACY:
return get_accuracy();
case ERROR_RATE:
return get_error_rate();
case BAL:
return get_BAL();
case WRACC:
return get_WRACC();
case F1:
return get_F1();
case CROSS_CORRELATION:
return get_cross_correlation();
case RECALL:
return get_recall();
case PRECISION:
return get_precision();
case SPECIFICITY:
return get_specificity();
case CUSTOM:
return get_custom_score();
}
SG_NOTIMPLEMENTED
return 42;
}
EEvaluationDirection CContingencyTableEvaluation::get_evaluation_direction() const
{
switch (m_type)
{
case ACCURACY:
return ED_MAXIMIZE;
case ERROR_RATE:
return ED_MINIMIZE;
case BAL:
return ED_MINIMIZE;
case WRACC:
return ED_MAXIMIZE;
case F1:
return ED_MAXIMIZE;
case CROSS_CORRELATION:
return ED_MAXIMIZE;
case RECALL:
return ED_MAXIMIZE;
case PRECISION:
return ED_MAXIMIZE;
case SPECIFICITY:
return ED_MAXIMIZE;
case CUSTOM:
return get_custom_direction();
default:
SG_NOTIMPLEMENTED
}
return ED_MINIMIZE;
}
void CContingencyTableEvaluation::compute_scores(CBinaryLabels* predicted, CBinaryLabels* ground_truth)
{
ASSERT(ground_truth->get_label_type() == LT_BINARY)
ASSERT(predicted->get_label_type() == LT_BINARY)
if (predicted->get_num_labels()!=ground_truth->get_num_labels())
{
SG_ERROR("%s::compute_scores(): Number of predicted labels (%d) is not "
"equal to number of ground truth labels (%d)!\n", get_name(),
predicted->get_num_labels(), ground_truth->get_num_labels());
}
m_TP = 0.0;
m_FP = 0.0;
m_TN = 0.0;
m_FN = 0.0;
m_N = predicted->get_num_labels();
for (int i=0; i<predicted->get_num_labels(); i++)
{
if (ground_truth->get_label(i)==1)
{
if (predicted->get_label(i)==1)
m_TP += 1.0;
else
m_FN += 1.0;
}
else
{
if (predicted->get_label(i)==1)
m_FP += 1.0;
else
m_TN += 1.0;
}
}
m_computed = true;
}