### Given a quantum state, find the universal gate parameters $ \alpha\, \beta\, \gamma\,  \phi$ 

In [118]:
import json
import pennylane as qml
import pennylane.numpy as np

np.random.seed(1967)

def get_matrix(params):
    alpha, beta, gamma, phi = params
    device = qml.device('default.qubit', wires=1)

    @qml.qnode(device)
    def circuit(params):
        qml.RZ(params[0], wires=0)
        qml.RX(params[1], wires=0)
        qml.RZ(params[2], wires=0)
        return qml.expval(qml.PauliZ(wires=0))

    op = qml.matrix(circuit)
    weights = np.array([alpha, beta, gamma], requires_grad=True)
    return op(weights)*(np.cos(phi)+1j*np.sin(phi))

def error(U, params):
    matrix = get_matrix(params)
    return np.sum(np.absolute(matrix-U))

def train_parameters(U):
    epochs = 1000
    lr = 0.01

    grad = qml.grad(error, argnum=1)
    params = np.random.rand(4) * np.pi

    for epoch in range(epochs):
        params -= lr * grad(U, params)

    return params

def run(test_case_input: str) -> str:
    matrix = json.loads(test_case_input)
    params = [float(p) for p in train_parameters(matrix)]
    return json.dumps(params)

def check(solution_output: str, expected_output: str) -> None:
    matrix1 = get_matrix(json.loads(solution_output))
    matrix2 = json.loads(expected_output)
    assert not np.allclose(get_matrix(np.random.rand(4)), get_matrix(np.random.rand(4)))
    assert np.allclose(matrix1, matrix2, atol=0.2)

test_cases = [
    ('[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]', '[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]'),
    ('[[ 1,  0], [ 0, -1]]', '[[ 1,  0], [ 0, -1]]')
]

for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)
        
    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[ 0.70710678,  0.70710678], [ 0.70710678, -0.70710678]]'...
Correct!
Running test case 1 with input '[[ 1,  0], [ 0, -1]]'...
Correct!
