In [1]:
import numpy as np
import os
import time
from pathlib import Path # For data directory handling

# --- Configuration ---
NUM_QUBITS = 6  # Start small for testing (e.g., 4-6). Paper uses 20.
NUM_SAMPLES_PER_BASIS = 20 # Start small. Paper uses 6400.
RNG_SEED = 42 # For reproducibility

# --- Define Data Directory ---
DATA_DIR = Path("./w_aug_tomography_data") # Changed name slightly for clarity
DATA_DIR.mkdir(parents=True, exist_ok=True)
print(f"Data will be saved to: {DATA_DIR.resolve()}")

# --- Initialize Global RNG for consistent phase generation and sampling ---
# One RNG for generating the W-state phases
rng_phases = np.random.default_rng(RNG_SEED)
# Another RNG for the measurement process (can be the same seed or different)
rng_sampling = np.random.default_rng(RNG_SEED + 1) # Offset seed for sampling

print(f"Number of qubits: {NUM_QUBITS}")
print(f"Samples per basis: {NUM_SAMPLES_PER_BASIS}")
print(f"RNG Seed for phases: {RNG_SEED}, for sampling: {RNG_SEED + 1}")

Data will be saved to: /Users/Tonni/Desktop/master-code/neural-quantum-tomo/case_studies/w_phase_augmented/w_aug_tomography_data
Number of qubits: 6
Samples per basis: 20
RNG Seed for phases: 42, for sampling: 43


In [2]:
def format_bytes(b):
    """Utility function to format bytes into KB, MB, GB, etc."""
    for u in ['B', 'KB', 'MB', 'GB', 'TB']:
        if abs(b) < 1024.0:
            return f"{b:.2f} {u}"
        b /= 1024.0
    return f"{b:.2f} PB" # Should not happen unless very large

