Calculate a hessian matrix using quantum circuit. 

- https://pennylane.ai/blog/2021/04/how-to-calculate-the-hessian-of-a-classical-quantum-hybrid-model/
- https://en.wikipedia.org/wiki/Hessian_matrix
- https://docs.pennylane.ai/en/stable/code/api/pennylane.gradients.param_shift_hessian.html

In [36]:
import json
import pennylane as qml
import pennylane.numpy as np

def compute_hessian(num_wires, w):
    num_wires = len(w)-2
    dev = qml.device("default.qubit", wires=num_wires)

    @qml.qnode(dev)
    def variational_circuit(params):
        for wire in range(num_wires):
            qml.RX(params[wire], wires=wire)

        for wire in range(num_wires-1):
            qml.CNOT(wires=[wire, wire+1])
        
        qml.CNOT(wires=[num_wires-1, 0])

        qml.RY(params[num_wires], wires=1)
        
        for wire in range(num_wires-1):
            qml.CNOT(wires=[wire, wire+1])
    
        qml.CNOT(wires=[num_wires-1, 0])

        qml.RX(params[num_wires+1], wires=num_wires-1)
    
        return [qml.expval(qml.PauliZ(0) @ qml.PauliZ(num_wires-1))]

    return np.round(qml.gradients.param_shift_hessian(variational_circuit)(w), 5)

def run(test_case_input: str) -> str:
    ins = json.loads(test_case_input)
    wires = ins[0]
    weights = np.array(ins[1], requires_grad = True)
    output = compute_hessian(wires, weights)
    
    if isinstance(output,(tuple)):
        output = np.array(output).numpy().round(3)    
        return str([elem.tolist() for elem in output])
    
    elif isinstance(output,(np.tensor)):
        
        return str(output.tolist())
    
def check(solution_output: str, expected_output: str) -> None:
    solution_output = json.loads(solution_output)
    expected_output = json.loads(expected_output)
    
    assert np.allclose(solution_output, expected_output, atol=1e-2), "Your function does not calculate the Hessian correctly."


# These are the public test cases
test_cases = [
    ('[3,[0.1,0.2,0.1,0.2,0.7]]', '[[0.013, 0.0, 0.013, 0.006, 0.002], [0.0, -0.621, 0.077, 0.125, -0.604], [0.013, 0.077, -0.608, -0.628, -0.073], [0.006, 0.125, -0.628, 0.138, -0.044], [0.002, -0.604, -0.073, -0.044, -0.608]]'),
    ('[2,[0.3, 1.1, 0.4, 1.3]]', '[[0.0, 0.0, -0.0, 0.0], [0.0, -0.121, 0.0, 0.859], [0.0, 0.0, 0.0, 0.0], [0.0, 0.859, 0.0, -0.121]]')
]

# This will run the public test cases locally
for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)
        print (output)
    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[3,[0.1,0.2,0.1,0.2,0.7]]'...
[[0.01271, 0.0, 0.01271, 0.00629, 0.00151], [0.0, -0.62104, 0.07695, 0.12481, -0.60364], [0.01271, 0.07695, -0.60833, -0.62762, -0.07254], [0.00629, 0.12481, -0.62762, 0.13752, -0.04442], [0.00151, -0.60364, -0.07254, -0.04442, -0.60833]]
Correct!
Running test case 1 with input '[2,[0.3, 1.1, 0.4, 1.3]]'...
[[-0.0, 0.0, -0.0, 0.0], [0.0, -0.12134, 0.0, 0.85873], [-0.0, 0.0, -0.0, 0.0], [0.0, 0.85873, 0.0, -0.12134]]
Correct!


In [27]:
num_wires=3
dev = qml.device("default.qubit", wires=num_wires)

@qml.qnode(dev)
def variational_circuit(params):
    for wire in range(num_wires):
        qml.RX(params[wire], wires=wire)

    for wire in range(num_wires-1):
        qml.CNOT(wires=[wire, wire+1])
    
    qml.CNOT(wires=[num_wires-1, 0])

    qml.RY(params[num_wires], wires=1)
    
    for wire in range(num_wires-1):
        qml.CNOT(wires=[wire, wire+1])

    qml.CNOT(wires=[num_wires-1, 0])

    qml.RX(params[num_wires+1], wires=num_wires-1)

    return [qml.expval(qml.PauliZ(0) @ qml.PauliZ(num_wires-1))]

print(qml.draw(variational_circuit)([1,2,3,4,5]))

0: ──RX(1.00)─╭●────╭X───────────╭●────╭X───────────┤ ╭<Z@Z>
1: ──RX(2.00)─╰X─╭●─│───RY(4.00)─╰X─╭●─│────────────┤ │     
2: ──RX(3.00)────╰X─╰●──────────────╰X─╰●──RX(5.00)─┤ ╰<Z@Z>
