# Arbitrary State Preparation

## Problem:

- Input: psi: np.ndarray of size $2^{2^n}$, complex, normalized (up to floating error).
- Output: QuantumCircuit on n qubits (no classical bits) that prepares psi from |0^n>.
- Constraint: Only 1Q gates and multi-controlled Rz. We realize uniformly-controlled Ry by basis-changes plus multiplexed Rz.
- How to Use:
  - prepare_state_circuit(psi: np.ndarray) -> QuantumCircuit

## Implementation Idea

1. Magnitudes: build a binary tree of uniformly-controlled rotations on qubits $$0\dots n-1$$ (MSB->LSB). At level k, for each control prefix, apply an $R_y(
\theta)$ on target k splitting the block's norm.
2. Phases: apply a diagonal via uniformly-controlled $Rz(\phi)$ so that each basis state $|x>$ acquires the desired relative phase $arg(\psi_x)$.
3. Gate set: identify $$R_y(\theta) = S^† H R_z(-\theta) H S.$$ Thus we only need local H, S, Sdg, and multi-controlled Rz.

In [15]:
# Imports
%pip install qiskit numpy
import numpy as np
from typing import Dict, List, Sequence, Tuple
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector
from qiskit.circuit.library import RZGate, RYGate
np.set_printoptions(precision=5, suppress=True)




In [16]:
# Utilities
def is_power_of_two(m: int) -> bool:
    return (m & (m - 1) == 0) and m > 0

def num_qubits_from_len(m: int) -> int:
    if not is_power_of_two(m):
        raise ValueError(f"Length {m} is not a power of two.")
    return m.bit_length() - 1


def principal(phi):
    """
    Wrap angle(s) to (-pi, pi]. Works for scalars or numpy arrays.
    """
    return np.arctan2(np.sin(phi), np.cos(phi))


In [17]:
# Magnitude angles with LSB prefix matching
def magnitude_angle_table_lsb(psi: np.ndarray) -> List[Dict[int, float]]:
    """
    Given a normalized state vector psi (length 2^n, LSB ordering),
    compute the theta_k,p angles for uniformly controlled RY layers:

        - k = target qubit (0 = LSB, ..., n-1)
        - p = pattern of lower bits 0..k-1 (LSB-encoded integer)

    For each (k,p), theta_k,p is defined by:
        cos(theta/2) = A / sqrt(A^2 + B^2),
    where A^2 and B^2 are the total probabilities in the
    (bit k = 0) and (bit k = 1) halves under that prefix p.
    """
    psi = np.asarray(psi, dtype=complex)
    N = len(psi)
    n = num_qubits_from_len(N)

    thetas: List[Dict[int, float]] = [dict() for _ in range(n)]
    # indices 0,1,...,N-1, used just for bit-twiddling
    indices = np.arange(N, dtype=np.int64)

    for k in range(n):               # target qubit k
        mask_low = (1 << k) - 1      # mask for bits 0..k-1
        for p in range(1 << k):      # all patterns for lower bits
            match = (indices & mask_low) == p
            idx_block = indices[match]

            # Split by bit k
            left  = idx_block[(idx_block & (1 << k)) == 0]  # bit k = 0
            right = idx_block[(idx_block & (1 << k)) != 0]  # bit k = 1

            # Squared norms (probability mass) in each half
            A2 = float(np.sum(np.abs(psi[left])**2))
            B2 = float(np.sum(np.abs(psi[right])**2))
            denom2 = A2 + B2

            if denom2 <= 1e-16:
                # No amplitude in this prefix at all => arbitrary; pick 0
                theta = 0.0
            else:
                c = np.sqrt(A2 / denom2)  # = A / sqrt(A^2 + B^2)
                # Numerical safety: clip into [-1,1]
                c = max(-1.0, min(1.0, c))
                theta = 2.0 * float(np.arccos(c))

            thetas[k][p] = theta

    return thetas

