In [6]:
import functools
import json
import math
import pandas as pd
import pennylane as qml
import pennylane.numpy as np
import scipy

def fourier_decomp(layers_params):
    """
    Returns the frequencies and coefficient of our quantum model, specified by layers_params

    Args:
    layers_params: list(list(list(float))). Specifies the number of basic entangling layers and their
    parameters as explained in the statement of the problem.

    Returns: list([float,float,float]). A list three-element list. The first element of each list is the frequency. The second
    element is the real part of the coefficient associated with that frequency in the Fourier decomposition. The third element
    is the imaginary part of such coefficient.
    """

    dev = qml.device("default.qubit", wires=4)

    @qml.qnode(dev)
    def circuit(layers_params, x):
        """
        This function is the quantum circuit made of alternating entangling layers and rotations representing our quantum model
        """
        # Put your code here #
        print(len(layers_params))
        
        for k in range(len(layers_params) - 1):
            for i in range(len(layers_params[k])):
                qml.RX(layers_params[k][i][0],wires=0)
                qml.RX(layers_params[k][i][1],wires=1)
                qml.RX(layers_params[k][i][2],wires=2)
                qml.RX(layers_params[k][i][3],wires=3)

                qml.CNOT(wires=[0,1])
                qml.CNOT(wires=[1,2])
                qml.CNOT(wires=[2,3])
                qml.CNOT(wires=[3,0])

            qml.RX(x,wires=0,id="x")
            qml.RX(x,wires=1,id="x")
            qml.RX(x,wires=2,id="x")
            qml.RX(x,wires=3,id="x")
        
        for j in range(len(layers_params[-1])):
            #print(f'{i}')
            qml.RX(layers_params[-1][j][0],wires=0)
            qml.RX(layers_params[-1][j][1],wires=1)
            qml.RX(layers_params[-1][j][2],wires=2)
            qml.RX(layers_params[-1][j][3],wires=3)
                
            qml.CNOT(wires=[0,1])
            qml.CNOT(wires=[1,2])
            qml.CNOT(wires=[2,3])
            qml.CNOT(wires=[3,0])
        
        # Return a single expectation value!
        return qml.expval(qml.PauliZ(0))
    
    #print(len(layers_params[0]))
    # Use the Fourier module to obtain the coefficients and frequencies. Then return the required list of lists.
    y=0.1
    #print("Example Circuit:")
    #print(qml.draw(circuit)(layers_params,y))
    
    res = qml.fourier.circuit_spectrum(circuit)(layers_params,y)
    #print(res)
    
    partial_circuit = functools.partial(circuit, layers_params)
    
    #print("Frequencies:")
    freqs = list(res.values())
    #print(len(freqs[0]))
    #print(freqs[0])
    coeffs = qml.fourier.coefficients(partial_circuit, 1, int((len(freqs[0])-1)/2))
    #print("Coeffs:")
    #print(coeffs)
    
    answer = []
    
    mid = int((len(freqs[0])-1)/2)
    #print(mid)
    for i in range(mid):
        answer.append([freqs[0][i],coeffs[mid+i+1].real,coeffs[mid+i+1].imag])
                      
    answer.append([freqs[0][mid],coeffs[0].real,coeffs[0].imag])
                  
    for i in range(mid):
        answer.append([freqs[0][mid+1+i],coeffs[i+1].real,coeffs[i+1].imag])
    
    print("Answer:")
    print(answer)
    
    return answer
    
# These functions are responsible for testing the solution.
def run(test_case_input: str) -> str:

    ins = json.loads(test_case_input)
    output = fourier_decomp(ins)

    return str(output)

def check(solution_output: str, expected_output: str) -> None:
    """
    Compare solution with expected.

    Args:
            solution_output: The output from an evaluated solution. Will be
            the same type as returned.
            expected_output: The correct result for the test case.

    Raises:
            ``AssertionError`` if the solution output is incorrect in any way.
    """

    solution_output = json.loads(solution_output)
    expected_output = json.loads(expected_output)
    assert np.allclose(
        solution_output, expected_output, rtol=1e-2
    ), "Your calculated Fourier spectrum isn't quite right."


test_cases = [['[[[2,2,2,2],[1,2,1,1]],[[3,4,5,6]]]', '[[-4.0, -2.4671622769447922e-17, -1.2335811384723961e-17], [-3.0, -0.03395647263976357, 0.010208410500915437], [-2.0, 2.8360500437920326e-17, 1.850371707708594e-17], [-1.0, 0.11762992558035439, -0.13619443127813127], [0.0, 8.018277400070575e-17, 0.0], [1.0, 0.11762992558035439, 0.13619443127813124], [2.0, 3.700743415417188e-17, -1.850371707708594e-17], [3.0, -0.03395647263976357, -0.010208410500915437],[4.0, -3.688877668472405e-18, 1.850371707708594e-17]]'], ['[[[2,2,2,2]],[[3,4,5,6]]]', '[[-4.0, 1.2335811384723961e-17, 3.700743415417188e-17],  [-3.0, 0.022482345076620468, -0.07855141721016852], [-2.0, -1.2335811384723961e-17, -6.536793459209221e-17], [-1.0, -0.13243693333822854, 0.17097830099559677], [0.0, -2.4671622769447922e-17, 0.0], [1.0, -0.13243693333822854, -0.17097830099559677], [2.0, -2.4671622769447922e-17, 7.401486830834377e-17], [3.0, 0.022482345076620468, 0.07855141721016852], [4.0, -1.2335811384723961e-17, -3.331855648569948e-17]]']]

for i, (input_, expected_output) in enumerate(test_cases):
    print(f"Running test case {i} with input '{input_}'...")

    try:
        output = run(input_)

    except Exception as exc:
        print(f"Runtime Error. {exc}")

    else:
        if message := check(output, expected_output):
            print(f"Wrong Answer. Have: '{output}'. Want: '{expected_output}'.")

        else:
            print("Correct!")

Running test case 0 with input '[[[2,2,2,2],[1,2,1,1]],[[3,4,5,6]]]'...
2
2
2
2
2
2
2
2
2
2
Answer:
[[-4.0, 6.167905692361981e-17, 6.167905692361981e-17], [-3.0, -0.03395647263976355, 0.010208410500915468], [-2.0, -4.6421381810329376e-17, -4.1317133731492965e-18], [-1.0, 0.11762992558035412, -0.13619443127813097], [0, 7.401486830834377e-17, 0.0], [1.0, 0.1176299255803541, 0.13619443127813097], [2.0, -4.9343245538895844e-17, 0.0], [3.0, -0.03395647263976355, -0.010208410500915468], [4.0, 7.10930045797773e-17, -5.75473435504705e-17]]
Correct!
Running test case 1 with input '[[[2,2,2,2]],[[3,4,5,6]]]'...
2
2
2
2
2
2
2
2
2
2
Answer:
[[-4.0, -1.2335811384723961e-17, -4.9343245538895844e-17], [-3.0, 0.022482345076620374, -0.07855141721016855], [-2.0, 2.1366252070928486e-17, 2.632430811870764e-17], [-1.0, -0.13243693333822845, 0.17097830099559674], [0, -1.1719020815487764e-16, 0.0], [1.0, -0.13243693333822845, -0.17097830099559674], [2.0, 1.850371707708594e-17, -2.4671622769447922e-17], [3.0,