Skip to content

Commit

Permalink
Add transformation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Feb 26, 2019
1 parent b4946cf commit 9ea8c68
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion tests/_program/transformation_tests.py
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from qupulse._program.transformation import LinearTransformation, Transformation, IdentityTransformation,\
ChainedTransformation, chain_transformations
ChainedTransformation, ParallelConstantChannelTransformation, chain_transformations


class TransformationStub(Transformation):
Expand Down Expand Up @@ -234,6 +234,48 @@ def test_chain(self):
chain_transformations.assert_called_once_with(*trafos, trafo)


class ParallelConstantChannelTransformationTests(unittest.TestCase):
def test_init(self):
channels = {'X': 2, 'Y': 4.4}

trafo = ParallelConstantChannelTransformation(channels)

self.assertEqual(trafo._channels, channels)
self.assertTrue(all(isinstance(v, float) for v in trafo._channels.values()))

self.assertEqual(trafo.compare_key, (('X', 2.), ('Y', 4.4)))

self.assertEqual(trafo.get_input_channels(set()), set())
self.assertEqual(trafo.get_input_channels({'X'}), set())
self.assertEqual(trafo.get_input_channels({'Z'}), {'Z'})
self.assertEqual(trafo.get_input_channels({'X', 'Z'}), {'Z'})

self.assertEqual(trafo.get_output_channels(set()), {'X', 'Y'})
self.assertEqual(trafo.get_output_channels({'X'}), {'X', 'Y'})
self.assertEqual(trafo.get_output_channels({'X', 'Z'}), {'X', 'Y', 'Z'})

def test_trafo(self):
channels = {'X': 2, 'Y': 4.4}
trafo = ParallelConstantChannelTransformation(channels)

n_points = 17
time = np.arange(17, dtype=float)

expected_overwrites = {'X': np.full((n_points,), 2.),
'Y': np.full((n_points,), 4.4)}

empty_input_result = trafo(time, {})
np.testing.assert_equal(empty_input_result, expected_overwrites)

z_input_result = trafo(time, {'Z': np.sin(time)})
np.testing.assert_equal(z_input_result, {'Z': np.sin(time), **expected_overwrites})

x_input_result = trafo(time, {'X': np.cos(time)})
np.testing.assert_equal(empty_input_result, expected_overwrites)

x_z_input_result = trafo(time, {'X': np.cos(time), 'Z': np.sin(time)})
np.testing.assert_equal(z_input_result, {'Z': np.sin(time), **expected_overwrites})


class TestChaining(unittest.TestCase):
def test_identity_result(self):
Expand Down

0 comments on commit 9ea8c68

Please sign in to comment.