Use optimization method to find the rotation parameters of a U3 gate in the circuit. 

In [82]:
import json
import pennylane as qml
import pennylane.numpy as np

dev = qml.device('default.qubit', wires = ['atom', 'cat'])

@qml.qnode(dev)
def evolve_atom_cat(unitary, params):
    qml.QubitUnitary(unitary, wires=['atom', 'cat'])
    qml.U3(params[0], params[1], params[2], wires='atom')
    return qml.state()

def u3_parameters(unitary):
    opt = qml.GradientDescentOptimizer()
    params = np.array([0,0,0])

    def cf(params):
        state = np.array(evolve_atom_cat(unitary, params))
        return np.abs(state[0] - state[1])

    for i in range(1000):
        params = opt.step(cf, params)

    return np.array(params)


# These functions are responsible for testing the solution.
def run(test_case_input: str) -> str:

    ins = json.loads(test_case_input)
    output = u3_parameters(ins).tolist()

    if np.isclose(evolve_atom_cat(ins,output)[0], evolve_atom_cat(ins,output)[1], atol = 5e-2):
        return "Cat state generated"
    return "Cat state not generated"

def check(solution_output: str, expected_output: str) -> None:
    def unitary_circ():
        qml.Hadamard(wires=0)
        qml.CNOT(wires=[0,1])
    
    U1 = qml.matrix(unitary_circ)()
    print ("check =", evolve_atom_cat(U1,[1,1,1])[0])
    assert np.isclose(evolve_atom_cat(U1,[1,1,1])[0], 0.62054458), "Your evolve_atom_cat circuit does not do what is expected."
    assert solution_output == expected_output, "Your parameters do not generate a Schrodinger cat"


# These are the public test cases
test_cases = [
    ('[[ 0.70710678,  0 ,  0.70710678,  0], [0 ,0.70710678, 0, 0.70710678], [ 0,  0.70710678,  0, -0.70710678], [ 0.70710678,  0, -0.70710678,  0]]', 'Cat state generated'),
    ('[[-0.00202114,  0.99211964, -0.05149589, -0.11420469], [-0.13637119, -0.1236727 , -0.30532593, -0.93428263], [0.89775373,  0.00794205, -0.363445  ,  0.24876274], [ 0.41885207, -0.01845563, -0.8786535 ,  0.22845207]]', 'Cat state generated')
]

# This will run the public test cases locally
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)
        print (output)
    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[ 0.70710678,  0 ,  0.70710678,  0], [0 ,0.70710678, 0, 0.70710678], [ 0,  0.70710678,  0, -0.70710678], [ 0.70710678,  0, -0.70710678,  0]]'...
Cat state generated
check = (0.6205445805637455+0j)
Correct!
Running test case 1 with input '[[-0.00202114,  0.99211964, -0.05149589, -0.11420469], [-0.13637119, -0.1236727 , -0.30532593, -0.93428263], [0.89775373,  0.00794205, -0.363445  ,  0.24876274], [ 0.41885207, -0.01845563, -0.8786535 ,  0.22845207]]'...
Cat state generated
check = (0.6205445805637455+0j)
Correct!
