In [2]:
import numpy as np
import os
from pathlib import Path
import json # For potentially saving/loading parameters if we had them

# --- Configuration (should match the data generation script) ---
NUM_QUBITS = 6 # Or whatever NUM_QUBITS you used for data generation
RNG_SEED = 42   # Same seed to regenerate the original W-state

# --- Define Data Directory (where measurements are stored) ---
DATA_DIR = Path("./w_aug_tomography_data") # Must match previous notebook
if not DATA_DIR.exists():
    print(f"ERROR: Data directory {DATA_DIR.resolve()} not found!")
    print("Please run the data generation notebook first.")
    # You might want to raise an error or stop execution here in a real script
else:
    print(f"Reading data from: {DATA_DIR.resolve()}")

# --- Initialize RNG for regenerating the original W-state ---
rng_phases_original = np.random.default_rng(RNG_SEED)

print(f"Number of qubits for reconstruction: {NUM_QUBITS}")

Reading data from: /Users/Tonni/Desktop/master-code/neural-quantum-tomo/case_studies/w_phase_augmented/w_aug_tomography_data
Number of qubits for reconstruction: 6


In [1]:
def format_bytes(b):
    """Utility function to format bytes into KB, MB, GB, etc."""
    for u in ['B', 'KB', 'MB', 'GB', 'TB']:
        if abs(b) < 1024.0:
            return f"{b:.2f} {u}"
        b /= 1024.0
    return f"{b:.2f} PB"

In [None]:
print(f"Regenerating the original {NUM_QUBITS}-qubit phase-augmented W state for comparison...")
state_dim = 1 << NUM_QUBITS
w_aug_original = np.zeros(state_dim, dtype=complex)

# Generate random phases using the same RNG seed as in data generation
thetas_original = rng_phases_original.uniform(0, 2 * np.pi, size=NUM_QUBITS)

for k in range(NUM_QUBITS):
    idx = 1 << (NUM_QUBITS - 1 - k)
    w_aug_original[idx] = np.exp(1j * thetas_original[k]) / np.sqrt(NUM_QUBITS)

print(f"Size of original W state vector: {format_bytes(w_aug_original.nbytes)}")
print(f"Original W state norm: {np.linalg.norm(w_aug_original):.6f}")

In [3]:
all_measurement_bases_strings = []
amplitude_basis_str = 'Z' * NUM_QUBITS
all_measurement_bases_strings.append(amplitude_basis_str)

for i in range(NUM_QUBITS - 1):
    basis_list = ['Z'] * NUM_QUBITS; basis_list[i] = 'X'; basis_list[i+1] = 'X'
    all_measurement_bases_strings.append("".join(basis_list))
for i in range(NUM_QUBITS - 1):
    basis_list = ['Z'] * NUM_QUBITS; basis_list[i] = 'X'; basis_list[i+1] = 'Y'
    all_measurement_bases_strings.append("".join(basis_list))

print(f"Expecting data for {len(all_measurement_bases_strings)} bases.")

Expecting data for 11 bases.


In [4]:
# We need to know NUM_SAMPLES_PER_BASIS from the filename if not hardcoded
# For simplicity, let's try to infer it from one of the filenames or assume it's known.
# This is a bit fragile; ideally, metadata would be saved with the data.

# Try to find a file to infer NUM_SAMPLES_PER_BASIS
num_samples_inferred = None
first_basis_str = all_measurement_bases_strings[0]
expected_pattern = f"w_aug_{first_basis_str}_*.txt"
try:
    # Find one file matching the pattern for the first basis
    matching_files = list(DATA_DIR.glob(expected_pattern))
    if matching_files:
        fname_parts = matching_files[0].name.split('_') # e.g., ['w', 'aug', 'ZZZZZZ', '20.txt']
        num_samples_inferred = int(fname_parts[-1].replace('.txt', ''))
        print(f"Inferred NUM_SAMPLES_PER_BASIS = {num_samples_inferred} from filenames.")
    else:
        print(f"Warning: Could not find files matching {expected_pattern} to infer NUM_SAMPLES_PER_BASIS.")
        # Fallback or error
        num_samples_inferred = 20 # Example: Set a default if not found
        print(f"Using default NUM_SAMPLES_PER_BASIS = {num_samples_inferred}")

