In [1]:
import site
from pathlib import Path
site.addsitedir(str(Path.cwd().parents[2]))


from time import time
import numpy as np

import jax.numpy as jnp
from jax import random

from presentation.measurement.measurement import MultiQubitMeasurement
from presentation.measurement.formatting import format_bytes

rng_seed = 42
print(f"Random seed is {rng_seed}")

Random seed is 42


In [2]:
#### PHASE AUGMENTED GHZ STATE

num_qubits = 4

state_dim = 1 << num_qubits  # same as 2**num_qubits

rng_phase = random.PRNGKey(rng_seed)
ghz_thetas = random.uniform(rng_phase, shape=(2,), minval=0.0, maxval=2 * jnp.pi)
ghz_amplitudes = jnp.exp(1j * ghz_thetas) / jnp.sqrt(2)

# GHZ has support only at |0000⟩ and |1111⟩ → indices 0 and 2**num_qubits - 1
ghz_indices = jnp.array([0, state_dim - 1])

w_aug = jnp.zeros(state_dim, dtype=jnp.complex64).at[ghz_indices].set(ghz_amplitudes)

print("First 10 amplitudes:")
for idx in range(10):
    print(f"{idx:0{num_qubits}b}: {w_aug[idx]:.8f}")

print(f"\nSize of state vector in memory: {format_bytes(w_aug.nbytes)} \n")

First 10 amplitudes:
0000: 0.70396447+0.06658886j
0001: 0.00000000+0.00000000j
0010: 0.00000000+0.00000000j
0011: 0.00000000+0.00000000j
0100: 0.00000000+0.00000000j
0101: 0.00000000+0.00000000j
0110: 0.00000000+0.00000000j
0111: 0.00000000+0.00000000j
1000: 0.00000000+0.00000000j
1001: 0.00000000+0.00000000j

Size of state vector in memory: 128.00 B 



In [3]:
def save_state_vector_columns(w_aug: jnp.ndarray, file_path: str):
    w_np = jnp.stack([jnp.real(w_aug), jnp.imag(w_aug)], axis=1)
    np.savetxt(file_path, np.array(w_np), fmt="%.10f")


file_name = "ghz_state.txt"
save_state_vector_columns(w_aug, file_name)

In [4]:
measurement_bases = [
    ['Z'] * num_qubits,
    ['X'] * num_qubits,
    ['Y'] * num_qubits,
]

# Optional: rotated variants
for i in range(num_qubits):
    basis = ['X'] * num_qubits
    basis[i] = 'Y'
    measurement_bases.append(basis)

for i, basis in enumerate(measurement_bases):
    print(f"Basis {i:2d}: {''.join(basis)}")

Basis  0: ZZZZ
Basis  1: XXXX
Basis  2: YYYY
Basis  3: YXXX
Basis  4: XYXX
Basis  5: XXYX
Basis  6: XXXY


In [5]:
file_name = "ghz_unique_bases.txt"
with open(file_name, "w") as f:
    for basis in measurement_bases:
        basis_str = " ".join(basis) + " \n"
        f.write(basis_str)

In [6]:
samples_per_basis = 5000
rng_samples = random.PRNGKey(rng_seed)

value_file_name = "ghz_meas_values.txt"
basis_file_name = "ghz_meas_bases.txt"

with open(value_file_name, "w") as f_meas, open(basis_file_name, "w") as f_basis:
    for _, measurement_dirs in enumerate(measurement_bases):
        basis_str = " ".join(measurement_dirs) + " \n"
        print(f"Sampling {samples_per_basis} samples for basis {''.join(measurement_dirs)}...")

        measurement = MultiQubitMeasurement(measurement_dirs)

        start = time()
        samples = measurement.sample_state(w_aug, samples_per_basis, rng=rng_samples)
        print(f"  Done in {time() - start:.2f} seconds.")

        for bitstring in samples:
            bits_str = " ".join(map(str, bitstring)) + " \n"
            f_meas.write(bits_str)
            f_basis.write(basis_str)

Sampling 5000 samples for basis ZZZZ...


Constructing basis ZZZZ: 100%|██████████| 16/16 [00:00<00:00, 305.36it/s]


  Done in 0.28 seconds.
Sampling 5000 samples for basis XXXX...


Constructing basis XXXX: 100%|██████████| 16/16 [00:00<00:00, 3557.89it/s]


  Done in 0.00 seconds.
Sampling 5000 samples for basis YYYY...


Constructing basis YYYY: 100%|██████████| 16/16 [00:00<00:00, 3377.74it/s]

  Done in 0.00 seconds.





Sampling 5000 samples for basis YXXX...


Constructing basis YXXX: 100%|██████████| 16/16 [00:00<00:00, 3381.14it/s]


  Done in 0.00 seconds.
Sampling 5000 samples for basis XYXX...


Constructing basis XYXX: 100%|██████████| 16/16 [00:00<00:00, 3204.51it/s]

  Done in 0.00 seconds.





Sampling 5000 samples for basis XXYX...


Constructing basis XXYX: 100%|██████████| 16/16 [00:00<00:00, 3583.54it/s]

  Done in 0.00 seconds.





Sampling 5000 samples for basis XXXY...


Constructing basis XXXY: 100%|██████████| 16/16 [00:00<00:00, 2930.52it/s]

  Done in 0.00 seconds.



