In [239]:
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit, transpile
from qiskit import Aer
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

In [32]:
def walsh(order,x):
    res = 1
    x = np.modf(x)[0]
    x *= 2
    while order>0:
        res *= (-1)**((order%2)*np.modf(x)[1])
        order //= 2
        x = np.modf(x)[0]
        x *= 2
    return res

def walsh_coeff(f,order,M):
    sampling_points = np.linspace(0,1,num=M,endpoint=False)
    sampled_values = [walsh(order,i) for i in sampling_points]
    return np.dot(f(sampling_points),sampled_values)/M

def walsh_approximator(f,n,max_terms=1024):
    #Returns a function f^eps, the Walsh approximation of f
    # m = np.floor(np.log2(1/eps))+1
    M = int(2**n)
    # print('This is M',M)
    all_coeffs = [walsh_coeff(f,j,M) for j in range(M)]
    #Keep only largest max_terms coeffs sorted by absolute value
    largest = sorted(zip(np.abs(all_coeffs),all_coeffs,range(M)),reverse=True)[:max_terms]
    coeffs = [coeff for _,coeff,idx in largest]
    orders = [idx for _,coeff,idx in largest]
    def approximated_f(x):
        values = [walsh(j,x) for j in orders]
        return np.dot(coeffs,values)
    return approximated_f, coeffs, orders

def get_interp(amp_list):
    return interp1d(np.linspace(0,1,num=amp_list.shape[0],endpoint=False),amp_list)

In [338]:
def walsh_operator(qc,wires,ancilla,coeff,eps0):
    for q in wires[:-1]:
        qc.ccx(ancilla,q,wires[-1])
    qc.crz(-2*coeff*eps0,ancilla,wires[-1])
    for q in reversed(wires[:-1]):
        qc.ccx(ancilla,q,wires[-1])

def walsh_unitary(f,qc,wires,ancilla,eps0,k,max_terms):
    approx_f,coeffs,orders = walsh_approximator(f,k,max_terms)
    for coeff,order in zip(coeffs,orders):
        bin_ord = list(np.binary_repr(order,width=len(wires)))
        bin_ord.reverse()
        bin_ord = np.array(bin_ord)
        indices = np.where(bin_ord == '1')
        on_wires = np.take_along_axis(np.array(wires),indices[0],0)
        if order>0:
            walsh_operator(qc,on_wires,ancilla,coeff,eps0)
    return np.array(coeffs)[np.nonzero(np.array(orders)==0)]

def walsh_state_prep(wires,ancilla,f,eps0,k,max_terms=1024):
    qc = QuantumCircuit(ancilla,wires)
    # while True:
    for wire in wires:
        qc.h(wire)
    qc.h(ancilla)
    zero_coeff = walsh_unitary(f,qc,wires,ancilla,eps0,k,max_terms)
    if zero_coeff.size > 0:
        qc.rz(1*zero_coeff[0]*eps0,ancilla)
    qc.h(ancilla)
    qc.sdg(ancilla)
        # m_0 = qml.measure(ancilla,reset=True)
        # if m_0 == 1:
        #     break
    return qc

In [339]:
num_qubits = 4
amp_list = np.random.normal(0,1,2**num_qubits)
amp_list = np.abs(amp_list)/np.linalg.norm(amp_list)
print(amp_list)

[0.32625243 0.20549852 0.06385121 0.06388351 0.40718745 0.14377335
 0.13000894 0.1063525  0.19175659 0.15268961 0.18574232 0.39246565
 0.59231528 0.03207959 0.16172522 0.04301536]


In [340]:
reg = QuantumRegister(num_qubits,name='reg')
anc = QuantumRegister(1,name='anc')
cr = ClassicalRegister(num_qubits+1,name='c')
qc = walsh_state_prep(reg,anc,get_interp(amp_list),11,num_qubits)
# for i in range(num_qubits):
#     qc.measure(reg[i],cr[i])
# qc.measure(anc[0],cr[-1])

In [341]:
backend = Aer.get_backend('statevector_simulator')
# print(qc.draw())
qc_t = transpile(qc,backend=backend)
# print(qc_t.draw())

In [352]:
job = qiskit.execute(qc, Aer.get_backend('statevector_simulator'), shots=1,memory=True)
job_result = job.result()
statevector = job_result.get_statevector()
while True:
    res = statevector.measure([
# counts = job_result.get_counts()
# print(counts)
# output = np.zeros(2**num_qubits)
# for key,val in counts.items():
#     if key[0] == '0':
#         output[int(key[1:],2)] = val
# print(output)
# res = np.array([statevector[i] for i in range(1,2**(num_qubits+1),2)])
# # print(res)
# res = np.abs(res)/np.linalg.norm(res)
# print(res)
# test_points = np.linspace(0,1,num=2**num_qubits,endpoint=False)
# plt.plot(test_points,output/np.sum(output))
# plt.plot(test_points,get_interp(amp_list)(test_points))

Statevector([ 0.        +0.j        , -0.27732277-0.23116443j,
              0.        +0.j        , -0.32169445+0.01443983j,
              0.        +0.j        , -0.12130791-0.26395997j,
              0.        +0.j        , -0.02377843+0.03571272j,
              0.        +0.j        , -0.09333856+0.08666426j,
              0.        +0.j        , -0.31484041+0.02458694j,
              0.        +0.j        , -0.22503415+0.09104964j,
              0.        +0.j        , -0.28127149+0.05995947j,
              0.        +0.j        , -0.33473844-0.01028773j,
              0.        +0.j        , -0.26640743+0.07079169j,
              0.        +0.j        , -0.25075732+0.07998299j,
              0.        +0.j        , -0.03921135+0.05182526j,
              0.        +0.j        , -0.09339923+0.08668969j,
              0.        +0.j        , -0.15088601-0.26864171j,
              0.        +0.j        , -0.17795866+0.10061535j,
              0.        +0.j        , -0.05642566+0.065

In [None]:
num_wires = 8
amp_list = np.abs(np.random.normal(0,1,2**num_wires))
amp_list = amp_list/np.linalg.norm(amp_list)
dev = qml.device('default.qubit',wires=num_wires+1)

@qml.qnode(dev)
def circuit(wires,ancilla,f,eps0,k,max_terms=1024):
    walsh_state_prep(wires,ancilla,f,eps0,k,max_terms)
    return qml.state()