diff --git a/magma/interface.py b/magma/interface.py index 74a4696a3e..3413bd9c78 100644 --- a/magma/interface.py +++ b/magma/interface.py @@ -1,74 +1,70 @@ -from __future__ import division from itertools import chain from collections import OrderedDict +from .conversions import array from .ref import AnonRef, InstRef, DefnRef -from .t import Type, Kind, In, Out, Flip -from .bit import BitKind, MakeBit +from .t import Type, Kind from .port import INPUT, OUTPUT, INOUT -#from .bit import * from .clock import ClockType, ClockTypes from .array import ArrayType from .tuple import TupleType from .compatibility import IntegerTypes, StringTypes + __all__ = ['DeclareInterface'] __all__ += ['Interface'] __all__ += ['InterfaceKind'] -# flatten an iterable of iterables to list -def flatten(l): - return list(chain(*l)) - -# -# parse argument declaration of the form -# -# (name0, type0, name1, type1, ..., namen, typen) -# -def parse(decl): - #print(decl) - n = len(decl) - assert n % 2 == 0 - - names = [] - ports = [] - for i in range(0,n,2): - name = decl[i] # name - if not name: - name = i//2 - port = decl[i+1] # type - assert isinstance(port, Kind) or isinstance(port, Type) +def _flatten(l): + """ + Flat an iterable of iterables to list. + """ + return list(chain(*l)) - names.append(name) - ports.append(port) +def parse(decl): + """ + Parse argument declaration of the form: + + (name0, type0, name1, type1, ..., namen, typen) + """ + if len(decl) % 2: + raise ValueError(f"Expected even number of arguments, got {len(decl)}") + + names = decl[::2] + ports = decl[1::2] + # If name is empty, convert to the index. + names = [name if name else str(i) for i, name in enumerate(names)] + # Check that all ports are given as instances of Kind or Type. + if not all(isinstance(port, (Kind, Type)) for port in ports): + raise ValueError(f"Expected kinds or types, got {ports}") return names, ports -# -# Abstract Base Class for an Interface -# -class _Interface(Type): +class _Interface(Type): + """ + Abstract Base Class for an Interface. + """ def __str__(self): return str(type(self)) def __repr__(self): s = "" for name, input in self.ports.items(): - if input.isinput(): - output = input.value() - if isinstance(output, ArrayType) \ - or isinstance(output, TupleType): - if not output.iswhole(output.ts): - for i in range(len(input)): - iname = repr( input[i] ) - oname = repr( output[i] ) - s += 'wire({}, {})\n'.format(oname, iname) - continue - iname = repr( input ) - oname = repr( output ) - s += 'wire({}, {})\n'.format(oname, iname) + if not input.isinput(): + continue + output = input.value() + if isinstance(output, (ArrayType, TupleType)): + if not output.iswhole(output.ts): + for i in range(len(input)): + iname = repr(input[i]) + oname = repr(output[i]) + s += f"wire({oname}, {iname})\n" + continue + iname = repr(input) + oname = repr(output) + s += f"wire({oname}, {iname})\n" return s @classmethod @@ -82,125 +78,102 @@ def __len__(self): return len(self.ports.keys()) def __getitem__(self, key): - if isinstance(key, int): - if isinstance(key,slice): - return array([self[i] for i in range(*key.indices(len(self)))]) - else: - n = len(self) - assert -n < key and key < n, "key: %d, self.N: %d" %(key,len(self)) - return self.arguments()[key] - else: - assert isinstance(key, str) + if isinstance(key, str): return self.ports[key] + if isinstance(key, int): + return self.arguments()[key] + if isinstance(key, slice): + return array([self[i] for i in range(*key.indices(len(self)))]) + raise ValueError(f"Expected key as str, int, or slice, got {key} " + f"({type(key)})") - # return all the argument ports def arguments(self): - return [port for name, port in self.ports.items()] + """Return all the argument ports.""" + return list(self.ports.values()) - # return all the argument input ports def inputs(self, include_clocks=False): - return [port for name, port in self.ports.items() \ - if port.isinput() and (not isinstance(port, ClockTypes) or include_clocks) ] -# name not in ['SET', 'CIN']] + """Return all the argument input ports.""" + fn = lambda port: port.isinput() and \ + (not isinstance(port, ClockTypes) or include_clocks) + return list(filter(fn, self.ports.values())) - # return all the argument output ports def outputs(self): - return [port for name, port in self.ports.items() if port.isoutput()] - + """Return all the argument output ports.""" + return list(filter(lambda port: port.isoutput(), self.ports.values())) - # return all the arguments as name, port def args(self): - return flatten([name, port] for name, port in self.ports.items()) + """Return all the arguments as name, port.""" + return _flatten(self.ports.items()) - # return all the arguments as name, flip(port) - # same as the declaration def decl(self): - return flatten([name, type(port).flip()] \ - for name, port in self.ports.items() ) + """ + Return all the arguments as name, flip(port) (same as the declaration). + """ + return _flatten([name, type(port).flip()] + for name, port in self.ports.items()) - - # return all the input arguments as name, port def inputargs(self): - return flatten( \ - [name, port] for name, port in self.ports.items() \ - if port.isinput() and not isinstance(port, ClockTypes) ) -# name not in ['SET', 'CIN']] ) + """Return all the input arguments as name, port.""" + return _flatten([name, port] for name, port in self.ports.items() + if port.isinput() and not isinstance(port, ClockTypes)) - # return all the output arguments as name, port def outputargs(self): - return flatten( [name, port] for name, port in self.ports.items() \ - if port.isoutput() ) + """Return all the output arguments as name, port.""" + return _flatten([name, port] for name, port in self.ports.items() + if port.isoutput()) - # return all the clock arguments as name, port def clockargs(self): - return flatten( [name, port] for name, port in self.ports.items() \ - if isinstance(port, ClockTypes) ) -# or name in ['SET'] ] ) + """Return all the clock arguments as name, port.""" + return _flatten([name, port] for name, port in self.ports.items() + if isinstance(port, ClockTypes)) - # return all the clock argument names def clockargnames(self): - return [name for name, port in self.ports.items() \ - if isinstance(port, ClockTypes) ] -# or name in ['SET'] ] - + """Return all the clock argument names.""" + return [name for name, port in self.ports.items() + if isinstance(port, ClockTypes)] - # return True if this interface has a Clock def isclocked(self): - for name, port in self.ports.items(): - if isinstance(port, ClockType): - return True - return False - -# -# Interface class -# -# This function assumes the port instances are provided -# -# e.g. Interface('I0', In(Bit)(), 'I1', In(Bit)(), 'O', Out(Bit)()) -# + """Return True if this interface has a Clock.""" + return any(isinstance(port, ClockType) for + port in self.ports.values()) + class Interface(_Interface): + """Interface class.""" def __init__(self, decl, renamed_ports={}): - + """ + This function assumes the port instances are provided: + e.g. Interface('I0', In(Bit)(), 'I1', In(Bit)(), 'O', Out(Bit)()) + """ names, ports = parse(decl) - - # setup ports args = OrderedDict() - for name, port in zip(names, ports): - if isinstance(name, IntegerTypes): - name = str(name) # convert integer to str, e.g. 0 to "0" - + name = str(name) # convert integer to str, e.g. 0 to "0" if name in renamed_ports: - raise NotImplementedError() - + raise NotImplementedError("Port renaming not implemented") args[name] = port - self.ports = args def __str__(self): - return f'Interface({", ".join(f"{k}: {v}" for k, v in self.ports.items())})' - - -# -# _DeclareInterface class -# -# First, an Interface is declared -# -# Interface = DeclareInterface('I0', In(Bit), 'I1', In(Bit), 'O', Out(Bit)) -# -# Then, the interface is instanced -# -# interface = Interface() -# + s = ", ".join(f"{k}: {v}" for k, v in self.ports.items()) + return f"Interface({s})" + + class _DeclareInterface(_Interface): - def __init__(self, renamed_ports={}, inst=None, defn=None): + """ + _DeclareInterface class. - # parse the class Interface declaration + First, an Interface is declared: + Interface = DeclareInterface('I0', In(Bit), 'I1', In(Bit), 'O', Out(Bit)) + + Then, the interface is instanced: + interface = Interface() + """ + def __init__(self, renamed_ports={}, inst=None, defn=None): + # Parse the class Interface declaration. names, ports = parse(self.Decl) args = OrderedDict() - for name, port in zip(names, ports): if inst: ref = InstRef(inst, name) elif defn: ref = DefnRef(defn, name) @@ -208,14 +181,13 @@ def __init__(self, renamed_ports={}, inst=None, defn=None): if name in renamed_ports: ref.name = renamed_ports[name] - if defn: port = port.flip() - args[name] = port(name=ref) self.ports = args + class InterfaceKind(Kind): def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) @@ -238,27 +210,20 @@ def __str__(cls): args = [] for i, arg in enumerate(cls.Decl): if i % 2 == 0: - args.append('"{}"'.format(arg)) + args.append(f"\"{arg}\"") else: args.append(str(arg)) - return ', '.join(args) + return ", ".join(args) def __eq__(cls, rhs): - if not isinstance(rhs, InterfaceKind): return False - - if cls.Decl != rhs.Decl: return False - return True + return cls.Decl == rhs.Decl __ne__=Kind.__ne__ __hash__=Kind.__hash__ -# -# Interface factory -# def DeclareInterface(*decl, **kwargs): - name = '%s(%s)' % ('Interface', ', '.join([str(a) for a in decl])) - #print('DeclareInterface', name) + """Interface factory function.""" + name = f"Interface({', '.join([str(a) for a in decl])})" dct = dict(Decl=decl, **kwargs) return InterfaceKind(name, (_DeclareInterface,), dct) - diff --git a/magma/port.py b/magma/port.py index 8eefaa38f4..5c29e6814c 100644 --- a/magma/port.py +++ b/magma/port.py @@ -1,39 +1,39 @@ -from magma.config import get_debug_mode +from .config import get_debug_mode from .logging import error, warning, get_source_line from .backend.util import make_relative -from .ref import DefnRef, InstRef + __all__ = ['INPUT', 'OUTPUT', 'INOUT'] __all__ += ['flip'] __all__ += ['Port'] + INPUT = 'input' OUTPUT = 'output' INOUT = 'inout' -def report_wiring_error(message, debug_info): - if debug_info: - error(f"\033[1m{make_relative(debug_info[0])}:{debug_info[1]}: {message}", - include_wire_traceback=True) - try: - error(get_source_line(debug_info[0], debug_info[1])) - except FileNotFoundError: - error(f" Could not find file {debug_info[0]}") - else: +def _report_wiring_messgae(fn, message, debug_info): + if not debug_info: error(message) + return + file = debug_info[0] + line = debug_info[1] + message = f"\033[1m{make_relative(file)}:{line}: {message}" + fn(message, include_wire_traceback=True) + try: + fn(get_source_line(file, line)) + except FileNotFoundError: + fn(f" Could not file file {file}") + + +def report_wiring_error(message, debug_info): + _report_wiring_messgae(error, message, debug_info) def report_wiring_warning(message, debug_info): - # TODO: Include wire traceback support - if debug_info: - warning(f"\033[1m{make_relative(debug_info[0])}:{debug_info[1]}: {message}") - try: - warning(get_source_line(debug_info[0], debug_info[1])) - except FileNotFoundError: - warning(f" Could not find file {debug_info[0]}") - else: - warning(message) + # TODO(rsetaluri): Include wire traceback support. + _report_wiring_messgae(warning, message, debug_info) def flip(direction): @@ -42,7 +42,8 @@ def flip(direction): elif direction == OUTPUT: return INPUT elif direction == INOUT: return INOUT -def mergewires(new, old, debug_info): + +def merge_wires(new, old, debug_info): oldinputs = set(old.inputs) newinputs = set(new.inputs) oldoutputs = set(old.outputs) @@ -55,70 +56,71 @@ def mergewires(new, old, debug_info): for o in oldoutputs - newoutputs: if len(new.outputs) > 0: outputs = [o.bit.debug_name for o in new.outputs] - report_wiring_error(f"Connecting more than one output ({outputs}) to an input `{i.bit.debug_name}`", debug_info) # noqa + report_wiring_error(f"Connecting more than one output ({outputs}) to " + f"an input `{i.bit.debug_name}`", debug_info) new.outputs.append(o) o.wires = new -def fast_mergewires(w, i, o): +def fast_merge_wires(w, i, o): w.inputs = i.wires.inputs + o.wires.inputs w.outputs = i.wires.outputs + o.wires.outputs w.inputs = list(set(w.inputs)) w.outputs = list(set(w.outputs)) if len(w.outputs) > 1: outputs = [o.bit.debug_name for o in w.outputs] - # use w.inputs[0] as i, similar to {i.bit.debug_name} - report_wiring_error(f"Connecting more than one output ({outputs}) to an input `{w.inputs[0].bit.debug_name}`", debug_info) # noqa + # Use w.inputs[0] as i, similar to {i.bit.debug_name}. + report_wiring_error(f"Connecting more than one output ({outputs}) to " + f"an input `{w.inputs[0].bit.debug_name}`", + debug_info) for p in w.inputs: p.wires = w for p in w.outputs: p.wires = w -# -# A Wire has a list of input and output Ports. -# + class Wire: + """ + A Wire has a list of input and output Ports. + """ def __init__(self): self.inputs = [] self.outputs = [] def connect( self, o, i , debug_info): + """ + Anon Ports are added to the input or output list of this wire. - # anon Ports are added to the input or output list of this wire - # - # connecting to a non-anonymous port to an anonymous port - # add the non-anonymous port to the wire associated with the - # anonymous port - - #print(str(o), o.anon(), o.bit.isinput(), o.bit.isoutput()) - #print(str(i), i.anon(), i.bit.isinput(), i.bit.isoutput()) + Connecting to a non-anonymous port to an anonymous port add the + non-anonymous port to the wire associated with the anonymous port. + """ if not o.anon(): - #assert o.bit.direction is not None if o.bit.isinput(): - report_wiring_error(f"Using `{o.bit.debug_name}` (an input) as an output", debug_info) + report_wiring_error(f"Using `{o.bit.debug_name}` (an input) as " + f"an output", debug_info) return if o not in self.outputs: if len(self.outputs) != 0: - warn_str = "Adding the output `{}` to the wire `{}` which already has output(s) `[{}]`".format(o.bit.debug_name, i.bit.debug_name, ", ".join(output.bit.debug_name for output in self.outputs)) - report_wiring_warning(warn_str, debug_info) # noqa - #print('adding output', o) + output_str = ", ".join([output.bit.debug_name \ + for output in self.outputs]) + msg = (f"Adding the output `{o.bit.debug_name}` to the " + f"wire `{i.bit.debug_name}` which already has " + f"output(s) `[{output_str}]`") + report_wiring_warning(msg, debug_info) self.outputs.append(o) if not i.anon(): - #assert i.bit.direction is not None if i.bit.isoutput(): - report_wiring_error(f"Using `{i.bit.debug_name}` (an output) as an input", debug_info) + report_wiring_error(f"Using `{i.bit.debug_name}` (an output) " + f"as an input", debug_info) return if i not in self.inputs: - #print('adding input', i) self.inputs.append(i) - # print(o.wires,i.wires,self,self.outputs,self.inputs) - - # always update wires + # Always update wires. o.wires = self i.wires = self @@ -132,23 +134,22 @@ def check(self): error("Input in the wire outputs: {}".format(o)) return False - # check that this wire is only driven by a single output + # Check that this wire is only driven by a single output. if len(self.outputs) > 1: error("Multiple outputs on a wire: {}".format(self.outputs)) return False return True -# -# Port implements wiring -# -# Each port is represented by a Bit() -# + class Port: - def __init__(self, bit): + """ + Ports implement wiring. + Each port is represented by a Bit(). + """ + def __init__(self, bit): self.bit = bit - self.wires = Wire() def __repr__(self): @@ -160,25 +161,17 @@ def __str__(self): def anon(self): return self.bit.anon() - # wire a port to a port def wire(i, o, debug_info): - #if o.bit.direction is None: - # o.bit.direction = OUTPUT - #if i.bit.direction is None: - # i.bit.direction = INPUT - - #print("Wiring", o.bit.direction, str(o), "->", i.bit.direction, str(i)) - + """ + Wire a port to a port. + """ if i.wires and o.wires and i.wires is not o.wires: - # print('merging', i.wires.inputs, i.wires.outputs) - # print('merging', o.wires.inputs, o.wires.outputs) w = Wire() if get_debug_mode(): - mergewires(w, i.wires, debug_info) - mergewires(w, o.wires, debug_info) + merge_wires(w, i.wires, debug_info) + merge_wires(w, o.wires, debug_info) else: - fast_mergewires(w, i, o) - # print('after merge', w.inputs, w.outputs) + fast_merge_wires(w, i, o) elif o.wires: w = o.wires elif i.wires: @@ -188,49 +181,44 @@ def wire(i, o, debug_info): w.connect(o, i, debug_info) - #print("after",o,"->",i, w) - - - # if the port is an input or inout, return the output - # if the port is an output, return the first input def trace(self): + """ + If the port is an input or inout, return the output. + If the port is an output, return the first input. + """ if not self.wires: return None if self in self.wires.inputs: if len(self.wires.outputs) < 1: - # print('Warning:', str(self), 'is not connected to an output') return None assert len(self.wires.outputs) == 1 return self.wires.outputs[0] if self in self.wires.outputs: if len(self.wires.inputs) < 1: - # print('Warning:', str(self), 'is not connected to an input') return None assert len(self.wires.inputs) == 1 return self.wires.inputs[0] return None - # if the port is in the inputs, return the output def value(self): + """ + If the port is in the inputs, return the output. + """ if not self.wires: return None if self in self.wires.inputs: if len(self.wires.outputs) < 1: - # print('Warning:', str(self), 'is not connected to an output') return None - #assert len(self.wires.outputs) == 1 return self.wires.outputs[0] return None - def driven(self): return self.value() is not None def wired(self): return self.trace() is not None - diff --git a/magma/ref.py b/magma/ref.py index fc4519eb65..31af690afb 100644 --- a/magma/ref.py +++ b/magma/ref.py @@ -1,7 +1,9 @@ from .compatibility import IntegerTypes + __all__ = ['AnonRef', 'InstRef', 'DefnRef', 'ArrayRef', 'TupleRef'] + class Ref: def __str__(self): return str(self.name) @@ -9,6 +11,7 @@ def __str__(self): def __repr__(self): return self.qualifiedname() + class AnonRef(Ref): def __init__(self, name=""): self.name = name @@ -22,63 +25,67 @@ def qualifiedname(self, sep='.'): def anon(self): return False if self.name else True + class InstRef(Ref): def __init__(self, inst, name): - assert inst - self.inst = inst # Inst + if not inst: + raise ValueError(f"Bad inst: {inst}") + self.inst = inst self.name = name - def qualifiedname(self, sep='.'): + def qualifiedname(self, sep="."): name = self.name if isinstance(self.name, IntegerTypes): - # Hack, Hack, Hack - if sep == '.': - return self.inst.name + '[%d]' % self.name + # Hack, Hack, Hack! + if sep == ".": + return f"{self.inst.name}[{self.name}]" return self.inst.name + sep + str(name) def anon(self): return False + class DefnRef(Ref): def __init__(self, defn, name): - assert defn - self.defn = defn # Definition + if not defn: + raise ValueError(f"Bad defn: {defn}") + self.defn = defn self.name = name - def qualifiedname(self, sep='.'): - if sep == '.': + def qualifiedname(self, sep="."): + if sep == ".": return self.defn.__name__ + sep + self.name - else: - return self.name + return self.name def anon(self): return False + class ArrayRef(Ref): def __init__(self, array, index): - self.array = array # Array + self.array = array self.index = index def __str__(self): return self.qualifiedname() - def qualifiedname(self, sep='.'): - return self.array.name.qualifiedname(sep=sep) + '[%d]' % self.index + def qualifiedname(self, sep="."): + return f"{self.array.name.qualifiedname(sep=sep)}[{self.index}]" def anon(self): return self.array.name.anon() + class TupleRef(Ref): def __init__(self, tuple, index): - self.tuple = tuple # Tuple + self.tuple = tuple self.index = index def __str__(self): return self.qualifiedname() - def qualifiedname(self, sep='.'): + def qualifiedname(self, sep="."): return self.tuple.name.qualifiedname(sep=sep) + sep + str(self.index) def anon(self): return self.tuple.name.anon() - diff --git a/magma/wire.py b/magma/wire.py index dc1742bd01..6f994228c3 100644 --- a/magma/wire.py +++ b/magma/wire.py @@ -1,45 +1,36 @@ -import inspect -from collections.abc import Sequence -from .port import INPUT, OUTPUT, INOUT from .compatibility import IntegerTypes -from .t import Type from .debug import debug_wire -from .logging import info, warning, error from .port import report_wiring_error + __all__ = ['wire'] @debug_wire def wire(o, i, debug_info): - - # Wire(o, Circuit) + # Wire(o, Circuit). if hasattr(i, 'interface'): i.wire(o, debug_info) return - # replace output Circuit with its output (should only be 1 output) + # Replace output Circuit with its output (should only be 1 output). if hasattr(o, 'interface'): - # if wiring a Circuit to a Port - # then circuit should have 1 output + # If wiring a Circuit to a Port then circuit should have 1 output. o_orig = o o = o.interface.outputs() if len(o) != 1: - report_wiring_error(f'Can only wire circuits with one output. Argument 0 to wire `{o_orig.debug_name}` has outputs {o}', debug_info) # noqa + report_wiring_error(f"Can only wire circuits with one output. " + f"Argument 0 to wire `{o_orig.debug_name}` has " + f"outputs {o}", debug_info) return o = o[0] - # if o is an input + # If o is an input. if not isinstance(o, IntegerTypes) and o.isinput(): - # if i is not an input + # If i is not an input. if isinstance(i, IntegerTypes) or not i.isinput(): - # flip i and o + # Flip i and o. i, o = o, i - #if hasattr(i, 'wire'): - # error('Wiring Error: The input must have a wire method - {} to {}'.format(o, i)) - # return - - # Wire(o, Type) + # Wire(o, Type). i.wire(o, debug_info) - diff --git a/tests/test_wire/test_arg.py b/tests/test_wire/test_arg.py index e92760b11a..2d7be25ad7 100644 --- a/tests/test_wire/test_arg.py +++ b/tests/test_wire/test_arg.py @@ -40,6 +40,20 @@ def test_pos(): compile("build/pos", main, output="verilog") assert check_files_equal(__file__, "build/pos.v", "gold/pos.v") + +def test_pos_slice(): + Buf = DeclareCircuit("Buf", "I0", In(Bit), "I1", In(Bit), "O", Out(Bits[2])) + + main = DefineCircuit("main", "I", In(Array[2, Bit]), "O", Out(Bit)) + buf = Buf() + wire(buf[0:2], main.I) + wire(buf.O, main.O) + EndDefine() + + compile("build/pos_slice", main, output="verilog") + assert check_files_equal(__file__, "build/pos.v", "gold/pos.v") + + def test_arg_array1(): def DefineAndN(n): name = 'AndN%d' % n