# Getting set up

In [None]:
#@title Setting up the notebook

### Installing dependencies
!pip install openai tiktoken

!apt-get update
!apt-get install -y iverilog

In [None]:
import os
from google.colab import userdata
# os.environ["OPENAI_API_KEY"] = userdata.get('openai_api_key') # Removed this line as a hardcoded key is provided below.
os.environ["OPENAI_API_KEY"] = "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"

In [None]:
#@title Utility functions

import sys
import os
import openai
import tiktoken
from abc import ABC, abstractmethod
import re
import getopt
import json

################################################################################
### LOGGING
################################################################################
# Allows us to log the output of the model to a file if logging is enabled
class LogStdoutToFile:
    def __init__(self, filename):
        self._filename = filename
        self._original_stdout = sys.stdout

    def __enter__(self):
        if self._filename:
            sys.stdout = open(self._filename, 'w')
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self._filename:
            sys.stdout.close()
        sys.stdout = self._original_stdout

################################################################################
### CONFIG & ARGS
################################################################################
def load_config(config_file="config.json"):
    """Load and validate the configuration from the specified JSON file."""
    with open(config_file, 'r') as file:
        config = json.load(file)

    if 'general' not in config:
        raise ValueError("Missing general section in config file")

    config_values = config['general']

    # Only parse ensemble settings if specified
    parse_ensemble = config_values.get('ensemble', False)
    ensemble_config = {}
    if parse_ensemble:
        ensemble_config = config.get('ensemble', {})

    #return config_values
    return config_values, ensemble_config


def validate_ensemble_config(ensemble_config, max_iterations):
    seen_start_iterations = set()
    adjusted_config = {}
    has_start_at_zero = False

    for model_name, model_info in ensemble_config.items():
        start_iteration = model_info['start_iteration']

        # Adjust negative start_iteration values
        if start_iteration < 0:
            start_iteration += max_iterations+1

        # Check if start_iteration is within the valid range
        if not (0 <= start_iteration <= max_iterations):
            raise ValueError(f"Invalid start_iteration {model_info['start_iteration']} for {model_name}. "
                             f"Must be within the range of 0 to {max_iterations} or valid negative index.")

        # Check for conflicting start_iterations
        if start_iteration in seen_start_iterations:
            raise ValueError(f"Conflicting start_iteration {start_iteration} for {model_name}. "
                             f"Another model already uses this start iteration.")
        seen_start_iterations.add(start_iteration)

        # Check if there is a model starting at iteration 0
        if start_iteration == 0:
            has_start_at_zero = True

        # Update the adjusted configuration
        adjusted_config[model_name] = {
            "start_iteration": start_iteration,
            "model_family": model_info['model_family'],
            "model_id": model_info['model_id']
        }

        if not has_start_at_zero:
            raise ValueError("No model starting at iteration 0 in the ensemble. One model must start at iteration 0.")

    return adjusted_config


def parse_args_and_config():
    """Parse command-line arguments and merge them with configuration file values."""
    usage = """Usage: auto_create_verilog.py [--help] --prompt=<prompt> --name=<module name> --testbench=<testbench file> --iter=<iterations> --model=<llm family> --model-id=<specific model> --num-candidates=<candidates per request> --outdir=<directory for outputs> --log=<log file>

	-h|--help: Prints this usage message

	-p|--prompt: The initial design prompt for the Verilog module

	-n|--name: The module name, must match the testbench expected module name

	-t|--testbench: The testbench file to be run

	-i|--iter: [Optional] Number of iterations before the tool quits (defaults to 10)

	-m|--model: The LLM family to use. Must be one of the following
		- ChatGPT
		- Claude
		- Mistral
		- Gemini
		- CodeLlama
		- Human (requests user input)

	--model-id: The specific model to use for the model family

	--num-candidates: The number of candidates to rank per tree level

	-o|--outdir: Directory to place all run-specific files in

	-l|--log: [Optional] Log the output of the model to the given file
"""

    config_file = "config.json"

    # Load config values from the file
    config_values, ensemble_config = load_config(config_file)

    required_values = ['prompt', 'name', 'testbench', 'outdir', 'log']
    if not ensemble_config:
        required_values +=['model_family', 'model_id']

    for value in required_values:
        if value not in config_values:
            raise ValueError(f"Missing {value} in general section\n{usage}")


    # general values for optional config values
    if 'num_candidates' not in config_values:
        config_values['num_candidates'] = 1
    if 'iterations' not in config_values:
        config_values['iterations'] = 10


    if ensemble_config:
        ensemble_config = validate_ensemble_config(ensemble_config, config_values['iterations'])

    # Ensure outdir exists
    if config_values['outdir']:
        os.makedirs(config_values['outdir'], exist_ok=True)

    logfile = os.path.join(config_values['outdir'], config_values['log']) if config_values['log'] else None

    #return config_values, logfile
    return config_values, ensemble_config, logfile