In [18]:
# Exact phase angles with LSB ordering
def phase_angle_table_exact_lsb(psi: np.ndarray) -> List[Dict[int, float]]:
    """
    Given a target state psi (length 2^n, LSB ordering), compute
    φ_k,p such that a ladder of uniformly controlled RZ(φ_k,p) gates
    applied AFTER the magnitude-prep RY ladder produces psi.

    We work with a phase-only vector u_j = e^{i arg psi_j} and
    peel off diagonal phases from the TOP qubit down: k = n-1..0.
    """
    psi = np.asarray(psi, dtype=complex)
    N = len(psi)
    n = num_qubits_from_len(N)

    # Residual phase vector (unit phasors)
    u = np.exp(1j * np.angle(psi))
    phi_levels: List[Dict[int, float]] = [dict() for _ in range(n)]

    for k in reversed(range(n)):          # KEY: go MSB -> LSB
        mask_low = (1 << k) - 1          # bits 0..k-1

        # 1) Read φ_k,p from current residual phases u
        for p in range(1 << k):
            acc = 0.0 + 0.0j
            for j in range(N):
                # select pairs with low bits = p and bit k = 0
                if ((j & mask_low) == p) and (((j >> k) & 1) == 0):
                    jL = j
                    jR = j | (1 << k)
                    acc += u[jR] * np.conj(u[jL])

            if abs(acc) < 1e-15:
                phi = 0.0
            else:
                phi = principal(np.angle(acc))
            phi_levels[k][p] = phi

        # 2) Apply inverse Rz(-φ_k,p) to u to peel off this layer
        for j in range(N):
            p = j & mask_low
            phi = phi_levels[k][p]
            if ((j >> k) & 1) == 0:
                u[j] *= np.exp(+1j * phi / 2.0)  # bit k = 0
            else:
                u[j] *= np.exp(-1j * phi / 2.0)  # bit k = 1

    return phi_levels

In [19]:
# Builders (Using direct RYGate)

def apply_pattern_controls(qc, controls, pattern: int):
    """
    Flip those control qubits whose desired value is 0, so that
    a multi-controlled gate (which triggers on all-1) triggers
    exactly when (control bits) == pattern.

    controls[j] corresponds to bit j of 'pattern' (LSB).
    """
    for j, c in enumerate(controls):
        desired_bit = (pattern >> j) & 1  # LSB of pattern -> controls[0]
        if desired_bit == 0:
            qc.x(c)

def undo_pattern_controls(qc, controls, pattern: int):
    # Just undo the X's in the same way
    for j, c in enumerate(controls):
        desired_bit = (pattern >> j) & 1
        if desired_bit == 0:
            qc.x(c)

# def mcry_on_pattern(qc, controls, target, angle: float):
#     if abs(angle) < 1e-15:
#         return
#     num_ctrl = len(controls)
#     if num_ctrl == 0:
#         qc.ry(angle, target)
#     else:
#         gate = RYGate(angle).control(num_ctrl)
#         qc.append(gate, list(controls) + [target])

def mcrz_on_pattern(qc, controls, target, angle: float):
    if abs(angle) < 1e-15:
        return
    num_ctrl = len(controls)
    if num_ctrl == 0:
        qc.rz(angle, target)
    else:
        gate = RZGate(angle).control(num_ctrl)
        qc.append(gate, list(controls) + [target])


def RY_via_RZ(theta):
    """
    Return a 1-qubit gate that implements R_y(theta)
    using only S, H, and Rz.
    """
    qc = QuantumCircuit(1, name="Ry_via_Rz")
    qc.s(0)
    qc.h(0)
    qc.rz(-theta, 0)
    qc.h(0)
    qc.sdg(0)
    return qc.to_gate(label=f"Ry({theta:.3f})")

def mcry_on_pattern(qc, controls, target, angle: float, eps: float = 1e-15):
    if abs(angle) < eps:
        return

    num_ctrl = len(controls)
    base = RY_via_RZ(angle)  # our custom R_y gate

    if num_ctrl == 0:
        # Just apply the 1-qubit Ry_via_Rz gate
        qc.append(base, [target])
    else:
        # Multi-controlled version of that composite gate
        gate = base.control(num_ctrl)
        qc.append(gate, list(controls) + [target])

