# Arbitrary State Preparation

## Problem:

- Input: psi: np.ndarray of size $2^{2^n}$, complex, normalized (up to floating error).
- Output: QuantumCircuit on n qubits (no classical bits) that prepares psi from |0^n>.
- Constraint: Only 1Q gates and multi-controlled Rz. We realize uniformly-controlled Ry by basis-changes plus multiplexed Rz.
- How to Use:
  - prepare_state_circuit(psi: np.ndarray) -> QuantumCircuit

## Implementation Idea

1. Magnitudes: build a binary tree of uniformly-controlled rotations on qubits $$0\dots n-1$$ (MSB->LSB). At level k, for each control prefix, apply an $R_y(
\theta)$ on target k splitting the block's norm.
2. Phases: apply a diagonal via uniformly-controlled $Rz(\phi)$ so that each basis state $|x>$ acquires the desired relative phase $arg(\psi_x)$.
3. Gate set: identify $$R_y(\theta) = S^† H R_z(\theta) H S.$$ Thus we only need local H, S, Sdg, and multi-controlled Rz.

In [51]:
# Imports
%pip install qiskit numpy
import numpy as np
from typing import Dict, List, Sequence, Tuple
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector
from qiskit.circuit.library import RZGate
np.set_printoptions(precision=5, suppress=True)



In [52]:
# Utilities
def is_power_of_two(m: int) -> bool:
    return (m & (m - 1) == 0) and m > 0

def num_qubits_from_len(m: int) -> int:
    if not is_power_of_two(m):
        raise ValueError(f"Length {m} is not a power of two.")
    return m.bit_length() - 1

def normalize(psi: np.ndarray) -> np.ndarray:
    psi = np.asarray(psi, dtype=complex)
    nrm = np.linalg.norm(psi)
    if nrm == 0:
        raise ValueError("Zero vector.")
    return psi / nrm

In [53]:
# Magnitude angles with LSB prefix matching
def magnitude_angle_table_lsb(psi: np.ndarray) -> List[Dict[int, float]]:
    psi = np.asarray(psi, dtype=complex)
    N = len(psi); n = num_qubits_from_len(N)
    thetas: List[Dict[int, float]] = [dict() for _ in range(n)]
    indices = np.arange(N, dtype=np.int64)
    for k in range(n):  # target qubit k
        for p in range(1 << k):  # controls are qubits 0..k-1
            mask_low = (1 << k) - 1
            match = (indices & mask_low) == p
            idx_block = indices[match]
            left = idx_block[(idx_block & (1 << k)) == 0]   # bit k = 0
            right = idx_block[(idx_block & (1 << k)) != 0]  # bit k = 1
            A = np.linalg.norm(psi[left]); B = np.linalg.norm(psi[right])
            denom = max(A*A + B*B, 1e-16)**0.5
            c = (A/denom) if denom>0 else 1.0
            c = min(1.0, max(-1.0, c))
            theta = 2*float(np.arccos(c))
            thetas[k][p] = theta
    return thetas

In [73]:
# Exact phase angles with LSB ordering
def phase_angle_table_exact_lsb(psi: np.ndarray) -> List[Dict[int, float]]:
    """
    Exact diagonal synthesis in LSB convention via bottom-up residuals.
    - LSB: qubit k corresponds to bit k of the basis index.
    - For each level k, we:
        1) compute φ_k(p) by averaging ALL pairwise differences for pattern p
           from a snapshot of residual unit phasors u (no in-level updates yet);
        2) update residuals for all pairs using the half-angle action of Rz(φ_k(p)).
    """
    psi = np.asarray(psi, dtype=complex)
    N = len(psi); n = (N.bit_length() - 1)
    # Work with unit phasors to avoid unwrap trouble
    u = np.exp(1j * np.angle(psi))  # shape (N,)
    phi_levels: List[Dict[int, float]] = [dict() for _ in range(n)]

    def principal(x: float) -> float:
        return (x + np.pi) % (2*np.pi) - np.pi

    for k in range(n):                    # k=0 (LSB) .. n-1 (MSB)
        block = 1 << (k + 1)              # block size for bit k
        half  = 1 << k                    # half-block stride
        num_patterns = 1 << k if k > 0 else 1

        # ---------- PASS 1: compute all φ_k(p) from a snapshot ----------
        # snapshot to prevent in-level contamination
        u_snap = u.copy()
        for p in range(num_patterns):
            acc = 0.0 + 0.0j
            for start in range(0, N, block):
                left  = start + p
                right = left + half
                acc += u_snap[right] * np.conj(u_snap[left])   # e^{i(rR - rL)}
            phi = 0.0 if acc == 0 else float(np.angle(acc))
            phi_levels[k][p] = principal(phi)

        # ---------- PASS 2: update residuals for ALL pairs ----------
        for p in range(num_patterns):
            phi = phi_levels[k][p]

            # We are "undoing" the Rz(phi) gate, so we apply Rz(-phi).
            eL = np.exp(+1j * phi / 2.0)   # left action of Rz(-φ)
            eR = np.exp(-1j * phi / 2.0)   # right action of Rz(-φ)

            for start in range(0, N, block):
                left  = start + p
                right = left + half
                # Apply the inverse rotation directly to the residuals.
                u[left]  *= eL
                u[right] *= eR

    return phi_levels


In [55]:
# Builders (unchanged API)
def basis_change_for_Ry(qc: QuantumCircuit, q: int, inverse: bool=False):
    if not inverse:
        qc.sdg(q); qc.h(q)
    else:
        qc.h(q); qc.s(q)