In [3]:
def sample_state_fast(state_vector_to_measure, pauli_labels, *, rng_instance):
    """
    Single-shot measurement of |state_vector_to_measure⟩ in the tensor-product
    basis defined by pauli_labels.

    Parameters
    ----------
    state_vector_to_measure : np.ndarray, shape=(2**n,), complex
        Normalised state vector. This function *modifies this array in place*
        after projecting it.
    pauli_labels : Sequence[str]
        One character per qubit: 'X', 'Y' or 'Z'. (No 'I' expected for tomography bases)
    rng_instance : np.random.Generator
        RNG instance to use for sampling.

    Returns
    -------
    str
        Measurement bit-string (‘0’ = +1 eigenvalue, ‘1’ = −1 eigenvalue).
    """
    n = len(pauli_labels)
    # Reshape the input state_vector_to_measure for multi-dimensional indexing.
    # This operation itself does not copy data if state_vector_to_measure is contiguous.
    psi = state_vector_to_measure.reshape((2,) * n)
    outcome_bits = []
    sqrt2 = np.sqrt(2.0)

    for q_idx, pauli_op in enumerate(pauli_labels):
        # The axis in the n-dimensional tensor corresponding to qubit q_idx.
        # If q_idx=0 is MSB, axis becomes n-1. If q_idx=0 is LSB, axis is 0.
        # Let's assume q_idx=0 is the first qubit in pauli_labels, conventionally LSB for tensor products.
        # If your state vector |abc> has a as MSB, then a corresponds to axis 0 of psi.
        # The paper's W-state construction 1<<(num_qubits - 1 -j) implies j=0 is MSB.
        # So, if pauli_labels[0] acts on MSB, then current_axis = 0 for q_idx=0.
        # Let's stick to your original: axis 0 = MSB, so for qubit `q`, use axis `n - 1 - q` if `q` is MSB-indexed.
        # Or, if pauli_labels[q] refers to qubit `q` (0 to n-1), then current_axis = q assuming LSB first in tensor.
        # Let's be explicit: pauli_labels[k] acts on the k-th qubit.
        # If psi is (qubit_0, qubit_1, ..., qubit_n-1), then current_axis = k
        current_axis = q_idx # Assuming pauli_labels[q_idx] acts on the q_idx-th dimension of psi

        # Slices for qubit q_idx being |0> and |1>
        # These are views, not copies.
        psi_slice_0 = np.take(psi, 0, axis=current_axis)
        psi_slice_1 = np.take(psi, 1, axis=current_axis)

        # --- Probability p0 of measuring eigenvalue +1 ---
        if pauli_op == 'Z':
            # For Z, <0|H|0> and <1|H|1> components
            prob_plus_one = np.sum(np.abs(psi_slice_0)**2)
        elif pauli_op == 'X':
            # For X, |+> = (|0> + |1>)/sqrt(2)
            # Project current psi onto <+| state for this qubit
            projected_onto_plus_X = (psi_slice_0 + psi_slice_1) / sqrt2
            prob_plus_one = np.sum(np.abs(projected_onto_plus_X)**2)
        elif pauli_op == 'Y':
            # For Y, |+i> = (|0> + i|1>)/sqrt(2)
            # Project current psi onto <+i| state for this qubit
            projected_onto_plus_Y = (psi_slice_0 + 1j * psi_slice_1) / sqrt2
            prob_plus_one = np.sum(np.abs(projected_onto_plus_Y)**2)
        else:
            raise ValueError(f"Invalid Pauli label: {pauli_op}. Expected 'X', 'Y', or 'Z'.")

        # Ensure probability is well-behaved due to potential floating point inaccuracies
        prob_plus_one = np.clip(prob_plus_one, 0.0, 1.0)

        # Sample the outcome
        measured_plus_one = rng_instance.random() < prob_plus_one
        outcome_bits.append('0' if measured_plus_one else '1')

        # --- Collapse and renormalize the state `psi` (which is a view of state_vector_to_measure) ---
        # Selectors to address the parts of psi corresponding to qubit q_idx being 0 or 1
        selector_q_is_0 = [slice(None)] * n
        selector_q_is_0[current_axis] = 0
        selector_q_is_1 = [slice(None)] * n
        selector_q_is_1[current_axis] = 1

        if pauli_op == 'Z':
            if measured_plus_one: # Collapsed to |0> for qubit q_idx
                psi[tuple(selector_q_is_1)] = 0.0 # Set components where qubit q_idx is 1 to zero
                renorm_factor = np.sqrt(prob_plus_one)
            else: # Collapsed to |1> for qubit q_idx
                psi[tuple(selector_q_is_0)] = 0.0 # Set components where qubit q_idx is 0 to zero
                renorm_factor = np.sqrt(1.0 - prob_plus_one)
        elif pauli_op == 'X':
            # Need original components to reconstruct the projected state
            psi0_original_for_reconstruction = psi_slice_0.copy()
            psi1_original_for_reconstruction = psi_slice_1.copy()
            if measured_plus_one: # Collapsed to |+>_X = (|0> + |1>)/sqrt(2)
                # State becomes proportional to sum of original |0> and |1> parts
                projected_state_component = (psi0_original_for_reconstruction + psi1_original_for_reconstruction) / sqrt2
                psi[tuple(selector_q_is_0)] = projected_state_component / sqrt2 # |0> component of |+>
                psi[tuple(selector_q_is_1)] = projected_state_component / sqrt2 # |1> component of |+>
                renorm_factor = np.sqrt(prob_plus_one)
            else: # Collapsed to |->_X = (|0> - |1>)/sqrt(2)
                projected_state_component = (psi0_original_for_reconstruction - psi1_original_for_reconstruction) / sqrt2
                psi[tuple(selector_q_is_0)] = projected_state_component / sqrt2 # |0> component of |->
                psi[tuple(selector_q_is_1)] = -projected_state_component / sqrt2# |1> component of |->
                renorm_factor = np.sqrt(1.0 - prob_plus_one)
        elif pauli_op == 'Y':
            psi0_original_for_reconstruction = psi_slice_0.copy()
            psi1_original_for_reconstruction = psi_slice_1.copy()
            if measured_plus_one: # Collapsed to |+i>_Y = (|0> + i|1>)/sqrt(2)
                projected_state_component = (psi0_original_for_reconstruction + 1j * psi1_original_for_reconstruction) / sqrt2
                psi[tuple(selector_q_is_0)] = projected_state_component / sqrt2      # |0> component of |+i>
                psi[tuple(selector_q_is_1)] = -1j * projected_state_component / sqrt2 # |1> component of |+i>
                renorm_factor = np.sqrt(prob_plus_one)
            else: # Collapsed to |-i>_Y = (|0> - i|1>)/sqrt(2)
                projected_state_component = (psi0_original_for_reconstruction - 1j * psi1_original_for_reconstruction) / sqrt2
                psi[tuple(selector_q_is_0)] = projected_state_component / sqrt2      # |0> component of |-i>
                psi[tuple(selector_q_is_1)] = 1j * projected_state_component / sqrt2 # |1> component of |-i>
                renorm_factor = np.sqrt(1.0 - prob_plus_one)

        if renorm_factor > 1e-9: # Avoid division by zero if probability was tiny
            psi /= renorm_factor
        else:
            # This case should ideally not happen if the state was valid and prob_plus_one was calculated correctly.
            # If it does, it means the probability of this outcome was (near) zero.
            # For safety, if renorm_factor is ~0, the state psi is already mostly zero where it should be.
            # We might need to re-normalize based on the remaining non-zero parts, but this is tricky.
            # The paper's approach implies that such outcomes are just very rare.
            # Let's ensure psi is normalized if it got scaled down.
            current_norm_sq = np.sum(np.abs(psi)**2)
            if current_norm_sq > 1e-9:
                psi /= np.sqrt(current_norm_sq)
            # else:
            # print(f"Warning: Renormalization factor near zero for qubit {q_idx}, op {pauli_op}, outcome {'0' if measured_plus_one else '1'}. State may be invalid.")


    return "".join(outcome_bits)