except Exception as e:
    print(f"Error inferring NUM_SAMPLES_PER_BASIS: {e}")
    num_samples_inferred = 20 # Fallback
    print(f"Using default NUM_SAMPLES_PER_BASIS = {num_samples_inferred}")


all_datasets = {} # Dictionary to store data: {basis_string: [list_of_measurements]}

print("\n--- Ingesting Measurement Data ---")
for basis_str in all_measurement_bases_strings:
    # Filename format: w_aug_<basis_string>_<num_samples>.txt
    filename = DATA_DIR / f"w_aug_{basis_str}_{num_samples_inferred}.txt"
    dataset_for_basis = []
    if filename.exists():
        print(f"Reading data for basis: {basis_str} from {filename.name}")
        with open(filename, 'r') as f_in:
            for line in f_in:
                # Measurements are stored as 'XxY...'
                dataset_for_basis.append(line.strip())
        all_datasets[basis_str] = dataset_for_basis
        print(f"  Read {len(dataset_for_basis)} samples.")
    else:
        print(f"Warning: Measurement file not found for basis {basis_str}: {filename}")

if not all_datasets:
    print("ERROR: No measurement data was loaded. Cannot proceed.")
else:
    print("\n--- Data Ingestion Complete ---")
    print(f"Loaded data for {len(all_datasets)} bases.")
    # Example: print data for one basis
    # example_basis = list(all_datasets.keys())[0]
    # print(f"First 5 samples for basis {example_basis}: {all_datasets[example_basis][:5]}")

Inferred NUM_SAMPLES_PER_BASIS = 20 from filenames.

--- Ingesting Measurement Data ---
Reading data for basis: ZZZZZZ from w_aug_ZZZZZZ_20.txt
  Read 20 samples.
Reading data for basis: XXZZZZ from w_aug_XXZZZZ_20.txt
  Read 20 samples.
Reading data for basis: ZXXZZZ from w_aug_ZXXZZZ_20.txt
  Read 20 samples.
Reading data for basis: ZZXXZZ from w_aug_ZZXXZZ_20.txt
  Read 20 samples.
Reading data for basis: ZZZXXZ from w_aug_ZZZXXZ_20.txt
  Read 20 samples.
Reading data for basis: ZZZZXX from w_aug_ZZZZXX_20.txt
  Read 20 samples.
Reading data for basis: XYZZZZ from w_aug_XYZZZZ_20.txt
  Read 20 samples.
Reading data for basis: ZXYZZZ from w_aug_ZXYZZZ_20.txt
  Read 20 samples.
Reading data for basis: ZZXYZZ from w_aug_ZZXYZZ_20.txt
  Read 20 samples.
Reading data for basis: ZZZXYZ from w_aug_ZZZXYZ_20.txt
  Read 20 samples.
Reading data for basis: ZZZZXY from w_aug_ZZZZXY_20.txt
  Read 20 samples.

--- Data Ingestion Complete ---
Loaded data for 11 bases.


In [5]:
print("\n--- Estimating Properties from Measurement Statistics ---")