def apply_uniformly_controlled_Ry_direct(qc, target, controls, angle_map):
    """
    For fixed k: apply RY(theta_{k,p}) to 'target' qubit,
    controlled on 'controls' in pattern p (LSB pattern).
    angle_map is thetas[k]: dict p -> theta_{k,p}.
    """
    for p, theta in angle_map.items():
        if abs(theta) < 1e-15:
            continue
        apply_pattern_controls(qc, controls, p)
        mcry_on_pattern(qc, controls, target, theta)
        undo_pattern_controls(qc, controls, p)

def apply_uniformly_controlled_Rz_direct(qc, target, controls, angle_map):
    """
    Same as above but for RZ and phis.
    """
    for p, phi in angle_map.items():
        if abs(phi) < 1e-15:
            continue
        apply_pattern_controls(qc, controls, p)
        mcrz_on_pattern(qc, controls, target, phi)
        undo_pattern_controls(qc, controls, p)

def normalize(psi: np.ndarray) -> np.ndarray:
    psi = np.asarray(psi, dtype=complex)
    nrm = np.linalg.norm(psi)
    if nrm == 0:
        raise ValueError("Zero vector.")
    return psi / nrm

def prepare_state_circuit_lsb(psi: np.ndarray):
    """
    Build a circuit U such that U |0...0> = psi,
    using LSB-based uniformly controlled RY (magnitudes)
    and RZ (phases).

    Returns:
        qc    : QuantumCircuit
        thetas: List[Dict[int, float]]  # magnitude angles
        phis  : List[Dict[int, float]]  # phase angles
    """
    psi = normalize(psi)
    N = len(psi)
    n = num_qubits_from_len(N)

    qc = QuantumCircuit(n, name="Prep(|psi>)")

    # 1) Magnitudes: RY layers
    thetas = magnitude_angle_table_lsb(psi)
    for k in range(n):
        controls = list(range(k))   # qubits 0..k-1
        target = k                  # qubit k
        apply_uniformly_controlled_Ry_direct(qc, target, controls, thetas[k])

    # 2) Phases: RZ layers
    phis = phase_angle_table_exact_lsb(psi)
    for k in range(n):
        controls = list(range(k))
        target = k
        apply_uniformly_controlled_Rz_direct(qc, target, controls, phis[k])

    return qc, thetas, phis


In [20]:

# Assemble (LSB convention)

def global_phase_align(vec: np.ndarray, ref: np.ndarray) -> np.ndarray:
    """
    Multiply `vec` by a global phase so that it best aligns with `ref`
    in L2 sense. Both can be unnormalized; we only care about direction.
    """
    v = np.asarray(vec, dtype=complex)
    r = np.asarray(ref, dtype=complex)

    # Inner product <v, r> = v^† r
    ip = np.vdot(v, r)
    if np.abs(ip) < 1e-15:
        # Almost orthogonal => no meaningful global phase to align
        return v

    phase = ip / np.abs(ip)   # unit-modulus complex number
    return v * phase


def check_prep_lsb(psi: np.ndarray, verbose: bool = True):
    """
    Build the LSB-based prep circuit for `psi`, apply it to |0...0>,
    align global phase, and report the L2 error.

    Returns:
        error : float
        qc    : QuantumCircuit
        thetas: List[Dict[int, float]]
        phis  : List[Dict[int, float]]
    """
    psi = normalize(psi)
    n = num_qubits_from_len(len(psi))

    qc, thetas, phis = prepare_state_circuit_lsb(psi)

    sv0 = Statevector.from_label('0' * n)
    out = sv0.evolve(qc).data

    out_aligned = global_phase_align(out, psi)
    err = np.linalg.norm(out_aligned - psi)

    if verbose:
        print(qc.draw())
        print("theta tables (LSB):", thetas)
        print("phi tables (LSB):", phis)
        print("||error||_2 =", float(err))

    return float(err), qc, thetas, phis