In [4]:
print(f"Constructing {NUM_QUBITS}-qubit phase-augmented W state...")
state_dim = 1 << NUM_QUBITS  # 2**NUM_QUBITS
w_aug_initial = np.zeros(state_dim, dtype=complex)

# Generate random phases using the dedicated RNG
thetas = rng_phases.uniform(0, 2 * np.pi, size=NUM_QUBITS)

# Populate the W state.
# The k-th basis state in the W superposition is |...1...>, where the 1 is at the k-th position (0-indexed).
# If 0-indexed MSB: |100...>, |010...>, ...
# The index for |...1(at k-th pos from MSB)...> is 2**(NUM_QUBITS - 1 - k)
for k in range(NUM_QUBITS):
    # idx corresponds to the state where only the k-th qubit (0-indexed, MSB is qubit 0) is |1>
    # and all others are |0>.
    idx = 1 << (NUM_QUBITS - 1 - k)
    w_aug_initial[idx] = np.exp(1j * thetas[k]) / np.sqrt(NUM_QUBITS)

print(f"Size of initial W state vector in memory: {format_bytes(w_aug_initial.nbytes)}")
# Verify normalization (should be close to 1)
print(f"Initial W state norm: {np.linalg.norm(w_aug_initial):.6f}")

# Optional: Print first few components to verify
# print("\nFirst few components of the W state ( Psi[i] ):")
# for i in range(min(10, state_dim)):
#     if abs(w_aug_initial[i]) > 1e-9: # Print only non-negligible components
#         print(f"Psi[{i:0{NUM_QUBITS}b}] = {w_aug_initial[i].real:.3f} + {w_aug_initial[i].imag:.3f}j")

Constructing 6-qubit phase-augmented W state...
Size of initial W state vector in memory: 1.00 KB
Initial W state norm: 1.000000


In [8]:
all_measurement_bases_strings = []

# 1. Amplitude basis (all Z)
amplitude_basis_str = 'Z' * NUM_QUBITS
all_measurement_bases_strings.append(amplitude_basis_str)

# 2. XX type bases (N-1 of them)
# These measure correlation between adjacent qubits i and i+1 in XX
for i in range(NUM_QUBITS - 1):
    basis_list = ['Z'] * NUM_QUBITS
    basis_list[i] = 'X'
    basis_list[i+1] = 'X'
    all_measurement_bases_strings.append("".join(basis_list))

# 3. XY type bases (N-1 of them)
# These measure correlation between adjacent qubits i and i+1 in XY
for i in range(NUM_QUBITS - 1):
    basis_list = ['Z'] * NUM_QUBITS
    basis_list[i] = 'X'
    basis_list[i+1] = 'Y'
    all_measurement_bases_strings.append("".join(basis_list))

print(f"\nTotal number of measurement bases to process: {len(all_measurement_bases_strings)}")
if len(all_measurement_bases_strings) != (1 + 2 * (NUM_QUBITS - 1)):
    print(f"Warning: Expected {1 + 2 * (NUM_QUBITS - 1)} bases, but got {len(all_measurement_bases_strings)}")
# print("Generated bases strings:")
# for b_str in all_measurement_bases_strings:
#     print(b_str)


Total number of measurement bases to process: 11


