Skip to content

Commit

Permalink
remove partial connect
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiwi committed Aug 21, 2023
1 parent 271512b commit 272f959
Show file tree
Hide file tree
Showing 22 changed files with 19 additions and 141 deletions.
1 change: 0 additions & 1 deletion src/main/antlr4/FIRRTL.g4
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ stmt
| 'inst' id 'of' id info?
| 'node' id '=' exp info?
| ref '<=' exp info?
| ref '<-' exp info?
| ref 'is' 'invalid' info?
| when
| 'stop(' exp exp intLit ')' stmtName? info?
Expand Down
21 changes: 10 additions & 11 deletions src/main/scala/firrtl2/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -630,17 +630,16 @@ object Utils extends LazyLogging {
case ex => throwInternalError(s"flow: shouldn't be here - $e")
}
def get_flow(s: Statement): Flow = s match {
case sx: DefWire => DuplexFlow
case sx: DefRegister => DuplexFlow
case sx: DefNode => SourceFlow
case sx: DefInstance => SourceFlow
case sx: DefMemory => SourceFlow
case sx: Block => UnknownFlow
case sx: Connect => UnknownFlow
case sx: PartialConnect => UnknownFlow
case sx: Stop => UnknownFlow
case sx: Print => UnknownFlow
case sx: IsInvalid => UnknownFlow
case sx: DefWire => DuplexFlow
case sx: DefRegister => DuplexFlow
case sx: DefNode => SourceFlow
case sx: DefInstance => SourceFlow
case sx: DefMemory => SourceFlow
case sx: Block => UnknownFlow
case sx: Connect => UnknownFlow
case sx: Stop => UnknownFlow
case sx: Print => UnknownFlow
case sx: IsInvalid => UnknownFlow
case EmptyStmt => UnknownFlow
}
def get_flow(p: Port): Flow = if (p.direction == Input) SourceFlow else SinkFlow
Expand Down
1 change: 0 additions & 1 deletion src/main/scala/firrtl2/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
case _ =>
ctx.getChild(1).getText match {
case "<=" => Connect(info, visitRef(ctx.ref), visitExp(ctx_exp(0)))
case "<-" => PartialConnect(info, visitRef(ctx.ref), visitExp(ctx_exp(0)))
case "is" => IsInvalid(info, visitRef(ctx.ref))
case "mport" =>
CDefMPort(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,6 @@ private class ModuleToTransitionSystem(
}
case s: ir.Conditionally =>
error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}")
case s: ir.PartialConnect =>
error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}")
case s: ir.Attach =>
error(s"Analog wires are not supported in the SMT backend: ${s.serialize}")
case s: ir.Stop =>
Expand Down
18 changes: 2 additions & 16 deletions src/main/scala/firrtl2/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ abstract class Expression extends FirrtlNode {
/** Represents reference-like expression nodes: SubField, SubIndex, SubAccess and Reference
* The following fields can be cast to RefLikeExpression in every well formed firrtl AST:
* - SubField.expr, SubIndex.expr, SubAccess.expr
* - IsInvalid.expr, Connect.loc, PartialConnect.loc
* - IsInvalid.expr, Connect.loc
* - Attach.exprs
*/
sealed trait RefLikeExpression extends Expression { def flow: Flow }
Expand Down Expand Up @@ -667,21 +667,7 @@ case class Block(stmts: Seq[Statement]) extends Statement with UseSerializer {
def foreachString(f: String => Unit): Unit = ()
def foreachInfo(f: Info => Unit): Unit = ()
}
case class PartialConnect(info: Info, loc: Expression, expr: Expression)
extends Statement
with HasInfo
with UseSerializer {
def mapStmt(f: Statement => Statement): Statement = this
def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr))
def mapType(f: Type => Type): Statement = this
def mapString(f: String => String): Statement = this
def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
def foreachStmt(f: Statement => Unit): Unit = ()
def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) }
def foreachType(f: Type => Unit): Unit = ()
def foreachString(f: String => Unit): Unit = ()
def foreachInfo(f: Info => Unit): Unit = f(info)
}

