Skip to content

Commit

Permalink
Add staged generator class
Browse files Browse the repository at this point in the history
  • Loading branch information
rsetaluri committed Mar 29, 2019
1 parent 4021009 commit c44466e
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 0 deletions.
2 changes: 2 additions & 0 deletions magma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def set_mantle_target(t):
from .is_primitive import isprimitive
from .is_definition import isdefinition

from .generator import *

from .product import Product


Expand Down
65 changes: 65 additions & 0 deletions magma/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
import inspect
import typing
from collections import OrderedDict


__all__ = ["GeneratorBase"]
__all__ += ["wrap_generator"]


class _DataclassMeta(ABCMeta):
def __new__(metacls, name, bases, dct):
cls = super().__new__(metacls, name, bases, dct)
cls = dataclass(cls)
cls._cache_ = {}

elaborate = cls.elaborate
if getattr(elaborate, "__isabstractmethod__", False):
return cls

fields = cls.__dataclass_fields__
def _elaborate_wrapped(self_):
key = frozenset({getattr(self_, f) for f in fields})
elaborated = cls._cache_.get(key, None)
if elaborated is None:
elaborated = elaborate(self_)
cls._cache_[key] = elaborated
return elaborated
cls.elaborate = _elaborate_wrapped

return cls


class GeneratorBase(metaclass=_DataclassMeta):
@abstractmethod
def elaborate(self):
pass


def wrap_generator(generator, name=None):
sig = inspect.signature(generator)
params = OrderedDict()
defaults = OrderedDict()
for param_name, param in sig.parameters.items():
typ = typing.Any
if param.annotation is not inspect.Parameter.empty:
typ = param.annotation
params[param_name] = typ
if param.default is not inspect.Parameter.empty:
defaults[param_name] = param.default
if name is None:
name = generator.__name__
bases = (GeneratorBase,)
def _elaborate(self_):
args = [getattr(self_, p) for p in sig.parameters]
return generator(*args)
dct = {
"elaborate": _elaborate,
"__annotations__": dict(params),
"__module__": generator.__module__,
}
dct.update(defaults)
cls = type(name, bases, dct)
return dataclass(cls)
42 changes: 42 additions & 0 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import magma as m


class SimpleGenerator(m.GeneratorBase):
width: int = 16

def elaborate(self):
class _Simple(m.Circuit):
IO = ["I", m.In(m.Bits[self.width]),
"O", m.Out(m.Bits[self.width]),]
name = f"Simple{self.width}"

@classmethod
def definition(io):
m.wire(io.O, io.I)

return _Simple


def test_simple_generator():
gen = SimpleGenerator(width=10)

# Check a simple invocation of the generator.
Simple10 = gen.elaborate()
assert repr(Simple10) == """\
Simple10 = DefineCircuit("Simple10", "I", In(Bits[10]), "O", Out(Bits[10]))
wire(Simple10.I, Simple10.O)
EndCircuit()"""

# Check that changing the parameter on the generator instance results in the
# correct elaboration.
gen.width = 20
Simple20 = gen.elaborate()
assert repr(Simple20) == """\
Simple20 = DefineCircuit("Simple20", "I", In(Bits[20]), "O", Out(Bits[20]))
wire(Simple20.I, Simple20.O)
EndCircuit()"""

# Check that caching works as expected.
assert Simple20 is not Simple10
gen.width = 10
assert gen.elaborate() is Simple10

0 comments on commit c44466e

Please sign in to comment.