In [7]:
# This is Cell 6 from the *FIRST* (data generation) notebook - CORRECTED

print("\n--- Starting Measurement Generation ---")
total_start_time = time.time()

for basis_idx, pauli_dirs_str in enumerate(all_measurement_bases_strings):
    basis_start_time = time.time()
    print(f"\nProcessing Basis {basis_idx + 1}/{len(all_measurement_bases_strings)}: {pauli_dirs_str}")

    filename = DATA_DIR / f"w_aug_{pauli_dirs_str}_{NUM_SAMPLES_PER_BASIS}.txt"
    pauli_labels_for_sampler = list(pauli_dirs_str) # e.g., ['Z', 'X', 'X', 'Z', ...]

    with open(filename, 'w') as f_out:
        for sample_num in range(NUM_SAMPLES_PER_BASIS):
            if (sample_num + 1) % (max(1, NUM_SAMPLES_PER_BASIS // 20)) == 0 or sample_num == 0:
                print(f"  Sample {sample_num + 1}/{NUM_SAMPLES_PER_BASIS}...", end='\r')

            state_to_measure_copy = w_aug_initial.copy()
            measurement_bit_string = sample_state_fast(
                state_to_measure_copy,
                pauli_labels_for_sampler,
                rng_instance=rng_sampling
            )

            # --- CORRECTED ENCODING ---
            formatted_measurement_eigenvalues_list = []
            for qubit_idx, outcome_bit in enumerate(measurement_bit_string):
                measured_pauli = pauli_labels_for_sampler[qubit_idx] # X, Y, or Z for this qubit

                if outcome_bit == '0': # Eigenvalue +1
                    if measured_pauli == 'X':
                        formatted_measurement_eigenvalues_list.append('X')
                    elif measured_pauli == 'Y':
                        formatted_measurement_eigenvalues_list.append('Y')
                    elif measured_pauli == 'Z':
                        formatted_measurement_eigenvalues_list.append('Z')
                    else: # Should not happen
                        formatted_measurement_eigenvalues_list.append('?')
                else: # outcome_bit == '1', Eigenvalue -1
                    if measured_pauli == 'X':
                        formatted_measurement_eigenvalues_list.append('x')
                    elif measured_pauli == 'Y':
                        formatted_measurement_eigenvalues_list.append('y')
                    elif measured_pauli == 'Z':
                        formatted_measurement_eigenvalues_list.append('z')
                    else: # Should not happen
                        formatted_measurement_eigenvalues_list.append('?')

            formatted_measurement_eigenvalues = "".join(formatted_measurement_eigenvalues_list)
            f_out.write(formatted_measurement_eigenvalues + "\n")

        print(f"  Finished {NUM_SAMPLES_PER_BASIS} samples. Saved to {filename.name}.{' '*20}")

    basis_end_time = time.time()
    print(f"  Time for this basis: {basis_end_time - basis_start_time:.2f} seconds.")

total_end_time = time.time()
print(f"\n--- All Measurements Generated (Corrected Encoding) ---")
print(f"Total time: {total_end_time - total_start_time:.2f} seconds.")
print(f"Output files are in the directory: {DATA_DIR.resolve()}")


--- Starting Measurement Generation ---

Processing Basis 1/11: ZZZZZZ
  Finished 20 samples. Saved to w_aug_ZZZZZZ_20.txt.                    
  Time for this basis: 0.00 seconds.

Processing Basis 2/11: XXZZZZ
  Finished 20 samples. Saved to w_aug_XXZZZZ_20.txt.                    
  Time for this basis: 0.00 seconds.

Processing Basis 3/11: ZXXZZZ
  Finished 20 samples. Saved to w_aug_ZXXZZZ_20.txt.                    
  Time for this basis: 0.00 seconds.

Processing Basis 4/11: ZZXXZZ
  Finished 20 samples. Saved to w_aug_ZZXXZZ_20.txt.                    
  Time for this basis: 0.00 seconds.

Processing Basis 5/11: ZZZXXZ
  Finished 20 samples. Saved to w_aug_ZZZXXZ_20.txt.                    
  Time for this basis: 0.00 seconds.

Processing Basis 6/11: ZZZZXX
  Finished 20 samples. Saved to w_aug_ZZZZXX_20.txt.                    
  Time for this basis: 0.00 seconds.

Processing Basis 7/11: XYZZZZ
  Finished 20 samples. Saved to w_aug_XYZZZZ_20.txt.                    
  Time fo