case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo with UseSerializer {
def mapStmt(f: Statement => Statement): Statement = this
def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr))
Expand Down
3 changes: 1 addition & 2 deletions src/main/scala/firrtl2/ir/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ object Serializer {
writers.foreach { w => b ++= "writer => "; b ++= w; newLineAndIndent(1) }
readwriters.foreach { r => b ++= "readwriter => "; b ++= r; newLineAndIndent(1) }
b ++= "read-under-write => "; b ++= readUnderWrite.toString
case PartialConnect(info, loc, expr) => s(loc); b ++= " <- "; s(expr); s(info)
case Attach(info, exprs) =>
case Attach(info, exprs) =>
// exprs should never be empty since the attach statement takes *at least* two signals according to the spec
b ++= "attach ("; s(exprs, ", "); b += ')'; s(info)
case veri @ Verification(op, info, clk, pred, en, msg, _) =>
Expand Down
1 change: 0 additions & 1 deletion src/main/scala/firrtl2/ir/StructuralHash.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ class StructuralHash private (h: Hasher, renameModule: String => String) {
hash(writers.length); writers.foreach(hash)
hash(readwriters.length); readwriters.foreach(hash)
hash(readUnderWrite)
case PartialConnect(_, loc, expr) => id(31); hash(loc); hash(expr)
case Attach(_, exprs) => id(32); hash(exprs.length); exprs.foreach(hash)
// WIR
case firrtl2.CDefMemory(_, name, tpe, size, seq, readUnderWrite) =>
Expand Down
4 changes: 0 additions & 4 deletions src/main/scala/firrtl2/passes/CInferMDir.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ object CInferMDir extends Pass {
infer_mdir_e(mports, MRead)(sx.expr)
infer_mdir_e(mports, MWrite)(sx.loc)
sx
case sx: PartialConnect =>
infer_mdir_e(mports, MRead)(sx.expr)
infer_mdir_e(mports, MWrite)(sx.loc)
sx
case sx => sx.map(infer_mdir_s(mports)).map(infer_mdir_e(mports, MRead))
}

Expand Down
3 changes: 0 additions & 3 deletions src/main/scala/firrtl2/passes/CheckFlows.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,6 @@ object CheckFlows extends Pass {
s.args.foreach(check_flow(info, mname, flows, SourceFlow))
check_flow(info, mname, flows, SourceFlow)(s.en)
check_flow(info, mname, flows, SourceFlow)(s.clk)
case (s: PartialConnect) =>
check_flow(info, mname, flows, SinkFlow)(s.loc)
check_flow(info, mname, flows, SourceFlow)(s.expr)
case (s: Conditionally) =>
check_flow(info, mname, flows, SourceFlow)(s.pred)
case (s: Stop) =>
Expand Down
7 changes: 3 additions & 4 deletions src/main/scala/firrtl2/passes/CheckHighForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,9 @@ trait CheckHighFormLike { this: Pass =>
errors.append(new MemWithFlipException(info, mname, sx.name))
if (sx.depth <= 0)
errors.append(new NegMemSizeException(info, mname))
case sx: DefInstance => checkInstance(info, mname, sx.module)
case sx: Connect => checkValidLoc(info, mname, sx.loc)
case sx: PartialConnect => checkValidLoc(info, mname, sx.loc)
case sx: Print => checkFstring(info, mname, sx.string, sx.args.length)
case sx: DefInstance => checkInstance(info, mname, sx.module)
case sx: Connect => checkValidLoc(info, mname, sx.loc)
case sx: Print => checkFstring(info, mname, sx.string, sx.args.length)
case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) }
case mport: CDefMPort =>
errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) }
Expand Down
6 changes: 0 additions & 6 deletions src/main/scala/firrtl2/passes/CheckTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ object CheckTypes extends Pass {

def validConnect(con: Connect): Boolean = validConnect(con.loc.tpe, con.expr.tpe)

def validPartialConnect(con: PartialConnect): Boolean =
bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default)

