diff --git a/conftest.py b/conftest.py index e97039e138..b157df43ef 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,5 @@ import pytest -from magma.circuit import magma_clear_circuit_cache +from magma.circuit import magma_clear_circuit_database from magma import clear_cachedFunctions import magma.backend.coreir_ @@ -8,6 +8,6 @@ def magma_test(): import magma.config magma.config.set_compile_dir('callee_file_dir') - magma_clear_circuit_cache() + magma_clear_circuit_database() clear_cachedFunctions() magma.backend.coreir_.__reset_context() diff --git a/magma/backend/coreir_.py b/magma/backend/coreir_.py index 5e2c76a6b1..dfcfcceb42 100644 --- a/magma/backend/coreir_.py +++ b/magma/backend/coreir_.py @@ -189,7 +189,8 @@ def compile_instance(self, instance, module_definition): lib = self.libs[instance.coreir_lib] logger.debug(instance.name, type(instance)) if instance.coreir_genargs is None: - if hasattr(instance, "wrappedModule"): + if hasattr(instance, "wrappedModule") and \ + instance.wrappedModule.context == self.context: module = instance.wrappedModule else: module = lib.modules[name] @@ -402,7 +403,8 @@ def compile_dependencies(self, defn): continue if key.is_definition: # don't try to compile if already have definition - if hasattr(key, 'wrappedModule'): + if hasattr(key, 'wrappedModule') and \ + key.wrappedModule.context == self.context: self.modules[key.name] = key.wrappedModule else: self.modules[key.name] = self.compile_definition(key) @@ -414,7 +416,8 @@ def compile(self, defn_or_declaration): if defn_or_declaration.is_definition: self.compile_dependencies(defn_or_declaration) # don't try to compile if already have definition - if hasattr(defn_or_declaration, 'wrappedModule'): + if hasattr(defn_or_declaration, 'wrappedModule') and \ + defn_or_declaration.wrappedModule.context == self.context: self.modules[defn_or_declaration.name] = defn_or_declaration.wrappedModule self.libs_used |= defn_or_declaration.coreir_wrapped_modules_libs_used else: diff --git a/magma/circuit.py b/magma/circuit.py index 6830f80fb9..d5c62157af 100644 --- a/magma/circuit.py +++ b/magma/circuit.py @@ -17,6 +17,7 @@ from .logging import warning from .port import report_wiring_warning from .is_definition import isdefinition +from .circuit_database import CircuitDatabase __all__ = ['AnonymousCircuitType'] __all__ += ['AnonymousCircuit'] @@ -26,7 +27,7 @@ __all__ += ['DeclareCircuit'] __all__ += ['DefineCircuit', 'EndDefine', 'EndCircuit'] __all__ += ['getCurrentDefinition'] -__all__ += ['magma_clear_circuit_cache'] +__all__ += ['magma_clear_circuit_database'] __all__ += ['CopyInstance'] __all__ += ['circuit_type_method'] @@ -35,6 +36,11 @@ circuit_type_method = namedtuple('circuit_type_method', ['name', 'definition']) +circuit_database = CircuitDatabase() + +def magma_clear_circuit_database(): + circuit_database.clear() + def circuit_to_html(cls): if isdefinition(cls): # Avoid circular dependency so dot backend can use passes @@ -138,6 +144,25 @@ def __repr__(cls): def _repr_html_(cls): return circuit_to_html(cls) + def rename(cls, new_name): + old_name = cls.name + cls.name = new_name + cls.coreir_name = new_name + cls.verilog_name = new_name + cls.__name__ = new_name + + # NOTE(rsetaluri): This is a very hacky way to try to rename wrapped + # verilog. We simply replace the first instance of "module " + # with "module ". This ignores the possibility of "module + # " existing anywhere else, most likely in comments etc. The + # more robust way to do this would to modify the AST directly and + # generate the new verilog code. + if cls.verilogFile: + find_str = f"module {old_name}" + replace_str = f"module {new_name}" + assert cls.verilogFile.find(find_str) != -1 + cls.verilogFile = cls.verilogFile.replace(find_str, replace_str, 1) + def find(cls, defn): name = cls.__name__ if not isdefinition(cls): @@ -389,12 +414,6 @@ def popDefinition(): else: currentDefinition = None -# a map from circuitDefinition names to circuit definition objects -definitionCache = {} - -def magma_clear_circuit_cache(): - definitionCache.clear() - class DefineCircuitKind(CircuitKind): def __new__(metacls, name, bases, dct): @@ -415,14 +434,6 @@ def __new__(metacls, name, bases, dct): self = CircuitKind.__new__(metacls, name, bases, dct) - if (hasattr(self, 'definition') or dct.get('is_definition', False) \ - or hasattr(self, 'wrappedModule')) and not getattr(self, '__magma_no_cache__', False): - if name in definitionCache: - return definitionCache[name] - else: - #print('creating',name) - definitionCache[name] = self - self.verilog = None self.verilogFile = None self.verilogLib = None @@ -511,8 +522,7 @@ def DefineCircuit(name, *decl, **args): coreir_genargs = args.get('coreir_genargs', None), coreir_configargs = args.get('coreir_configargs', None), default_kwargs = args.get('default_kwargs', {}), - renamed_ports = args.get('renamed_ports', {}), - __magma_no_cache__ = args.get('__magma_no_cache__', False)) + renamed_ports = args.get('renamed_ports', {})) currentDefinition = DefineCircuitKind( name, (Circuit,), dct) return currentDefinition @@ -522,7 +532,10 @@ def EndDefine(): debug_info = get_callee_frame_info() currentDefinition.end_circuit_filename = debug_info[0] currentDefinition.end_circuit_lineno = debug_info[1] + circuit_database.insert(currentDefinition) popDefinition() + else: + raise Exception("EndDefine called without Define/DeclareCircuit") EndCircuit = EndDefine diff --git a/magma/circuit_database.py b/magma/circuit_database.py new file mode 100644 index 0000000000..4dcb7ad017 --- /dev/null +++ b/magma/circuit_database.py @@ -0,0 +1,67 @@ +from abc import ABC, abstractmethod +import tempfile +import uuid +from .compile import compile +import coreir +from .logging import warning + + +class CircuitDatabaseInterface(ABC): + @abstractmethod + def insert(self, circuit): + pass + + @abstractmethod + def clear(self): + pass + + +class ConservativeCircuitDatabase(CircuitDatabaseInterface): + def insert(self, circuit): + name = circuit.name + new_name = name + "-" + str(uuid.uuid4()) + type(circuit).rename(circuit, new_name) + + def clear(self): + pass + + +class CircuitDatabase(CircuitDatabaseInterface): + class Entry: + def __init__(self, name): + self.name = name + self.circuits = {} + + def hash(self, circuit): + with tempfile.TemporaryDirectory() as tempdir: + try: + compile(tempdir + "/circuit", circuit, output="coreir", context=coreir.Context()) + json_str = open(tempdir + "/circuit.json").read() + except Exception as e: + warning(f"Could not compile circuit: '{str(e)}'. Uniquifying anyway.") + json_str = uuid.uuid4() + return hash(json_str) + + def add_circuit(self, circuit): + hash_ = self.hash(circuit) + if hash_ in self.circuits: + index = self.circuits[hash_][0] + else: + index = len(self.circuits) + self.circuits[hash_] = (index, circuit) + if index > 0: + type(circuit).rename(circuit, circuit.name + "_unq" + str(index)) + + def __repr__(self): + return repr(self.circuits) + + def __init__(self): + self.entries = {} + + def insert(self, circuit): + name = circuit.name + entry = self.entries.setdefault(name, CircuitDatabase.Entry(name)) + entry.add_circuit(circuit) + + def clear(self): + self.entries = {} diff --git a/magma/compile.py b/magma/compile.py index e275bf4c7b..455f5c49ad 100644 --- a/magma/compile.py +++ b/magma/compile.py @@ -87,7 +87,8 @@ def __compile_to_coreir(main, file_name, opts): # Underscore so our coreir module doesn't conflict with coreir bindings # package. from .backend import coreir_ - backend = coreir_.CoreIRBackend() + context = opts.get("context", None) + backend = coreir_.CoreIRBackend(context) backend.compile(main) if opts.get("passes", False): backend.context.run_passes(opts["passes"], ["global"]) @@ -131,7 +132,6 @@ def compile(basename, main, output='verilog', **kwargs): opts["output_verilog"] = True output = "coreir" - check_definitions_are_unique(main) if get_compile_dir() == 'callee_file_dir': (_, filename, _, _, _, _) = inspect.getouterframes(inspect.currentframe())[1] file_path = os.path.dirname(filename) diff --git a/magma/frontend/coreir_.py b/magma/frontend/coreir_.py index 3575beb3b4..ae556994ef 100644 --- a/magma/frontend/coreir_.py +++ b/magma/frontend/coreir_.py @@ -1,5 +1,5 @@ from magma.backend.coreir_ import CoreIRBackend -from magma.circuit import DefineCircuitKind, Circuit, definitionCache +from magma.circuit import DefineCircuitKind, Circuit from magma import cache_definition from coreir.generator import Generator @@ -18,8 +18,6 @@ def definition(cls): def DefineCircuitFromGeneratorWrapper(cirb: CoreIRBackend, namespace: str, generator: str, uniqueName: str, dependentNamespaces: list = [], genargs: dict = {}, runGenerators = True): - if uniqueName in definitionCache: - return definitionCache[uniqueName] moduleToWrap = cirb.context.import_generator(namespace,generator)(**genargs) if runGenerators: cirb.context.run_passes(["rungenerators"], [namespace] + dependentNamespaces) diff --git a/tests/test_cache.py b/tests/test_cache.py deleted file mode 100644 index fb535686c9..0000000000 --- a/tests/test_cache.py +++ /dev/null @@ -1,44 +0,0 @@ -import magma as m - - -def test_cache(): - class Main0(m.Circuit): - name = "Main" - IO = ["I", m.In(m.Bits(2)), "O", m.Out(m.Bits(2))] - @classmethod - def definition(io): - m.wire(io.I, io.O) - - class Main1(m.Circuit): - name = "Main" - IO = ["I", m.In(m.UInt(2)), "O", m.Out(m.UInt(2))] - @classmethod - def definition(io): - m.wire(io.I, io.O) - - assert Main0 is Main1, "Main1 should be the cached version of Main0 since it has the same name" - - -def test_no_cache(): - class Main0(m.Circuit): - name = "Main" - IO = ["I", m.In(m.Bits(2)), "O", m.Out(m.Bits(2))] - @classmethod - def definition(io): - m.wire(io.I, io.O) - - class Main1(m.Circuit): - __magma_no_cache__ = True - name = "Main" - IO = ["I", m.In(m.UInt(2)), "O", m.Out(m.UInt(2))] - @classmethod - def definition(io): - m.wire(io.I, io.O) - - assert Main0 is not Main1, "__magma_no_cache__ is set so they should not be the same" - - Main2 = m.DefineCircuit("Main", "I", m.In(m.SInt(2)), "O", - m.Out(m.SInt(2)), __magma_no_cache__=True) - m.EndDefine() - - assert Main0 is not Main2, "__magma_no_cache__ is set so they should not be the same" diff --git a/tests/test_circuit/test_define.py b/tests/test_circuit/test_define.py index 5b736227d0..aa5ce6b52c 100644 --- a/tests/test_circuit/test_define.py +++ b/tests/test_circuit/test_define.py @@ -118,8 +118,8 @@ def test_unwired_ports_warnings(caplog): m.compile("build/test_unwired_output", main) assert check_files_equal(__file__, f"build/test_unwired_output.v", f"gold/test_unwired_output.v") - assert caplog.records[0].msg == "main.And2_inst0.I0 not connected" - assert caplog.records[1].msg == "main.O is unwired" + assert caplog.records[-2].msg == "main.And2_inst0.I0 not connected" + assert caplog.records[-1].msg == "main.O is unwired" def test_2d_array_error(caplog):