Skip to content

Commit

Permalink
Fixes for version 0.16.0 (#73)
Browse files Browse the repository at this point in the history
* draft fix and test

* add more cases

* comments

* tests for bit2dec

* changelog and version number

* Update pennylane_forest/wavefunction.py

Co-authored-by: Theodor <theodor@xanadu.ai>

* change test name and docstring

* Update CHANGELOG.md

Co-authored-by: Theodor <theodor@xanadu.ai>

* dtype

Co-authored-by: Theodor <theodor@xanadu.ai>
  • Loading branch information
antalszava and thisac committed Jun 15, 2021
1 parent 9ef214f commit 4b1c899
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 34 deletions.
13 changes: 6 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# Release 0.16.0-dev

### New features

### Improvements
# Release 0.16.0

### Bug fixes

### Breaking changes
* Fixed a bug caused by the `expand_state` method always assuming that

### Documentation
inactive wires are the least significant bits.
[(#73)](https://github.com/PennyLaneAI/pennylane-forest/pull/73)

### Contributors

This release contains contributions from (in alphabetical order):

Antal Száva.

---

# Release 0.15.0
Expand Down
2 changes: 1 addition & 1 deletion pennylane_forest/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Version number (convention major.minor.patch[-label])
"""

__version__ = "0.16.0-dev"
__version__ = "0.16.0"
44 changes: 34 additions & 10 deletions pennylane_forest/wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,22 @@ def apply(self, operations, **kwargs):
self._state = self._state.reshape([2] * len(self._active_wires)).T.flatten()
self.expand_state()

@staticmethod
def bit2dec(x):
"""Auxiliary method that converts a bitstring to a decimal integer
using the PennyLane convention of bit ordering.
Args:
x (Iterable): bit string
Returns:
int: decimal value of the bitstring
"""
y = 0
for i, j in enumerate(x[::-1]):
y += j << i
return y

def expand_state(self):
"""The pyQuil wavefunction simulator initializes qubits dymnically as they are requested.
This method expands the state to the full number of wires in the device."""
Expand All @@ -98,20 +114,28 @@ def expand_state(self):
# all wires in the device have been initialised
return

num_inactive_wires = len(self.wires) - len(self._active_wires)

# translate active wires to the device's labels
device_active_wires = self.map_wires(self._active_wires)

# place the inactive subsystems in the vacuum state
other_subsystems = np.zeros([2 ** num_inactive_wires])
other_subsystems[0] = 1
inactive_wires = [x for x in range(len(self.wires)) if x not in device_active_wires]

# initialize the entire new expanded state to zeros
expanded_state = np.zeros([2 ** len(self.wires)], dtype=self.C_DTYPE)

# expand the state of the device into a length-num_wire state vector
expanded_state = np.kron(self._state, other_subsystems).reshape([2] * self.num_wires)
expanded_state = np.moveaxis(
expanded_state, range(len(device_active_wires)), device_active_wires.labels
# gather the bit strings for the subsystem made up of the active qubits
subsystem_bit_strings = self.states_to_binary(
np.arange(2 ** len(self._active_wires)), len(self._active_wires)
)
expanded_state = expanded_state.flatten()

for string, amplitude in zip(subsystem_bit_strings, self._state):
for w in inactive_wires:

# expand the bitstring by inserting a zero bit for each inactive qubit
string = np.insert(string, w, 0)

# calculate the decimal value of the bit string, that gives the
# index of the amplitude in the state vector
decimal_val = self.bit2dec(string)
expanded_state[decimal_val] = amplitude

self._state = expanded_state
61 changes: 45 additions & 16 deletions tests/test_wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@
class TestWavefunctionBasic(BaseTest):
"""Unit tests for the wavefunction simulator."""

def test_expand_state(self):
"""Test that a multi-qubit state is correctly expanded for a N-qubit device"""
dev = plf.WavefunctionDevice(wires=3)

# expand a two qubit state to the 3 qubit device
dev._state = np.array([0, 1, 1, 0]) / np.sqrt(2)
dev._active_wires = Wires([0, 2])
dev.expand_state()
self.assertAllEqual(dev._state, np.array([0, 1, 0, 0, 1, 0, 0, 0]) / np.sqrt(2))

# expand a three qubit state to the 3 qubit device
dev._state = np.array([0, 1, 1, 0, 0, 1, 1, 0]) / 2
dev._active_wires = Wires([0, 1, 2])
dev.expand_state()
self.assertAllEqual(dev._state, np.array([0, 1, 1, 0, 0, 1, 1, 0]) / 2)

def test_var(self, tol, qvm):
"""Tests for variance calculation"""
dev = plf.WavefunctionDevice(wires=2)
Expand Down Expand Up @@ -411,3 +395,48 @@ def circuit(x, y, z):

expected_var = np.sqrt(1 / shots)
self.assertAlmostEqual(np.mean(runs), np.cos(a) * np.sin(b), delta=expected_var)

class TestExpandState(BaseTest):
"""Test the expand_state method"""

def test_expand_state(self):
"""Test that a multi-qubit state is correctly expanded for a N-qubit device"""
dev = plf.WavefunctionDevice(wires=3)

# expand a two qubit state to the 3 qubit device
dev._state = np.array([0, 1, 1, 0]) / np.sqrt(2)
dev._active_wires = Wires([0, 2])
dev.expand_state()
self.assertAllEqual(dev._state, np.array([0, 1, 0, 0, 1, 0, 0, 0]) / np.sqrt(2))

# expand a three qubit state to the 3 qubit device
dev._state = np.array([0, 1, 1, 0, 0, 1, 1, 0]) / 2
dev._active_wires = Wires([0, 1, 2])
dev.expand_state()
self.assertAllEqual(dev._state, np.array([0, 1, 1, 0, 0, 1, 1, 0]) / 2)


@pytest.mark.parametrize("bitstring, dec", [([1], 1), ([0,0,1], 1), ([0,1,0], 2), ([1,1,1], 7)])
def test_bit2dec(self, bitstring, dec):
"""Test that the bit2dec method produces the correct decimal values for
bitstrings"""
assert plf.WavefunctionDevice.bit2dec(bitstring) == dec

@pytest.mark.parametrize("num_wires", [3,4] )
@pytest.mark.parametrize("wire_idx1, wire_idx2", [(0, 0), (1, 0), (1, 1)] )
def test_expanded_matches_default_qubit(self, wire_idx1, wire_idx2, num_wires):
"""Test that the statevector is expanded correctly by checking the
output of a QNode with default.qubit."""
dev_default_qubit = qml.device('default.qubit', wires=num_wires)
dev_wavefunction = qml.device('forest.wavefunction', wires=num_wires)

def circuit():
qml.RY(np.pi/2, wires=[wire_idx1 + 1])
qml.RY(np.pi, wires=[wire_idx2 + 0])
qml.CNOT(wires=[wire_idx1 + 1, wire_idx2 + 0])
qml.CNOT(wires=[wire_idx1 + 1, wire_idx2 + 0])
return qml.probs(range(num_wires))

qnode1 = qml.QNode(circuit, dev_wavefunction)
qnode2 = qml.QNode(circuit, dev_default_qubit)
assert np.allclose(qnode1(), qnode2())

0 comments on commit 4b1c899

Please sign in to comment.