################################################################################
### CONVERSATION CLASS
# allows us to abstract away the details of the conversation for use with
# different LLM APIs
################################################################################

class Conversation:
    def __init__(self, log_file=None):
        self.messages = []
        self.log_file = log_file

        if self.log_file and os.path.exists(self.log_file):
            open(self.log_file, 'w').close()

    def add_message(self, role, content):
        """Add a new message to the conversation."""
        self.messages.append({'role': role, 'content': content})

        if self.log_file:
            with open(self.log_file, 'a') as file:
                file.write(f"{role}: {content}\n")

    def get_messages(self):
        """Retrieve the entire conversation."""
        return self.messages

    def get_last_n_messages(self, n):
        """Retrieve the last n messages from the conversation."""
        return self.messages[-n:]

    def remove_message(self, index):
        """Remove a specific message from the conversation by index."""
        if index < len(self.messages):
            del self.messages[index]

    def get_message(self):
        """Retrieve a specific message from the conversation by index."""
        return self.messages[index] if index < len(self.messages) else None

    def clear_messages(self):
        """Clear all messages from the conversation."""
        self.messages = []

    def __str__(self):
        """Return the conversation in a string format."""
        return "\n".join([f"{msg['role']}: {msg['content']}" for msg in self.messages])

################################################################################
### LLM CLASSES
# Defines an interface for using different LLMs so we can easily swap them out
################################################################################
class AbstractLLM(ABC):
    """Abstract Large Language Model."""

    def __init__(self):
        pass

    @abstractmethod
    def generate(self, conversation: Conversation, num_candidates=1):
        """Generate a response based on the given conversation."""
        pass


class ChatGPT(AbstractLLM):
    """ChatGPT Large Language Model."""

    def __init__(self, model_id="gpt-4o-mini"):
        super().__init__()
        openai.api_key=os.environ['OPENAI_API_KEY']
        self.client = openai.OpenAI()
        self.model_id = model_id

    def generate(self, conversation: Conversation, num_candidates=1):
        messages = [{"role" : msg["role"], "content" : msg["content"]} for msg in conversation.get_messages()]


        #print(f"model_id: {self.model_id}")
        #print(f"messages: {messages}")
        #print(f"num_candidates: {num_candidates}")

        response = self.client.chat.completions.create(
        model=self.model_id,
        n=num_candidates,
        messages=messages,
        temperature=0.1
)


        return [c.message.content for c in response.choices]

