/
float.go
141 lines (132 loc) · 4.56 KB
/
float.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package filters
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
)
// FloatConvertFilters convert a given DataGrid into one which
// only contains BinaryAttributes.
//
// FloatAttributes are discretised into either 0 (if the value is 0)
// or 1 (if the value is not 0).
//
// CategoricalAttributes are discretised into one or more new
// BinaryAttributes.
type FloatConvertFilter struct {
attrs []base.Attribute
converted []base.FilteredAttribute
twoValuedCategoricalAttributes map[base.Attribute]bool // Two-valued categorical Attributes
nValuedCategoricalAttributeMap map[base.Attribute]map[uint64]base.Attribute
}
// NewFloatConvertFilter creates a blank FloatConvertFilter
func NewFloatConvertFilter() *FloatConvertFilter {
ret := &FloatConvertFilter{
make([]base.Attribute, 0),
make([]base.FilteredAttribute, 0),
make(map[base.Attribute]bool),
make(map[base.Attribute]map[uint64]base.Attribute),
}
return ret
}
// AddAttribute adds a new Attribute to this Filter
func (f *FloatConvertFilter) AddAttribute(a base.Attribute) error {
f.attrs = append(f.attrs, a)
return nil
}
// GetAttributesAfterFiltering returns the Attributes previously computed via Train()
func (f *FloatConvertFilter) GetAttributesAfterFiltering() []base.FilteredAttribute {
return f.converted
}
// String gets a human-readable string
func (f *FloatConvertFilter) String() string {
return fmt.Sprintf("FloatConvertFilter(%d Attribute(s))", len(f.attrs))
}
// Transform converts the given byte sequence using the old Attribute into the new
// byte sequence.
func (f *FloatConvertFilter) Transform(a base.Attribute, n base.Attribute, attrBytes []byte) []byte {
ret := make([]byte, 8)
// Check for CategoricalAttribute
if _, ok := a.(*base.CategoricalAttribute); ok {
// Unpack byte value
val := base.UnpackBytesToU64(attrBytes)
// If it's a two-valued one, check for non-zero
if f.twoValuedCategoricalAttributes[a] {
if val > 0 {
ret = base.PackFloatToBytes(1.0)
} else {
ret = base.PackFloatToBytes(0.0)
}
} else if an, ok := f.nValuedCategoricalAttributeMap[a]; ok {
// If it's an n-valued one, check the new Attribute maps onto
// the unpacked value
if af, ok := an[val]; ok {
if af.Equals(n) {
ret = base.PackFloatToBytes(1.0)
} else {
ret = base.PackFloatToBytes(0.0)
}
} else {
panic("Categorical value not defined!")
}
} else {
panic(fmt.Sprintf("Not a recognised Attribute %v", a))
}
} else if _, ok := a.(*base.FloatAttribute); ok {
// Binary: just return the original value
ret = attrBytes
} else if _, ok := a.(*base.BinaryAttribute); ok {
// Float: check for non-zero
if attrBytes[0] > 0 {
ret = base.PackFloatToBytes(1.0)
} else {
ret = base.PackFloatToBytes(0.0)
}
} else {
panic(fmt.Sprintf("Unrecognised Attribute: %v", a))
}
return ret
}
// Train converts the Attributes into equivalently named FloatAttributes,
// leaves FloatAttributes unmodified and processes
// CategoricalAttributes as follows.
//
// If the CategoricalAttribute has two values, one of them is
// designated 0.0 and the other 1.0, and a single identically-named
// FloatAttribute is returned.
//
// If the CategoricalAttribute has more than two (n) values, the Filter
// generates n FloatAttributes and sets each of them if the value's observed.
func (f *FloatConvertFilter) Train() error {
for _, a := range f.attrs {
if ac, ok := a.(*base.CategoricalAttribute); ok {
vals := ac.GetValues()
if len(vals) <= 2 {
nAttr := base.NewFloatAttribute(ac.GetName())
fAttr := base.FilteredAttribute{ac, nAttr}
f.converted = append(f.converted, fAttr)
f.twoValuedCategoricalAttributes[a] = true
} else {
if _, ok := f.nValuedCategoricalAttributeMap[a]; !ok {
f.nValuedCategoricalAttributeMap[a] = make(map[uint64]base.Attribute)
}
for i := uint64(0); i < uint64(len(vals)); i++ {
v := vals[i]
newName := fmt.Sprintf("%s_%s", ac.GetName(), v)
newAttr := base.NewFloatAttribute(newName)
fAttr := base.FilteredAttribute{ac, newAttr}
f.converted = append(f.converted, fAttr)
f.nValuedCategoricalAttributeMap[a][i] = newAttr
}
}
} else if ab, ok := a.(*base.FloatAttribute); ok {
fAttr := base.FilteredAttribute{ab, ab}
f.converted = append(f.converted, fAttr)
} else if af, ok := a.(*base.BinaryAttribute); ok {
newAttr := base.NewFloatAttribute(af.GetName())
fAttr := base.FilteredAttribute{af, newAttr}
f.converted = append(f.converted, fAttr)
} else {
return fmt.Errorf("Unsupported Attribute type: %v", a)
}
}
return nil
}