In [62]:
import pennylane as qml
import numpy as np
import warnings as warn
backend = qml.device('default.qubit', wires=6)
tol = 1E-6

In [63]:
# Prepares a|110000> + b|001100> with a^2 + b^2 = 1, a,b real valued
def StatePrep(a,b,c,d=1):
    if np.abs(np.square(a) + np.square(b) + np.square(c) + np.square (d)- 1) > tol:
        if d == 1:
            pass
        else:
            warn.warn('Overdetermined parameters a and b: a^2 + b^2 + c^2 + d^2 != 1. Continuing with d = sqrt(1 - a^2 - b^2 - c^2)')
        d = np.sqrt(1 - np.square(a) - np.square(b) - np.square(c))
    
    qml.BasisState(np.array([0,0,1,1,0,0]), wires=[0,1,2,3,4,5])
    qml.DoubleExcitation(-1*np.arcsin(a)*2, wires=[2, 3, 0, 1])
    aBar = np.sqrt(1 - np.square(a))
    qml.DoubleExcitation(-1*np.arcsin(c/aBar)*2, wires=[2, 3, 4, 5])
    acBar = np.sqrt(1- np.square(a) - np.square(c))
    qml.ctrl(qml.SingleExcitation, control=3) (-1*np.arcsin(d/acBar)*2, wires=[2, 0])


    return qml.state()


def GenTestVals():
    v = np.random.uniform(size=4)
    v = v / np.linalg.norm(v)
    print(v)
    return v[0],v[1],v[2],v[3]


In [70]:
circuit = qml.QNode(StatePrep, backend)
a,b,c,d = GenTestVals()
output = circuit(a,b,c,d=d)
states = [(output[i], np.binary_repr(i, width=6))
          for i in range(len(output)) if output[i] != 0]
print(states)


[0.25254433 0.67117984 0.54174628 0.43846317]
[(tensor(0.54174628+0.j, requires_grad=True), '000011'), (tensor(0.67117984+0.j, requires_grad=True), '001100'), (tensor(0.43846317+0.j, requires_grad=True), '100100'), (tensor(0.25254433+0.j, requires_grad=True), '110000')]


In [69]:
def testValidInput():
    a,b,c,d = GenTestVals()
    resultDict = {'110000':a, '001100':b, '000011': c, '100100':d}
    output = circuit(a, b, c, d=d)
    states = [(output[i], np.binary_repr(i, width=6))
              for i in range(len(output)) if output[i] != 0]
    for state in states:
        assert resultDict[state[1]] - state[0] < tol

#def test_InvalidD():

for i in range(10):
    testValidInput()


[0.11500876 0.6038128  0.53030199 0.58392028]
[0.81480438 0.51710464 0.25974205 0.03508103]
[0.75217145 0.60071745 0.14228234 0.23050464]
[0.34473451 0.1946629  0.91829295 0.00158956]
[0.03111296 0.75964343 0.22528281 0.6092795 ]
[0.66721582 0.31585303 0.15638463 0.65620405]
[0.41650338 0.82839    0.11884604 0.35520496]
[0.23562847 0.41073251 0.6467498  0.59790695]
[0.36167711 0.65851655 0.5584849  0.35162515]
[0.18806872 0.69584555 0.18764042 0.66724823]