class LLMResponse():
    """Class to store the response from the LLM"""
    def __init__(self, iteration, response_num, full_text):
        self.iteration = iteration
        self.response_num = response_num

        self.full_text = full_text
        self.tokens = 0

        self.parsed_text = ""
        self.parsed_length = 0

        self.feedback = ""
        self.compiled = False
        self.rank = -3
        self.message = ""

    def set_parsed_text(self, parsed_text):
        self.parsed_text = parsed_text
        self.parsed_length = len(parsed_text)

    def parse_verilog(self):
        module_list = find_verilog_modules(self.full_text)
        if not module_list:
            print("No modules found in response")
            self.parsed_text = ""
        else:
            for module in module_list:
                self.parsed_text += module + "\n\n"
        self.parsed_length = len(self.parsed_text)

    def calculate_rank(self, outdir, module, testbench):
        filename = os.path.join(outdir, module + ".sv")
        vvp_file = os.path.join(outdir, module + ".vvp")

        compiler_cmd = f"iverilog -Wall -Winfloop -Wno-timescale -g2012 -s tb_sequence_detector -o {vvp_file} {filename} {testbench}"
        simulator_cmd = f"vvp -n {vvp_file}"

        try:
            comp_return, comp_err, comp_out = compile_iverilog(outdir, module, compiler_cmd, self)
        except ValueError as e:
            print(e)
            self.rank = -2
            return

        # -----------------------------
        # Compilation handling
        # -----------------------------
        if comp_return != 0:
            self.feedback = comp_err
            self.compiled = False
            print("Compilation error")
            print(comp_err)
            self.message = (
                "The design failed to compile. Please fix the module. "
                "The output of iverilog is as follows:\n" + comp_err
            )
            self.rank = -1
            return

        elif comp_err != "":
            self.feedback = comp_err
            self.compiled = True
            print("Compilation warning")
            print(comp_err)
            self.message = (
                "The design compiled with warnings. Please fix the module. "
                "The output of iverilog is as follows:\n" + comp_err
            )
            self.rank = -0.5
            return

        # -----------------------------
        # Simulation handling
        # -----------------------------
        sim_return, sim_err, sim_out = simulate_iverilog(simulator_cmd)

        print("Simulation output:")
        print(sim_out)

        # Success case (ChipChat testbench style)
        if "All test cases passed." in sim_out:
            self.compiled = True
            self.feedback = ""
            self.message = "The testbench completed successfully."
            self.rank = 1.0
            print("Testbench ran successfully")
            return

        # Failure case
        else:
            self.compiled = True
            self.feedback = sim_out
            self.message = (
                "The testbench simulated, but had errors. "
                "Please fix the module. Output:\n" + sim_out
            )
            self.rank = 0.0
            print("Simulation failed")
            return


################################################################################
### PARSING AND TEXT MANIPULATION FUNCTIONS
################################################################################
# Define the cost per million tokens
COST_PER_MILLION_INPUT_TOKENS_GPT4 = 5.0
COST_PER_MILLION_OUTPUT_TOKENS_GPT4 = 15.0

COST_PER_MILLION_INPUT_TOKENS_GPT4M = 0.15
COST_PER_MILLION_OUTPUT_TOKENS_GPT4M = 0.60

COST_PER_MILLION_INPUT_TOKENS_GPT = 0.50
COST_PER_MILLION_OUTPUT_TOKENS_GPT = 1.50

COST_PER_MILLION_INPUT_TOKENS_CLAUDE = 0.25
COST_PER_MILLION_OUTPUT_TOKENS_CLAUDE = 1.25

# Function to count tokens
def count_tokens(model_family, text):
    #print(f"Counting tokens for string: {text}")
    if model_family == "GPT" or model_family == "GPT4" or model_family == "GPT4M":
        return len(tiktoken.get_encoding("cl100k_base").encode(text))
    elif model_family == "claude":
        return anthropic.Client().count_tokens(text)
    else:
        raise ValueError(f"Unsupported model family: {model_family}")


