Skip to content

Commit

Permalink
Optimize weighted random forest
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanbressler committed May 6, 2014
1 parent 746a287 commit 5f21f6e
Showing 1 changed file with 45 additions and 18 deletions.
63 changes: 45 additions & 18 deletions wrftarget.go
Expand Up @@ -40,33 +40,60 @@ func (target *WRFTarget) SplitImpurity(l *[]int, r *[]int, m *[]int, allocs *Bes
return
}

//UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l as in learning from numerical variables.
//Here it just wraps SplitImpurity but it can be implemented to provide further optimization.
//UpdateSImpFromAllocs willl be called when splits are being built by moving cases from r to l
//to avoid recalulatign the entire split impurity.
func (target *WRFTarget) UpdateSImpFromAllocs(l *[]int, r *[]int, m *[]int, allocs *BestSplitAllocs, movedRtoL *[]int) (impurityDecrease float64) {
return target.SplitImpurity(l, r, m, allocs)
var cat, i int
lcounter := *allocs.LCounter
rcounter := *allocs.RCounter
for _, i = range *movedRtoL {

//most expensive statement:
cat = target.Geti(i)
lcounter[cat]++
rcounter[cat]--
//counter[target.Geti(i)]++

}
nl := float64(len(*l))
nr := float64(len(*r))
nm := 0.0

impurityDecrease = nl * target.ImpFromCounts(allocs.LCounter)
impurityDecrease += nr * target.ImpFromCounts(allocs.RCounter)
if m != nil && len(*m) > 0 {
nm = float64(len(*m))
impurityDecrease += nm * target.ImpFromCounts(allocs.Counter)
}

impurityDecrease /= nl + nr + nm
return
}

//Impurity is Gini impurity that uses the weights specified in WRFTarget.weights.
func (target *WRFTarget) Impurity(cases *[]int, counter *[]int) (e float64) {

target.CountPerCat(cases, counter)

return target.ImpFromCounts(counter)
}

//ImpFromCounts recalculates gini impurity from class counts for us in intertive updates.
func (target *WRFTarget) ImpFromCounts(counter *[]int) (e float64) {

total := 0.0
counts := *counter
for i := range counts {
counts[i] = 0
}
for _, i := range *cases {
if !target.IsMissing(i) {
cati := target.Geti(i)
counts[cati]++
total += target.Weights[cati]
}
}
e = 1.0
t := float64(total * total)
for i, v := range counts {
for i, v := range *counter {
w := target.Weights[i]
e -= float64(v*v) * w * w / t
total += float64(v) * w

e -= float64(v*v) * w * w
}

e /= float64(total * total)
e++

return

}

//FindPredicted finds the predicted target as the weighted catagorical Mode.
Expand Down

0 comments on commit 5f21f6e

Please sign in to comment.