diff --git a/src/main/scala/Cpp.scala b/src/main/scala/Cpp.scala index e98996b4..cb0b2a1b 100644 --- a/src/main/scala/Cpp.scala +++ b/src/main/scala/Cpp.scala @@ -332,8 +332,8 @@ class CppBackend extends Backend { // schedule before Reg updates in case a MemWrite input is a Reg if (m.inputs.length == 2) return "" - def wmask(w: Int) = "(-" + emitLoWordRef(m.cond) + (if (m.isMasked) " & " + emitWordRef(m.wmask, w) else "") + ")" - block((0 until words(m)).map(i => emitRef(m.mem) + ".put(" + emitLoWordRef(m.addr) + ", " + i + ", (" + emitWordRef(m.data, i) + " & " + wmask(i) + ") | (" + emitRef(m.mem) + ".get(" + emitLoWordRef(m.addr) + ", " + i + ") & ~" + wmask(i) + "))")) + def mask(w: Int) = "(-" + emitLoWordRef(m.cond) + (if (m.isMasked) " & " + emitWordRef(m.mask, w) else "") + ")" + block((0 until words(m)).map(i => emitRef(m.mem) + ".put(" + emitLoWordRef(m.addr) + ", " + i + ", (" + emitWordRef(m.data, i) + " & " + mask(i) + ") | (" + emitRef(m.mem) + ".get(" + emitLoWordRef(m.addr) + ", " + i + ") & ~" + mask(i) + "))")) case _ => "" diff --git a/src/main/scala/FPGA.scala b/src/main/scala/FPGA.scala index 17ef0e47..ef65ec74 100644 --- a/src/main/scala/FPGA.scala +++ b/src/main/scala/FPGA.scala @@ -46,7 +46,7 @@ class FPGABackend extends VerilogBackend val i = "i" + emitTmp(m) (if (mw) " wire [" + (m.mem.width-1) + ":0] " + emitRef(m.mem) + "_w" + me + " = " + writeMap(m.mem, me).map(_ + "[" + emitRef(m.addr) + "]").reduceLeft(_+" ^ "+_) + ";\n" else "") + (if (m.isMasked) { - val bm = m.mem.width % 8 == 0 && useByteMask(m.wmask) + val bm = m.mem.width % 8 == 0 && useByteMask(m.mask) val max = if (bm) m.mem.width/8 else m.mem.width val maskIdx = if(bm) i+"*8" else i val dataIdx = if (bm) i+"*8+7:"+i+"*8" else i @@ -54,7 +54,7 @@ class FPGABackend extends VerilogBackend " genvar " + i + ";\n" + " for (" + i + " = 0; " + i + " < " + max + "; " + i + " = " + i + " + 1) begin: f" + emitTmp(m) + "\n" + " always @(posedge clk)\n" + - " if (" + emitRef(m.cond) + " && " + emitRef(m.wmask) + "["+maskIdx+"])\n" + + " if (" + emitRef(m.cond) + " && " + emitRef(m.mask) + "["+maskIdx+"])\n" + " " + meStr + "["+emitRef(m.addr)+"]["+dataIdx+"] <= " + emitRef(m.data) + "["+dataIdx+"]" + (if (mw) " ^ " + emitRef(m.mem) + "_w" + me + "["+dataIdx+"]" else "") + ";\n" + " end\n" + " endgenerate\n" diff --git a/src/main/scala/Mem.scala b/src/main/scala/Mem.scala index 28d5341f..23b1c50c 100644 --- a/src/main/scala/Mem.scala +++ b/src/main/scala/Mem.scala @@ -166,10 +166,10 @@ class PutativeMemWrite(mem: Mem[_], addri: Bits) extends Node with proc { class MemReadWrite(val read: MemSeqRead, val write: MemWrite) extends MemAccess(read.mem, null) { override def cond = throw new Exception("") - override def getPortType = "rw" + override def getPortType = if (write.isMasked) "mrw" else "rw" } -class MemWrite(mem: Mem[_], condi: Bool, addri: Node, datai: Node, wmaski: Node) extends MemAccess(mem, addri) { +class MemWrite(mem: Mem[_], condi: Bool, addri: Node, datai: Node, maski: Node) extends MemAccess(mem, addri) { inputs += condi override def cond = inputs(1) @@ -180,8 +180,8 @@ class MemWrite(mem: Mem[_], condi: Bool, addri: Node, datai: Node, wmaski: Node) b } inputs += wrap(datai) - if (wmaski != null) - inputs += wrap(wmaski) + if (maski != null) + inputs += wrap(maski) } override def forceMatchingWidths = { @@ -206,9 +206,9 @@ class MemWrite(mem: Mem[_], condi: Bool, addri: Node, datai: Node, wmaski: Node) wp.find(wc => rp.exists(rc => isNegOf(rc, wc) || isNegOf(wc, rc))) } def data = inputs(2) - def wmask = inputs(3) + def mask = inputs(3) def isMasked = inputs.length > 3 override def toString: String = mem + "[" + addr + "] = " + data + " COND " + cond - override def getPortType: String = "write" + override def getPortType: String = if (isMasked) "mwrite" else "write" override def isRamWriteInput(n: Node) = inputs.contains(n) } diff --git a/src/main/scala/Verilog.scala b/src/main/scala/Verilog.scala index e4620e40..cbe1376d 100644 --- a/src/main/scala/Verilog.scala +++ b/src/main/scala/Verilog.scala @@ -68,32 +68,34 @@ class VerilogBackend extends Backend { } def emitPortDef(m: MemAccess, idx: Int): String = { + def str(prefix: String, ports: (String, String)*) = + ports.toList.filter(_._2 != null) + .map(p => " ." + prefix + idx + p._1 + "(" + p._2 + ")") + .reduceLeft(_ + ",\n" + _) + m match { case r: MemSeqRead => - " .R" + idx + "A(" + emitRef(r.addr) + "),\n" + - " .R" + idx + "E(" + emitRef(r.cond) + "),\n" + - " .R" + idx + "O(" + emitTmp(r) + ")" + val addr = ("A", emitRef(r.addr)) + val en = ("E", emitRef(r.cond)) + val out = ("O", emitTmp(r)) + str("R", addr, en, out) case w: MemWrite => - val mask = if (w.isMasked) emitRef(w.wmask) else "{"+w.mem.width+"{1'b1}}" - - " .W" + idx + "A(" + emitRef(w.addr) + "),\n" + - " .W" + idx + "E(" + emitRef(w.cond) + "),\n" + - " .W" + idx + "M(" + mask + "),\n" + - " .W" + idx + "I(" + emitRef(w.data) + ")" + val addr = ("A", emitRef(w.addr)) + val en = ("E", emitRef(w.cond)) + val data = ("I", emitRef(w.data)) + val mask = ("M", if (w.isMasked) emitRef(w.mask) else null) + str("W", addr, en, data, mask) case rw: MemReadWrite => val (r, w) = (rw.read, rw.write) - val en = emitRef(r.cond) + " || " + emitRef(w.cond) - val addr = emitRef(w.cond) + " ? " + emitRef(w.addr) + " : " + emitRef(r.addr) - val mask = if (w.isMasked) emitRef(w.wmask) else "{"+rw.mem.width+"{1'b1}}" - - " .RW" + idx + "A(" + addr + "),\n" + - " .RW" + idx + "E(" + en + "),\n" + - " .RW" + idx + "W(" + emitRef(w.cond) + "),\n" + - " .RW" + idx + "M(" + mask + "),\n" + - " .RW" + idx + "I(" + emitRef(w.data) + "),\n" + - " .RW" + idx + "O(" + emitTmp(r) + ")" + val addr = ("A", emitRef(w.cond) + " ? " + emitRef(w.addr) + " : " + emitRef(r.addr)) + val en = ("E", emitRef(r.cond) + " || " + emitRef(w.cond)) + val write = ("W", emitRef(w.cond)) + val data = ("I", emitRef(w.data)) + val mask = ("M", if (w.isMasked) emitRef(w.mask) else null) + val out = ("O", emitTmp(r)) + str("RW", addr, en, write, data, mask, out) } } @@ -477,7 +479,7 @@ class VerilogBackend extends Backend { val i = "i" + emitTmp(m) if (m.isMasked) (0 until m.mem.width).map(i => - " if (" + emitRef(m.cond) + " && " + emitRef(m.wmask) + "[" + i + "])\n" + + " if (" + emitRef(m.cond) + " && " + emitRef(m.mask) + "[" + i + "])\n" + " " + emitRef(m.mem) + "[" + emitRef(m.addr) + "][" + i + "] <= " + emitRef(m.data) + "[" + i + "];\n" ).reduceLeft(_+_) else