Skip to content

Commit

Permalink
Add support for wires in ConstProp
Browse files Browse the repository at this point in the history
This requires a quick second pass to back propagate constant wires but
the QoR win is substantial. We also only need to count back propagations
in determining whether to run ConstProp again which shaves off an
iteration in the common case.
  • Loading branch information
jackkoenig committed Jun 27, 2017
1 parent 0fca90f commit f8572ba
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
28 changes: 24 additions & 4 deletions src/main/scala/firrtl/passes/ConstProp.scala
Expand Up @@ -234,34 +234,54 @@ object ConstProp extends Pass {
case _ => r
}

// Two pass process
// 1. Propagate constants in expressions and forward propagate references
// 2. Propagate references again for backwards reference (Wires)
// TODO Replacing all wires with nodes makes the second pass unnecessary
@tailrec
private def constPropModule(m: Module): Module = {
var nPropagated = 0L
val nodeMap = collection.mutable.HashMap[String, Expression]()

def backPropExpr(expr: Expression): Expression = {
val old = expr map backPropExpr
val propagated = old match {
case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
constPropNodeRef(ref, nodeMap(rname))
case x => x
}
if (old ne propagated) {
nPropagated += 1
}
propagated
}
def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr

def constPropExpression(e: Expression): Expression = {
val old = e map constPropExpression
val propagated = old match {
case p: DoPrim => constPropPrim(p)
case m: Mux => constPropMux(m)
case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name))
case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) =>
constPropNodeRef(ref, nodeMap(rname))
case x => x
}
if (old ne propagated)
nPropagated += 1
propagated
}

def constPropStmt(s: Statement): Statement = {
val stmtx = s map constPropStmt map constPropExpression
stmtx match {
case x: DefNode => nodeMap(x.name) = x.value
case Connect(_, WRef(wname, wtpe, WireKind, _), expr) =>
val exprx = constPropExpression(pad(expr, wtpe))
nodeMap(wname) = exprx
case _ =>
}
stmtx
}

val res = Module(m.info, m.name, m.ports, constPropStmt(m.body))
val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body)))
if (nPropagated > 0) constPropModule(res) else res
}

Expand Down
7 changes: 4 additions & 3 deletions src/test/scala/firrtlTests/AnnotationTests.scala
Expand Up @@ -272,7 +272,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
anno("n.a"), anno("n.b[0]"), anno("n.b[1]"),
anno("r.a"), anno("r.b[0]"), anno("r.b[1]"),
anno("write.a"), anno("write.b[0]"), anno("write.b[1]"),
dontTouch("Top.r")
dontTouch("Top.r"), dontTouch("Top.w")
)
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
Expand Down Expand Up @@ -326,7 +326,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| out <= n
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"))
val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"),
dontTouch("Top.w"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_a"))
Expand Down Expand Up @@ -362,7 +363,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"),
dontTouch("Top.r"))
dontTouch("Top.r"), dontTouch("Top.w"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_b_0"))
Expand Down

0 comments on commit f8572ba

Please sign in to comment.