Skip to content

Commit

Permalink
Parallelized RNN.
Browse files Browse the repository at this point in the history
  • Loading branch information
tsteczniewski committed Jun 4, 2015
1 parent 740f255 commit dafa7e7
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 253 deletions.
1 change: 0 additions & 1 deletion engine.json
Expand Up @@ -14,7 +14,6 @@
"outSize": 5,
"alpha": 0.95,
"regularizationCoeff": 0.001,
"useAdaGrad": true,
"steps": 1000
}
}
Expand Down
7 changes: 2 additions & 5 deletions src/main/scala/Algorithm.scala
Expand Up @@ -7,14 +7,12 @@ import org.apache.spark.SparkContext

import grizzled.slf4j.Logger

import scala.util.Random

case class AlgorithmParams(
inSize: Int,
outSize: Int,
alpha: Double,
regularizationCoeff: Double,
useAdaGrad: Boolean,
steps: Int
) extends Params

Expand All @@ -24,16 +22,15 @@ class Algorithm(val ap: AlgorithmParams)
@transient lazy val logger = Logger[this.type]

def train(sc: SparkContext, data: PreparedData): Model = {
val rnn = new RNN(ap.inSize, ap.outSize, ap.alpha, ap.regularizationCoeff, ap.useAdaGrad)
val rnn = new RNN(ap.inSize, ap.outSize, ap.alpha, ap.regularizationCoeff, data.labeledTrees)
for(i <- 0 until ap.steps) {
logger.info(s"Iteration $i: ${rnn.forwardPropagateError(data.labeledTrees)}")
rnn.stochasticGradientDescent(Random.shuffle(data.labeledTrees))
rnn.fit()
}
Model(rnn)
}

def predict(model: Model, query: Query): PredictedResult = {
// parser
val parser = Parser(query.content.length)
val pennFormatted = parser.pennFormatted(query.content)
val tree = Tree.fromPennTreeBankFormat(pennFormatted)
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/Preparator.scala
Expand Up @@ -4,6 +4,7 @@ import grizzled.slf4j.Logger
import io.prediction.controller.PPreparator

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

class Preparator
extends PPreparator[TrainingData, PreparedData] {
Expand Down Expand Up @@ -31,10 +32,10 @@ class Preparator

val labeledTrees = maybeLabeledTrees.filter(_ != None).map(_.get)

PreparedData(labeledTrees.toVector)
PreparedData(sc.parallelize(labeledTrees))
}
}

case class PreparedData(
labeledTrees : Vector[(Tree, Int)]
labeledTrees : RDD[(Tree, Int)]
) extends Serializable
219 changes: 107 additions & 112 deletions src/main/scala/RNN.scala
@@ -1,9 +1,9 @@
package org.template.rnn

import breeze.linalg.{argmax, sum, DenseVector, DenseMatrix}
import breeze.stats.distributions.Uniform
import breeze.linalg.{Vector => _, _}
import org.apache.spark.rdd.RDD
import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
import breeze.stats.distributions.Uniform
import scala.math.{exp, log, sqrt}