# 1. Amplitude Distribution from 'ZZ...Z' basis
z_basis_str = 'Z' * NUM_QUBITS
if z_basis_str in all_datasets:
    z_measurements = all_datasets[z_basis_str]

    # Convert 'Z'/'z' to binary strings '1'/'0' (or any consistent numeric representation)
    # 'Z' (+1 eigenvalue) -> maps to computational |0> if spin-up is |0>
    # 'z' (-1 eigenvalue) -> maps to computational |1> if spin-down is |1>
    # Let's assume 'Z' means outcome '0' (for +1 of sigma_z) and 'z' means '1' (for -1 of sigma_z)
    # This matches the output_bits of sample_state_fast

    binary_outcomes_z_basis = []
    for m_str in z_measurements:
        bits = []
        for char_idx, char_val in enumerate(m_str):
            if char_val == 'Z': bits.append('0')
            elif char_val == 'z': bits.append('1')
            else: bits.append('?') # Should not happen for ZZZ...Z basis
        binary_outcomes_z_basis.append("".join(bits))

    # Count frequencies
    from collections import Counter
    z_outcome_counts = Counter(binary_outcomes_z_basis)

    print(f"\n--- Statistics for ZZZ...Z basis ({len(z_measurements)} samples) ---")
    print(f"Number of unique outcomes observed: {len(z_outcome_counts)}")

    # Expected W-state outcomes (binary representations)
    expected_w_binary_outcomes = []
    for k in range(NUM_QUBITS):
        bits = ['0'] * NUM_QUBITS
        bits[k] = '1' # If W-state is sum of |..1..> where 1 is at k-th position (MSB=0)
        # And if '1' in binary_outcomes maps to where W-state amp is.
        # Original W-state: idx = 1 << (NUM_QUBITS - 1 - k) means 1 at (NUM_QUBITS-1-k) from LSB.
        # If binary_outcomes_z_basis[k] is k-th char, and k=0 is MSB.
        # A '1' at position `k` (MSB) means index 2^(N-1-k).
        expected_w_binary_outcomes.append("".join(bits))

    # This mapping needs to be careful:
    # Our w_aug_initial[idx] = value, where idx = 1 << (NUM_QUBITS - 1 - k_msb_pos_of_1)
    # A binary string "b0b1b2..." (b0 is MSB) corresponds to int(val, 2)
    # If W-state is |100> + |010> + |001>
    # k_msb_pos_of_1 = 0 -> "100..." -> index 2^(N-1)
    # k_msb_pos_of_1 = 1 -> "010..." -> index 2^(N-2)
    # So, the expected binary strings are those with a single '1'.

    print("Top 10 most frequent outcomes in Z-basis:")
    total_z_samples = len(z_measurements)
    for outcome, count in z_outcome_counts.most_common(10):
        is_w_component = (outcome.count('1') == 1) # Simple check for W-state like components
        prob = count / total_z_samples
        print(f"  Outcome: {outcome} (W-like: {is_w_component}) - Count: {count} (Prob: {prob:.4f})")

    # Check sum of probabilities for expected W-components
    prob_sum_w_components = 0
    for k in range(NUM_QUBITS):
        # Construct the binary string for |...1 (at k_th MSB)...>
        temp_bits = ['0'] * NUM_QUBITS
        temp_bits[k] = '1'
        w_outcome_k_str = "".join(temp_bits)
        prob_sum_w_components += z_outcome_counts.get(w_outcome_k_str, 0) / total_z_samples
    print(f"Sum of probabilities for single '1' (W-like) components in Z-basis: {prob_sum_w_components:.4f} (expected close to 1.0)")

else:
    print("Warning: Data for ZZZ...Z basis not found in datasets.")


