# **RISC-V Thunder Core (Using Amaranth HDL)**


|Name|Affiliation| IEEE Member | SSCS Member |
|:-----------------:|:----------:|:----------:|:----------:|
| Muddassir Ali Siddiqui| Usman Institute of Technology | No | No  |
| Syed Ali Ahmed Naqvi  | Usman Institute of Technology | No | No  |
| Muhammad Fayz         | Usman Institute of Technology | No | No  |
| Ammar Saleem          | Usman Institute of Technology | No | No  |
| Zeeshan Rafique (Advisor)    | Usman Institute of Technology | Yes | Yes |


![RV-Thunder Logo](https://i.imgur.com/oeywSSS.png)

 ## Introduction

 RISC-V Thunder Core is a 32-bit CPU core that currenly implements the [RISC-V](https://riscv.org/) RV32I instruction set. Its microarchitecture is described in plain Python code using [Amaranth HDL](https://amaranth-lang.org/docs/amaranth/latest/). There is one main file [rv-thunder.py](src/rv-thunder.py)

 ## Features

 * Written in user-friendly language Python's framework (amaranth HDL)
 * Support instruction set architecture of RISC-V 32I
 * 8KB Instruction and data memory
 * 32 general purpose registers
 * Test bench (using Amaranth HDL)

 ## Dependencies
 [amaranth 0.4.dev197+g11d5bb1.editable](https://amaranth-lang.org/docs/amaranth/latest/install.html)  
 [amarant-boards](https://github.com/amaranth-lang/amaranth-boards.)  
 [Python 3.10.12](https://www.python.org/downloads/release/python-31013/)  
 [Yosys](https://yosyshq.net/yosys/download.html)



 ## Block Diagram
 ![RV-Thunder Block Diagram](https://imgur.com/1kFOHdm.png)

 ## Code Blocks

 ### Top Level

In [None]:
## 
from amaranth import *
from amaranth.sim import Simulator

# Import the defined modules
from fetch import *
from control import *
from regfile import *
from alu import *
from mem import *
from branch import *

# Create a top-level module that connects the modules
class TopModule(Elaboratable):
    def elaborate(self, platform):
        m = Module()

        # Instantiate each module
        fetch_unit = FetchUnit()
        control_unit = control()
        reg_file = regfile()
        branch_unit = branch()
        alu = ALU()
        inst_memory_unit = instr_mem()
        data_memory_unit = data_mem()

        # Connect modules together
        m.submodules.fetch_unit = fetch_unit
        m.submodules.control_unit = control_unit
        m.submodules.reg_file = reg_file
        m.submodules.alu = alu
        m.submodules.branch_unit = branch_unit
        m.submodules.inst_memory_unit = inst_memory_unit
        m.submodules.data_memory_unit = data_memory_unit

#===========================< Instruction memory connection >===========================
        m.d.comb += [
            inst_memory_unit.adr.eq(fetch_unit.pc[2:15]),
            control_unit.instr_dat.eq(inst_memory_unit.dat_r),  
            alu.aluop.eq(control_unit.aluop),
#===========================< Registers Connections >===========================
            reg_file.rs1.eq(control_unit.rs1),
            reg_file.rs2.eq(control_unit.rs2),
            reg_file.rd.eq(control_unit.rd),

            reg_file.we.eq(control_unit.we),
            # alu.inp1.eq(reg_file.rf_out1),

            branch_unit.op1.eq(reg_file.rf_out1),
            branch_unit.op2.eq(reg_file.rf_out2),
            branch_unit.func3.eq(control_unit.funct3),

            data_memory_unit.adr.eq(alu.alu_out[2:15]),
            data_memory_unit.dmem_we.eq(control_unit.dmem_we),
        ]
#==========================< Store into memory >========================
        with m.If(control_unit.dmem_we == 1):
            m.d.comb += data_memory_unit.dmem_din.eq(reg_file.rf_out2)

#==========================< Operand b select >========================
        with m.If (control_unit.op_b_sel == 1):
            m.d.comb += alu.inp2.eq(control_unit.imm)
        with m.Else ():
            m.d.comb += alu.inp2.eq(reg_file.rf_out2)

#==========================< Operand a select >========================
        with m.If (control_unit.op_a_sel == 0):
            m.d.comb += alu.inp1.eq(reg_file.rf_out1)
        with m.Elif (control_unit.op_a_sel == 1):
            m.d.comb += alu.inp1.eq(fetch_unit.pc[0:12])
        with m.Elif (control_unit.op_a_sel == 2):
            m.d.comb += alu.inp1.eq(fetch_unit.pc)
        with m.Else ():
            m.d.comb += alu.inp1.eq(0x00000000)

#==========================< Update Pc and Branch select >========================

        with m.If (control_unit.op == 0b1100011):
            m.d.comb += [
                fetch_unit.branch.eq(control_unit.br & branch_unit.br_out),     #branch 
                fetch_unit.branch_tar.eq(alu.alu_out),
                ]
        with m.Elif (control_unit.op == 0b1100111):
            m.d.comb += [
                fetch_unit.branch.eq(1),    #jalr signal 
                fetch_unit.branch_tar.eq(alu.alu_out),
            ]

        with m.Elif (control_unit.op == 0b1101111):
            m.d.comb += [
                fetch_unit.branch.eq(1),    #jal signal
                fetch_unit.branch_tar.eq(alu.alu_out),
            ]

#==========================< load data from memory Or store address of next_pc/ jal/ jalr in regfile >========================
        with m.If (control_unit.ld_wd == 1):
            m.d.comb += reg_file.wb_data.eq(data_memory_unit.dmem_dout)
        
        with m.Else ():
            with m.If (control_unit.ld_adr == 1):
                m.d.comb += reg_file.wb_data.eq(fetch_unit.pc + 4)

            with m.Else ():
                m.d.comb += reg_file.wb_data.eq(alu.alu_out)

        return m

# Simulate the top module
dut = TopModule()
def bench():
    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield

    yield
    yield
    yield
    yield
    yield
# We can provide initial values for signals above

sim = Simulator(dut)
sim.add_clock(1e-6)  #Add clock
sim.add_sync_process(bench)
with sim.write_vcd("sim.vcd"): # Generate Vcd, which is useful to see a result in GTKwave
    sim.run()

from amaranth.back import verilog

top = TopModule()
with open("output/verilog/rv-thunder.v", "w") as f:
    f.write(verilog.convert(top, ports=[]))

### ALU 

In [None]:
from amaranth import *

class ALU(Elaboratable):

    def __init__(self):

        self.aluop = Signal(4)
        self.inp1 = Signal(32)
        self.inp2 = Signal(32)
        self.alu_out = Signal(32)
       
    def elaborate(self, platform):
        m = Module() 

        forshft = self.inp2 & 0x1F
        forshft1 = forshft[0:5]
        
        with m.If(self.aluop == 0b0000): #Add,ADDI
            m.d.comb += self.alu_out.eq(self.inp1 + self.inp2)

        with m.Elif(self.aluop == 0b0001): #SLL,SLLI
            m.d.comb += self.alu_out.eq(self.inp1 << (forshft1))

        with m.Elif(self.aluop == 0b0010): #SLT,SLTI
            m.d.comb += self.alu_out.eq(self.inp1 <= self.inp2)

        with m.Elif(self.aluop == 0b0011): #SLTU,SLTIU 
            m.d.comb += self.alu_out.eq(self.inp1.as_unsigned() <= self.inp2.as_unsigned())

        with m.Elif(self.aluop == 0b0100): #XOR,XORI
            m.d.comb += self.alu_out.eq(self.inp1 ^ self.inp2)

        with m.Elif(self.aluop == 0b0101): #SRL,SRLI
            m.d.comb += self.alu_out.eq(self.inp1 >> (forshft1))
        
        with m.Elif(self.aluop == 0b0110): #OR,ORI
            m.d.comb += self.alu_out.eq(self.inp1 | self.inp2)

        with m.Elif(self.aluop == 0b0111): #AND,ANDI
            m.d.comb += self.alu_out.eq(self.inp1 & self.inp2)  

        with m.Elif(self.aluop == 0b1000): #SUB
            m.d.comb += self.alu_out.eq(self.inp1 - self.inp2)

        with m.Elif(self.aluop == 0b1101): #SRA,SRAI
            m.d.comb += self.alu_out.eq(Cat(self.inp1[-1], self.inp1[:-1]) >> self.inp2)

        return m

Memory

In [None]:
from amaranth import *

class instr_mem(Elaboratable):
    def __init__(self):
        # Define the inputs and outputs of the instruction memory module
        self.adr = Signal(13)  # 10-bit address for 8192 (2^13) instructions
        self.dat_r = Signal(32)  # 32-bit RISC-V instruction output
        # Make a .txt file and put Hexa decimal values in it 
        with open('src/memory.txt', 'r') as file: # this is the format for open file  
            mem_init_file = file.readlines()
            toint = [int(value, 16) for value in mem_init_file] # Add this line to make it int otherwise it shows an error 
        # Define the instruction memory content (replace this with your actual instructions)
        self.mem = Memory(width=32, depth=8192, init= toint)

    def elaborate(self, platform):
        m = Module()
        # Create a read port for the instruction memory
        m.submodules.rdport = rdport = self.mem.read_port(domain="comb")
        m.d.comb += [
            rdport.addr.eq(self.adr),
            self.dat_r.eq(rdport.data)
        ]

        return m

class data_mem(Elaboratable):
    def __init__(self):
        self.adr = Signal(13)
        self.dmem_din = Signal(32)
        self.dmem_dout = Signal(32)
        self.dmem_we = Signal()
        # Create a memory with the specified depth (replace this with your actual data)
        self.memory = Memory(width=32, depth=8192)

    def elaborate(self, platform):
        m = Module()
        # Create a read and write port for the data memory
        m.submodules.rdport = rdport = self.memory.read_port(domain="comb")
        m.submodules.wrport = wrport = self.memory.write_port()
        # Connect the address and data signals
        m.d.comb += [
            rdport.addr.eq(self.adr),
            self.dmem_dout.eq(rdport.data),
            wrport.addr.eq(self.adr),
            wrport.data.eq(self.dmem_din),
            wrport.en.eq(self.dmem_we)  # Enable write operation
        ]

        return m

### Control

In [None]:
from amaranth import *

class control(Elaboratable):

    def __init__(self):

        self.instr_dat = Signal(32)  # Input Instruction
        self.funct3 = Signal(3) # Function 3 
        self.funct7 = Signal() # Function 7
        self.rs1 = Signal(5) # Source Register 1
        self.rs2 = Signal(5) # Source Register 2
        self.rd = Signal(5) # Destination Register
        self.op = Signal(7) # Opcode

        self.we = Signal() # Register write enable (It will be 1 for R and I type and it is 0 for S type)
        self.ld_wd = Signal() #Load
        self.aluop = Signal(4) #ALU Operation
        self.dmem_we = Signal() # Store Word (It will be 1 only if store instruction occur )
        self.ld_adr = Signal()
        self.br = Signal()

        self.imm = Signal(32) #Immediate

        self.iimm = Signal(12) # I type immediate

        self.simm = Signal(12) # S-type full immediate
        self.simm1 = Signal(5) # Sub1 immediate of S type
        self.simm2 = Signal(7) # Sub2 immediate of S type

        self.uimm = Signal(20)  # U type immediate

        self.sbimm0 = Signal()
        self.sbimm1 = Signal(4)
        self.sbimm2 = Signal(6)
        self.sbimm3 = Signal()
        self.sbimm4 = Signal()
        self.sbimm = Signal(13) # SB type immediate

        self.ujimm0 = Signal()
        self.ujimm1 = Signal(10)
        self.ujimm2 = Signal()
        self.ujimm3 = Signal(8)
        self.ujimm4 = Signal()
        self.ujimm = Signal(21) # UJ type immediate

        self.op_b_sel = Signal() # Operand B select bit for mux (Useful when there is an immediate)
        self.op_a_sel = Signal(2)

#==========================< Instr Decode >===========================
    def elaborate(self, platform):
        m = Module()

        m.d.comb += [
            self.op.eq (self.instr_dat[0:]), # op is of 7 bits so (0 to 6)
            self.rd.eq (self.instr_dat[7:]), # rd is of 5 bits so (7 to 11)
            self.funct3.eq (self.instr_dat[12:]), #funct3 is of 3 bits so (12 to 14)
            self.rs1.eq (self.instr_dat[15:]), # rs1 is of 5 bits so (15 to 19)
            self.rs2.eq (self.instr_dat[20:]), # rs2 is of 5 bits so (20 to 24)
            self.funct7.eq (self.instr_dat[30]), # funct3 is of 1 bit (30th bit)

#====================================Immediate for I type Instruction=========================== 
            self.iimm.eq (self.instr_dat[20:]), # iimm is of 12 bits so (20 to 31)

#====================================Immediate for S type Instruction===========================
            self.simm1.eq (self.instr_dat[7:]), # simm1 is of 5 bits so (7 to 11)
            self.simm2.eq (self.instr_dat[25:]), # simm2 is of 7 bits so (25 to 31)
            self.simm.eq (Cat(self.simm1, self.simm2)), # simm is of 12 bits , make simm by concatenating both simm1 and simm2

#====================================Immediate for SB type Instruction===========================
            self.sbimm0.eq (0),
            self.sbimm1.eq (self.instr_dat[8:]),
            self.sbimm2.eq (self.instr_dat[25:]),
            self.sbimm3.eq (self.instr_dat[7]),
            self.sbimm4.eq (self.instr_dat[31]),
            self.sbimm.eq (Cat(self.sbimm0,self.sbimm1, self.sbimm2, self.sbimm3, self.sbimm4)),

#====================================Immediate for UJ type Instruction===========================
            self.ujimm0.eq (0),
            self.ujimm1.eq (self.instr_dat[21:]),
            self.ujimm2.eq (self.instr_dat[20]),
            self.ujimm3.eq (self.instr_dat[12:]),
            self.ujimm4.eq (self.instr_dat[31]),
            self.ujimm.eq (Cat(self.ujimm0,self.ujimm1, self.ujimm2, self.ujimm3, self.ujimm4)),

            self.uimm.eq (self.instr_dat[12:]),

            self.aluop.eq (Cat(self.funct3, self.funct7))
            ] # aluop is of 4 bits, make aluop by concatenating both funct3 and funct7
        
        m.d.comb += [
            self.we.eq(0),
            self.ld_wd.eq(0),
            self.dmem_we.eq(0),
            self.ld_adr.eq(0),
            self.op_a_sel.eq(0),
            self.op_b_sel.eq(0),
            self.br.eq(0)
        ]

#=====================================< R-Type 33 >=====================================
        with m.Switch(self.op):
            with m.Case(0b0110011): # opcode of R-Type

                m.d.comb += [
                    self.we.eq(1),
                    self.op_a_sel.eq(0),
                    self.dmem_we.eq(0)
                    ]

#=====================================< I-Type 13 >=====================================
            with m.Case(0b0010011):# opcode of I-Type
                m.d.comb += self.imm[0:12].eq(self.iimm)# put 12 bit iimm in first 12 bits of imm
                with m.If (self.imm[11] == 1):#check for sign extension, if it's 1 then convert (13 to 32) bits of imm to 1 otherwise 0
                    m.d.comb += self.imm[12:32].eq(0b11111111111111111111)

                with m.Else ():
                    m.d.comb += self.imm[12:32].eq(0b00000000000000000000)

                m.d.comb += [
                    self.we.eq(1),
                    self.op_b_sel.eq(1),
                    self.dmem_we.eq(0)
                ]

#=====================================< S-Type 23 >=====================================
            with m.Case(0b0100011): # opcode of S-Type
                m.d.comb += self.imm[0:12].eq(self.simm) #put 12 bit simm in first 12 bits of imm
                with m.If (self.imm[11] == 1):
                    m.d.comb += self.imm[12:32].eq(0b11111111111111111111)

                with m.Else ():
                    m.d.comb += self.imm[12:32].eq(0b00000000000000000000)

                m.d.comb += [
                    self.we.eq(0),
                    self.aluop.eq(0b0000),
                    self.op_b_sel.eq(1),
                    self.dmem_we.eq(1)                
                    ]
                
#=================================< ld_wd 3 >========================================
            with m.Case(0b0000011): # opcode of Load Instruction
                m.d.comb += self.imm[0:12].eq(self.iimm)
                with m.If (self.imm[11] == 1):
                    m.d.comb += self.imm[12:32].eq(0b11111111111111111111)

                with m.Else ():
                    m.d.comb += self.imm[12:32].eq(0b00000000000000000000)
                m.d.comb += [
                    self.ld_wd.eq(1),
                    self.aluop.eq(0b0000),
                    self.op_b_sel.eq(1),
                    self.we.eq(1)
                ]

#=====================================< U-Type 17 & 27 >=====================================
            with m.Case(0b0010111):     #AUIPC
                m.d.comb += self.imm[12:32].eq(self.uimm)
                m.d.comb += self.imm[0:12].eq(0b000000000000)

                m.d.comb += [
                    self.ld_adr.eq(1),   #TOP Level not Add pc
                    self.aluop.eq(0b0000),
                    self.op_b_sel.eq(1),
                    self.op_a_sel.eq(1),
                    self.we.eq(1)
                ]
            
            with m.Case(0b0110111):     #LUI
                m.d.comb += self.imm[12:32].eq(self.uimm)
                m.d.comb += self.imm[0:12].eq(0b000000000000)

                m.d.comb += [
                    self.ld_adr.eq(0),   #TOP Level Add pc
                    self.aluop.eq(0b0000),
                    self.op_b_sel.eq(1),
                    self.we.eq(1)
                ]

#=====================================< SB-Type 63 >=====================================
            with m.Case(0b1100011):
                m.d.comb += self.imm[0:13].eq(self.sbimm)

                with m.If (self.imm[12] == 1):
                    m.d.comb += self.imm[13:32].eq(0b1111111111111111111)

                with m.Else ():
                    m.d.comb += self.imm[13:32].eq(0b0000000000000000000)

                m.d.comb += [
                    self.br.eq(1),
                    self.aluop.eq(0b0000),
                    self.op_a_sel.eq(2),
                    self.op_b_sel.eq(1),
                ]

#=====================================< jalr & jal >=====================================
            with m.Case(0b1100111): # jalr
                m.d.comb += self.imm[0:12].eq(self.iimm)# put 12 bit iimm in first 12 bits of imm
                with m.If (self.imm[11] == 1):#check for sign extension, if it's 1 then convert (13 to 32) bits of imm to 1 otherwise 0
                    m.d.comb += self.imm[12:32].eq(0b11111111111111111111)

                with m.Else ():
                    m.d.comb += self.imm[12:32].eq(0b00000000000000000000)

                m.d.comb += [
                    self.ld_adr.eq(1),
                    self.op_a_sel.eq(0),
                    self.op_b_sel.eq(1),
                    self.aluop.eq(0b0000),
                    self.we.eq(1),
                ]

            with m.Case(0b1101111): # jal
                m.d.comb += self.imm[0:21].eq(self.ujimm)
                with m.If (self.imm[20] == 1):
                    m.d.comb += self.imm[21:32].eq(0b11111111111)

                with m.Else ():
                    m.d.comb += self.imm[21:32].eq(0b00000000000)

                m.d.comb += [
                    self.ld_adr.eq(1),
                    self.op_a_sel.eq(2),
                    self.op_b_sel.eq(1),
                    self.aluop.eq(0b0000),
                    self.we.eq(1),
                ]

        return m

### Fetch

In [None]:
from amaranth import *

class FetchUnit(Elaboratable):
    def __init__(self):
        self.branch = Signal()
        self.branch_tar = Signal(32)
        self.pc = Signal(32)

    def elaborate(self, platform):
        m = Module()
        with m.If (self.branch):
            m.d.sync += self.pc.eq(self.branch_tar)

        with m.Else():
            # Increment the program counter.
            m.d.sync += self.pc.eq(self.pc + 4)

        return m

### Branch

In [None]:
from amaranth import *

class branch(Elaboratable):        

    def __init__(self):
        self.op1 = Signal(32) # connect with reg file out1
        self.op2 = Signal(32) # connect with reg file out2
        self.func3 = Signal(3) # connect with control unit func3
        self.br_out = Signal()

    def elaborate(self, platform):
        m = Module()

        with m.If(self.func3 == 0b000): #beq
            m.d.comb += self.br_out.eq(self.op1 == self.op2)

        with m.Elif(self.func3 == 0b001): #bne
            m.d.comb += self.br_out.eq(self.op1 != self.op2)

        with m.Elif(self.func3 == 0b100): #blt
            m.d.comb += self.br_out.eq(self.op1 < self.op2)

        with m.Elif(self.func3 == 0b101): #bge 
            m.d.comb += self.br_out.eq(self.op1 >= self.op2)

        with m.Elif(self.func3 == 0b110): #bltu
            m.d.comb += self.br_out.eq(self.op1.as_unsigned() < self.op2.as_unsigned())

        with m.Elif(self.func3 == 0b111): #bgeu
            m.d.comb += self.br_out.eq(self.op1.as_unsigned() >= self.op2.as_unsigned())
    
        return m

### Register File

In [None]:
from amaranth import *

class regfile(Elaboratable):
     def __init__(self):
          self.rs1 = Signal(5)
          self.rs2 = Signal(5)
          self.rd = Signal(5)
          self.rf_out1 = Signal(32)
          self.rf_out2 = Signal(32)
          self.wb_data = Signal(32)
          self.regfile = Memory(width = 32, depth = 32)
          self.we = Signal()
          
     def elaborate(self, platform):
        
        m = Module()
        m.d.comb +=self.rf_out1.eq(self.regfile[self.rs1]),
        m.d.comb +=self.rf_out2.eq(self.regfile[self.rs2])
        with m.If(self.we == 1):
            m.d.sync += self.regfile[self.rd].eq(self.wb_data)

        return m


## Testing

 A testbench for RISC-V Thunder Core is available [here](https://github.com/merledu/rv-thunder/tree/main/test).

 ## About Amaranth HDL

 Amaranth HDL is a hardware description language (which was previously known as nMigen) used for designing digital circuits and systems. It allows hardware engineers to specify the behavior and structure of digital designs, which can be synthesized into actual hardware using tools like [Yosys](https://github.com/YosysHQ/yosys) or translated into Verilog code. It's used for FPGA and ASIC design.
 
 ## Prerequisites

 Before working on this project, ensure you have the following prerequisites:
 * Python's library [Amaranth HDL](https://amaranth-lang.org/docs/amaranth/latest/) 
 * [iVerilog](https://github.com/steveicarus/iverilog)
 * [GTKWave](https://gtkwave.sourceforge.net/) 

 ## Amaranth HDL docs

 Install Amaranth HDL and other platforms (GTKWave, etc) using the installation method and also clone git given in [Language guide](https://amaranth-lang.org/docs/amaranth/latest/)
 
  For a basic understanding of Amaranth HDL use [Robert Baruch's introduction](https://github.com/RobertBaruch/nmigen-tutorial)

 ## Acknowledgement
 We want to express our gratitude to the RISC-V community for their valuable contributions to the open-source hardware ecosystem. Additionally, thanks to the Amaranth HDL developers for providing a platform for hardware design.
 
 Here is the link of [rv-thunder](https://github.com/merledu/rv-thunder)