diff --git a/magma/bits.py b/magma/bits.py index b4b788670d..7e00c780c9 100644 --- a/magma/bits.py +++ b/magma/bits.py @@ -179,10 +179,10 @@ def __getitem__(cls, index): def __str__(cls): if cls.isinput(): - return "In(UInt({}))".format(cls.N) + return "In(UInt[{}])".format(cls.N) if cls.isoutput(): - return "Out(UInt({}))".format(cls.N) - return "UInt({})".format(cls.N) + return "Out(UInt[{}])".format(cls.N) + return "UInt[{}]".format(cls.N) def qualify(cls, direction): if cls.T.isoriented(direction): @@ -247,10 +247,10 @@ def __getitem__(cls, index): def __str__(cls): if cls.isinput(): - return "In(SInt({}))".format(cls.N) + return "In(SInt[{}])".format(cls.N) if cls.isoutput(): - return "Out(SInt({}))".format(cls.N) - return "SInt({})".format(cls.N) + return "Out(SInt[{}])".format(cls.N) + return "SInt[{}]".format(cls.N) def qualify(cls, direction): if cls.T.isoriented(direction): @@ -321,10 +321,10 @@ def __getitem__(cls, index): def __str__(cls): if cls.isinput(): - return "In(BFloat({}))".format(cls.N) + return "In(BFloat[{}])".format(cls.N) if cls.isoutput(): - return "Out(BFloat({}))".format(cls.N) - return "BFloat({})".format(cls.N) + return "Out(BFloat[{}])".format(cls.N) + return "BFloat[{}]".format(cls.N) def qualify(cls, direction): if cls.T.isoriented(direction): diff --git a/magma/util.py b/magma/util.py index 61264e2b9a..71ab72ed85 100644 --- a/magma/util.py +++ b/magma/util.py @@ -7,3 +7,31 @@ def BitOrBits(width): if not isinstance(width, int): raise ValueError(f"Expected width to be None or int, got {width}") return m.Bits[width] + + +def pretty_str(t): + if isinstance(t, m.TupleKind): + args = [] + for i in range(t.N): + key_str = str(t.Ks[i]) + val_str = pretty_str(t.Ts[i]) + indent = " " * 4 + val_str = f"\n{indent}".join(val_str.splitlines()) + args.append(f"{key_str} = {val_str}") + # Pretty print by using newlines + indent + joiner = ",\n " + result = joiner.join(args) + # Insert first newline + indent and last newline + result = "\n " + result + "\n" + s = f"Tuple({result})" + elif isinstance(t, m.BitsKind): + s = str(t) + elif isinstance(t, m.ArrayKind): + s = f"Array[{t.N}, {pretty_str(t.T)}]" + else: + s = str(t) + return s + + +def pretty_print_type(t): + print(pretty_str(t)) diff --git a/tests/test_type/test_pretty_print.py b/tests/test_type/test_pretty_print.py new file mode 100644 index 0000000000..de35dbde27 --- /dev/null +++ b/tests/test_type/test_pretty_print.py @@ -0,0 +1,63 @@ +import magma as m + + +def test_pretty_print_tuple(): + t = m.Tuple(a=m.Bit, b=m.Bit, c=m.Bit) + assert m.util.pretty_str(t) == """\ +Tuple( + a = Bit, + b = Bit, + c = Bit +)\ +""" + + +def test_pretty_print_tuple_recursive(): + t = m.Tuple(a=m.Bit, b=m.Bit, c=m.Bit) + u = m.Tuple(x=t, y=t) + assert m.util.pretty_str(u) == """\ +Tuple( + x = Tuple( + a = Bit, + b = Bit, + c = Bit + ), + y = Tuple( + a = Bit, + b = Bit, + c = Bit + ) +)\ +""" + + +def test_pretty_print_array_of_tuple(): + t = m.Tuple(a=m.Bit, b=m.Bit, c=m.Bit) + u = m.Array[3, t] + assert m.util.pretty_str(u) == """\ +Array[3, Tuple( + a = Bit, + b = Bit, + c = Bit +)]\ +""" + + +def test_pretty_print_array_of_nested_tuple(): + t = m.Tuple(a=m.Bits[5], b=m.UInt[3], c=m.SInt[4]) + u = m.Tuple(x=t, y=t) + v = m.Array[3, u] + assert m.util.pretty_str(v) == """\ +Array[3, Tuple( + x = Tuple( + a = Bits[5], + b = UInt[3], + c = SInt[4] + ), + y = Tuple( + a = Bits[5], + b = UInt[3], + c = SInt[4] + ) +)]\ +"""