-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
random.go
132 lines (114 loc) · 3.61 KB
/
random.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package trees
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"math/rand"
)
// RandomTreeRuleGenerator is used to generate decision rules for Random Trees
type RandomTreeRuleGenerator struct {
Attributes int
internalRule InformationGainRuleGenerator
}
// GenerateSplitRule returns the best attribute out of those randomly chosen
// which maximises Information Gain
func (r *RandomTreeRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
var consideredAttributes []base.Attribute
// First step is to generate the random attributes that we'll consider
allAttributes := base.AttributeDifferenceReferences(f.AllAttributes(), f.AllClassAttributes())
maximumAttribute := len(allAttributes)
attrCounter := 0
for {
if len(consideredAttributes) >= r.Attributes {
break
}
selectedAttrIndex := rand.Intn(maximumAttribute)
selectedAttribute := allAttributes[selectedAttrIndex]
matched := false
for _, a := range consideredAttributes {
if a.Equals(selectedAttribute) {
matched = true
break
}
}
if matched {
continue
}
consideredAttributes = append(consideredAttributes, selectedAttribute)
attrCounter++
}
return r.internalRule.GetSplitRuleFromSelection(consideredAttributes, f)
}
// RandomTree builds a decision tree by considering a fixed number
// of randomly-chosen attributes at each node
type RandomTree struct {
base.BaseClassifier
Root *DecisionTreeNode
Rule *RandomTreeRuleGenerator
}
// NewRandomTree returns a new RandomTree which considers attrs randomly
// chosen attributes at each node.
func NewRandomTree(attrs int) *RandomTree {
return &RandomTree{
base.BaseClassifier{},
nil,
&RandomTreeRuleGenerator{
attrs,
InformationGainRuleGenerator{},
},
}
}
// Fit builds a RandomTree suitable for prediction
func (rt *RandomTree) Fit(from base.FixedDataGrid) error {
rt.Root = InferID3Tree(from, rt.Rule)
return nil
}
// Predict returns a set of Instances containing predictions
func (rt *RandomTree) Predict(from base.FixedDataGrid) (base.FixedDataGrid, error) {
return rt.Root.Predict(from)
}
// String returns a human-readable representation of this structure
func (rt *RandomTree) String() string {
return fmt.Sprintf("RandomTree(%s)", rt.Root)
}
// Prune removes nodes from the tree which are detrimental
// to determining the accuracy of the test set (with)
func (rt *RandomTree) Prune(with base.FixedDataGrid) {
rt.Root.Prune(with)
}
// Save outputs this model to a file
func (rt *RandomTree) Save(filePath string) error {
writer, err := base.CreateSerializedClassifierStub(filePath, rt.GetMetadata())
if err != nil {
return err
}
defer func() {
writer.Close()
}()
return rt.SaveWithPrefix(writer, "")
}
// SaveWithPrefix outputs this model to a file with a prefix.
func (rt *RandomTree) SaveWithPrefix(writer *base.ClassifierSerializer, prefix string) error {
return rt.Root.SaveWithPrefix(writer, prefix)
}
// Load retrieves this model from a file
func (rt *RandomTree) Load(filePath string) error {
reader, err := base.ReadSerializedClassifierStub(filePath)
if err != nil {
return err
}
return rt.LoadWithPrefix(reader, "")
}
// LoadWithPrefix retrives this random tree from disk with a given prefix.
func (rt *RandomTree) LoadWithPrefix(reader *base.ClassifierDeserializer, prefix string) error {
rt.Root = &DecisionTreeNode{}
return rt.Root.LoadWithPrefix(reader, prefix)
}
// GetMetadata returns required serialization metadata
func (rt *RandomTree) GetMetadata() base.ClassifierMetadataV1 {
return base.ClassifierMetadataV1{
FormatVersion: 1,
ClassifierName: "KNN",
ClassifierVersion: "1.0",
ClassifierMetadata: nil,
}
}