diff --git a/magma/array.py b/magma/array.py index f2b472ff67..4a67d965a6 100644 --- a/magma/array.py +++ b/magma/array.py @@ -257,6 +257,10 @@ def const(self): def flatten(self): return sum([t.flatten() for t in self.ts], []) + @classmethod + def concat(cls, *args): + return concat(*args) + # def Array(N, T): # assert isinstance(N, IntegerTypes) @@ -268,4 +272,4 @@ def flatten(self): # Workaround for circular dependency -from .conversions import array # nopep8 +from .conversions import array, concat # nopep8 diff --git a/tests/test_type/test_bits.py b/tests/test_type/test_bits.py index 5dface62e3..90b0cbe740 100644 --- a/tests/test_type/test_bits.py +++ b/tests/test_type/test_bits.py @@ -1,76 +1,135 @@ -from magma import * +""" +Test the `m.Bits` type +""" -Array2 = Array[2, Bit] -Array4 = Array[4, Bit] +import magma as m -def test(): +ARRAY2 = m.Array[2, m.Bit] +ARRAY4 = m.Array[4, m.Bit] - A2 = Bits[2] - B2 = In(Bits[2]) - C2 = Out(Bits[2]) - assert A2 == A2 - assert B2 == B2 - assert C2 == C2 - assert A2 != B2 - assert A2 != C2 - assert B2 != C2 +def test_bits_basic(): + """ + Basic bits tests + """ + bits_2 = m.Bits[2] + bits_in_2 = m.In(m.Bits[2]) + bits_out_2 = m.Out(m.Bits[2]) + assert bits_2 == m.Bits[2] + assert bits_in_2 == m.In(bits_2) + assert bits_out_2 == m.Out(bits_2) - A4 = Bits[4] - assert A4 == Array4 - assert A2 != A4 + assert bits_2 != bits_in_2 + assert bits_2 != bits_out_2 + assert bits_in_2 != bits_out_2 + + bits_4 = m.Bits[4] + assert bits_4 == ARRAY4 + assert bits_2 != bits_4 def test_val(): - Array4In = In(Bits[4]) - Array4Out = Out(Bits[4]) + """ + Test instances of Bits[4] work correctly + """ + bits_4_in = m.In(m.Bits[4]) + bits_4_out = m.Out(m.Bits[4]) + + assert m.Flip(bits_4_in) == bits_4_out + assert m.Flip(bits_4_out) == bits_4_in - assert Flip(Array4In) == Array4Out - assert Flip(Array4Out) == Array4In + a_0 = bits_4_out(name='a0') + print(a_0) - a0 = Array4Out(name='a0') - print(a0) + a_1 = bits_4_in(name='a1') + print(a_1) - a1 = Array4In(name='a1') - print(a1) + a_1.wire(a_0) - a1.wire(a0) + b_0 = a_1[0] + assert b_0 is a_1[0], "getitem failed" - b0 = a1[0] + a_3 = a_1[0:2] + assert a_3 == a_1[0:2], "getitem of slice failed" - a3 = a1[0:2] def test_flip(): - Bits2 = Bits[2] - AIn = In(Bits2) - AOut = Out(Bits2) + """ + Test flip interface + """ + bits_2 = m.Bits[2] + a_in = m.In(bits_2) + a_out = m.Out(bits_2) - print(AIn) - print(AOut) + print(a_in) + print(a_out) - assert AIn != Array2 - assert AOut != Array2 - assert AIn != AOut + assert a_in != ARRAY2 + assert a_out != ARRAY2 + assert a_in != a_out - A = In(AOut) - assert A == AIn - print(A) + in_a_out = m.In(a_out) + assert in_a_out == a_in + print(in_a_out) - A = Flip(AOut) - assert A == AIn + a_out_flipped = m.Flip(a_out) + assert a_out_flipped == a_in - A = Out(AIn) - assert A == AOut + out_a_in = m.Out(a_in) + assert out_a_in == a_out + + a_in_flipped = m.Flip(a_in) + assert a_in_flipped == a_out + print(a_in_flipped) - A = Flip(AIn) - assert A == AOut - print(A) def test_construct(): - a1 = bits([1,1]) - print(type(a1)) - assert isinstance(a1, BitsType) + """ + Test `m.bits` interface + """ + a_1 = m.bits([1, 1]) + print(type(a_1)) + assert isinstance(a_1, m.BitsType) + def test_const(): - Data = Bits[16] - zero = Data(0) + """ + Test constant constructor interface + """ + data = m.Bits[16] + zero = data(0) + assert zero == m.bits(0, 16) + + +def test_setitem_bfloat(): + """ + Test constant constructor interface + """ + class TestCircuit(m.Circuit): + IO = ["I", m.In(m.BFloat[16]), "O", m.Out(m.BFloat[16])] + @classmethod + def definition(io): + a = io.I + b = m.BFloat.concat(a[0:-1], m.bits(0, 1)) + io.O <= b + print(repr(TestCircuit)) + assert repr(TestCircuit) == """\ +TestCircuit = DefineCircuit("TestCircuit", "I", In(BFloat(16)), "O", Out(BFloat(16))) +wire(TestCircuit.I[0], TestCircuit.O[0]) +wire(TestCircuit.I[1], TestCircuit.O[1]) +wire(TestCircuit.I[2], TestCircuit.O[2]) +wire(TestCircuit.I[3], TestCircuit.O[3]) +wire(TestCircuit.I[4], TestCircuit.O[4]) +wire(TestCircuit.I[5], TestCircuit.O[5]) +wire(TestCircuit.I[6], TestCircuit.O[6]) +wire(TestCircuit.I[7], TestCircuit.O[7]) +wire(TestCircuit.I[8], TestCircuit.O[8]) +wire(TestCircuit.I[9], TestCircuit.O[9]) +wire(TestCircuit.I[10], TestCircuit.O[10]) +wire(TestCircuit.I[11], TestCircuit.O[11]) +wire(TestCircuit.I[12], TestCircuit.O[12]) +wire(TestCircuit.I[13], TestCircuit.O[13]) +wire(TestCircuit.I[14], TestCircuit.O[14]) +wire(0, TestCircuit.O[15]) +EndCircuit()\ +""" # noqa