In [8]:
import numpy as np

from pathlib import Path

In [2]:

def sample_state_fast(state, pauli_labels, *, rng=np.random.default_rng()):
    """
    Single-shot measurement of |state⟩ in the tensor-product basis defined by
    pauli_labels, without constructing the 4^n-sized eigenbasis.

    Parameters
    ----------
    state : np.ndarray, shape=(2**n,), complex
        Normalised state vector.  Will be *modified in place* (projection).
    pauli_labels : Sequence[str]
        One character per qubit: 'I', 'X', 'Y' or 'Z'.
    rng : np.random.Generator, optional
        RNG (defaults to NumPy’s PCG64).

    Returns
    -------
    str
        Measurement bit-string (‘0’ = +1 eigenvalue, ‘1’ = −1).
    """
    n = len(pauli_labels)
    psi = state.reshape((2,) * n)           # view |ψ⟩ as n-dimensional tensor
    outcome_bits = []
    sqrt2 = np.sqrt(2.0)

    for q, pl in enumerate(pauli_labels):
        if pl == 'I':
            outcome_bits.append('0')
            continue

        axis = n - 1 - q                    # axis 0 = MSB
        psi0 = np.take(psi, 0, axis=axis)   # slice with qubit q = |0⟩
        psi1 = np.take(psi, 1, axis=axis)   # slice with qubit q = |1⟩

        # --- probability p0 of getting eigenvalue +1 -----------------------
        if pl == 'Z':
            p0 = np.sum(np.abs(psi0) ** 2)
        elif pl == 'X':
            p0 = np.sum(np.abs((psi0 + psi1) / sqrt2) ** 2)
        elif pl == 'Y':
            p0 = np.sum(np.abs((psi0 + 1j * psi1) / sqrt2) ** 2)
        else:
            raise ValueError(f"Invalid Pauli label: {pl}")

        is_zero = rng.random() < p0
        outcome_bits.append('0' if is_zero else '1')

        # --- collapse & renormalise |ψ⟩ ------------------------------------
        if pl == 'Z':
            if is_zero:
                psi1[...] = 0
                psi /= np.sqrt(p0)
            else:
                psi0[...] = 0
                psi /= np.sqrt(1.0 - p0)

        else:  # X or Y
            if pl == 'X':
                plus  = (psi0 + psi1) / sqrt2
                minus = (psi0 - psi1) / sqrt2
                phase = 1.0
            else:  # Y
                plus  = (psi0 + 1j * psi1) / sqrt2
                minus = (psi0 - 1j * psi1) / sqrt2
                phase = 1j

            if is_zero:      # outcome ‘0’ (+1 eigenvalue)
                norm = np.sqrt(p0)
                psi0[...] = plus / norm
                psi1[...] =  plus / norm if pl == 'X' else -phase * plus / norm
            else:            # outcome ‘1’ (−1 eigenvalue)
                norm = np.sqrt(1.0 - p0)
                psi0[...] =  minus / norm
                psi1[...] = -minus / norm if pl == 'X' else  phase * minus / norm

    return ''.join(outcome_bits)


In [3]:
# constructing the phase augmented w state

num_qubits = 20

state_dim = 1 << num_qubits                     # bit shifting a ...0001 bitstring is the same as 2**a
w_aug = np.zeros(state_dim, dtype=complex)      # empty state vector

# since the W state has only non-zero amplitudes for one-hot states, we only need num_qubits random phases
rng = np.random.default_rng(42)
thetas = rng.uniform(0, 2*np.pi, size=num_qubits)

for j in range(num_qubits):
    idx = 1 << (num_qubits - 1 - j)              # find indexing mask via bit shifting

    # apply the phase to the corresponding amplitude coefficient
    w_aug[idx] = np.exp(1j * thetas[j]) / np.sqrt(num_qubits)


def format_bytes(b):
    for u in ['B', 'KB', 'MB', 'GB', 'TB']:
        if b < 1024:
            return f"{b:.2f} {u}"
        b /= 1024


print(f"Size of state vector in memory: {format_bytes(w_aug.nbytes)} \n")

for i in range(10):
    print(f"{i:0{num_qubits}b}: {w_aug[i]:.2f} + {w_aug[i].imag:.2f}j")

Size of state vector in memory: 16.00 MB 

00000000000000000000: 0.00+0.00j + 0.00j
00000000000000000001: -0.15-0.16j + -0.16j
00000000000000000010: 0.10-0.20j + -0.20j
00000000000000000011: 0.00+0.00j + 0.00j
00000000000000000100: 0.21+0.09j + 0.09j
00000000000000000101: 0.00+0.00j + 0.00j
00000000000000000110: 0.00+0.00j + 0.00j
00000000000000000111: 0.00+0.00j + 0.00j
00000000000000001000: -0.21-0.08j + -0.08j
00000000000000001001: 0.00+0.00j + 0.00j


