Skip to content

Commit

Permalink
Merge pull request #361 from hkapp/topic/optimizations/gvn
Browse files Browse the repository at this point in the history
Add GlobalValueNumbering pass
  • Loading branch information
densh committed Nov 2, 2016
2 parents 399b21b + 2436578 commit 473cac1
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ final class Compiler(opts: Opts) {
private lazy val passCompanions: Seq[PassCompanion] = Seq(
pass.GlobalBoxingElimination,
pass.DeadCodeElimination,
pass.GlobalValueNumbering,
pass.MainInjection,
pass.ExternHoisting,
pass.ModuleLowering,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ object ControlFlow {
final case class Edge(val from: Block, val to: Block, val next: Next)

final case class Block(name: Local, params: Seq[Val.Local], insts: Seq[Inst]) {
val pred = mutable.UnrolledBuffer.empty[Edge]
val succ = mutable.UnrolledBuffer.empty[Edge]
val inEdges = mutable.UnrolledBuffer.empty[Edge]
val outEdges = mutable.UnrolledBuffer.empty[Edge]

def pred = inEdges.map(_.from)
def succ = outEdges.map(_.to)

def label = Inst.Label(name, params)
}

Expand All @@ -33,7 +37,7 @@ object ControlFlow {
val node = worklist.pop()
if (!visited.contains(node)) {
visited += node
node.succ.foreach(e => worklist.push(e.to))
node.outEdges.foreach(e => worklist.push(e.to))
f(node)
}
}
Expand Down Expand Up @@ -79,8 +83,8 @@ object ControlFlow {

def edge(from: Block, to: Block, next: Next) = {
val e = new Edge(from, to, next)
from.succ += e
to.pred += e
from.outEdges += e
to.inEdges += e
}

val blocks: Seq[Block] = insts.zipWithIndex.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ import ControlFlow.Block

object DominatorTree {

private def pred(block: Block) = block.pred.map(_.from)
private def succ(block: Block) = block.succ.map(_.to)

/** Fixpoint-based method to build the dominator tree
* from the CFG. The dominator tree is simply represented
* as a Map from a CFG block to the set of blocks dominating
Expand All @@ -24,7 +21,7 @@ object DominatorTree {
val (block, dequeued) = workList.dequeue
workList = dequeued.filterNot(_ == block) // remove duplicates

val visitedPreds = pred(block).filter(domination.contains)
val visitedPreds = block.pred.filter(domination.contains)
val predDomination =
visitedPreds.toList.map(pred => domination.getOrElse(pred, Set.empty))
val correctPredDomination = predDomination.filterNot(
Expand All @@ -44,7 +41,7 @@ object DominatorTree {

if (oldDomination != newDomination) {
domination += (block -> newDomination)
workList ++= succ(block)
workList ++= block.succ
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package scala.scalanative
package compiler
package analysis

import ControlFlow.Block
import nir.Shows._

object Shows {

private def blockToString(block: Block): String =
showLocal(block.name).toString

def showCFG(cfg: ControlFlow.Graph): String = {
cfg.all.map { block =>
val succStr =
block.succ.map(blockToString).mkString("(", ",", ")")
val predStr =
block.pred.map(blockToString).mkString("(", ",", ")")
s"${blockToString(block)} -> ${succStr}, pred = ${predStr}"
}.mkString("\n")
}

def showDominatorTree(domination: Map[Block, Set[Block]]): String = {
domination.toSeq
.sortBy(_._1.name.id)
.map {
case (block, set) =>
s"${blockToString(block)} -> ${set.map(blockToString).mkString("(", ",", ")")}"
}
.mkString("\n")
}

def cfgToDot(cfg: ControlFlow.Graph): String = {
def blockToDot(block: Block): String = {
val successors = block.succ
val blockID = block.name.id
if (successors.nonEmpty)
successors
.map(succ => succ.name.id.toString)
.mkString(s"${blockID} -> {", " ", "};")
else
s"${blockID} [ shape=doublecircle ];"
}

s"""
|digraph {
| node [shape=circle, width=0.6, fixedsize=true];
|${cfg.map(blockToDot).mkString("\n")}
|}
""".stripMargin
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class GenTextualLLVM(assembly: Seq[Defn]) extends GenShow(assembly) {
val prologue: Show.Result =
if (isEntry) s()
else {
val shows = block.pred match {
val shows = block.inEdges match {
case ExSucc(branches) =>
params.zipWithIndex.map {
case (Val.Local(name, ty), n) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,243 @@ package scala.scalanative
package compiler
package pass

import nir._, Shows._
import util.sh
import scala.collection.mutable
import scala.util.hashing.MurmurHash3

object GlobalValueNumbering {
import analysis.ControlFlow
import analysis.ControlFlow.Block
import analysis.DominatorTree

import nir._

class GlobalValueNumbering extends Pass {
import GlobalValueNumbering._

override def preDefn = {
case defn: Defn.Define =>
val cfg = ControlFlow.Graph(defn.insts)
val domination = DominatorTree.build(cfg)

val newInsts = performSimpleValueNumbering(cfg, domination)

Seq(defn.copy(insts = newInsts))
}

private def performSimpleValueNumbering(
cfg: ControlFlow.Graph,
domination: Map[Block, Set[Block]]): Seq[Inst] = {

val variableVN = mutable.HashMap.empty[Local, Hash]
val instructions = mutable.HashMap.empty[Hash, List[Inst.Let]]
val localDefs = mutable.HashMap.empty[Local, Inst]

val hash = new HashFunction(variableVN)
val deepEquals = new DeepEquals(localDefs)

def blockDominatedByDef(dominatedBlock: Block,
dominatingDef: Local): Boolean = {

domination(dominatedBlock).exists { dominatingBlock =>
val foundInParam = dominatingBlock.params.exists {
case Val.Local(paramName, _) => (paramName == dominatingDef)
}
val foundInInsts = dominatingBlock.insts.exists {
case Inst.Let(name, _) => (name == dominatingDef)
case _ => false
}

foundInParam || foundInInsts
}
}

val newInsts = cfg.map { block =>
variableVN ++= block.params.map(lval =>
(lval.name, HashFunction.rawLocal(lval.name)))
localDefs ++= block.params.map(lval => (lval.name, block.label))

val newBlockInsts = block.insts.map {

case inst: Inst.Let => {
val idempotent = isIdempotent(inst.op)

val instHash =
if (idempotent)
hash(inst.op)
else
inst.hashCode // hash the assigned variable as well, so a = op(b) and c = op(b) don't have the same hash

variableVN += (inst.name -> instHash)
localDefs += (inst.name -> inst)

if (idempotent) {
val hashEqualInstrs = instructions.getOrElse(instHash, Nil)
instructions += (instHash -> (inst :: hashEqualInstrs))

val equalInstrs =
hashEqualInstrs.filter(otherInst =>
deepEquals.eqInst(inst, otherInst))
val redundantInstrs = equalInstrs.filter(eqInst =>
blockDominatedByDef(block, eqInst.name)) // only redundant if the current block is dominated by the block in which the equal instruction occurs

val newInstOpt = redundantInstrs.headOption.map(
redInst =>
Inst.Let(inst.name,
Op.Copy(Val.Local(redInst.name, redInst.op.resty))))
newInstOpt.getOrElse(inst)
} else {
inst
}
}

case otherInst @ _ =>
otherInst
}

block.label +: newBlockInsts
}

newInsts.flatten
}

}

object GlobalValueNumbering extends PassCompanion {
def apply(ctx: Ctx) = new GlobalValueNumbering()

def isIdempotent(op: Op): Boolean = {
import Op._
op match {
// Always idempotent:
case (_: Pure | _: Method | _: As | _: Is | _: Copy | _: Sizeof |
_: Module | _: Field) =>
true

// Never idempotent:
case (_: Load | _: Store | _: Stackalloc | _: Classalloc | _: Call |
_: Closure) =>
false
}
}

class DeepEquals(localDefs: Local => Inst) {

def eqInst(instA: Inst.Let, instB: Inst.Let): Boolean = {
(instA.name == instB.name) || eqOp(instA.op, instB.op)
}

def eqOp(opA: Op, opB: Op): Boolean = {
import Op._
if (!(isIdempotent(opA) && isIdempotent(opB)))
false
else {
(opA, opB) match {

case (Elem(tyA, ptrA, indexesA), Elem(tyB, ptrB, indexesB)) =>
eqType(tyA, tyB) && eqVal(ptrA, ptrB) && eqVals(indexesA, indexesB)

case (Extract(aggrA, indexesA), Extract(aggrB, indexesB)) =>
eqVal(aggrA, aggrB) && (indexesA == indexesB)

case (Insert(aggrA, valueA, indexesA),
Insert(aggrB, valueB, indexesB)) =>
eqVal(aggrA, aggrB) && eqVal(valueA, valueB) && (indexesA == indexesB)

// TODO handle commutativity of some bins
case (Bin(binA, tyA, lA, rA), Bin(binB, tyB, lB, rB)) =>
eqBin(binA, binB) && eqType(tyA, tyB) && eqVal(lA, lB) && eqVal(rA,
rB)

case (Comp(compA, tyA, lA, rA), Comp(compB, tyB, lB, rB)) =>
eqComp(compA, compB) && eqType(tyA, tyB) && eqVal(lA, lB) && eqVal(
rA,
rB)

case (Conv(convA, tyA, valueA), Conv(convB, tyB, valueB)) =>
eqConv(convA, convB) && eqType(tyA, tyB) && eqVal(valueA, valueB)

case (Select(condA, thenvA, elsevA),
Select(condB, thenvB, elsevB)) =>
eqVals(Seq(condA, thenvA, elsevA), Seq(condB, thenvB, elsevB))

case (Field(tyA, objA, nameA), Field(tyB, objB, nameB)) =>
eqType(tyA, tyB) && eqVal(objA, objB) && eqGlobal(nameA, nameB)

case (Method(tyA, objA, nameA), Method(tyB, objB, nameB)) =>
eqType(tyA, tyB) && eqVal(objA, objB) && eqGlobal(nameA, nameB)

case (Module(nameA), Module(nameB)) =>
eqGlobal(nameA, nameB)

case (As(tyA, objA), As(tyB, objB)) =>
eqType(tyA, tyB) && eqVal(objA, objB)

case (Is(tyA, objA), Is(tyB, objB)) =>
eqType(tyA, tyB) && eqVal(objA, objB)

case (Copy(valueA), Copy(valueB)) =>
eqVal(valueA, valueB)

case (Sizeof(tyA), Sizeof(tyB)) =>
eqType(tyA, tyB)

case _ => false // non-matching pairs of ops, or not idempotent ones
}
}
}

def eqVal(valueA: Val, valueB: Val): Boolean = {
import Val._
(valueA, valueB) match {
case (Struct(nameA, valuesA), Struct(nameB, valuesB)) =>
eqGlobal(nameA, nameB) && eqVals(valuesA, valuesB)

case (Array(elemtyA, valuesA), Array(elemtyB, valuesB)) =>
eqType(elemtyA, elemtyB) && eqVals(valuesA, valuesB)

case (Const(valueA), Const(valueB)) =>
eqVal(valueA, valueB)

case (Local(nameA, valtyA), Local(nameB, valtyB)) =>
lazy val eqDefs = (localDefs(nameA), localDefs(nameB)) match {
case (_: Inst.Label, _: Inst.Label) => (nameA == nameB)
case (instA: Inst.Let, instB: Inst.Let) => eqInst(instA, instB)
case _ => false
}
eqType(valtyA, valtyB) && ((nameA == nameB) || eqDefs)

case _ =>
valueA == valueB
}
}

def eqVals(valsA: Seq[Val], valsB: Seq[Val]): Boolean = {
val sizeEqual = (valsA.size == valsB.size)
lazy val contentEqual =
valsA.zip(valsB).forall { case (a, b) => eqVal(a, b) }
sizeEqual && contentEqual
}

def eqType(tyA: Type, tyB: Type): Boolean = {
tyA == tyB
}

def eqGlobal(globalA: Global, globalB: Global): Boolean = {
globalA == globalB
}

def eqBin(binA: Bin, binB: Bin): Boolean = {
binA == binB
}

def eqComp(compA: Comp, compB: Comp): Boolean = {
compA == compB
}

def eqConv(convA: Conv, convB: Conv): Boolean = {
convA == convB
}

}

type Hash = Int

Expand Down

0 comments on commit 473cac1

Please sign in to comment.