//;---------------- Helper Functions --------------
private val UIntUnknown = UIntType(UnknownWidth)
def ut: UIntType = UIntUnknown
Expand Down Expand Up @@ -366,9 +363,6 @@ object CheckTypes extends Pass {
case sx: Connect if !validConnect(sx) =>
val conMsg = sx.copy(info = NoInfo).serialize
errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr))
case sx: PartialConnect if !validPartialConnect(sx) =>
val conMsg = sx.copy(info = NoInfo).serialize
errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr))
case sx: DefRegister =>
sx.tpe match {
case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
Expand Down
4 changes: 0 additions & 4 deletions src/main/scala/firrtl2/passes/ConvertFixedToSInt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ object ConvertFixedToSInt extends Pass {
val point = calcPoint(Seq(loc))
val newExp = alignArg(exp, point)
Connect(info, loc, newExp).map(updateExpType)
case PartialConnect(info, loc, exp) =>
val point = calcPoint(Seq(loc))
val newExp = alignArg(exp, point)
PartialConnect(info, loc, newExp).map(updateExpType)
// check Connect case, need to shl
case s => (s.map(updateStmtType)).map(updateExpType)
}
Expand Down
16 changes: 0 additions & 16 deletions src/main/scala/firrtl2/passes/ExpandConnects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,6 @@ object ExpandConnects extends Pass {
case Flip => Connect(sx.info, expx, locx)
}
})
case sx: PartialConnect =>
val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default)
val locs = create_exps(sx.loc)
val exps = create_exps(sx.expr)
val stmts = ls.map {
case (x, y) =>
locs(x).tpe match {
case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
case _ =>
to_flip(flow(locs(x))) match {
case Default => Connect(sx.info, locs(x), exps(y))
case Flip => Connect(sx.info, exps(y), locs(x))
}
}
}
Block(stmts)
case sx => sx.map(expand_s)
}
}
Expand Down
14 changes: 0 additions & 14 deletions src/main/scala/firrtl2/passes/InferBinaryPoints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,6 @@ class InferBinaryPoints extends Pass {
}
}
c
case pc: PartialConnect =>
val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
val locs = create_exps(pc.loc)
val exps = create_exps(pc.expr)
ls.foreach {
case (x, y) =>
val loc = locs(x)
val exp = exps(y)
to_flip(flow(loc)) match {
case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
}
}
pc
case r: DefRegister =>
addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
r
Expand Down
14 changes: 0 additions & 14 deletions src/main/scala/firrtl2/passes/InferWidths.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,6 @@ class InferWidths extends Transform with ResolvedAnnotationPaths with Dependency
}
}
c
case pc: PartialConnect =>
val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
val locs = create_exps(pc.loc)
val exps = create_exps(pc.expr)
ls.foreach {
case (x, y) =>
val loc = locs(x)
val exp = exps(y)
to_flip(flow(loc)) match {
case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
}
}
pc
case r: DefRegister =>
if (r.reset.tpe != AsyncResetType) {
addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1)))
Expand Down
2 changes: 0 additions & 2 deletions src/main/scala/firrtl2/passes/LowerTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ object LowerTypes extends Transform with DependencyAPIMigration {
// We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated.
val lhs = symbols.getReferences(loc.asInstanceOf[RefLikeExpression])
Block(lhs.map(loc => Connect(info, loc, rhs)))
case p: PartialConnect =>
throw new RuntimeException(s"LowerTypes expects PartialConnects to be resolved! $p")
case IsInvalid(info, expr) =>
if (!expr.tpe.isInstanceOf[GroundType]) {
throw new RuntimeException(s"LowerTypes expects IsInvalids to have been expanded! ${expr.tpe.serialize}")
Expand Down
19 changes: 0 additions & 19 deletions src/main/scala/firrtl2/passes/RemoveCHIRRTL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,25 +276,6 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
}
}
if (stmts.isEmpty) sx else Block(sx +: stmts.toSeq)
case PartialConnect(info, loc, expr) =>
val locx = remove_chirrtl_e(SinkFlow)(loc)
val rocx = remove_chirrtl_e(SourceFlow)(expr)
val sx = PartialConnect(info, locx, rocx)
val stmts = ArrayBuffer[Statement]()
has_read_mport match {
case None =>
case Some(en) => stmts += Connect(info, en, one)
}
if (has_write_mport) {
val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default)
val locs = create_exps(get_mask(refs)(loc))
stmts ++= (ls.map { case (x, _) => Connect(info, locs(x), one) })
has_readwrite_mport match {
case None =>
case Some(wmode) => stmts += Connect(info, wmode, one)
}
}
if (stmts.isEmpty) sx else Block(sx +: stmts.toSeq)
case sx => sx.map(remove_chirrtl_s(refs, raddrs)).map(remove_chirrtl_e(SourceFlow))
}
}
Expand Down
2 changes: 0 additions & 2 deletions src/main/scala/firrtl2/passes/ResolveFlows.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ object ResolveFlows extends Pass {
IsInvalid(info, resolve_e(SinkFlow)(expr))
case Connect(info, loc, expr) =>
Connect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr))
case PartialConnect(info, loc, expr) =>
PartialConnect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr))
case sx => sx.map(resolve_e(SourceFlow)).map(resolve_s)
}

