forked from sjwhitworth/golearn
/
linearsvc.go
217 lines (191 loc) · 5.64 KB
/
linearsvc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
package linear_models
import "C"
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"unsafe"
)
// LinearSVCParams represnts all available LinearSVC options.
//
// SolverKind: can be linear_models.L2_L1LOSS_SVC_DUAL,
// L2R_L2LOSS_SVC_DUAL, L2R_L2LOSS_SVC, L1R_L2LOSS_SVC.
// It must be set via SetKindFromStrings.
//
// ClassWeights describes how each class is weighted, and can
// be used in class-imabalanced scenarios. If this is nil, then
// all classes will be weighted the same unless WeightClassesAutomatically
// is True.
//
// C is a float64 represnenting the misclassification penalty.
//
// Eps is a float64 convergence threshold.
//
// Dual indicates whether the solution is primary or dual.
type LinearSVCParams struct {
SolverType int
ClassWeights []float64
C float64
Eps float64
WeightClassesAutomatically bool
Dual bool
}
// Copy return s a copy of these parameters
func (p *LinearSVCParams) Copy() *LinearSVCParams {
ret := &LinearSVCParams{
p.SolverType,
nil,
p.C,
p.Eps,
p.WeightClassesAutomatically,
p.Dual,
}
if p.ClassWeights != nil {
ret.ClassWeights = make([]float64, len(p.ClassWeights))
copy(ret.ClassWeights, p.ClassWeights)
}
return ret
}
// SetKindFromStrings configures the solver kind from strings.
// Penalty and Loss parameters can either be l1 or l2.
func (p *LinearSVCParams) SetKindFromStrings(loss, penalty string) error {
var ret error
p.SolverType = 0
// Loss validation
if loss == "l1" {
} else if loss == "l2" {
} else {
return fmt.Errorf("loss must be \"l1\" or \"l2\"")
}
// Penalty validation
if penalty == "l2" {
if loss == "l1" {
if !p.Dual {
ret = fmt.Errorf("Important: changed to dual form")
}
p.SolverType = L2R_L1LOSS_SVC_DUAL
p.Dual = true
} else {
if p.Dual {
p.SolverType = L2R_L2LOSS_SVC_DUAL
} else {
p.SolverType = L2R_L2LOSS_SVC
}
}
} else if penalty == "l1" {
if loss == "l2" {
if p.Dual {
ret = fmt.Errorf("Important: changed to primary form")
}
p.Dual = false
p.SolverType = L1R_L2LOSS_SVC
} else {
return fmt.Errorf("Must have L2 loss with L1 penalty")
}
} else {
return fmt.Errorf("Penalty must be \"l1\" or \"l2\"")
}
// Final validation
if p.SolverType == 0 {
return fmt.Errorf("Invalid parameter combination")
}
return ret
}
// convertToNativeFormat converts the LinearSVCParams given into a format
// for liblinear.
func (p *LinearSVCParams) convertToNativeFormat() *Parameter {
return NewParameter(p.SolverType, p.C, p.Eps)
}
// LinearSVC represents a linear support-vector classifier.
type LinearSVC struct {
param *Parameter
model *Model
Param *LinearSVCParams
}
// NewLinearSVC creates a new support classifier.
//
// loss and penalty: see LinearSVCParams#SetKindFromString
//
// dual: see LinearSVCParams
//
// eps: see LinearSVCParams
//
// C: see LinearSVCParams
func NewLinearSVC(loss, penalty string, dual bool, C float64, eps float64) (*LinearSVC, error) {
// Convert and check parameters
params := &LinearSVCParams{0, nil, C, eps, false, dual}
err := params.SetKindFromStrings(loss, penalty)
if err != nil {
return nil, err
}
return NewLinearSVCFromParams(params)
}
// NewLinearSVCFromParams constructs a LinearSVC from the given LinearSVCParams structure.
func NewLinearSVCFromParams(params *LinearSVCParams) (*LinearSVC, error) {
// Construct model
lr := LinearSVC{}
lr.param = params.convertToNativeFormat()
lr.Param = params
lr.model = nil
return &lr, nil
}
// Fit automatically weights the class vector (if configured to do so)
// converts the FixedDataGrid into the right format and trains the model.
func (lr *LinearSVC) Fit(X base.FixedDataGrid) error {
var weightVec []float64
var weightClasses []C.int
// Creates the class weighting
if lr.Param.ClassWeights == nil {
if lr.Param.WeightClassesAutomatically {
weightVec = generateClassWeightVectorFromDist(X)
} else {
weightVec = generateClassWeightVectorFromFixed(X)
}
} else {
weightVec = lr.Param.ClassWeights
}
weightClasses = make([]C.int, len(weightVec))
for i := range weightVec {
weightClasses[i] = C.int(i)
}
// Convert the problem
problemVec := convertInstancesToProblemVec(X)
labelVec := convertInstancesToLabelVec(X)
// Train
prob := NewProblem(problemVec, labelVec, 0)
lr.param.c_param.nr_weight = C.int(len(weightVec))
lr.param.c_param.weight_label = &(weightClasses[0])
lr.param.c_param.weight = (*C.double)(unsafe.Pointer(&weightVec[0]))
// lr.param.weights = (*C.double)unsafe.Pointer(&(weightVec[0]));
lr.model = Train(prob, lr.param)
return nil
}
// Predict issues predictions from a trained LinearSVC.
func (lr *LinearSVC) Predict(X base.FixedDataGrid) (base.FixedDataGrid, error) {
// Only support 1 class Attribute
classAttrs := X.AllClassAttributes()
if len(classAttrs) != 1 {
panic(fmt.Sprintf("%d Wrong number of classes", len(classAttrs)))
}
// Generate return structure
ret := base.GeneratePredictionVector(X)
classAttrSpecs := base.ResolveAttributes(ret, classAttrs)
// Retrieve numeric non-class Attributes
numericAttrs := base.NonClassFloatAttributes(X)
numericAttrSpecs := base.ResolveAttributes(X, numericAttrs)
// Allocate row storage
row := make([]float64, len(numericAttrSpecs))
X.MapOverRows(numericAttrSpecs, func(rowBytes [][]byte, rowNo int) (bool, error) {
for i, r := range rowBytes {
row[i] = base.UnpackBytesToFloat(r)
}
val := Predict(lr.model, row)
vals := base.PackFloatToBytes(val)
ret.Set(classAttrSpecs[0], rowNo, vals)
return true, nil
})
return ret, nil
}
// String return a humaan-readable version.
func (lr *LinearSVC) String() string {
return "LogisticSVC"
}