In [61]:
import json
import pennylane as qml
import pennylane.numpy as np
import math 

# Write any helper functions you need here
dev = qml.device('default.qubit', wires=[0,1,2])
dev1 = qml.device('default.qubit', wires=[0])

@qml.qnode(dev)
def cloning_machine(coefficients, wire):
    c0, c1 = coefficients
    alpha = np.array([0, (c0+c1)/np.sqrt(2), c1/np.sqrt(2), 0, c0/np.sqrt(2)])
    
    gates = qml.MottonenStatePreparation.compute_decomposition(alpha[1:], wires=[1, 2])

    qml.apply(gates)

    qml.CNOT(wires=[0,1])
    qml.CNOT(wires=[0,2])
    qml.CNOT(wires=[1,0])
    qml.CNOT(wires=[2,0])

    return qml.density_matrix(wires=wire)

@qml.qnode(dev1)
def zero_state():
    return qml.density_matrix(wires=0)

def fidelity(coefficients):
    cloning_node = qml.QNode(cloning_machine,dev)(coefficients, 0)
    cloning_node_1 = qml.QNode(cloning_machine,dev)(coefficients, 1)
    zero_states = qml.QNode(zero_state, dev1)()
    print ("fidelity --> state - ", np.round(cloning_machine(coefficients, 0), 10))
    # qml.draw_mpl(cloning_machine, decimals=2)(coefficients, 0)
    return np.array([qml.math.fidelity(cloning_node, zero_states),qml.math.fidelity(cloning_node_1, zero_states)]) 


# These functions are responsible for testing the solution.


def run(test_case_input: str) -> str:
    ins = json.loads(test_case_input)
    outs = fidelity(ins).tolist()
    
    return str(outs)


def check(solution_output: str, expected_output: str) -> None:
    solution_output = json.loads(solution_output)
    expected_output = json.loads(expected_output)
    print (solution_output, expected_output)
    u = cloning_machine([1/np.sqrt(3),1/np.sqrt(3)],1)
    for op in cloning_machine.tape.operations:
        assert (isinstance(op, qml.RX) or isinstance(op, qml.RY) or isinstance(op, qml.CNOT)), "You are using forbidden gates!"
    assert np.allclose(solution_output,expected_output, atol = 1e-4), "Not the correct fidelities"


# These are the public test cases
test_cases = [
    ('[0.5773502691896258, 0.5773502691896257]', '[0.83333333, 0.83333333]'),
    ('[0.2, 0.8848857801796105]', '[0.60848858, 0.98]')
]

# This will run the public test cases locally
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)
    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

# [7.07106781e-01+0.j 3.06161700e-17+0.j 5.00000000e-01+0.j
#  5.00000000e-01+0.j 0.00000000e+00+0.j 0.00000000e+00+0.j
#  0.00000000e+00+0.j 0.00000000e+00+0.j]
            
#  [0.81649658+0.j 0.40824829+0.j 0.        +0.j 0.40824829+0.j
#  0.        +0.j 0.        +0.j 0.        +0.j 0.        +0.j]          

# [8.16496581e-01+0.j 2.49979981e-17+0.j 4.08248290e-01+0.j
#  4.08248290e-01+0.j 0.00000000e+00+0.j 0.00000000e+00+0.j
#  0.00000000e+00+0.j 0.00000000e+00+0.j]            

Running test case 0 with input '[0.5773502691896258, 0.5773502691896257]'...
fidelity --> state -  [[0.83333333+0.j 0.        +0.j]
 [0.        +0.j 0.16666667+0.j]]
[0.8333333333333334, 0.8333333333333334] [0.83333333, 0.83333333]
Correct!
Running test case 1 with input '[0.2, 0.8848857801796105]'...
fidelity --> state -  [[0.60848858+0.j 0.        +0.j]
 [0.        +0.j 0.39151142+0.j]]
[0.608488578017961, 0.9800000000000001] [0.60848858, 0.98]
Correct!
