-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
109 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from abc import ABC, ABCMeta, abstractmethod | ||
from dataclasses import dataclass | ||
import inspect | ||
import typing | ||
from collections import OrderedDict | ||
|
||
|
||
__all__ = ["GeneratorBase"] | ||
__all__ += ["wrap_generator"] | ||
|
||
|
||
class _DataclassMeta(ABCMeta): | ||
def __new__(metacls, name, bases, dct): | ||
cls = super().__new__(metacls, name, bases, dct) | ||
cls = dataclass(cls) | ||
cls._cache_ = {} | ||
|
||
elaborate = cls.elaborate | ||
if getattr(elaborate, "__isabstractmethod__", False): | ||
return cls | ||
|
||
fields = cls.__dataclass_fields__ | ||
def _elaborate_wrapped(self_): | ||
key = frozenset({getattr(self_, f) for f in fields}) | ||
elaborated = cls._cache_.get(key, None) | ||
if elaborated is None: | ||
elaborated = elaborate(self_) | ||
cls._cache_[key] = elaborated | ||
return elaborated | ||
cls.elaborate = _elaborate_wrapped | ||
|
||
return cls | ||
|
||
|
||
class GeneratorBase(metaclass=_DataclassMeta): | ||
@abstractmethod | ||
def elaborate(self): | ||
pass | ||
|
||
|
||
def wrap_generator(generator, name=None): | ||
sig = inspect.signature(generator) | ||
params = OrderedDict() | ||
defaults = OrderedDict() | ||
for param_name, param in sig.parameters.items(): | ||
typ = typing.Any | ||
if param.annotation is not inspect.Parameter.empty: | ||
typ = param.annotation | ||
params[param_name] = typ | ||
if param.default is not inspect.Parameter.empty: | ||
defaults[param_name] = param.default | ||
if name is None: | ||
name = generator.__name__ | ||
bases = (GeneratorBase,) | ||
def _elaborate(self_): | ||
args = [getattr(self_, p) for p in sig.parameters] | ||
return generator(*args) | ||
dct = { | ||
"elaborate": _elaborate, | ||
"__annotations__": dict(params), | ||
"__module__": generator.__module__, | ||
} | ||
dct.update(defaults) | ||
cls = type(name, bases, dct) | ||
return dataclass(cls) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import magma as m | ||
|
||
|
||
class SimpleGenerator(m.GeneratorBase): | ||
width: int = 16 | ||
|
||
def elaborate(self): | ||
class _Simple(m.Circuit): | ||
IO = ["I", m.In(m.Bits[self.width]), | ||
"O", m.Out(m.Bits[self.width]),] | ||
name = f"Simple{self.width}" | ||
|
||
@classmethod | ||
def definition(io): | ||
m.wire(io.O, io.I) | ||
|
||
return _Simple | ||
|
||
|
||
def test_simple_generator(): | ||
gen = SimpleGenerator(width=10) | ||
|
||
# Check a simple invocation of the generator. | ||
Simple10 = gen.elaborate() | ||
assert repr(Simple10) == """\ | ||
Simple10 = DefineCircuit("Simple10", "I", In(Bits[10]), "O", Out(Bits[10])) | ||
wire(Simple10.I, Simple10.O) | ||
EndCircuit()""" | ||
|
||
# Check that changing the parameter on the generator instance results in the | ||
# correct elaboration. | ||
gen.width = 20 | ||
Simple20 = gen.elaborate() | ||
assert repr(Simple20) == """\ | ||
Simple20 = DefineCircuit("Simple20", "I", In(Bits[20]), "O", Out(Bits[20])) | ||
wire(Simple20.I, Simple20.O) | ||
EndCircuit()""" | ||
|
||
# Check that caching works as expected. | ||
assert Simple20 is not Simple10 | ||
gen.width = 10 | ||
assert gen.elaborate() is Simple10 |