-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improved Global Dead Code Elimination (#549)
Performs DCE by constructing a global dependency graph starting with top-level outputs, external module ports, and simulation constructs as circuit sinks. External modules can optionally be eligible for DCE via the OptimizableExtModuleAnnotation. Dead code is eliminated across module boundaries. Wires, ports, registers, and memories are all eligible for removal. Components marked with a DontTouchAnnotation will be treated as a circuit sink and thus anything that drives such a marked component will NOT be removed. This transform preserves deduplication. All instances of a given DefModule are treated as the same individual module. Thus, while certain instances may have dead code due to the circumstances of their instantiation in their parent module, they will still not be removed. To remove such modules, use the NoDedupAnnotation to prevent deduplication.
- Loading branch information
1 parent
fba12e0
commit cf22636
Showing
9 changed files
with
813 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
296 changes: 296 additions & 0 deletions
296
src/main/scala/firrtl/transforms/DeadCodeElimination.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
|
||
package firrtl.transforms | ||
|
||
import firrtl._ | ||
import firrtl.ir._ | ||
import firrtl.passes._ | ||
import firrtl.annotations._ | ||
import firrtl.graph._ | ||
import firrtl.analyses.InstanceGraph | ||
import firrtl.Mappers._ | ||
import firrtl.WrappedExpression._ | ||
import firrtl.Utils.{throwInternalError, toWrappedExpression, kind} | ||
import firrtl.MemoizedHash._ | ||
import wiring.WiringUtils.getChildrenMap | ||
|
||
import collection.mutable | ||
import java.io.{File, FileWriter} | ||
|
||
/** Dead Code Elimination (DCE) | ||
* | ||
* Performs DCE by constructing a global dependency graph starting with top-level outputs, external | ||
* module ports, and simulation constructs as circuit sinks. External modules can optionally be | ||
* eligible for DCE via the [[OptimizableExtModuleAnnotation]]. | ||
* | ||
* Dead code is eliminated across module boundaries. Wires, ports, registers, and memories are all | ||
* eligible for removal. Components marked with a [[DontTouchAnnotation]] will be treated as a | ||
* circuit sink and thus anything that drives such a marked component will NOT be removed. | ||
* | ||
* This transform preserves deduplication. All instances of a given [[DefModule]] are treated as | ||
* the same individual module. Thus, while certain instances may have dead code due to the | ||
* circumstances of their instantiation in their parent module, they will still not be removed. To | ||
* remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication. | ||
*/ | ||
class DeadCodeElimination extends Transform { | ||
def inputForm = LowForm | ||
def outputForm = LowForm | ||
|
||
/** Based on LogicNode ins CheckCombLoops, currently kind of faking it */ | ||
private type LogicNode = MemoizedHash[WrappedExpression] | ||
private object LogicNode { | ||
def apply(moduleName: String, expr: Expression): LogicNode = | ||
WrappedExpression(Utils.mergeRef(WRef(moduleName), expr)) | ||
def apply(moduleName: String, name: String): LogicNode = apply(moduleName, WRef(name)) | ||
def apply(component: ComponentName): LogicNode = { | ||
// Currently only leaf nodes are supported TODO implement | ||
val loweredName = LowerTypes.loweredName(component.name.split('.')) | ||
apply(component.module.name, WRef(loweredName)) | ||
} | ||
/** External Modules are representated as a single node driven by all inputs and driving all | ||
* outputs | ||
*/ | ||
def apply(ext: ExtModule): LogicNode = LogicNode(ext.name, ext.name) | ||
} | ||
|
||
/** Expression used to represent outputs in the circuit (# is illegal in names) */ | ||
private val circuitSink = LogicNode("#Top", "#Sink") | ||
|
||
/** Extract all References and SubFields from a possibly nested Expression */ | ||
def extractRefs(expr: Expression): Seq[Expression] = { | ||
val refs = mutable.ArrayBuffer.empty[Expression] | ||
def rec(e: Expression): Expression = { | ||
e match { | ||
case ref @ (_: WRef | _: WSubField) => refs += ref | ||
case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec | ||
case ignore @ (_: Literal) => // Do nothing | ||
case unexpected => throwInternalError | ||
} | ||
e | ||
} | ||
rec(expr) | ||
refs | ||
} | ||
|
||
// Gets all dependencies and constructs LogicNodes from them | ||
private def getDepsImpl(mname: String, | ||
instMap: collection.Map[String, String]) | ||
(expr: Expression): Seq[LogicNode] = | ||
extractRefs(expr).map { e => | ||
if (kind(e) == InstanceKind) { | ||
val (inst, tail) = Utils.splitRef(e) | ||
LogicNode(instMap(inst.name), tail) | ||
} else { | ||
LogicNode(mname, e) | ||
} | ||
} | ||
|
||
|
||
/** Construct the dependency graph within this module */ | ||
private def setupDepGraph(depGraph: MutableDiGraph[LogicNode], | ||
instMap: collection.Map[String, String]) | ||
(mod: Module): Unit = { | ||
def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr) | ||
|
||
def onStmt(stmt: Statement): Unit = stmt match { | ||
case DefRegister(_, name, _, clock, reset, init) => | ||
val node = LogicNode(mod.name, name) | ||
depGraph.addVertex(node) | ||
Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref)) | ||
case DefNode(_, name, value) => | ||
val node = LogicNode(mod.name, name) | ||
depGraph.addVertex(node) | ||
getDeps(value).foreach(ref => depGraph.addEdge(node, ref)) | ||
case DefWire(_, name, _) => | ||
depGraph.addVertex(LogicNode(mod.name, name)) | ||
case mem: DefMemory => | ||
// Treat DefMems as a node with outputs depending on the node and node depending on inputs | ||
// From perpsective of the module or instance, MALE expressions are inputs, FEMALE are outputs | ||
val memRef = WRef(mem.name, MemPortUtils.memType(mem), ExpKind, FEMALE) | ||
val exprs = Utils.create_exps(memRef).groupBy(Utils.gender(_)) | ||
val sources = exprs.getOrElse(MALE, List.empty).flatMap(getDeps(_)) | ||
val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_)) | ||
val memNode = getDeps(memRef) match { case Seq(node) => node } | ||
depGraph.addVertex(memNode) | ||
sinks.foreach(sink => depGraph.addEdge(sink, memNode)) | ||
sources.foreach(source => depGraph.addEdge(memNode, source)) | ||
case Attach(_, exprs) => // Add edge between each expression | ||
exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach { | ||
case Seq(a, b) => | ||
depGraph.addEdge(a, b) | ||
depGraph.addEdge(b, a) | ||
} | ||
case Connect(_, loc, expr) => | ||
// This match enforces the low Firrtl requirement of expanded connections | ||
val node = getDeps(loc) match { case Seq(elt) => elt } | ||
getDeps(expr).foreach(ref => depGraph.addEdge(node, ref)) | ||
// Simulation constructs are treated as top-level outputs | ||
case Stop(_,_, clk, en) => | ||
Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref)) | ||
case Print(_, _, args, clk, en) => | ||
(args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref)) | ||
case Block(stmts) => stmts.foreach(onStmt(_)) | ||
case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing | ||
case other => throw new Exception(s"Unexpected Statement $other") | ||
} | ||
|
||
// Add all ports as vertices | ||
mod.ports.foreach { | ||
case Port(_, name, _, _: GroundType) => depGraph.addVertex(LogicNode(mod.name, name)) | ||
case other => throwInternalError | ||
} | ||
onStmt(mod.body) | ||
} | ||
|
||
// TODO Make immutable? | ||
private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]], | ||
doTouchExtMods: Set[String], | ||
c: Circuit): MutableDiGraph[LogicNode] = { | ||
val depGraph = new MutableDiGraph[LogicNode] | ||
c.modules.foreach { | ||
case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod) | ||
case ext: ExtModule => | ||
// Connect all inputs to all outputs | ||
val node = LogicNode(ext) | ||
ext.ports.foreach { | ||
case Port(_, pname, _, AnalogType(_)) => | ||
depGraph.addEdge(LogicNode(ext.name, pname), node) | ||
depGraph.addEdge(node, LogicNode(ext.name, pname)) | ||
case Port(_, pname, Output, _) => | ||
val portNode = LogicNode(ext.name, pname) | ||
depGraph.addEdge(portNode, node) | ||
// Don't touch external modules *unless* they are specifically marked as doTouch | ||
if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode) | ||
case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname)) | ||
} | ||
} | ||
// Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface) | ||
val topModule = c.modules.find(_.name == c.main).get | ||
val topOutputs = topModule.ports.foreach { port => | ||
depGraph.addEdge(circuitSink, LogicNode(c.main, port.name)) | ||
} | ||
|
||
depGraph | ||
} | ||
|
||
private def deleteDeadCode(instMap: collection.Map[String, String], | ||
deadNodes: Set[LogicNode], | ||
moduleMap: collection.Map[String, DefModule], | ||
renames: RenameMap) | ||
(mod: DefModule): Option[DefModule] = { | ||
def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr) | ||
|
||
var emptyBody = true | ||
renames.setModule(mod.name) | ||
|
||
def onStmt(stmt: Statement): Statement = { | ||
val stmtx = stmt match { | ||
case inst: WDefInstance => | ||
moduleMap.get(inst.module) match { | ||
case Some(instMod) => inst.copy(tpe = Utils.module_type(instMod)) | ||
case None => | ||
renames.delete(inst.name) | ||
EmptyStmt | ||
} | ||
case decl: IsDeclaration => | ||
val node = LogicNode(mod.name, decl.name) | ||
if (deadNodes.contains(node)) { | ||
renames.delete(decl.name) | ||
EmptyStmt | ||
} | ||
else decl | ||
case con: Connect => | ||
val node = getDeps(con.loc) match { case Seq(elt) => elt } | ||
if (deadNodes.contains(node)) EmptyStmt else con | ||
case Attach(info, exprs) => // If any exprs are dead then all are | ||
val dead = exprs.flatMap(getDeps(_)).forall(deadNodes.contains(_)) | ||
if (dead) EmptyStmt else Attach(info, exprs) | ||
case block: Block => block map onStmt | ||
case other => other | ||
} | ||
stmtx match { // Check if module empty | ||
case EmptyStmt | _: Block => | ||
case other => emptyBody = false | ||
} | ||
stmtx | ||
} | ||
|
||
val (deadPorts, portsx) = mod.ports.partition(p => deadNodes.contains(LogicNode(mod.name, p.name))) | ||
deadPorts.foreach(p => renames.delete(p.name)) | ||
|
||
mod match { | ||
case Module(info, name, _, body) => | ||
val bodyx = onStmt(body) | ||
if (emptyBody && portsx.isEmpty) None else Some(Module(info, name, portsx, bodyx)) | ||
case ext: ExtModule => | ||
if (portsx.isEmpty) None | ||
else { | ||
if (ext.ports != portsx) throwInternalError // Sanity check | ||
Some(ext.copy(ports = portsx)) | ||
} | ||
} | ||
|
||
} | ||
|
||
def run(state: CircuitState, | ||
dontTouches: Seq[LogicNode], | ||
doTouchExtMods: Set[String]): CircuitState = { | ||
val c = state.circuit | ||
val moduleMap = c.modules.map(m => m.name -> m).toMap | ||
val iGraph = new InstanceGraph(c) | ||
val moduleDeps = iGraph.graph.edges.map { case (k,v) => | ||
k.module -> v.map(i => i.name -> i.module).toMap | ||
} | ||
val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_)) | ||
|
||
val depGraph = { | ||
val dGraph = createDependencyGraph(moduleDeps, doTouchExtMods, c) | ||
for (dontTouch <- dontTouches) { | ||
dGraph.getVertices.find(_ == dontTouch) match { | ||
case Some(node) => dGraph.addEdge(circuitSink, node) | ||
case None => | ||
val (root, tail) = Utils.splitRef(dontTouch.e1) | ||
DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize) | ||
} | ||
} | ||
DiGraph(dGraph) | ||
} | ||
|
||
val liveNodes = depGraph.reachableFrom(circuitSink) + circuitSink | ||
val deadNodes = depGraph.getVertices -- liveNodes | ||
val renames = RenameMap() | ||
renames.setCircuit(c.main) | ||
|
||
// As we delete deadCode, we will delete ports from Modules and somtimes complete modules | ||
// themselves. We iterate over the modules in a topological order from leaves to the top. The | ||
// current status of the modulesxMap is used to either delete instances or update their types | ||
val modulesxMap = mutable.HashMap.empty[String, DefModule] | ||
topoSortedModules.foreach { case mod => | ||
deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames)(mod) match { | ||
case Some(m) => modulesxMap += m.name -> m | ||
case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) | ||
} | ||
} | ||
|
||
// Preserve original module order | ||
val newCircuit = c.copy(modules = c.modules.flatMap(m => modulesxMap.get(m.name))) | ||
|
||
state.copy(circuit = newCircuit, renames = Some(renames)) | ||
} | ||
|
||
def execute(state: CircuitState): CircuitState = { | ||
val (dontTouches: Seq[LogicNode], doTouchExtMods: Seq[String]) = | ||
state.annotations match { | ||
case Some(aMap) => | ||
// TODO Do with single walk over annotations | ||
val dontTouches = aMap.annotations.collect { | ||
case DontTouchAnnotation(component) => LogicNode(component) | ||
} | ||
val optExtMods = aMap.annotations.collect { | ||
case OptimizableExtModuleAnnotation(ModuleName(name, _)) => name | ||
} | ||
(dontTouches, optExtMods) | ||
case None => (Seq.empty, Seq.empty) | ||
} | ||
run(state, dontTouches, doTouchExtMods.toSet) | ||
} | ||
} |
48 changes: 48 additions & 0 deletions
48
src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
|
||
package firrtl | ||
package transforms | ||
|
||
import firrtl.annotations._ | ||
import firrtl.passes.PassException | ||
|
||
/** A component that should be preserved | ||
* | ||
* DCE treats the component as a top-level sink of the circuit | ||
*/ | ||
object DontTouchAnnotation { | ||
private val marker = "DONTtouch!" | ||
def apply(target: ComponentName): Annotation = Annotation(target, classOf[Transform], marker) | ||
|
||
def unapply(a: Annotation): Option[ComponentName] = a match { | ||
case Annotation(component: ComponentName, _, value) if value == marker => Some(component) | ||
case _ => None | ||
} | ||
|
||
class DontTouchNotFoundException(module: String, component: String) extends PassException( | ||
s"Component marked DONT Touch ($module.$component) not found!\n" + | ||
"Perhaps it is an aggregate type? Currently only leaf components are supported.\n" + | ||
"Otherwise it was probably accidentally deleted. Please check that your custom passes are not" + | ||
"responsible and then file an issue on Github." | ||
) | ||
|
||
def errorNotFound(module: String, component: String) = | ||
throw new DontTouchNotFoundException(module, component) | ||
} | ||
|
||
/** An [[firrtl.ir.ExtModule]] that can be optimized | ||
* | ||
* Firrtl does not know the semantics of an external module. This annotation provides some | ||
* "greybox" information that the external module does not have any side effects. In particular, | ||
* this means that the external module can be Dead Code Eliminated. | ||
* | ||
* @note Unlike [[DontTouchAnnotation]], we don't care if the annotation is deleted | ||
*/ | ||
object OptimizableExtModuleAnnotation { | ||
private val marker = "optimizableExtModule!" | ||
def apply(target: ModuleName): Annotation = Annotation(target, classOf[Transform], marker) | ||
|
||
def unapply(a: Annotation): Option[ModuleName] = a match { | ||
case Annotation(component: ModuleName, _, value) if value == marker => Some(component) | ||
case _ => None | ||
} | ||
} |
Oops, something went wrong.