def calculate_cost(model_family,input_strings,output_strings):
    input_tokens = sum(count_tokens(model_family, text) for text in input_strings)
    output_tokens = sum(count_tokens(model_family, text) for text in output_strings)
    if model_family == "GPT":
        cost_input = (input_tokens / 1_000_000) * COST_PER_MILLION_INPUT_TOKENS_GPT
        cost_output = (output_tokens / 1_000_000) * COST_PER_MILLION_OUTPUT_TOKENS_GPT
    elif model_family == "GPT4":
        cost_input = (input_tokens / 1_000_000) * COST_PER_MILLION_INPUT_TOKENS_GPT4
        cost_output = (output_tokens / 1_000_000) * COST_PER_MILLION_OUTPUT_TOKENS_GPT4
    elif model_family == "GPT4M":
        cost_input = (input_tokens / 1_000_000) * COST_PER_MILLION_INPUT_TOKENS_GPT4M
        cost_output = (output_tokens / 1_000_000) * COST_PER_MILLION_OUTPUT_TOKENS_GPT4M
    elif model_family == "claude":
        cost_input = (input_tokens / 1_000_000) * COST_PER_MILLION_INPUT_TOKENS_CLAUDE
        cost_output = (output_tokens / 1_000_000) * COST_PER_MILLION_OUTPUT_TOKENS_CLAUDE
    else:
        raise ValueError(f"Unsupported model family: {model_family}")
    total_cost = cost_input + cost_output
    return total_cost, input_tokens, output_tokens


def format_message(role, content):
    return f"\n{{role : '{role}', content : '{content}'}}"

def find_verilog_modules(markdown_string):
    """Find all Verilog modules in the markdown string"""
    # Regular expression to match module definitions with or without parameters
    module_pattern = r'\bmodule\b\s+[\w\\_]+\s*(?:#\s*\([^)]*\))?\s*\([^)]*\)\s*;.*?endmodule\b'
    # Find all matches in the input string
    matches = re.findall(module_pattern, markdown_string, re.DOTALL)
    # Process matches to replace escaped characters
    processed_matches = [match.replace('\\_', '_') for match in matches]
    return processed_matches

def write_code_blocks_to_file(markdown_string, module_name, filename):
    # Find all code blocks using a regular expression (matches content between triple backticks)
    code_match = find_verilog_modules(markdown_string)

    if not code_match:
        print("No code blocks found in response")
        exit(3)

    # Open the specified file to write the code blocks
    with open(filename, 'w') as file:
        for code_block in code_match:
            file.write(code_block)
            file.write('\n')


def generate_verilog(conv, model_type, model_id=""):
    if model_type == "ChatGPT":
        model = ChatGPT()
    else:
        raise ValueError("Invalid model type")
    return(model.generate(conv))

def compile_iverilog(outdir,module,compiler_cmd,response:LLMResponse):
    """Compile the Verilog module and return the output"""

    filename = os.path.join(outdir,module+".sv")
    write_code_blocks_to_file(response.parsed_text, "module", filename)

    attempt = 0
    while attempt < 3:
        try:
            proc = subprocess.run(compiler_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=120)
            return proc.returncode, proc.stderr, proc.stdout
        except subprocess.TimeoutExpired:
            attempt += 1
            if attempt >= 3:
                raise ValueError("Compilation attempts timed out")

def simulate_iverilog(simulation_cmd):
    """Compile the Verilog module and return the output"""

    attempt = 0
    while attempt < 3:
        try:
            proc = subprocess.run(simulation_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, timeout=120)
            return proc.returncode, proc.stderr, proc.stdout
        except subprocess.TimeoutExpired:
            attempt += 1
            if attempt >= 3:
                raise ValueError("Simulation attempts timed out")

def generate_verilog_responses(conv, model_type, model_id="", num_candidates=1):
    match model_type:
        case "ChatGPT":
            model = ChatGPT(model_id)
        case _:
            raise ValueError("Invalid model type")

    return(model.generate(conversation=conv, num_candidates=num_candidates))

def get_iteration_ensemble(iteration, ensemble_config):

    sorted_ensemble = sorted(ensemble_config.values(), key=lambda x: x['start_iteration'], reverse=True)

    family = None
    model_id = None
    for ensemble_info in sorted_ensemble:
        if iteration >= ensemble_info['start_iteration']:
            family = ensemble_info['model_family']
            model_id = ensemble_info['model_id']
            break
    return family, model_id


In [None]:
import subprocess
import sys
import os