# 2. Two-qubit correlators
# Example: <sigma_0^X sigma_1^X> from basis "XXZ...Z"
xx_basis_str_list = ['Z'] * NUM_QUBITS
if NUM_QUBITS >= 2:
    xx_basis_str_list[0] = 'X'
    xx_basis_str_list[1] = 'X'
    xx_basis_str = "".join(xx_basis_str_list)

    if xx_basis_str in all_datasets:
        xx_measurements = all_datasets[xx_basis_str]
        count_XX = 0; count_Xx = 0; count_xX = 0; count_xx = 0
        for m_str in xx_measurements:
            q0_outcome = m_str[0] # Assuming first char is for qubit 0
            q1_outcome = m_str[1] # Assuming second char is for qubit 1

            if q0_outcome == 'X' and q1_outcome == 'X': count_XX += 1
            elif q0_outcome == 'X' and q1_outcome == 'x': count_Xx += 1
            elif q0_outcome == 'x' and q1_outcome == 'X': count_xX += 1
            elif q0_outcome == 'x' and q1_outcome == 'x': count_xx += 1

        if len(xx_measurements) > 0:
            corr_X0X1 = (count_XX - count_Xx - count_xX + count_xx) / len(xx_measurements)
            print(f"\n--- Statistics for {xx_basis_str} basis ({len(xx_measurements)} samples) ---")
            print(f"  Counts for (q0, q1): XX={count_XX}, Xx={count_Xx}, xX={count_xX}, xx={count_xx}")
            print(f"  Estimated <sigma_0^X sigma_1^X>: {corr_X0X1:.4f}")
        else:
            print(f"Warning: No samples found for {xx_basis_str} basis, though key exists.")

    else:
        print(f"Warning: Data for {xx_basis_str} basis not found.")
else:
    print("Skipping two-qubit correlator example (NUM_QUBITS < 2).")

# Example: <sigma_0^X sigma_1^Y> from basis "XYZ...Z"
xy_basis_str_list = ['Z'] * NUM_QUBITS
if NUM_QUBITS >= 2:
    xy_basis_str_list[0] = 'X'
    xy_basis_str_list[1] = 'Y'
    xy_basis_str = "".join(xy_basis_str_list)

    if xy_basis_str in all_datasets:
        xy_measurements = all_datasets[xy_basis_str]
        # For <X0 Y1>:
        # XY (+1,+1) -> +1 contribution to correlator
        # Xy (+1,-1) -> -1
        # xY (-1,+1) -> -1
        # xy (-1,-1) -> +1
        count_XY = 0; count_Xy = 0; count_xY = 0; count_xy = 0
        for m_str in xy_measurements:
            q0_outcome = m_str[0]
            q1_outcome = m_str[1]

            if q0_outcome == 'X' and q1_outcome == 'Y': count_XY += 1
            elif q0_outcome == 'X' and q1_outcome == 'y': count_Xy += 1
            elif q0_outcome == 'x' and q1_outcome == 'Y': count_xY += 1
            elif q0_outcome == 'x' and q1_outcome == 'y': count_xy += 1

        if len(xy_measurements) > 0:
            corr_X0Y1 = (count_XY - count_Xy - count_xY + count_xy) / len(xy_measurements)
            print(f"\n--- Statistics for {xy_basis_str} basis ({len(xy_measurements)} samples) ---")
            print(f"  Counts for (q0, q1): XY={count_XY}, Xy={count_Xy}, xY={count_xY}, xy={count_xy}")
            print(f"  Estimated <sigma_0^X sigma_1^Y>: {corr_X0Y1:.4f}")
        else:
            print(f"Warning: No samples found for {xy_basis_str} basis, though key exists.")
    else:
        print(f"Warning: Data for {xy_basis_str} basis not found.")


--- Estimating Properties from Measurement Statistics ---

--- Statistics for ZZZ...Z basis (20 samples) ---
Number of unique outcomes observed: 1
Top 10 most frequent outcomes in Z-basis:
  Outcome: ?????? (W-like: False) - Count: 20 (Prob: 1.0000)
Sum of probabilities for single '1' (W-like) components in Z-basis: 0.0000 (expected close to 1.0)

--- Statistics for XXZZZZ basis (20 samples) ---
  Counts for (q0, q1): XX=7, Xx=2, xX=6, xx=5
  Estimated <sigma_0^X sigma_1^X>: 0.2000

--- Statistics for XYZZZZ basis (20 samples) ---
  Counts for (q0, q1): XY=0, Xy=0, xY=0, xy=0
  Estimated <sigma_0^X sigma_1^Y>: 0.0000
