forked from sjwhitworth/golearn
/
attributes.go
267 lines (238 loc) · 8.12 KB
/
attributes.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
package base
import "fmt"
import "strconv"
const (
// CategoricalType is for Attributes which represent values distinctly.
CategoricalType = iota
// Float64Type should be replaced with a FractionalNumeric type [DEPRECATED].
Float64Type
)
// Attribute Attributes disambiguate columns of the feature matrix and declare their types.
type Attribute interface {
// Returns the general characterstics of this Attribute .
// to avoid the overhead of casting
GetType() int
// Returns the human-readable name of this Attribute.
GetName() string
// Sets the human-readable name of this Attribute.
SetName(string)
// Gets a human-readable overview of this Attribute for debugging.
String() string
// Converts a value given from a human-readable string into a system
// representation. For example, a CategoricalAttribute with values
// ["iris-setosa", "iris-virginica"] would return the float64
// representation of 0 when given "iris-setosa".
GetSysValFromString(string) float64
// Converts a given value from a system representation into a human
// representation. For example, a CategoricalAttribute with values
// ["iris-setosa", "iris-viriginica"] might return "iris-setosa"
// when given 0.0 as the argument.
GetStringFromSysVal(float64) string
// Tests for equality with another Attribute. Other Attributes are
// considered equal if:
// * They have the same type (i.e. FloatAttribute <> CategoricalAttribute)
// * They have the same name
// * If applicable, they have the same categorical values (though not
// necessarily in the same order).
Equals(Attribute) bool
}
// FloatAttribute is an implementation which stores floating point
// representations of numbers.
type FloatAttribute struct {
Name string
Precision int
}
// NewFloatAttribute returns a new FloatAttribute with a default
// precision of 2 decimal places
func NewFloatAttribute() *FloatAttribute {
return &FloatAttribute{"", 2}
}
// Equals tests a FloatAttribute for equality with another Attribute.
//
// Returns false if the other Attribute has a different name
// or if the other Attribute is not a FloatAttribute.
func (Attr *FloatAttribute) Equals(other Attribute) bool {
// Check whether this FloatAttribute is equal to another
_, ok := other.(*FloatAttribute)
if !ok {
// Not the same type, so can't be equal
return false
}
if Attr.GetName() != other.GetName() {
return false
}
return true
}
// GetName returns this FloatAttribute's human-readable name.
func (Attr *FloatAttribute) GetName() string {
return Attr.Name
}
// SetName sets this FloatAttribute's human-readable name.
func (Attr *FloatAttribute) SetName(name string) {
Attr.Name = name
}
// GetType returns Float64Type.
func (Attr *FloatAttribute) GetType() int {
return Float64Type
}
// String returns a human-readable summary of this Attribute.
// e.g. "FloatAttribute(Sepal Width)"
func (Attr *FloatAttribute) String() string {
return fmt.Sprintf("FloatAttribute(%s)", Attr.Name)
}
// CheckSysValFromString confirms whether a given rawVal can
// be converted into a valid system representation.
func (Attr *FloatAttribute) CheckSysValFromString(rawVal string) (float64, error) {
f, err := strconv.ParseFloat(rawVal, 64)
if err != nil {
return 0.0, err
}
return f, nil
}
// GetSysValFromString parses the given rawVal string to a float64 and returns it.
//
// float64 happens to be a 1-to-1 mapping to the system representation.
// IMPORTANT: This function panic()s if rawVal is not a valid float.
// Use CheckSysValFromString to confirm.
func (Attr *FloatAttribute) GetSysValFromString(rawVal string) float64 {
f, err := strconv.ParseFloat(rawVal, 64)
if err != nil {
panic(err)
}
return f
}
// GetStringFromSysVal converts a given system value to to a string with two decimal
// places of precision [TODO: revise this and allow more precision].
func (Attr *FloatAttribute) GetStringFromSysVal(rawVal float64) string {
formatString := fmt.Sprintf("%%.%df", Attr.Precision)
return fmt.Sprintf(formatString, rawVal)
}
// GetSysVal returns the system representation of userVal.
//
// Because FloatAttribute represents float64 types, this
// just returns its argument.
func (Attr *FloatAttribute) GetSysVal(userVal float64) float64 {
return userVal
}
// GetUsrVal returns the user representation of sysVal.
//
// Because FloatAttribute represents float64 types, this
// just returns its argument.
func (Attr *FloatAttribute) GetUsrVal(sysVal float64) float64 {
return sysVal
}
// CategoricalAttribute is an Attribute implementation
// which stores discrete string values
// - useful for representing classes.
type CategoricalAttribute struct {
Name string
values []string
}
func NewCategoricalAttribute() *CategoricalAttribute {
return &CategoricalAttribute{
"",
make([]string, 0),
}
}
// GetName returns the human-readable name assigned to this attribute.
func (Attr *CategoricalAttribute) GetName() string {
return Attr.Name
}
// SetName sets the human-readable name on this attribute.
func (Attr *CategoricalAttribute) SetName(name string) {
Attr.Name = name
}
// GetType returns CategoricalType to avoid casting overhead.
func (Attr *CategoricalAttribute) GetType() int {
return CategoricalType
}
// GetSysVal returns the system representation of userVal as an index into the Values slice
// If the userVal can't be found, it returns -1.
func (Attr *CategoricalAttribute) GetSysVal(userVal string) float64 {
for idx, val := range Attr.values {
if val == userVal {
return float64(idx)
}
}
return -1
}
// GetUsrVal returns a human-readable representation of the given sysVal.
//
// IMPORTANT: this function doesn't check the boundaries of the array.
func (Attr *CategoricalAttribute) GetUsrVal(sysVal float64) string {
idx := int(sysVal)
return Attr.values[idx]
}
// GetSysValFromString returns the system representation of rawVal
// as an index into the Values slice. If rawVal is not inside
// the Values slice, it is appended.
//
// IMPORTANT: If no system representation yet exists, this functions adds it.
// If you need to determine whether rawVal exists: use GetSysVal and check
// for a -1 return value.
//
// Example: if the CategoricalAttribute contains the values ["iris-setosa",
// "iris-virginica"] and "iris-versicolor" is provided as the argument,
// the Values slide becomes ["iris-setosa", "iris-virginica", "iris-versicolor"]
// and 2.00 is returned as the system representation.
func (Attr *CategoricalAttribute) GetSysValFromString(rawVal string) float64 {
// Match in raw values
catIndex := -1
for i, s := range Attr.values {
if s == rawVal {
catIndex = i
break
}
}
if catIndex == -1 {
Attr.values = append(Attr.values, rawVal)
catIndex = len(Attr.values) - 1
}
return float64(catIndex)
}
// String returns a human-readable summary of this Attribute.
//
// Returns a string containing the list of human-readable values this
// CategoricalAttribute can take.
func (Attr *CategoricalAttribute) String() string {
return fmt.Sprintf("CategoricalAttribute(\"%s\", %s)", Attr.Name, Attr.values)
}
// GetStringFromSysVal returns a human-readable value from the given system-representation
// value val.
//
// IMPORTANT: This function calls panic() if the value is greater than
// the length of the array.
// TODO: Return a user-configurable default instead.
func (Attr *CategoricalAttribute) GetStringFromSysVal(val float64) string {
convVal := int(val)
if convVal >= len(Attr.values) {
panic(fmt.Sprintf("Out of range: %d in %d", convVal, len(Attr.values)))
}
return Attr.values[convVal]
}
// Equals checks equality against another Attribute.
//
// Two CategoricalAttributes are considered equal if they contain
// the same values and have the same name. Otherwise, this function
// returns false.
func (Attr *CategoricalAttribute) Equals(other Attribute) bool {
attribute, ok := other.(*CategoricalAttribute)
if !ok {
// Not the same type, so can't be equal
return false
}
if Attr.GetName() != attribute.GetName() {
return false
}
// Check that this CategoricalAttribute has the same
// values as the other, in the same order
if len(attribute.values) != len(Attr.values) {
return false
}
for i, a := range Attr.values {
if a != attribute.values[i] {
return false
}
}
return true
}