Skip to content

Commit

Permalink
Add missing transformation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Aug 18, 2018
1 parent 6409bf2 commit da7f069
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/_program/transformation_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def test_output_channels(self):
chans = {'a', 'b'}
self.assertIs(IdentityTransformation().get_output_channels(chans), chans)

def test_input_channels(self):
chans = {'a', 'b'}
self.assertIs(IdentityTransformation().get_input_channels(chans), chans)

def test_chain(self):
trafo = TransformationStub()
self.assertIs(IdentityTransformation().chain(trafo), trafo)
Expand Down Expand Up @@ -164,6 +168,22 @@ def test_get_output_channels(self):
get_output_channels_1.assert_called_once_with({1})
get_output_channels_2.assert_called_once_with({2})

def test_get_input_channels(self):
trafos = TransformationStub(), TransformationStub(), TransformationStub()
chained = ChainedTransformation(*trafos)
chans = {1}, {2}, {3}

# note reverse trafos order
with mock.patch.object(trafos[2], 'get_input_channels', return_value=chans[0]) as get_input_channels_0,\
mock.patch.object(trafos[1], 'get_input_channels', return_value=chans[1]) as get_input_channels_1,\
mock.patch.object(trafos[0], 'get_input_channels', return_value=chans[2]) as get_input_channels_2:
outs = chained.get_input_channels({0})

self.assertIs(outs, chans[2])
get_input_channels_0.assert_called_once_with({0})
get_input_channels_1.assert_called_once_with({1})
get_input_channels_2.assert_called_once_with({2})

def test_call(self):
trafos = TransformationStub(), TransformationStub(), TransformationStub()
chained = ChainedTransformation(*trafos)
Expand Down Expand Up @@ -212,6 +232,15 @@ def test_single_transformation(self):
self.assertIs(chain_transformations(trafo), trafo)
self.assertIs(chain_transformations(trafo, IdentityTransformation()), trafo)

def test_denesting(self):
trafo = TransformationStub()
chained = ChainedTransformation(TransformationStub(), TransformationStub())

expected = ChainedTransformation(trafo, *chained.transformations, trafo)
result = chain_transformations(trafo, chained, trafo)

self.assertEqual(expected, result)

def test_chaining(self):
trafo = TransformationStub()

Expand Down

0 comments on commit da7f069

Please sign in to comment.