Skip to content

Commit

Permalink
Annotate temporary nodes in when statements
Browse files Browse the repository at this point in the history
  • Loading branch information
rameloni committed Jul 9, 2024
1 parent 2d57816 commit 41e0211
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 28 deletions.
104 changes: 87 additions & 17 deletions core/src/main/scala/chisel3/tywavesinternal/TywavesAnnotation.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package chisel3.tywavesinternal

import chisel3.{Data, Record, Vec, VecLike}
import chisel3.{Data, MemBase, Record, Vec, VecLike}
import chisel3.experimental.{BaseModule, ChiselAnnotation}
import chisel3.internal.HasId
import chisel3.internal.{HasId, NamedComponent}
import chisel3.internal.firrtl.ir._
import chisel3.properties.{DynamicObject, StaticObject}
import firrtl.annotations.{Annotation, IsMember, SingleTargetAnnotation}

import scala.collection.mutable

// TODO: if the code touches a lot of Chisel internals, it might be better to put it into
// - core
// otherwise:
Expand Down Expand Up @@ -45,11 +48,31 @@ private[chisel3] case class TywavesAnnotation[T <: IsMember](
}

object TywavesChiselAnnotation {

private val annoCreated = new mutable.HashSet[IsMember]()
private def createTywavesChiselAnno[T <: IsMember](
target: T,
name: String,
paramsOpt: Option[Seq[ClassParam]]
): Option[ChiselAnnotation] = {

if (annoCreated.contains(target)) {
None
} else {
annoCreated.add(target)
Some(new ChiselAnnotation {
override def toFirrtl: Annotation = TywavesAnnotation(target, name, paramsOpt)
})
}
}

def generate(circuit: Circuit): Seq[ChiselAnnotation] = {
// TODO: iterate over a circuit and generate TywavesAnnotation
val typeAliases: Seq[String] = circuit.typeAliases.map(_.name)

circuit.components.flatMap(c => generate(c, typeAliases))
val result = circuit.components.flatMap(c => generate(c, typeAliases))
annoCreated.clear()
result
// circuit.layers
// circuit.options

Expand All @@ -59,18 +82,18 @@ object TywavesChiselAnnotation {
def generate(component: Component, typeAliases: Seq[String]): Seq[ChiselAnnotation] = component match {
case ctx @ DefModule(id, name, public, layers, ports, cmds) =>
// TODO: Add tywaves annotation: components, ports, commands, layers
Seq(createAnno(id)) ++ (ports ++ ctx.secretPorts).flatMap(p =>
createAnno(id) ++ (ports ++ ctx.secretPorts).flatMap(p =>
generate(p, typeAliases)
) ++ (cmds ++ ctx.secretCommands).flatMap(c => generate(c, typeAliases))
case ctx @ DefBlackBox(id, name, ports, topDir, params) =>
// TODO: Add tywaves annotation, ports, ?params?
Seq(createAnno(id)) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases))
createAnno(id) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases))
case ctx @ DefIntrinsicModule(id, name, ports, topDir, params) =>
// TODO: Add tywaves annotation: ports, ?params?
Seq(createAnno(id)) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases))
createAnno(id) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases))
case ctx @ DefClass(id, name, ports, cmds) =>
// TODO: Add tywaves annotation: ports, commands
Seq(createAnno(id)) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases)) ++ cmds.flatMap(c =>
createAnno(id) ++ (ports ++ ctx.secretPorts).flatMap(p => generate(p, typeAliases)) ++ cmds.flatMap(c =>
generate(c, typeAliases)
)
case ctx => throw new Exception(s"Failed to generate TywavesAnnotation. Unknown component type: $ctx")
Expand All @@ -84,9 +107,10 @@ object TywavesChiselAnnotation {
val name = s"$binding[${dataToTypeName(innerType)}[$size]]"
// TODO: what if innerType is a Vec or a Bundle?

Seq(new ChiselAnnotation {
override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, None)
}) //++ createAnno(chisel3.Wire(innerType))
createTywavesChiselAnno(target.toTarget, name, None).toSeq
// Seq(new ChiselAnnotation {
// override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, None)
// }) //++ createAnno(chisel3.Wire(innerType))
}
command match {
case e: DefPrim[_] => Seq.empty // TODO: check prim
Expand All @@ -98,7 +122,7 @@ object TywavesChiselAnnotation {
case e @ FirrtlMemory(info, id, t, size, readPortNames, writePortNames, readwritePortNames) =>
createAnnoMem(id, id.getClass.getSimpleName, size, t)
case e @ DefMemPort(info, id, source, dir, idx, clock) => createAnno(id)
case Connect(info, loc, exp) => Seq.empty // TODO: check connect
case Connect(info, loc, exp) => createAnno(exp)
case PropAssign(info, loc, exp) => ???
case Attach(info, locs) => ???
case DefInvalid(info, arg) => Seq.empty // TODO: check invalid
Expand All @@ -113,6 +137,11 @@ object TywavesChiselAnnotation {
case e @ ProbeForce(sourceInfo, clock, cond, probe, value) => ???
case e @ ProbeRelease(sourceInfo, clock, cond, probe) => ???
case e @ Verification(_, op, info, clk, pred, pable) => ???
case e @ When(info, arg, ifRegion, elseRegion) =>
println(s"$ifRegion")
println(s"$elseRegion")
ifRegion.flatMap(generate(_, typeAliases)) ++ elseRegion
.flatMap(generate(_, typeAliases))
case e =>
println(s"Unknown command: $e") // TODO: replace with logger
Seq.empty
Expand Down Expand Up @@ -316,18 +345,59 @@ object TywavesChiselAnnotation {
case _ => getConstructorParamsOpt(target)
}

annotations :+ new ChiselAnnotation {
override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, paramsOpt)
}
annotations ++
createTywavesChiselAnno(target.toTarget, name, paramsOpt).toSeq
// new ChiselAnnotation {
// override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, paramsOpt)
// }
}

private def createAnno(target: BaseModule): ChiselAnnotation = {
private def createAnno(target: BaseModule): Seq[ChiselAnnotation] = {
val name = target.desiredName
val paramsOpt = getConstructorParamsOpt(target)
// val name = target.getClass.getTypeName
new ChiselAnnotation {
override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, paramsOpt)
createTywavesChiselAnno(target.toTarget, name, paramsOpt).toSeq

// new ChiselAnnotation {
// override def toFirrtl: Annotation = TywavesAnnotation(target.toTarget, name, paramsOpt)
// }
}

// TODO: replace ??? with a nice logger to avoid unexpected crashes
private def createAnno(target: HasId): Seq[ChiselAnnotation] = {
target match {
case t: Data => createAnno(t)
case t: BaseModule => createAnno(t)
case t: MemBase[_] => ???
case t: NamedComponent => ???
case t: VecLike[_] => ???
case t if t.isInstanceOf[DynamicObject] => ???
case t if t.isInstanceOf[StaticObject] => ???
}
}
// TODO: check all the cases for Arg
private def createAnno(target: Arg): Seq[ChiselAnnotation] = {
target match {
case t @ Node(id) => createAnno(id)
case t @ ModuleIO(mod, name) => ???
case t @ ILit(n) => ???
case t @ Ref(name) => ???
case t @ PropertyLit(propertyType, lit) => ???
case t @ PropExpr(sourceInfo, tpe, op, args) => ???
case t @ Slot(imm, name) => ???
case t @ ProbeExpr(probe) => ???
case t @ ProbeRead(probe) => ???
case t @ RWProbeExpr(probe) => ???
case t @ Index(imm, value) => ???
case t @ ModuleCloneIO(mod, name) => ???
case t @ OpaqueSlot(imm) => ???
case t: Component => ???
case t: LitArg => Seq.empty // Ignore
case t =>
println(s"Unknown Arg type: $t")
Seq.empty
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,60 @@ import org.scalatest.matchers.should.Matchers
/** Utility functions for testing [[chisel3.tywavesinternal.TywavesAnnotation]] */
object TestUtils extends Matchers {

def getSubOccurency(mainString: String, subString: String): List[String] = { // Split the text T into lines
val lines = mainString.split("\n").toList

// Compile a regex pattern to find the string X
val pattern = (".*" + subString + ".*").r

// Find the index of the line that matches the pattern
val matchIndex = lines.indexWhere(line => pattern.matches(line))

// If a match is found, return the line and the immediate next line
if (matchIndex >= 0) {
lines.slice(matchIndex, matchIndex + 2)
} else {
List() // Return an empty list if no match is found
}
}

def getMissingSubOccurency(mainString: String, subString: String, expectedMatches: Seq[String]): String = {
// Split the text into lines
val lines = mainString.split("\n").toSeq

// Compile a regex pattern to find the string X
val pattern = (".*" + subString + ".*").r

// Filter lines that match the pattern and get also the next line
val matchingIndexes = lines.indices.filter(i => pattern.findFirstIn(lines(i)).isDefined)

// Get the lines at these indexes and the next line
val linesWithNext = matchingIndexes.flatMap { idx =>
if (idx < lines.length - 1) List(lines(idx), lines(idx + 1))
else List(lines(idx)) // Handle edge case where last line matches
}.distinct // Remove duplicates in case same line is matched more than once

val expectedRegex = expectedMatches.map(_.r)
// Filter out lines that are in expectedMatches
val missingLines = linesWithNext.filterNot { line =>
expectedRegex.exists(_.findFirstIn(line).isDefined)
}
missingLines.mkString("\n")

}

def countSubstringOccurrences(mainString: String, subString: String): Int = {
val pattern = subString.r
pattern.findAllMatchIn(mainString).length
}

// Return target and expected regex string
def createExpected(
target: String,
typeName: String,
binding: String = "",
params: Option[Seq[ClassParam]] = None
): String = {
): (String, String) = {
val realTypeName = binding match {
case "" => typeName
case _ => s"$binding\\[$typeName\\]"
Expand All @@ -43,16 +86,25 @@ object TestUtils extends Matchers {
}.mkString(",\\s+")}\\s+\\]"""
case None => ""
}
s"""\"target\":\"$target\",\\s+\"typeName\":\"$realTypeName\"$realParams\\s*}""".stripMargin
(target, s"""\"target\":\"$target\",\\s+\"typeName\":\"$realTypeName\"$realParams\\s*""".stripMargin)
}

def checkAnno(expectedMatches: Seq[(String, Int)], refString: String, includeConstructor: Boolean = false): Unit = {
def checkAnno(
expectedMatches: Seq[((String, String), Int)],
refString: String,
includeConstructor: Boolean = false
): Unit = {
def totalAnnoCheck(n: Int): (String, Int) =
(""""class":"chisel3.tywavesinternal.TywavesAnnotation"""", if (includeConstructor) n else n + 1)

(expectedMatches :+ totalAnnoCheck(expectedMatches.map(_._2).sum)).foreach {
val targetStrings = expectedMatches.map(p => {
val s = p._1._1
"\"target\":\"" + s + "\""
})
(expectedMatches.map(p => (p._1._2, p._2)) :+ totalAnnoCheck(expectedMatches.map(_._2).sum)).foreach {
case (pattern, count) =>
(countSubstringOccurrences(refString, pattern) should be(count)).withClue(s"Pattern: $pattern")
(countSubstringOccurrences(refString, pattern) should be(count)).withClue(
s"Pattern: $pattern: ${getMissingSubOccurency(refString, pattern, targetStrings)}"
)
}
}
}
Expand Down Expand Up @@ -237,6 +289,29 @@ object TywavesAnnotationCircuits {
class TopCircuitTypeInSubmodule(bindingChoice: BindingChoice) extends RawModule {
val mod = Module(new TopCircuitGroundTypes(bindingChoice))
}

// Test temporary values declared inside when and otherwise blocks
class TopCircuitWhenElse extends RawModule {
// Internally implement a MUX
val inSeq = IO(Input(Vec(8, UInt(8.W))))
val out = IO(Output(UInt(8.W)))
val sel = IO(Input(UInt(math.sqrt(8).ceil.toInt.W)))

when(sel % 2.U === 0.U) {
val outTmp = inSeq(sel)
val evenSel = outTmp + 1.U
out := evenSel
}.elsewhen(sel === 1.U) {
val outTmp = inSeq(sel)
val selIsOne = outTmp + 1.U
out := selIsOne
}.otherwise {
val outTmp = inSeq(sel)
val oddSel = outTmp + 1.U
out := oddSel
}

}
}

object MemCircuits {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,13 @@ class TypeAnnotationDataTypesSpec extends AnyFunSpec with Matchers with chiselTe
(createExpected("~TopCircuitTypeInSubmodule\\|TopCircuitGroundTypes>sint", "SInt<8>", b.toString), 1),
(createExpected("~TopCircuitTypeInSubmodule\\|TopCircuitGroundTypes>bool", "Bool", b.toString), 1),
(createExpected("~TopCircuitTypeInSubmodule\\|TopCircuitGroundTypes>bits", "UInt<8>", b.toString), 1),
(""""target":"~TopCircuitTypeInSubmodule\|TopCircuitGroundTypes",\s+"typeName":"TopCircuitGroundTypes"""", 1)
(
(
"""~TopCircuitTypeInSubmodule\|TopCircuitGroundTypes""",
""""target":"~TopCircuitTypeInSubmodule\|TopCircuitGroundTypes",\s+"typeName":"TopCircuitGroundTypes""""
),
1
)
) ++ addClockReset("TopCircuitTypeInSubmodule", Some("TopCircuitGroundTypes")) ++ analog
checkAnno(expectedMatches, string)
}
Expand Down Expand Up @@ -185,4 +191,26 @@ class TypeAnnotationDataTypesSpec extends AnyFunSpec with Matchers with chiselTe
typeTests(args, targetDir, RegBinding)
}

describe("Tmp Values Annotations") {
val targetDir = os.pwd / "test_run_dir" / "TywavesAnnotationSpec" / "Tmp Values Annotations"
val args: Array[String] = Array("--target", "chirrtl", "--target-dir", targetDir.toString)
// format: off
it("should annotate tmp value in when") {
(new ChiselStage(true)).execute(args, Seq(ChiselGeneratorAnnotation(() => new TopCircuitWhenElse)))
val string = os.read(targetDir / "TopCircuitWhenElse.fir")
val expectedMatches = Seq(
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>inSeq", "UInt<8>\\[8\\]", "IO",
params = Some(Seq(ClassParam("gen", "=> T", None), ClassParam("length", "Int", Some("8"))))), 1),
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>inSeq\\[0\\]", "UInt<8>", "IO"), 1),
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>out", "UInt<8>", "IO"), 1),
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>sel", "UInt<3>", "IO"), 1),
// Tmp
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>evenSel", "UInt<8>", "OpResult"), 1),
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>oddSel", "UInt<8>", "OpResult"), 1),
(createExpected("~TopCircuitWhenElse\\|TopCircuitWhenElse>selIsOne", "UInt<8>", "OpResult"), 1)
)
checkAnno(expectedMatches, string)
// format: on
}
}
}
Loading

0 comments on commit 41e0211

Please sign in to comment.