In [2]:
import os
import numpy as np
import re

import jax
import jax.lax
from jax.random import PRNGKey
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state

import functools

from pathlib import Path


import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches

data_dir = "./data"

print(f"Data resides in        : {data_dir}")

Data resides in        : ./data


In [6]:
class MultiBasisDataLoader:
    def __init__(self, data_dict: dict[str, jnp.ndarray],
                 batch_size: int = 128,
                 shuffle: bool = True,
                 drop_last: bool = False,
                 seed: int = 0):
        lengths = [len(v) for v in data_dict.values()]
        if len(set(lengths)) != 1:
            raise ValueError(f"All arrays must have the same length, got: {lengths}")

        self.data = data_dict
        self.n = lengths[0]
        self.bs = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.rng = np.random.default_rng(seed)

        self.idx_slices = [
            (i, i + batch_size)
            for i in range(0, self.n, batch_size)
            if not drop_last or i + batch_size <= self.n
        ]

    def __iter__(self):
        self.order = np.arange(self.n)
        if self.shuffle:
            self.rng.shuffle(self.order)
        self.slice_idx = 0
        return self

    def __next__(self):
        if self.slice_idx >= len(self.idx_slices):
            raise StopIteration
        s, e = self.idx_slices[self.slice_idx]
        self.slice_idx += 1
        return {k: v[self.order[s:e]] for k, v in self.data.items()}


def load_measurements(folder: str, file_pattern: str = "w_*.txt") -> dict[str, jnp.ndarray]:
    out: dict[str, jnp.ndarray] = {}

    for fp in Path(folder).glob(file_pattern):
        basis = fp.stem.split("_")[2]

        bitstrings = []
        with fp.open() as f:
            for line in f:
                bitstring = np.fromiter((c.islower() for c in line.strip()), dtype=np.float32)
                bitstrings.append(bitstring)

        arr = jnp.asarray(np.stack(bitstrings))
        if basis in out:
            out[basis] = jnp.concatenate([out[basis], arr], axis=0)
        else:
            out[basis] = arr

    return out

In [7]:
data_dict = load_measurements("data/", "w_*.txt")

# 2. Create two basis key groups (regex matched)
amp_keys = [k for k in data_dict if re.fullmatch(r"^Z+$", k)]
pha_keys = [k for k in data_dict if re.fullmatch(r"^(?!Z+$).*", k)]

amp_dict = {k: data_dict[k] for k in amp_keys}
pha_dict = {k: data_dict[k] for k in pha_keys}

amp_loader           = MultiBasisDataLoader(amp_dict, batch_size=128)
pha_loader           = MultiBasisDataLoader(pha_dict, batch_size=128)

In [22]:
# iterate in parallel and print the keys
for amp_batch, pha_batch in zip(amp_loader, pha_loader):
    print(f"Amplitude batch keys: {list(amp_batch.keys())}")
    amp_key, amp_val = next(iter(amp_batch.items()))
    print(f"Amp[{amp_key}] shape: {amp_val.shape}, dtype: {amp_val.dtype}")

    print(f"Phase batch keys: {list(pha_batch.keys())}")
    
    
    #phase_Sample_count = pha_batch['ZZ'].shape[0]
    break  # just to show the first batch

Amplitude batch keys: ['ZZZZZZZZZZZZZZZ']
Amp[ZZZZZZZZZZZZZZZ] shape: (128, 15), dtype: float32
Phase batch keys: ['ZZZZZZXYZZZZZZZ', 'ZZZZZZZZZXXZZZZ', 'ZZZZZZZZZZZZZXX', 'ZZZZZZZZZZZZZXY', 'ZZZZZZZZZZXXZZZ', 'ZZZZZZZZZZZXYZZ', 'ZZZXYZZZZZZZZZZ', 'ZZZZZZZZZXYZZZZ', 'XXZZZZZZZZZZZZZ', 'ZZZZZZXXZZZZZZZ', 'ZZZZZZZZXYZZZZZ', 'ZZXYZZZZZZZZZZZ', 'ZZZZXYZZZZZZZZZ', 'ZZXXZZZZZZZZZZZ', 'ZZZZZZZXXZZZZZZ', 'ZZZZZZZZXXZZZZZ', 'ZZZZZXYZZZZZZZZ', 'ZZZZXXZZZZZZZZZ', 'ZZZZZZZZZZXYZZZ', 'ZXXZZZZZZZZZZZZ', 'ZZZZZZZZZZZZXYZ', 'ZZZZZZZXYZZZZZZ', 'ZZZZZZZZZZZXXZZ', 'XYZZZZZZZZZZZZZ', 'ZXYZZZZZZZZZZZZ', 'ZZZZZXXZZZZZZZZ', 'ZZZZZZZZZZZZXXZ', 'ZZZXXZZZZZZZZZZ']