object RNN {
Expand All @@ -17,74 +17,84 @@ object RNN {
def logDerivative(x: Double) = 1 / x

def regularization(m: DenseMatrix[Double]) = sum(m.map(x => x * x))

def weightedMean(a: RNN, b: RNN, aq: Double, bq: Double): RNN = {
assert(a.inSize == b.inSize && a.outSize == b.outSize && a.alpha == b.alpha && a.regularizationCoeff == b.regularizationCoeff && a.useAdaGrad == b.useAdaGrad)
val sum = aq + bq
val rnn = new RNN(a.inSize, a.outSize, a.alpha, a.regularizationCoeff, a.useAdaGrad)
// judge
rnn.judge = ((aq * a.judge) + (bq * b.judge)) / sum
// combinator
if (true) {
val aKeySet = a.labelToCombinatorMap.keySet
val bKeySet = b.labelToCombinatorMap.keySet
for (key <- aKeySet.diff(bKeySet)) rnn.labelToCombinatorMap.put(key, (aq * a.labelToCombinatorMap(key)) / sum)
for (key <- bKeySet.diff(bKeySet)) rnn.labelToCombinatorMap.put(key, (bq * b.labelToCombinatorMap(key)) / sum)
for (key <- aKeySet.intersect(bKeySet)) rnn.labelToCombinatorMap.put(key, ((aq * a.labelToCombinatorMap(key)) + (bq * b.labelToCombinatorMap(key))) / sum)
}
// word vec
if (true) {
val aKeySet = a.wordToVecMap.keySet
val bKeySet = b.wordToVecMap.keySet
for (key <- aKeySet.diff(bKeySet)) rnn.wordToVecMap.put(key, (aq * a.wordToVecMap(key)) / sum)
for (key <- bKeySet.diff(bKeySet)) rnn.wordToVecMap.put(key, (bq * b.wordToVecMap(key)) / sum)
for (key <- aKeySet.intersect(bKeySet)) rnn.wordToVecMap.put(key, (aq * a.wordToVecMap(key) + bq * b.wordToVecMap(key)) / sum)
}
rnn
}
def regularization(m: DenseVector[Double]) = sum(m.map(x => x * x))

def maxClass(v: DenseVector[Double]): Int = argmax(v)

def removeNans(m: DenseMatrix[Double]): DenseMatrix[Double] = m.map(x => if (x.isNaN) 0 else x)
def removeNans(v: DenseVector[Double]): DenseVector[Double] = v.map(x => if (x.isNaN) 0 else x)
def clearNans(m: DenseVector[Double]): DenseVector[Double] = m.map(x => if (x.isNaN) 0 else x)
def clearNans(m: DenseMatrix[Double]): DenseMatrix[Double] = m.map(x => if (x.isNaN) 0 else x)

val fudgeFactor = 1e-6
def adaGrad(g: DenseMatrix[Double], gh: DenseMatrix[Double]): DenseMatrix[Double] = {
gh += removeNans(g :* g)
gh += clearNans(g :* g)
g :/ gh.map(x => fudgeFactor + sqrt(x))
}
def adaGrad(g: DenseVector[Double], gh: DenseVector[Double]): DenseVector[Double] = {
gh += removeNans(g :* g)
gh += clearNans(g :* g)
g :/ gh.map(x => fudgeFactor + sqrt(x))
}

case class Gradient(
judgeGradient: DenseMatrix[Double],
labelToCombinatorGradientMap: Map[(String, Int), DenseMatrix[Double]],
wordToVecGradientMap: Map[(String, String), DenseVector[Double]]
) extends Serializable

def mergeCombinatorGradientMaps(a: Map[(String, Int), DenseMatrix[Double]], b: Map[(String, Int), DenseMatrix[Double]]): Map[(String, Int), DenseMatrix[Double]] = {
// let b be smaller map
if(a.size < b.size) mergeCombinatorGradientMaps(b, a)
for (key <- b.keySet intersect a.keySet) a.get(key).get += b.get(key).get
for (key <- b.keySet diff a.keySet) a.put(key, b.get(key).get)
a
}
def mergeVecGradientMaps(a: Map[(String, String), DenseVector[Double]], b: Map[(String, String), DenseVector[Double]]): Map[(String, String), DenseVector[Double]] = {
// let b be smaller map
if(a.size < b.size) mergeVecGradientMaps(b, a)
for (key <- b.keySet intersect a.keySet) a.get(key).get += b.get(key).get
for (key <- b.keySet diff a.keySet) a.put(key, b.get(key).get)
a
}
def mergeGradients(a: Gradient, b: Gradient) =
Gradient(
a.judgeGradient + b.judgeGradient,
mergeCombinatorGradientMaps(a.labelToCombinatorGradientMap, b.labelToCombinatorGradientMap),
mergeVecGradientMaps(a.wordToVecGradientMap, b.wordToVecGradientMap)
)
}

case class RNN (
inSize: Int,
outSize: Int,
alpha: Double,
regularizationCoeff: Double,
useAdaGrad: Boolean
@transient labeledTrees: RDD[(Tree, Int)]
) extends Serializable {

var judge = RNN.randomMatrix(outSize, inSize + 1)
var labelToCombinatorMap = HashMap[(String, Int), DenseMatrix[Double]]()
var wordToVecMap = HashMap[(String, String), DenseVector[Double]]()

@transient var judgeGradient: DenseMatrix[Double] = null
@transient var labelToCombinatorGradientMap: Map[(String, Int), DenseMatrix[Double]] = null
@transient var wordToVecGradientMap: Map[(String, String), DenseVector[Double]] = null
var labelToCombinatorMap = Map[(String, Int), DenseMatrix[Double]]()
var wordToVecMap = Map[(String, String), DenseVector[Double]]()

@transient var judgeGradientHistory = RNN.randomMatrix(outSize, inSize + 1)
@transient var labelToCombinatorGradientHistoryMap = HashMap[(String, Int), DenseMatrix[Double]]()
@transient var wordToVecGradientHistoryMap = HashMap[(String, String), DenseVector[Double]]()

def clearCache() = {
judgeGradient = DenseMatrix.zeros(judge.rows, judge.cols)
labelToCombinatorGradientMap = Map[(String, Int), DenseMatrix[Double]]()
wordToVecGradientMap = Map[(String, String), DenseVector[Double]]()
@transient var labelToCombinatorGradientHistoryMap = Map[(String, Int), DenseMatrix[Double]]()
@transient var wordToVecGradientHistoryMap = Map[(String, String), DenseVector[Double]]()

def initializeMaps(trees: Array[Tree]): Unit = {
val nodeTypes = trees.map(Tree.nodeTypes(_)).reduce((a, b) => a union b)
val leafTypes = trees.map(Tree.leafTypes(_)).reduce((a, b) => a union b)
for ((label, childrenLength) <- nodeTypes) {
labelToCombinatorMap.put((label, childrenLength), RNN.randomMatrix(inSize, inSize * childrenLength + 1))
labelToCombinatorGradientHistoryMap.put((label, childrenLength), DenseMatrix.zeros[Double](inSize, inSize * childrenLength + 1))
}
for ((word, label) <- leafTypes) {
wordToVecMap.put((word, label), RNN.randomVector(inSize))
wordToVecGradientHistoryMap.put((word, label), DenseVector.zeros[Double](inSize))
}
}

// initialize maps
if(labeledTrees != null) // for tests
initializeMaps(labeledTrees.map(_._1).collect())

def label(i: Int) = {
val m = DenseVector.zeros[Double](outSize)
m(i) = 1
Expand All @@ -93,31 +103,34 @@ case class RNN (

def forwardPropagateTree(tree: Tree): ForwardPropagatedTree = tree match {
case Node(children, label) =>
val fpChildren = for(child <- children) yield forwardPropagateTree(child)
val vsChildren = for(fpChild <- fpChildren) yield fpChild.value
val joined: DenseVector[Double] = DenseVector.vertcat(vsChildren:_*)
val fpChildren = for (child <- children) yield forwardPropagateTree(child)
val vsChildren = for (fpChild <- fpChildren) yield fpChild.value
val joined: DenseVector[Double] = DenseVector.vertcat(vsChildren: _*)
val biased: DenseVector[Double] = DenseVector.vertcat(joined, DenseVector.ones(1))
val combinator = labelToCombinatorMap.getOrElseUpdate((label, children.length), RNN.randomMatrix(inSize, inSize * children.length + 1))
val transformed: DenseVector[Double] = combinator * biased
val combinator = labelToCombinatorMap.getOrElse((label, children.length), RNN.randomMatrix(inSize, inSize * children.length + 1))
val transformed: DenseVector[Double] = combinator * biased
ForwardPropagatedNode(fpChildren, label, transformed.map(RNN.sigmoid), transformed.map(RNN.sigmoidDerivative))
case Leaf(word, label) =>
val vec = wordToVecMap.getOrElseUpdate((word, label), RNN.randomVector(inSize))
val vec = wordToVecMap.getOrElse((word, label), RNN.randomVector(inSize))
ForwardPropagatedLeaf(word, label, vec.map(RNN.sigmoid), vec.map(RNN.sigmoidDerivative))
}

def backwardPropagateTree(tree: ForwardPropagatedTree, y: DenseVector[Double]): Unit = tree match {
def backwardPropagateTree(tree: ForwardPropagatedTree, y: DenseVector[Double], gradient: RNN.Gradient): Unit = tree match {
case ForwardPropagatedNode(children, label, _, d) =>
val z = y :* d
val vsChildren = for(child <- children) yield child.value
val joined: DenseVector[Double] = DenseVector.vertcat(vsChildren:_*)
// combinator
val vsChildren = for (child <- children) yield child.value
val joined: DenseVector[Double] = DenseVector.vertcat(vsChildren: _*)
val biased: DenseVector[Double] = DenseVector.vertcat(joined, DenseVector.ones(1))
val combinator = labelToCombinatorMap.get((label, children.length)).get
val combinatorGradient = labelToCombinatorGradientMap.getOrElseUpdate((label, children.length), DenseMatrix.zeros(inSize, inSize * children.length + 1))
combinatorGradient += z * biased.t
// update gradient
val combinatorGradient = gradient.labelToCombinatorGradientMap.getOrElseUpdate((label, children.length), DenseMatrix.zeros(inSize, inSize * children.length + 1))
// backward propagate children
val biasedGradient: DenseVector[Double] = combinator.t * z
for(i <- children.indices) backwardPropagateTree(children(i), biasedGradient(i * inSize to (i + 1) * inSize - 1))
for (i <- children.indices) backwardPropagateTree(children(i), biasedGradient(i * inSize to (i + 1) * inSize - 1), gradient)
combinatorGradient += z * biased.t
case ForwardPropagatedLeaf(word, label, _, d) =>
val vecGradient = wordToVecGradientMap.getOrElseUpdate((word, label), DenseVector.zeros(inSize))
val vecGradient = gradient.wordToVecGradientMap.getOrElseUpdate((word, label), DenseVector.zeros[Double](inSize))
vecGradient += y :* d
}

Expand All @@ -129,18 +142,18 @@ case class RNN (
activated
}

def backwardPropagateJudgement(tree: ForwardPropagatedTree, y: DenseVector[Double]): Unit = tree match {
def backwardPropagateJudgement(tree: ForwardPropagatedTree, y: DenseVector[Double], gradient: RNN.Gradient) = tree match {
case ForwardPropagatedTree(_, v, _) =>
val biased = DenseVector.vertcat(v, DenseVector.ones[Double](1))
val judged: DenseVector[Double] = judge * biased
val gradient: DenseVector[Double] = judged.map(RNN.sigmoidDerivative)
val z = y :* gradient
judgeGradient += z * biased.t
val activatedGradient: DenseVector[Double] = judged.map(RNN.sigmoidDerivative)
val z = y :* activatedGradient
gradient.judgeGradient += z * biased.t
val biasedGradient: DenseVector[Double] = judge.t * z
backwardPropagateTree(tree, biasedGradient(0 to inSize - 1))
backwardPropagateTree(tree, biasedGradient(0 to inSize - 1), gradient)
}

def forwardPropagateError(tree: ForwardPropagatedTree, expected: DenseVector[Double]): Double = {
def forwardPropagateError(tree: ForwardPropagatedTree, expected: DenseVector[Double]) = {
val oneMinusExpected = DenseVector.ones[Double](outSize) - expected
val actual = forwardPropagateJudgment(tree)
val logActual = actual.map(log)
Expand All @@ -149,71 +162,53 @@ case class RNN (
-(expected.t * logActual + oneMinusExpected.t * logOneMinusActual)
}

def backwardPropagateError(tree: ForwardPropagatedTree, expected: DenseVector[Double]): Unit = {
def backwardPropagateError(tree: ForwardPropagatedTree, expected: DenseVector[Double], gradient: RNN.Gradient) = {
val oneMinusExpected = DenseVector.ones[Double](outSize) - expected
val actual = forwardPropagateJudgment(tree)
val logActualGradient = actual.map(RNN.logDerivative)
val oneMinusActual = DenseVector.ones[Double](outSize) - actual
val logOneMinusActualGradient = - oneMinusActual.map(RNN.logDerivative)
val judgementGradient = - ((expected :* logActualGradient) + (oneMinusExpected :* logOneMinusActualGradient))
backwardPropagateJudgement(tree, judgementGradient)
val logOneMinusActualGradient = -oneMinusActual.map(RNN.logDerivative)
val judgementGradient = -((expected :* logActualGradient) + (oneMinusExpected :* logOneMinusActualGradient))
backwardPropagateJudgement(tree, judgementGradient, gradient)
}

def forwardPropagateError(labeledTrees: Vector[(Tree, Int)]): Double =
labeledTrees.foldLeft(0.0)((acc, labeledTree) => {
val (t, i) = labeledTree
acc + forwardPropagateError(forwardPropagateTree(t), label(i))
})
def forwardPropagateError(labeledTrees: RDD[(Tree, Int)]): Double =
labeledTrees.map(labeledTree => forwardPropagateError(forwardPropagateTree(labeledTree._1), label(labeledTree._2))).reduce((x, y) => x + y)

def forwardPropagateRegularizationError(influence: Double): Double = {
var regularization = RNN.regularization(judge)
for(combinator <- labelToCombinatorMap.values) regularization += RNN.regularization(combinator)
for(vec <- wordToVecMap.values) regularization += RNN.regularization(vec.asDenseMatrix)
for (combinator <- labelToCombinatorMap.values) regularization += RNN.regularization(combinator)
for (vec <- wordToVecMap.values) regularization += RNN.regularization(vec)
regularizationCoeff * influence * regularization
}

// calculates gradient only on touched matrices
def backwardPropagateRegularizationError(influence: Double): Unit = {
val coeff = regularizationCoeff * influence * 2.0
judgeGradient += coeff * judge
for((key, combinatorGradient) <- labelToCombinatorGradientMap) combinatorGradient += coeff * labelToCombinatorMap.get(key).get
for((key, vecGradient) <- wordToVecGradientMap) vecGradient += coeff * wordToVecMap.get(key).get
def backwardPropagateRegularizationError(gradient: RNN.Gradient): Unit = {
val coeff = regularizationCoeff * 2.0
gradient.judgeGradient += coeff * judge
for ((key, combinatorGradient) <- gradient.labelToCombinatorGradientMap) combinatorGradient += coeff * labelToCombinatorMap.get(key).get
for ((key, vecGradient) <- gradient.wordToVecGradientMap) vecGradient += coeff * wordToVecMap.get(key).get
}

def applyGradientWithoutAdaGrad() = {
judge -= RNN.removeNans(alpha * judgeGradient)
for((key, combinatorGradient) <- labelToCombinatorGradientMap) labelToCombinatorMap.get(key).get -= RNN.removeNans(alpha * combinatorGradient)
for((key, vecGradient) <- wordToVecGradientMap) wordToVecMap.get(key).get -= RNN.removeNans(alpha * vecGradient)
}

def applyGradientWithAdaGrad() = {
judge -= RNN.removeNans(alpha * RNN.adaGrad(judgeGradient, judgeGradientHistory))
for((key, combinatorGradient) <- labelToCombinatorGradientMap) {
val gradientHistory = labelToCombinatorGradientHistoryMap.getOrElseUpdate(key, DenseMatrix.zeros(combinatorGradient.rows, combinatorGradient.cols))
labelToCombinatorMap.get(key).get -= RNN.removeNans(alpha * RNN.adaGrad(combinatorGradient, gradientHistory))
def applyGradient(gradient: RNN.Gradient): Unit = {
judge -= RNN.clearNans(alpha * RNN.adaGrad(gradient.judgeGradient, judgeGradientHistory))
for ((key, combinatorGradient) <- gradient.labelToCombinatorGradientMap) {
val gradientHistory = labelToCombinatorGradientHistoryMap.get(key).get
labelToCombinatorMap.get(key).get -= RNN.clearNans(alpha * RNN.adaGrad(combinatorGradient, gradientHistory))
}

for((key, vecGradient) <- wordToVecGradientMap) {
val gradientHistory = wordToVecGradientHistoryMap.getOrElseUpdate(key, DenseVector.zeros(inSize))
wordToVecMap.get(key).get -= RNN.removeNans(alpha * RNN.adaGrad(vecGradient, gradientHistory))
for ((key, vecGradient) <- gradient.wordToVecGradientMap) {
val gradientHistory = wordToVecGradientHistoryMap.get(key).get
wordToVecMap.get(key).get -= RNN.clearNans(alpha * RNN.adaGrad(vecGradient, gradientHistory))
}
}

def fit(labeledTrees: Vector[(Tree, Int)]): Unit = {
clearCache()
for((t, i) <- labeledTrees) backwardPropagateError(forwardPropagateTree(t), label(i))
backwardPropagateRegularizationError(1)
if (useAdaGrad) applyGradientWithAdaGrad()
else applyGradientWithoutAdaGrad()
}

def stochasticGradientDescent(labeledTrees: Vector[(Tree, Int)]): Unit = {
for((t, i) <- labeledTrees) {
clearCache()
backwardPropagateError(forwardPropagateTree(t), label(i))
backwardPropagateRegularizationError(1 / labeledTrees.length)
if (useAdaGrad) applyGradientWithAdaGrad()
else applyGradientWithoutAdaGrad()
}
def fit(): Unit = {
val gradient = labeledTrees.mapPartitions(labeledTrees => {
val gradient = RNN.Gradient(DenseMatrix.zeros[Double](outSize, inSize + 1), Map.empty, Map.empty)
labeledTrees.foreach(labeledTree => backwardPropagateError(forwardPropagateTree(labeledTree._1), label(labeledTree._2), gradient))
Iterator(gradient)
}).reduce(RNN.mergeGradients)
backwardPropagateRegularizationError(gradient)
applyGradient(gradient)
}
}
10 changes: 10 additions & 0 deletions src/main/scala/Tree.scala
Expand Up @@ -66,6 +66,16 @@ object Tree {
val Closed(tree) = stack.top
tree
}

def nodeTypes(tree: Tree): Set[(String, Int)] = tree match {
case Node(children, label) => children.map(nodeTypes(_)).fold(Set())((a, b) => a union b) + ((label, children.length))
case Leaf(word, label) => Set()
}

def leafTypes(tree: Tree): Set[(String, String)] = tree match {
case Node(children, label) => children.map(leafTypes(_)).fold(Set())((a, b) => a union b)
case Leaf(word, label) => Set((word, label))
}
}


Expand Down

0 comments on commit dafa7e7

Please sign in to comment.