In [4]:
pauli_dirs = rng.choice(['X', 'Y', 'Z'], size=num_qubits)
print(pauli_dirs)


['X' 'Z' 'Z' 'Y' 'X' 'Z' 'Y' 'Z' 'Z' 'Z' 'Z' 'X' 'Y' 'Y' 'Y' 'X' 'Y' 'X'
 'Z' 'Z']


In [7]:
# sample the state
result = sample_state_fast(w_aug, pauli_dirs, rng=rng)

pauli_dirs = ''.join(pauli_dirs)
print(pauli_dirs)
print(result)

XZZYXZYZZZZXYYYXYXZZ
00000000000010001000


In [9]:
data_dir = Path("./data")
data_dir.mkdir(parents=True, exist_ok=True)

print(f"Data will be saved to {data_dir.resolve()}")

Data will be saved to /Users/Tonni/Desktop/master-code/neural-quantum-tomo/case_studies/w_phase_augmented/data


In [10]:
# --- Define all Measurement Bases ---
all_measurement_bases = []

# 1. Amplitude basis (all Z) - Standard basis for amplitudes
amplitude_basis = 'Z' * num_qubits
all_measurement_bases.append(amplitude_basis)

# 2. XX type bases (N-1 of them)
for i in range(num_qubits - 1):
    basis_list = ['Z'] * num_qubits
    basis_list[i] = 'X'
    basis_list[i+1] = 'X'
    all_measurement_bases.append("".join(basis_list))

# 3. XY type bases (N-1 of them)
for i in range(num_qubits - 1):
    basis_list = ['Z'] * num_qubits
    basis_list[i] = 'X'
    basis_list[i+1] = 'Y'
    all_measurement_bases.append("".join(basis_list))

print(f"Total number of measurement bases: {len(all_measurement_bases)}")
if len(all_measurement_bases) != (1 + 2 * (num_qubits - 1)):
    print(f"Warning: Expected {1 + 2 * (num_qubits - 1)} bases, but got {len(all_measurement_bases)}")
print("-" * 30)

Total number of measurement bases: 39
------------------------------


In [None]:
# --- Generate and Store Measurements ---

import os
import time
from collections import Counter

output_directory = data_dir

rng_seed = np.random.default_rng(42)


rng_for_sampling = np.random.default_rng(rng_seed + 1) # Different seed for sampling

total_start_time = time.time()

for basis_idx, pauli_dirs_str in enumerate(all_measurement_bases):
    basis_start_time = time.time()
    print(f"Processing basis {basis_idx + 1}/{len(all_measurement_bases)}: {pauli_dirs_str}")

    # Output filename format: w_aug_<basis_string>_<num_samples>.txt
    filename = os.path.join(output_directory, f"w_aug_{pauli_dirs_str}_{num_samples_per_basis}.txt")

    pauli_labels_list = list(pauli_dirs_str) # sample_state_fast expects a list/sequence

    with open(filename, 'w') as f_out:
        for sample_num in range(num_samples_per_basis):
            if (sample_num + 1) % (num_samples_per_basis // 10 if num_samples_per_basis >=10 else 1) == 0:
                print(f"  Sample {sample_num + 1}/{num_samples_per_basis} for basis {pauli_dirs_str}", end='\r')

            # IMPORTANT: Create a fresh copy of the state for EACH measurement
            # because sample_state_fast modifies it in-place.
            # This is the most time-consuming part for large N.
            state_copy = w_aug.copy()

            measurement_bits = sample_state_fast(state_copy, pauli_labels_list, rng=rng_for_sampling)

            # Convert '0' (+1 eigenvalue) to 'X' (uppercase)
            # Convert '1' (-1 eigenvalue) to 'x' (lowercase)
            formatted_measurement = "".join(['X' if bit == '0' else 'x' for bit in measurement_bits])
            f_out.write(formatted_measurement + "\n")
        print(f"  Finished {num_samples_per_basis} samples for basis {pauli_dirs_str}. Saved to {filename}.") # Clear the \r
    basis_end_time = time.time()
    print(f"  Time for basis {pauli_dirs_str}: {basis_end_time - basis_start_time:.2f} seconds.")
    print("-" * 10)

total_end_time = time.time()
print(f"\nAll measurements generated. Total time: {total_end_time - total_start_time:.2f} seconds.")
print(f"Output files are in the directory: {output_directory}")