-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
gr.go
76 lines (63 loc) · 2.34 KB
/
gr.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package trees
import (
"github.com/sjwhitworth/golearn/base"
"math"
)
//
// Information Gatio Ratio generator
//
// InformationGainRatioRuleGenerator generates DecisionTreeRules which
// maximise the InformationGain at each node.
type InformationGainRatioRuleGenerator struct {
}
// GenerateSplitRule returns a DecisionTreeRule which maximises information
// gain ratio considering every available Attribute.
//
// IMPORTANT: passing a base.Instances with no Attributes other than the class
// variable will panic()
func (r *InformationGainRatioRuleGenerator) GenerateSplitRule(f base.FixedDataGrid) *DecisionTreeRule {
attrs := f.AllAttributes()
classAttrs := f.AllClassAttributes()
candidates := base.AttributeDifferenceReferences(attrs, classAttrs)
return r.GetSplitRuleFromSelection(candidates, f)
}
// GetSplitRuleFromSelection returns the DecisionRule which maximizes information gain,
// considering only a subset of Attributes.
//
// IMPORTANT: passing a zero-length consideredAttributes parameter will panic()
func (r *InformationGainRatioRuleGenerator) GetSplitRuleFromSelection(consideredAttributes []base.Attribute, f base.FixedDataGrid) *DecisionTreeRule {
var selectedAttribute base.Attribute
var selectedVal float64
// Parameter check
if len(consideredAttributes) == 0 {
panic("More Attributes should be considered")
}
// Next step is to compute the information gain at this node
// for each randomly chosen attribute, and pick the one
// which maximises it
maxRatio := math.Inf(-1)
// Compute the base entropy
classDist := base.GetClassDistribution(f)
baseEntropy := getBaseEntropy(classDist)
// Compute the information gain for each attribute
for _, s := range consideredAttributes {
var informationGain float64
var localEntropy float64
var splitVal float64
if fAttr, ok := s.(*base.FloatAttribute); ok {
localEntropy, splitVal = getNumericAttributeEntropy(f, fAttr)
} else {
proposedClassDist := base.GetClassDistributionAfterSplit(f, s)
localEntropy = getSplitEntropy(proposedClassDist)
}
informationGain = baseEntropy - localEntropy
informationGainRatio := informationGain / localEntropy
if informationGainRatio > maxRatio {
maxRatio = informationGainRatio
selectedAttribute = s
selectedVal = splitVal
}
}
// Pick the one which maximises IG
return &DecisionTreeRule{selectedAttribute, selectedVal}
}