In [4]:
file_names = [
    "tfim_h1.00_3x3_10000.txt",
    "tfim_h2.00_3x3_10000.txt",
    "tfim_h2.80_3x3_10000.txt",
    "tfim_h3.00_3x3_10000.txt",
    "tfim_h3.30_3x3_10000.txt",
    "tfim_h3.60_3x3_10000.txt",
    "tfim_h4.00_3x3_10000.txt",
    "tfim_h5.00_3x3_10000.txt",
    "tfim_h6.00_3x3_10000.txt",
    "tfim_h7.00_3x3_10000.txt",
]

In [5]:
# full Jupyter cell: load, stack, create loader, sample first batch

from pathlib import Path

import jax
import jax.numpy as jnp
from jax.random import PRNGKey

from lib.data_loading import load_measurements, MixedDataLoader  # adjust import path

# 1) specify your data directory and files
data_dir = Path("./data")

# 2) load & concatenate all samples
bits_list = []
field_list = []
for fn in file_names:
    bits, field = load_measurements(data_dir / fn)
    bits_list.append(bits)   # (N, num_qubits)
    field_list.append(field) # (N,)

all_bits  = jnp.concatenate(bits_list,  axis=0)  # (2*N, num_qubits)
all_field = jnp.concatenate(field_list, axis=0)  # (2*N,)

# 3) create the mixed data loader
batch_size = 128
rng_key    = PRNGKey(42)


loader     = MixedDataLoader(
    bits=all_bits,
    field=all_field,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    rng=rng_key,
)

# 4) grab the very first batch via next()
it = iter(loader)
batch_bits, batch_field = next(it)

# 5) inspect shapes
print("batch_bits.shape: ", batch_bits.shape)   # → (128, num_qubits)
print("batch_field.shape:", batch_field.shape)  # → (128,)


Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 17090.08it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 19282.69it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 19348.87it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 19404.55it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 19595.13it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 18759.13it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 18597.63it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 19668.70it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 19663.98it/s]
Parsing measurements: 100%|██████████| 10000/10000 [00:00<00:00, 19693.08it/s]


batch_bits.shape:  (128, 9)
batch_field.shape: (128,)


In [6]:
# print first 10 field values
print("First 10 field values:", batch_field[:10].tolist())

First 10 field values: [3.5999999046325684, 2.0, 6.0, 2.0, 3.299999952316284, 1.0, 5.0, 1.0, 3.5999999046325684, 5.0]