def apply_pattern_controls(qc: QuantumCircuit, controls: Sequence[int], pattern: int, k: int):
    # controls are qubits 0..k-1; pattern's jth bit corresponds to qubit j
    for j, c in enumerate(controls):
        desired = (pattern >> j) & 1
        if desired == 0:
            qc.x(c)

def undo_pattern_controls(qc: QuantumCircuit, controls: Sequence[int], pattern: int, k: int):
    for j, c in enumerate(controls):
        desired = (pattern >> j) & 1
        if desired == 0:
            qc.x(c)

def mcrz_on_pattern(qc: QuantumCircuit, controls: Sequence[int], target: int, angle: float):
    num_ctrl = len(controls)
    if num_ctrl == 0:
        qc.rz(angle, target); return
    gate = RZGate(angle).control(num_ctrl)
    qc.append(gate, list(controls) + [target])

def apply_uniformly_controlled_Rz(qc: QuantumCircuit, target: int, controls: Sequence[int], angle_map: Dict[int, float]):
    k = len(controls)
    for pattern, angle in angle_map.items():
        if abs(angle) < 1e-15:
            continue
        apply_pattern_controls(qc, controls, pattern, k)
        mcrz_on_pattern(qc, controls, target, angle)
        undo_pattern_controls(qc, controls, pattern, k)

def apply_uniformly_controlled_Ry(qc: QuantumCircuit, target: int, controls: Sequence[int], angle_map: Dict[int, float]):
    basis_change_for_Ry(qc, target, inverse=False)
    apply_uniformly_controlled_Rz(qc, target, controls, angle_map)
    basis_change_for_Ry(qc, target, inverse=True)

In [70]:
# Assemble (LSB convention)
def prepare_state_circuit_lsb(psi: np.ndarray):
    psi = normalize(psi)
    N = len(psi); n = num_qubits_from_len(N)
    qc = QuantumCircuit(n, name="Prep(|psi>)")
    thetas = magnitude_angle_table_lsb(psi)
    for k in range(n):
        controls = list(range(k))  # [0..k-1]
        target = k
        apply_uniformly_controlled_Ry(qc, target, controls, thetas[k])
    phis = phase_angle_table_exact_lsb(psi)
    for k in range(n):
        controls = list(range(k))
        target = k
        apply_uniformly_controlled_Rz(qc, target, controls, phis[k])
    return qc, thetas, phis

def global_phase_align(vec: np.ndarray, ref: np.ndarray) -> np.ndarray:
    ip = np.vdot(ref, vec)
    if np.abs(ip) < 1e-15: return vec
    return vec * (np.conj(ip)/np.abs(ip))

def check_prep_lsb(psi: np.ndarray, verbose=True):
    psi = normalize(psi)
    n = num_qubits_from_len(len(psi))
    qc, thetas, phis = prepare_state_circuit_lsb(psi)
    sv0 = Statevector.from_label('0'*n)
    out = sv0.evolve(qc).data
    out_aligned = global_phase_align(out, psi)
    err = np.linalg.norm(out_aligned - psi)
    if verbose:
        print(qc)
        print("theta tables (LSB):", thetas)
        print("phi tables (LSB):", phis)
        print("||error||_2 =", float(err))
    return float(err), qc, thetas, phis

In [75]:
# Fixed tests (same states)

psi1 = np.array([0.5, (np.sqrt(3)/2)*np.exp(1j*0.4)], dtype=complex)
e1, qc1, th1, ph1 = check_prep_lsb(psi1, verbose=True)

psi2 = np.array([1.0, 1.0j, -1.0, np.exp(1j*0.3)], dtype=complex)/2.0
e2, qc2, th2, ph2 = check_prep_lsb(psi2, verbose=True)

phases = np.zeros(8, dtype=float)
for idx in range(8):
    b0 = (idx>>0)&1; b1=(idx>>1)&1; b2=(idx>>2)&1  # LSB order
    phases[idx] = 0.1*b2 + 0.2*b1 + 0.05*b0 + (0.08 if (b2 and b1) else 0.0)
psi3 = np.exp(1j*phases) / np.sqrt(8)
e3, qc3, th3, ph3 = check_prep_lsb(psi3, verbose=True)

   ┌─────┐┌───┐┌──────────┐┌───┐┌───┐┌─────────┐
q: ┤ Sdg ├┤ H ├┤ Rz(2π/3) ├┤ H ├┤ S ├┤ Rz(0.4) ├
   └─────┘└───┘└──────────┘└───┘└───┘└─────────┘
theta tables (LSB): [{0: 2.0943951023931957}]
phi tables (LSB): [{0: 0.3999999999999999}]
||error||_2 = 3.002735944869808e-16
     ┌─────┐┌───┐┌─────────┐┌───┐┌───┐┌───┐           ┌───┐           »
q_0: ┤ Sdg ├┤ H ├┤ Rz(π/2) ├┤ H ├┤ S ├┤ X ├─────■─────┤ X ├─────■─────»
     ├─────┤├───┤└─────────┘└───┘└───┘└───┘┌────┴────┐└───┘┌────┴────┐»
q_1: ┤ Sdg ├┤ H ├──────────────────────────┤ Rz(π/2) ├─────┤ Rz(π/2) ├»
     └─────┘└───┘                          └─────────┘     └─────────┘»
«     ┌────────────┐┌───┐          ┌───┐               
«q_0: ┤ Rz(2.5062) ├┤ X ├────■─────┤ X ├───────■───────
«     └───┬───┬────┘├───┤┌───┴────┐└───┘┌──────┴──────┐
«q_1: ────┤ H ├─────┤ S ├┤ Rz(-π) ├─────┤ Rz(-1.2708) ├
«         └───┘     └───┘└────────┘     └─────────────┘
theta tables (LSB): [{0: 1.5707963267948966}, {0: 1.5707963267948968, 1: 1.570796326794