Expand Down
5 changes: 0 additions & 5 deletions src/main/scala/firrtl2/passes/TrimIntervals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ class TrimIntervals extends Pass {
case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
case _ => c
}
case c @ PartialConnect(info, loc, expr) =>
loc.tpe match {
case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
case _ => c
}
case other => other.map(alignStmtBP)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ class SeparateWriteClocks extends Transform with DependencyAPIMigration {
}
replaceExprs ++= clockWireMap.map { case (pClk, clkWire) => pClk -> Reference(clkWire) }
Block(mem +: clockStmts)
case Connect(i, lhs, rhs) => Connect(i, onExpr(replaceExprs)(lhs), rhs)
case PartialConnect(i, lhs, rhs) => PartialConnect(i, onExpr(replaceExprs)(lhs), rhs)
case IsInvalid(i, invalidated) => IsInvalid(i, onExpr(replaceExprs)(invalidated))
case s => s.mapStmt(onStmt(replaceExprs, ns))
case Connect(i, lhs, rhs) => Connect(i, onExpr(replaceExprs)(lhs), rhs)
case IsInvalid(i, invalidated) => IsInvalid(i, onExpr(replaceExprs)(invalidated))
case s => s.mapStmt(onStmt(replaceExprs, ns))
}

override def execute(state: CircuitState): CircuitState = {
Expand Down
3 changes: 0 additions & 3 deletions src/main/scala/firrtl2/transforms/Dedup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,6 @@ object DedupModules extends LazyLogging {
case Connect(_, lhs, rhs) =>
markAggregatePorts(lhs)
markAggregatePorts(rhs)
case PartialConnect(_, lhs, rhs) =>
markAggregatePorts(lhs)
markAggregatePorts(rhs)
case _ =>
}
}
Expand Down
7 changes: 0 additions & 7 deletions src/main/scala/firrtl2/transforms/InferResets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ class InferResets extends Transform with DependencyAPIMigration {
for ((loc, exp) <- locs.zip(exps)) {
markResetDriver(loc, exp)
}
case PartialConnect(_, lhs, rhs) =>
val points = Utils.get_valid_points(lhs.tpe, rhs.tpe, Default, Default)
val locs = Utils.create_exps(lhs)
val exps = Utils.create_exps(rhs)
for ((i, j) <- points) {
markResetDriver(locs(i), exps(j))
}
case IsInvalid(_, lhs) =>
val exprs = Utils.create_exps(lhs)
for (expr <- exprs) {
Expand Down

0 comments on commit 272f959

Please sign in to comment.