Skip to content

Commit

Permalink
fix to pull request #54...
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Apr 19, 2013
1 parent 4b0f1f7 commit 4f79aef
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions learn/src/main/scala/breeze/classify/NNetClassifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class NNetClassifier[L, T](nnet: NeuralNetwork,
}

object NNetClassifier {
class CounterTrainer[L, T](opt: OptParams = OptParams()) extends Classifier.Trainer[L, Counter[T, Double]] {
class CounterTrainer[L, T](opt: OptParams = OptParams(),layersIn:Array[Int] = Array(100)) extends Classifier.Trainer[L, Counter[T, Double]] {
type MyClassifier = NNetClassifier[L, Counter[T, Double]]

def train(data: Iterable[Example[L, Counter[T, Double]]],layersIn:Array[Int] = Array(100)) = {
def train(data: Iterable[Example[L, Counter[T, Double]]]) = {
val labels = Index[L]()
data foreach { labels index _.label}
val featureIndex = Index[T]()
Expand All @@ -48,7 +48,7 @@ object NNetClassifier {
val obj = new NNObjective(processedData.toIndexedSeq, errorFun, layers)
val guess = obj.initialWeightVector
val weights = opt.minimize(obj,guess)
new NNetClassifier(obj.extract(weights), {fEncoder.encodeDense(_, true)}, labels)
new NNetClassifier(obj.extract(weights), {fEncoder.encodeDense(_:Counter[T, Double], true)}, labels)
}
}
}
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=0.12.0
sbt.version=0.12.3

0 comments on commit 4f79aef

Please sign in to comment.