In [1]:
import torch
torch.manual_seed(0)

from chop.ir.graph.mase_graph import MaseGraph

from chop.passes.graph.analysis import (
    init_metadata_analysis_pass,
    add_common_metadata_analysis_pass,
    add_hardware_metadata_analysis_pass,
    report_node_type_analysis_pass,
)

from chop.passes.graph.transforms import (
    emit_verilog_top_transform_pass,
    emit_internal_rtl_transform_pass,
    emit_bram_transform_pass,
    emit_cocotb_transform_pass,
    quantize_transform_pass,
)

from chop.tools.logger import set_logging_verbosity

set_logging_verbosity("debug")

import toml
import torch
import torch.nn as nn
import os

# os.environ["PATH"] = "/vol/bitbucket/oa321/verilator/verilator/bin:" + os.environ["PATH"]
os.environ["MODULE"] = "top"
!verilator --version

  from .autonotebook import tqdm as notebook_tqdm
2025-03-07 15:22:00.252145: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741360920.266462  123314 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741360920.270368  123314 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-07 15:22:00.285960: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[32mINFO    [0m [34mSet logging level to debug[0m


Verilator 5.024 2024-04-05 rev v5.024


In [2]:
class MLP(torch.nn.Module):
    """
    Toy FC model for digit recognition on MNIST
    """

    def __init__(self) -> None:
        super().__init__()

        self.fc1 = nn.Linear(4, 8)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        x = torch.nn.functional.relu(self.fc1(x))
        return x

In [3]:
from chop.dataset import MaseDataModule, get_dataset_info
from chop.tools.logger import set_logging_verbosity

from chop.passes.graph.interface import save_node_meta_param_interface_pass
from chop.passes.graph.analysis import (
    report_node_meta_param_analysis_pass,
    profile_statistics_analysis_pass,
)
from chop.passes.graph import (
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
)
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir.graph.mase_graph import MaseGraph

from chop.models import get_model_info, get_model



In [4]:
from chop.dataset.nerf import get_nerf_dataset
from chop.dataset import get_dataset_info
from pathlib import Path

batch_size = 8
model_name = "nerf"
dataset_name = "nerf-lego"

dataset_info = get_dataset_info(dataset_name)
model_info = get_model_info(model_name)

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name='nerf',
    num_workers=0,
)

data_module.prepare_data()
data_module.setup()

In [6]:
input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="nerf",
    which_dataloader="train",
)

dummy_in = next(iter(input_generator))

In [7]:
print(dummy_in['rays'].shape)

torch.Size([8, 8])


In [20]:
from chop.models.nerf.nerf_vision import  NeRFVision

nerf = NeRFVision()
mg = MaseGraph(model=nerf)

mg.draw()

In [10]:
print(dummy_in['rays'].shape)

torch.Size([8, 8])


In [12]:

# Provide a dummy input for the graph so it can use for tracing

x = torch.randn((batch_size, 6, 6))
dummy_in['x'] = x

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(
    mg, {"dummy_in": dummy_in, "add_value": False}
)


[36mDEBUG   [0m [34mgraph():
    %x : [num_users=1] = placeholder[target=x]
    %split : [num_users=2] = call_function[target=torch.functional.split](args = (%x, [3, 3]), kwargs = {dim: -1})
    %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
    %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
    %pts_linears_0 : [num_users=1] = call_module[target=pts_linears.0](args = (%getitem,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%pts_linears_0,), kwargs = {inplace: False})
    %pts_linears_1 : [num_users=1] = call_module[target=pts_linears.1](args = (%relu,), kwargs = {})
    %relu_1 : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%pts_linears_1,), kwargs = {inplace: False})
    %pts_linears_2 : [num_users=1] = call_module[target=pts_linears.2](args = (%relu_1,), kwargs = {})
    %relu_2 : [num_users=1] = ca

In [13]:
config_file = os.path.join(
    os.path.abspath(""),
    "..",
    "configs",
    "tests",
    "quantize",
    "fixed.toml",
)
with open(config_file, "r") as f:
    quan_args = toml.load(f)["passes"]["quantize"]
mg, _ = quantize_transform_pass(mg, quan_args)

_ = report_node_type_analysis_pass(mg)

# Update the metadata
for node in mg.fx_graph.nodes:
    for arg, arg_info in node.meta["mase"]["common"]["args"].items():
        if isinstance(arg_info, dict):
            arg_info["type"] = "fixed"
            arg_info["precision"] = [8, 3]
    for result, result_info in node.meta["mase"]["common"]["results"].items():
        if isinstance(result_info, dict):
            result_info["type"] = "fixed"
            result_info["precision"] = [8, 3]

[32mINFO    [0m [34mInspecting graph [add_common_node_type_analysis_pass][0m
[32mINFO    [0m [34m
Node name      Fx Node op     Mase type            Mase op      Value type
-------------  -------------  -------------------  -----------  -------------------------------------------------------
x              placeholder    placeholder          placeholder  NA
split          call_function  implicit_func        split        float
getitem        call_function  builtin_func         getitem      float
getitem_1      call_function  builtin_func         getitem      float
pts_linears_0  call_module    module_related_func  linear       fixed
relu           call_function  module_related_func  relu         fixed
pts_linears_1  call_module    module_related_func  linear       fixed
relu_1         call_function  module_related_func  relu         fixed
pts_linears_2  call_module    module_related_func  linear       fixed
relu_2         call_function  module_related_func  relu         fixed
pts

In [19]:
mg, _ = add_hardware_metadata_analysis_pass(mg, pass_args={'max_parallelism': [4] * 3})

In [15]:
from pathlib import  Path

mg, _ = emit_verilog_top_transform_pass(mg)
# mg, _ = emit_internal_rtl_transform_pass(mg)

[32mINFO    [0m [34mEmitting Verilog...[0m


In [16]:
mg, _ = emit_bram_transform_pass(mg)

[32mINFO    [0m [34mEmitting BRAM...[0m
[36mDEBUG   [0m [34mEmitting DAT file for node: pts_linears_0, parameter: weight[0m
[36mDEBUG   [0m [34mROM module weight successfully written into /home/omar/.mase/top/hardware/rtl/pts_linears_0_weight_source.sv[0m
[36mDEBUG   [0m [34mInit data weight successfully written into /home/omar/.mase/top/hardware/rtl/pts_linears_0_weight_rom.dat[0m
[36mDEBUG   [0m [34mEmitting DAT file for node: pts_linears_0, parameter: bias[0m
[36mDEBUG   [0m [34mROM module bias successfully written into /home/omar/.mase/top/hardware/rtl/pts_linears_0_bias_source.sv[0m
[36mDEBUG   [0m [34mInit data bias successfully written into /home/omar/.mase/top/hardware/rtl/pts_linears_0_bias_rom.dat[0m
[36mDEBUG   [0m [34mEmitting DAT file for node: pts_linears_1, parameter: weight[0m
[36mDEBUG   [0m [34mROM module weight successfully written into /home/omar/.mase/top/hardware/rtl/pts_linears_1_weight_source.sv[0m
[36mDEBUG   [0m [34mInit 

In [17]:
mg, _ = emit_cocotb_transform_pass(mg)

[32mINFO    [0m [34mEmitting testbench...[0m


In [None]:
from chop.actions import simulate

simulate(skip_build=False, skip_test=False, waves=True)