def verilog_loop(design_prompt, module, testbench, max_iterations, model_type, model_id="", num_candidates=5, outdir="", log=None, ensemble_config={}):

    if outdir != "":
        outdir = outdir + "/"

    conv = Conversation(log_file=log)

    #conv.add_message("system", "You are a Verilog engineering tool. Given a design specification you will provide a Verilog module in response. Given errors in that design you will provide a completed fixed module. Only complete functional models should be given. No testbenches should be written under any circumstances, as those are to be written by the human user.")
    conv.add_message("system", """
    You are an expert Verilog RTL engineer.

    Strict rules:
    - Use plain Verilog-2001 only.
    - Do NOT use typedef.
    - Do NOT use enum.
    - Do NOT use logic type.
    - Use reg and wire only.
    - Use parameter for state encoding.
    - Use posedge clk.
    - Active-low synchronous reset.
    - Fully synthesizable.

    You must implement an 8-state deterministic FSM.

    Detect exactly this sequence:

    001 → 101 → 110 → 000 → 110 → 110 → 011 → 101

    CRITICAL:
    The output sequence_found must be asserted
    in the SAME clock cycle that the final 3'b101 is received.

    It must NOT be delayed to the next clock cycle.

    Return ONLY one complete Verilog module inside triple backticks.
    """)





    #with open(testbench, 'r') as file: testbench_text = file.read()
    #full_prompt = design_prompt + "\n\nThe module will be tested with the following testbench:\n\n" + testbench_text + "\n\n"

    with open(testbench, 'r') as file:
      testbench_text = file.read()

    full_prompt = design_prompt + "\n\nThe module will be tested with the following testbench:\n\n" + testbench_text

    conv.add_message("user", full_prompt)


    success = False
    timeout = False

    iterations = 0

    global_max_response = LLMResponse(-3,-3,"")


    ##############################

    while not (success or timeout):


        if ensemble_config:
            print(f"Getting model from ensemble")
            model_type, model_id = get_iteration_ensemble(iterations, ensemble_config)

        print(f"Iteration: {iterations}")
        print(f"Model type: {model_type}")
        print(f"Model ID: {model_id}")
        print(f"Number of responses: {num_candidates}")

        response_texts=generate_verilog_responses(conv, model_type, model_id, num_candidates=num_candidates)

        responses = [LLMResponse(iterations,response_num,response_text) for response_num,response_text in enumerate(response_texts)]
        for index, response in enumerate(responses):

            response_outdir = os.path.join(outdir, f"iter{str(iterations)}/response{index}/")
            if not os.path.exists(response_outdir):
                os.makedirs(response_outdir)


            response_cost = 0
            input_tokens = 0
            output_tokens = 0

            response.parse_verilog()
            if response.parsed_text == "":
                response.rank = -2
                response.message = "No modules found in response"
            else:
                response.calculate_rank(response_outdir, module, testbench)

            input_messages = [msg['content'] for msg in conv.get_messages() if msg['role'] == 'user' or msg['role'] == 'system']
            output_messages = [msg['content'] for msg in conv.get_messages() if msg['role'] == 'assistant']
            output_messages.append(response.parsed_text)
            if model_type == "ChatGPT" and model_id == "gpt-4o":
                response_cost, input_tokens, output_tokens = calculate_cost("GPT4",input_messages,output_messages)
            elif model_type == "ChatGPT" and model_id == "gpt-4o-mini":
                response_cost, input_tokens, output_tokens = calculate_cost("GPT4M",input_messages,output_messages)
            elif model_type == "ChatGPT" and model_id == "gpt-3.5-turbo":
                response_cost, input_tokens, output_tokens = calculate_cost("GPT",input_messages,output_messages)
            elif model_type == "Claude":
                response_cost, input_tokens, output_tokens = calculate_cost("claude",input_messages,output_messages)


            print(f"Cost for response {index}: ${response_cost:.10f}")

            with open(os.path.join(response_outdir,f"log.txt"), 'w') as file:
                file.write('\n'.join(str(i) for i in conv.get_messages()))
                file.write(format_message("assistant", response.full_text))
                file.write('\n\n Iteration rank: ' + str(response.rank) + '\n') ## FIX

                file.write(f"\n Model: {model_id}")
                file.write(f"\n Input tokens: {input_tokens}")
                file.write(f"\n Output tokens: {output_tokens}")
                file.write(f"\nTotal cost: ${response_cost:.10f}\n")

        ## RANK RESPONSES
        max_rank_response = max(responses, key=lambda resp: (resp.rank, -resp.parsed_length))
        if max_rank_response.rank > global_max_response.rank:
            global_max_response = max_rank_response
        elif max_rank_response.rank == global_max_response.rank and max_rank_response.parsed_length > global_max_response.parsed_length:
            global_max_response = max_rank_response

        print(f"Response ranks: {[resp.rank for resp in responses]}")
        print(f"Response lengths: {[resp.parsed_length for resp in responses]}")

        conv.add_message("assistant", max_rank_response.parsed_text)

        if max_rank_response.rank == 1:
            success = True