In [21]:
rng = np.random.default_rng(123)
from qiskit.quantum_info import random_statevector
def rand_state_qiskit(n: int, seed: int = None) -> np.ndarray:
    sv = random_statevector(2**n, seed=seed)
    return sv.data

psi_r_qiskit = rand_state_qiskit(3, seed=3)

def show_case(label, psi):
    psi = normalize(psi)
    e, qc, th, ph = check_prep_lsb(psi, verbose=True)
    n = num_qubits_from_len(len(psi))
    out = Statevector.from_label('0'*n).evolve(qc).data
    print(f"\n[{label}]")
    print("psi (target):", psi)
    print("psi_out     :", out)
    return e, qc, th, ph, out

# ---- Fixed non-uniform case (small, n=2) ----
# Unequal magnitudes + phases to stress φ estimation & parent update
psi_fixed_n2 = np.array([
    0.80 + 0.00j,
    0.10 * np.exp(1j*0.9),
    0.30 * np.exp(1j*0.2),
    0.40 * np.exp(-1j*1.1),
], dtype=complex)
e_fix2, qc_fix2, th_fix2, ph_fix2, out_fix2 = show_case("FIXED n=2 (non-uniform)", psi_fixed_n2)
# ---- Fixed non-uniform (n=3) ----
psi_fixed_n3 = np.array([ 0.29454-0.35757j, -0.13775-0.36323j, -0.08682-0.02085j,
        0.3097 +0.18046j,  0.28911-0.03889j,  0.24013+0.31339j,
        0.17903+0.05203j,  0.18739-0.43264j], dtype=complex)
e_fix3, qc_fix3, th_fix3, ph_fix3, out_fix3 = show_case("FIXED n=3 (non-uniform, REF)", psi_fixed_n3)
# ---- Fixed non-uniform (n=4) ----
psi_fixed_n4 = np.array([
    0.30+0.05j,  -0.05+0.20j,   0.12-0.08j,   0.05+0.02j,
   -0.22-0.01j,  -0.11-0.09j,   0.01+0.19j,  -0.07-0.03j,
   -0.02-0.05j,  -0.14-0.02j,   0.10+0.09j,   0.16-0.01j,
    0.03+0.07j,   0.18-0.02j,   0.32+0.04j,  -0.12-0.10j
], dtype=complex)
e_fix4, qc_fix4, th_fix4, ph_fix4, out_fix4 = show_case("FIXED n=4 (non-uniform, REF)", psi_fixed_n4)

print("random case")
# ---- Random checks (unchanged) ----
e3, qc3, th3, ph3, out3 = show_case("n=3 (random)", psi_r_qiskit)

     ┌───────────┐┌───┐             ┌───┐             ┌──────────┐┌───┐»
q_0: ┤ Ry(0.899) ├┤ X ├──────■──────┤ X ├──────■──────┤ Rz(-0.2) ├┤ X ├»
     └───────────┘└───┘┌─────┴─────┐└───┘┌─────┴─────┐└──────────┘└───┘»
q_1: ──────────────────┤ Ry(0.718) ├─────┤ Ry(2.652) ├─────────────────»
                       └───────────┘     └───────────┘                 »
«                ┌───┐          
«q_0: ─────■─────┤ X ├────■─────
«     ┌────┴────┐└───┘┌───┴────┐
«q_1: ┤ Rz(0.2) ├─────┤ Rz(-2) ├
«     └─────────┘     └────────┘
theta tables (LSB): [{0: 0.8992181501769816}, {0: 0.7175413405411443, 1: 2.651635327336065}]
phi tables (LSB): [{0: np.float64(-0.20000000000000007)}, {0: np.float64(0.20000000000000004), 1: np.float64(-2.0)}]
||error||_2 = 1.5561871272885063e-15

[FIXED n=2 (non-uniform)]
psi (target): [0.84327+0.j      0.06552+0.08257j 0.30992+0.06282j 0.19125-0.37577j]
psi_out     : [0.84327+0.j      0.06552+0.08257j 0.30992+0.06282j 0.19125-0.37577j]
     ┌───────────┐┌───┐     