################################


        if not success:
            if iterations > 0:
                conv.remove_message(2)
                conv.remove_message(2)

            #with open(testbench, 'r') as file: testbench_text = file.read()
            #message = message + "\n\nThe testbench used for these results is as follows:\n\n" + testbench_text
            #message = message + "\n\nCommon sources of errors are as follows:\n\t- Use of SystemVerilog syntax which is not valid with iverilog\n\t- The reset must be made asynchronous active-low\n"
            feedback_message = f"""
            The previous FSM failed simulation.

            Simulation error:
            {max_rank_response.feedback}

            The FSM asserted sequence_found too early
            or did not complete the full 8-cycle sequence.

            Fix the FSM so that:
            - It advances through exactly 8 states.
            - It only asserts on the final state transition.
            - It returns to S0 on any mismatch.

            Return the full corrected module.
            """

            conv.add_message("user", feedback_message)


        if iterations >= max_iterations:
            timeout = True

        iterations += 1

    return global_max_response



In [None]:
config = {
  "general": {
    "name": "sequence_detector",
    "prompt": "prompt.txt",
    "testbench": "sequence_detector_tb.v",
    "iterations": 6,
    "num_candidates": 3,
    "outdir": "out",
    "log": "trajectory.log",
    "model_family": "ChatGPT",
    "model_id": "gpt-4o-mini"
  }
}


In [None]:
with open("config.json", "w") as f:
    json.dump(config, f, indent=2)

print(json.dumps(config, indent=2))


In [None]:
design_prompt = """
Implement module sequence_detector.

Interface:

module sequence_detector(
    input  wire clk,
    input  wire reset_n,
    input  wire [2:0] data,
    output reg sequence_found
);

Design an 8-state FSM (S0–S7).

Transitions:
S0: expect 001
S1: expect 101
S2: expect 110
S3: expect 000
S4: expect 110
S5: expect 110
S6: expect 011
S7: expect 101

Only when the final 101 is received in S7,
sequence_found must be 1 for one cycle.

All other cycles must output 0.
Return to S0 on mismatch.
"""


In [None]:
!curl -L -o sequence_detector_tb.v \
https://raw.githubusercontent.com/FCHXWH823/LLM4ChipDesign/main/VerilogGenBenchmark/TestBench/sequence_detector_tb.v

with open("sequence_detector_tb.v") as f:
    print(f.read())


In [None]:
config_values, ensemble_config, logfile = parse_args_and_config()

result = verilog_loop(
    design_prompt=design_prompt,
    module=config_values["name"],
    testbench=config_values["testbench"],
    max_iterations=config_values["iterations"],
    model_type=config_values["model_family"],
    model_id=config_values["model_id"],
    num_candidates=config_values["num_candidates"],
    outdir=config_values["outdir"],
    log=logfile,
    ensemble_config=ensemble_config
)

print("Final Best Rank:", result.rank)