In [1]:
import os, wave, struct
import numpy as np
import pandas as pd

# class BitWriter:
#     def __init__(self, fileobj):
#         self.f = fileobj
#         self.buf = 0
#         self.nbits = 0
#     def _putbit(self, b):
#         self.buf = (self.buf << 1) | (1 if b else 0)
#         self.nbits += 1
#         if self.nbits == 8:
#             self.f.write(bytes([self.buf]))
#             self.buf = 0
#             self.nbits = 0
#     def write_unary(self, q):
#         for _ in range(q): self._putbit(1)
#         self._putbit(0)
#     def write_kbits(self, val, k):
#         for i in range(k - 1, -1, -1):
#             self._putbit((val >> i) & 1)
#     def close(self):
#         if self.nbits:
#             self.f.write(bytes([self.buf << (8 - self.nbits)]))
#             self.buf = 0
#             self.nbits = 0

# ---- Bit I/O helpers for streaming mixed blocks ----
class BitWriter:
    def __init__(self, f):
        self.f, self.buf, self.nbits = f, 0, 0
    def _putbit(self, b):
        self.buf = (self.buf << 1) | (1 if b else 0); self.nbits += 1
        if self.nbits == 8:
            self.f.write(bytes([self.buf])); self.buf = 0; self.nbits = 0
    def write_unary(self, q):
        for _ in range(q): self._putbit(1)
        self._putbit(0)
    def write_kbits(self, val, k):
        for i in range(k-1, -1, -1): self._putbit((val>>i) & 1)
    def flush_to_byte(self):
        if self.nbits:
            self.f.write(bytes([self.buf << (8 - self.nbits)]))
            self.buf = 0; self.nbits = 0
    def close(self): self.flush_to_byte()

class BitReader:
    def __init__(self, f):
        self.f, self.cur, self.left = f, 0, 0
    def _fill(self):
        b = self.f.read(1)
        if not b: raise EOFError("BitReader exhausted")
        self.cur = b[0]; self.left = 8
    def getbit(self):
        if self.left == 0: self._fill()
        self.left -= 1
        return (self.cur >> self.left) & 1
    def getbits(self, k):
        v = 0
        for _ in range(k): v = (v<<1) | self.getbit()
        return v
    def align_to_byte(self):
        if self.left and self.left != 8:
            self.left = 0  # drop remaining bits in this byte
    def read_bytes(self, n):
        self.align_to_byte()
        return self.f.read(n)

# ---- Rice helpers (streaming versions) ----
def rice_write_num(bw: BitWriter, s: int, K: int):
    M = 1 << K
    u = _zz16_enc(s)
    q, r = divmod(u, M)
    bw.write_unary(q); bw.write_kbits(r, K)

def rice_read_num(br: BitReader, K: int):
    # unary q
    q = 0
    while br.getbit() == 1: q += 1
    # the '0' has just been consumed
    r = br.getbits(K)
    return _zz16_dec(q*(1<<K) + r)

def _estimate_rice_bits_block(block_int16: np.ndarray, K: int) -> int:
    x = block_int16.astype(np.int32)
    u = ((x << 1) ^ (x >> 15)) & 0xFFFF
    return int(np.sum((u >> K) + 1 + K))  # unary(q)=q+1, plus K remainder bits

# ---- NEW: mixed per-block writer (Rice or verbatim residuals) ----
# File header stays <4sBI> but we set K=250 to signal "mixed blocks" and
# immediately store one extra byte: K_global.
def _write_ex2_mixed_blocks_streaming(path, K_global, nvals, resid, blk_vals, log_every_blocks=10):
    with open(path, "wb") as f:
        # header: magic, K=250 (mixed), total nvals
        f.write(struct.pack(_HDR, b"EX2\0", 250, nvals))
        f.write(bytes([K_global]))  # extra byte after header

        bw = BitWriter(f)
        total_blocks = (nvals + blk_vals - 1) // blk_vals
        done_blocks = 0

        for start in range(0, nvals, blk_vals):
            blk = resid[start:start+blk_vals]
            # decide Rice vs verbatim
            rice_bits = _estimate_rice_bits_block(blk, K_global)
            raw_bits  = blk.size * 16
            mode = 1 if rice_bits < raw_bits else 0  # 1=Rice, 0=Verbatim

            # block header (byte-aligned): mode (uint8), nvals_block (uint32)
            bw.flush_to_byte()
            f.write(struct.pack("<BI", mode, blk.size))

            if mode == 1:
                # Rice-code this block
                for s in blk:
                    rice_write_num(bw, int(s), K_global)
            else:
                # Verbatim residuals (int16), byte-aligned
                bw.flush_to_byte()
                f.write(np.asarray(blk, np.int16).tobytes())

            done_blocks += 1
            if (done_blocks % log_every_blocks == 0) or (done_blocks == total_blocks):
                print(f"Encoding progress: {done_blocks}/{total_blocks} "
                      f"({done_blocks/total_blocks*100:.1f}%)", flush=True)
        bw.close()
def encode_wav_to_ex2_FLAClite(wav_path, K=4, block_size=4096, use_mid_side=True,
                               order=1, blockwise_adapt=True, guarantee_nonexpanding=True):
    params, data = read_wav(wav_path)
    nch, sw, fr, nframes, ct, cn = params

    frames = data.reshape(-1, nch)
    if nch == 2 and use_mid_side:
        frames = to_mid_side(frames)
    inter = frames.reshape(-1).astype(np.int16, copy=False)

    # residuals
    resid = fixed_pred_residual_interleaved(inter, nch, order=order)

    # ✅ define blk_vals before using it
    blk_vals = block_size * nch

    root, _ = os.path.splitext(wav_path)
    outp = f"{root}_K{K}_ord{order}_Enc.ex2"
    tmp  = outp + ".tmp"

    if blockwise_adapt:
        print("Writing .ex2 (mixed blocks: Rice-or-raw per block)...", flush=True)
        _write_ex2_mixed_blocks_streaming(tmp, K, len(resid), resid, blk_vals, log_every_blocks=10)
    else:
        print("Writing .ex2 (fixed Rice)...", flush=True)
        _write_ex2_streaming(
            tmp, K, len(resid),
            resid_iter=(int(v) for v in resid),
            log_blocks=blk_vals,
            total_blocks=(len(resid) + blk_vals - 1) // blk_vals
        )

    # final non-expanding safeguard vs original WAV size
    if guarantee_nonexpanding and os.path.getsize(tmp) > os.path.getsize(wav_path):
        print("Mixed blocks expanded vs WAV; rewriting as whole-file verbatim.", flush=True)
        _write_ex2_verbatim(outp, inter.size, inter)  # K=255 path
        os.remove(tmp)
        return outp, params

    os.replace(tmp, outp)
    return outp, params

# ---- Encoder: use mixed per-block mode ----
# def encode_wav_to_ex2_FLAClite(wav_path, K=4, block_size=4096, use_mid_side=True, order=1,
#                                blockwise_adapt=True, guarantee_nonexpanding=True):
#     params, data = read_wav(wav_path)
#     nch, sw, fr, nframes, ct, cn = params
#     frames = data.reshape(-1, nch)
#     if nch==2 and use_mid_side:
#         frames = to_mid_side(frames)
#     inter = frames.reshape(-1).astype(np.int16, copy=False)

#     # residuals
#     resid = fixed_pred_residual_interleaved(inter, nch, order=order)

    # root,_ = os.path.splitext(wav_path)
    # outp   = f"{root}_K{K}_ord{order}_Enc.ex2"

    # blk_vals = block_size * nch
    # if blockwise_adapt:
    #     print("Writing .ex2 (mixed blocks: Rice-or-raw per block)...", flush=True)
    #     _write_ex2_mixed_blocks_streaming(outp, K, len(resid), resid, blk_vals, log_every_blocks=10)
    #     print("Encode done.", flush=True)
    #     return outp, params

    # # (fallback: your previous fixed-K streaming path)
    # print("Writing .ex2 (fixed Rice)...", flush=True)
    # _write_ex2_streaming(
    #     outp, K, len(resid),
    #     resid_iter=(int(v) for v in resid),
    #     log_blocks=blk_vals,
    #     total_blocks=(len(resid)+blk_vals-1)//blk_vals
    # )
    # print("Encode done.", flush=True)
    
    # return outp, params
    # root,_ = os.path.splitext(wav_path)
    # outp   = f"{root}_K{K}_ord{order}_Enc.ex2"
    # tmp    = outp + ".tmp"

    # if blockwise_adapt:
    #     _write_ex2_mixed_blocks_streaming(tmp, K, len(resid), resid, blk_vals, log_every_blocks=10)
    # else:
    #     _write_ex2_streaming(tmp, K, len(resid), (int(v) for v in resid),
    #                          log_blocks=blk_vals, total_blocks=(len(resid)+blk_vals-1)//blk_vals)

    # # --- non-expanding guard vs original WAV size ---
    # if guarantee_nonexpanding:
    #     if os.path.getsize(tmp) > os.path.getsize(wav_path):
    #         print("Mixed blocks expanded vs WAV; rewriting as whole-file verbatim.")
    #         _write_ex2_verbatim(outp, inter.size, inter)   # K=255 path
    #         os.remove(tmp)
    #         return outp, params

    # os.replace(tmp, outp)
    # return outp, params

# ---- Decoder: handle K==250 mixed blocks ----
def decode_ex2_to_wav_FLAClite(ex2_path, params_hint, use_mid_side=True, order=1):
    with open(ex2_path, "rb") as f:
        magic, K, nvals = struct.unpack(_HDR, f.read(_HSZ))
        assert magic == b"EX2\0"
        if K == 255:
            # whole-file verbatim (original path)
            payload = f.read()
            inter = np.frombuffer(payload, dtype=np.int16, count=nvals)
            frames = inter.reshape(-1, params_hint.nchannels)
            if params_hint.nchannels==2 and use_mid_side:
                frames = from_mid_side(frames)
            out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
            write_wav(out, params_hint, frames.reshape(-1))
            return out

        if K == 250:
            # mixed-blocks (Rice-or-raw residuals)
            K_global = f.read(1)[0]
            br = BitReader(f)
            vals = np.empty(nvals, dtype=np.int16)
            j = 0
            while j < nvals:
                # read block header
                br.align_to_byte()
                mode = br.read_bytes(1)[0]
                blk_n = struct.unpack("<I", br.read_bytes(4))[0]

                if mode == 1:  # Rice
                    for _ in range(blk_n):
                        vals[j] = rice_read_num(br, K_global); j += 1
                else:          # Verbatim residuals
                    raw = br.read_bytes(blk_n * 2)
                    vals[j:j+blk_n] = np.frombuffer(raw, dtype=np.int16, count=blk_n); j += blk_n

            resid = vals
            nch = params_hint.nchannels
            inter = fixed_pred_inverse_interleaved(resid, nch, order=order)
            frames = inter.reshape(-1, nch)
            if nch==2 and use_mid_side:
                frames = from_mid_side(frames)
            out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
            write_wav(out, params_hint, frames.reshape(-1))
            return out

        # legacy fixed-K path
        payload = f.read()
    # For legacy fixed-K: unpack to bitstring then decode as before
    bitstr = _unpack_bits(payload)
    vals=[]; i=0
    for _ in range(nvals):
        v,i = rice_dec_num(bitstr, i, K)
        vals.append(v)
    resid = np.array(vals, dtype=np.int16)
    nch = params_hint.nchannels
    inter = fixed_pred_inverse_interleaved(resid, nch, order=order)
    frames = inter.reshape(-1, nch)
    if nch==2 and use_mid_side:
        frames = from_mid_side(frames)
    out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
    write_wav(out, params_hint, frames.reshape(-1))
    return out



# --------------------------
# Helpers: unary + k-bit I/O
# --------------------------
def _unary(q:int)->str: return "1"*q + "0"
def _u_decode(b,i):
    q=0; n=len(b)
    while i<n and b[i]=="1": q+=1; i+=1
    return q, i+1  # skip '0'

def _kbits(x,k):  return format(x, f"0{k}b")
def _from_k(b,i,k): return int(b[i:i+k],2), i+k

# Signed<->unsigned (ZigZag) for int16 residuals
def _zz16_enc(x):
    s = int(np.int16(x))
    return ((s<<1) ^ (s>>15)) & 0xFFFF
def _zz16_dec(u):
    s = (int(u)>>1) ^ -(int(u)&1)
    return np.int16(s)

# --------------------------
# Rice code for one number
# --------------------------
def rice_enc_num(s, K):
    M = 1<<K
    u = _zz16_enc(s)
    q, r = u//M, u%M
    return _unary(q) + _kbits(r, K)

def rice_dec_num(bitstr, i, K):
    M = 1<<K
    q, i = _u_decode(bitstr, i)
    r, i = _from_k(bitstr, i, K)
    return _zz16_dec(q*M + r), i

# Pack/unpack bitstring
def _pack_bits(bitstr:str)->bytes:
    out = bytearray(); cur=0; cnt=0
    for ch in bitstr:
        cur = (cur<<1) | (ch=="1")
        cnt+=1
        if cnt==8: out.append(cur); cur=0; cnt=0
    if cnt: out.append(cur<<(8-cnt))
    return bytes(out)

def _unpack_bits(bb:bytes)->str:
    return "".join(format(b, "08b") for b in bb)

# Tiny .ex2 header: [magic "EX2\0"][uint8 K][uint32 n_values]
_HDR = "<4sBI"; _HSZ = struct.calcsize(_HDR)
def _write_ex2(path, K, nvals, bitstr):
    with open(path, "wb") as f:
        f.write(struct.pack(_HDR, b"EX2\0", K, nvals))
        f.write(_pack_bits(bitstr))
def _read_ex2(path):
    with open(path, "rb") as f:
        magic,K,n = struct.unpack(_HDR, f.read(_HSZ))
        assert magic==b"EX2\0"
        payload=f.read()
    return K, n, _unpack_bits(payload)

# --------------------------
# WAV I/O (16‑bit PCM only)
# --------------------------
def read_wav(path):
    with wave.open(path, "rb") as wf:
        params = wf.getparams()
        nch, sw, fr, nframes, ct, cn = params
        assert sw==2 and ct=="NONE", "Expect 16‑bit PCM WAV"
        data = np.frombuffer(wf.readframes(nframes), dtype=np.int16)
    return params, data

def write_wav(path, params, data_int16):
    nch, sw, fr, _, ct, cn = params
    with wave.open(path, "wb") as wf:
        wf.setnchannels(nch); wf.setsampwidth(sw); wf.setframerate(fr)
        wf.setcomptype(ct, cn); wf.writeframes(np.asarray(data_int16, np.int16).tobytes())

# --------------------------
# FLAC‑lite building blocks
# --------------------------

# 1) Blocking (works on interleaved samples). We’ll just stream blocks; no per‑block header needed here.
def _iter_blocks(arr, block_size):
    for i in range(0, len(arr), block_size):
        yield arr[i:i+block_size]

# 2) Optional stereo mid/side (lossless integer transform)
def to_mid_side(frames_int16):
    # frames shape: [num_frames, 2]
    L = frames_int16[:,0].astype(np.int32)
    R = frames_int16[:,1].astype(np.int32)
    M = ((L + R) >> 1).astype(np.int16)     # floor((L+R)/2)
    S = (L - R).astype(np.int16)            # difference
    return np.column_stack([M,S])

def from_mid_side(ms_int16):
    # inverse of the above (FLAC-style)
    M = ms_int16[:,0].astype(np.int32)
    S = ms_int16[:,1].astype(np.int32)
    L = (M + (S >> 1)).astype(np.int32)
    R = (L - S).astype(np.int32)
    return np.column_stack([np.int16(L), np.int16(R)])

# 3) Fixed predictor (order 1): residual = x[n] - x[n-1] per channel
# def fixed_pred_residual_interleaved(data_int16, nch):
#     frames = data_int16.reshape(-1, nch).astype(np.int16, copy=False)
#     resid  = np.empty_like(frames)
#     resid[0,:]  = frames[0,:]
#     resid[1:,:] = (frames[1:,:].astype(np.int16) - frames[:-1,:].astype(np.int16))
#     return resid.reshape(-1)

# def fixed_pred_inverse_interleaved(resid_int16, nch):
#     frames = resid_int16.reshape(-1, nch).astype(np.int16, copy=False)
#     recon  = np.cumsum(frames.astype(np.int32), axis=0).astype(np.int16)
#     return recon.reshape(-1)
# ---- Fixed predictor coefficients (FLAC-style) ----
# pred[n] = c1*x[n-1] + c2*x[n-2] + ... + cN*x[n-N]
# N=1: [1]
# N=2: [2,-1]
# N=3: [3,-3,1]
# N=4: [4,-6,4,-1]
from math import comb
def flac_fixed_coeffs(order:int):
    if order < 1: return [0]   # order 0 would be "no prediction"; not used here
    return [ ((-1)**(k+1)) * comb(order, k) for k in range(1, order+1) ]

def fixed_pred_residual_interleaved(data_int16, nch, order=1):
    frames = data_int16.reshape(-1, nch).astype(np.int16, copy=False)
    coeffs = flac_fixed_coeffs(order)
    p = order
    resid = np.empty_like(frames, dtype=np.int16)

    # warm-up: first p samples per channel stored raw
    resid[:p, :] = frames[:p, :]

    x = frames.astype(np.int32)  # work in 32-bit to avoid overflow while predicting
    for n in range(p, len(frames)):
        pred = np.zeros((nch,), dtype=np.int32)
        for k, c in enumerate(coeffs, start=1):
            pred += c * x[n-k, :]
        # residual = x - pred (wrapped to int16 domain)
        resid[n, :] = np.int16(x[n, :] - pred)
    return resid.reshape(-1)

def fixed_pred_inverse_interleaved(resid_int16, nch, order=1):
    frames = resid_int16.reshape(-1, nch).astype(np.int16, copy=False)
    coeffs = flac_fixed_coeffs(order)
    p = order
    recon = np.empty_like(frames, dtype=np.int16)

    # warm-up: copy first p samples per channel
    recon[:p, :] = frames[:p, :]

    x = recon.astype(np.int32)
    for n in range(p, len(frames)):
        pred = np.zeros((nch,), dtype=np.int32)
        for k, c in enumerate(coeffs, start=1):
            pred += c * x[n-k, :]
        # x = pred + residual (wrapped to int16, then lifted back to int32 history)
        x[n, :] = np.int32(np.int16(pred + frames[n, :]))
        recon[n, :] = np.int16(x[n, :])
    return recon.reshape(-1)


# --------------------------
# Encode/decode (K fixed, block_size, mid/side option)
# --------------------------
# def encode_wav_to_ex2_FLAClite(wav_path, K=4, block_size=4096, use_mid_side=True):
#     params, data = read_wav(wav_path)
#     nch, sw, fr, nframes, ct, cn = params
#     frames = data.reshape(-1, nch)

#     if nch==2 and use_mid_side:
#         frames = to_mid_side(frames)
#     # back to interleaved for predictor
#     inter = frames.reshape(-1).astype(np.int16, copy=False)

#     # residuals (fixed predictor order-1)
#     resid = fixed_pred_residual_interleaved(inter, nch)

#     # Rice‑encode blockwise with fixed K (simplest). (We still just concatenate; decoder does same.)
#     bits = []
#     for blk in _iter_blocks(resid, block_size*nch):
#         # encode each residual in block
#         for v in blk:
#             bits.append(rice_enc_num(int(v), K))
#     bitstr = "".join(bits)

#     root,_ = os.path.splitext(wav_path)
#     outp   = f"{root}_K{K}_Enc.ex2"
#     _write_ex2(outp, K, len(resid), bitstr)
#     return outp, params

# def decode_ex2_to_wav_FLAClite(ex2_path, params_hint, use_mid_side=True):
#     K, nvals, bitstr = _read_ex2(ex2_path)
#     # Rice‑decode stream back to residuals
#     vals=[]; i=0
#     for _ in range(nvals):
#         v,i = rice_dec_num(bitstr, i, K)
#         vals.append(v)
#     resid = np.array(vals, dtype=np.int16)

#     nch = params_hint.nchannels
#     # invert fixed predictor
#     inter = fixed_pred_inverse_interleaved(resid, nch)

#     # back to (frames, nch)
#     frames = inter.reshape(-1, nch)
#     if nch==2 and use_mid_side:
#         frames = from_mid_side(frames)

#     wav_out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
#     write_wav(wav_out, params_hint, frames.reshape(-1))
#     return wav_out

def _write_ex2_streaming(path, K, nvals, resid_iter, log_blocks=None, total_blocks=None):
    with open(path, "wb") as f:
        f.write(struct.pack(_HDR, b"EX2\0", K, nvals))
        bw = BitWriter(f)
        M = 1 << K

        if log_blocks is not None and total_blocks is not None:
            blk = 0; in_blk = 0; block_size = log_blocks
        count = 0

        for s in resid_iter:
            u = _zz16_enc(int(s))
            q, r = divmod(u, M)
            bw.write_unary(q)
            bw.write_kbits(r, K)

            # optional block progress
            count += 1
            if log_blocks is not None and total_blocks is not None:
                in_blk += 1
                if in_blk == block_size:
                    blk += 1; in_blk = 0
                    print(f"Encoding progress: {(blk/total_blocks)*100:.1f}% ({blk}/{total_blocks})", flush=True)

        bw.close()
        
def _write_ex2_verbatim(path, nvals, payload_int16):
    with open(path, "wb") as f:
        f.write(struct.pack(_HDR, b"EX2\0", 255, nvals))  # K=255 means verbatim
        f.write(np.asarray(payload_int16, np.int16).tobytes())

def _estimate_rice_bits(resid_int16, K):
    x = resid_int16.astype(np.int32)
    u = ((x << 1) ^ (x >> 15)) & 0xFFFF  # zigzag (vectorized for int16)
    return int(np.sum((u >> K) + 1 + K))  # unary(q)=q+1, plus K bits for remainder


# def encode_wav_to_ex2_FLAClite(wav_path, K=4, block_size=4096, use_mid_side=True, order=1):
#     params, data = read_wav(wav_path)
#     nch, sw, fr, nframes, ct, cn = params
#     frames = data.reshape(-1, nch)
#     if nch==2 and use_mid_side:
#         frames = to_mid_side(frames)
#     inter = frames.reshape(-1).astype(np.int16, copy=False)

#     resid = fixed_pred_residual_interleaved(inter, nch, order=order)

#     bits = []
#     total_blocks = (len(resid) + block_size*nch - 1) // (block_size*nch)
#     for blk_i, blk in enumerate(_iter_blocks(resid, block_size*nch)):
#         for v in blk:
#             bits.append(rice_enc_num(int(v), K))
#         if blk_i % 5 == 0:
#             percent = blk_i / total_blocks * 100
#             print(f"Encoding progress: {percent:.1f}% ({blk_i}/{total_blocks})")
#     bitstr = "".join(bits)

#     root,_ = os.path.splitext(wav_path)
#     outp   = f"{root}_K{K}_ord{order}_Enc.ex2"
#     _write_ex2(outp, K, len(resid), bitstr)
#     return outp, params

# def encode_wav_to_ex2_FLAClite(wav_path, K=4, block_size=4096, use_mid_side=True, order=1):
#     params, data = read_wav(wav_path)
#     nch, sw, fr, nframes, ct, cn = params
#     frames = data.reshape(-1, nch)
#     if nch==2 and use_mid_side:
#         frames = to_mid_side(frames)
#     inter = frames.reshape(-1).astype(np.int16, copy=False)

#     # residuals
#     resid = fixed_pred_residual_interleaved(inter, nch, order=order)

#     # streaming write with progress every N residuals-as-a-block
#     total_blocks = (len(resid) + (block_size*nch) - 1) // (block_size*nch)
#     root,_ = os.path.splitext(wav_path)
#     outp   = f"{root}_K{K}_ord{order}_Enc.ex2"
#     print("Writing .ex2 (streaming)...", flush=True)
#     _write_ex2_streaming(
#         outp, K, len(resid),
#         resid_iter=(int(v) for v in resid),
#         log_blocks=block_size*nch,    # report every encode block
#         total_blocks=total_blocks
#     )
#     print("Encode done.", flush=True)
#     return outp, params

# def encode_wav_to_ex2_FLAClite(wav_path, K=4, block_size=4096, use_mid_side=True, order=1):
#     params, data = read_wav(wav_path)
#     nch, sw, fr, nframes, ct, cn = params
#     frames = data.reshape(-1, nch)
#     if nch==2 and use_mid_side:
#         frames = to_mid_side(frames)
#     inter = frames.reshape(-1).astype(np.int16, copy=False)

#     resid = fixed_pred_residual_interleaved(inter, nch, order=order)

#     # --- non-expanding guard ---
#     rice_bits = _estimate_rice_bits(resid, K)
#     raw_bits  = inter.size * 16
#     root,_ = os.path.splitext(wav_path)
#     outp = f"{root}_K{K}_ord{order}_Enc.ex2"
#     if rice_bits >= raw_bits:
#         print(f"⚠️ Rice(K={K}) would expand ({rice_bits/8:.0f}B > {raw_bits/8:.0f}B). Falling back to verbatim.")
#         _write_ex2_verbatim(outp, inter.size, inter)
#         return outp, params
#     # ---------------------------

#     # Otherwise, stream Rice as before
#     total_blocks = (len(resid) + (block_size*nch) - 1) // (block_size*nch)
#     print("Writing .ex2 (streaming)...", flush=True)
#     _write_ex2_streaming(
#         outp, K, len(resid),
#         resid_iter=(int(v) for v in resid),
#         log_blocks=block_size*nch,
#         total_blocks=total_blocks
#     )
#     print("Encode done.", flush=True)
#     return outp, params



# def decode_ex2_to_wav_FLAClite(ex2_path, params_hint, use_mid_side=True, order=1,
#                                log_interval=100_000):
#     K, nvals, bitstr = _read_ex2(ex2_path)

#     vals = np.empty(nvals, dtype=np.int16)
#     i = 0  # bit index
#     for j in range(nvals):
#         v, i = rice_dec_num(bitstr, i, K)
#         vals[j] = v
#         if (j+1) % log_interval == 0 or (j+1) == nvals:
#             pct = (j+1) / nvals * 100
#             print(f"Decoding: {pct:.1f}% ({j+1}/{nvals})")

#     resid = vals
#     nch = params_hint.nchannels
#     inter = fixed_pred_inverse_interleaved(resid, nch, order=order)
#     frames = inter.reshape(-1, nch)
#     if nch==2 and use_mid_side:
#         frames = from_mid_side(frames)

#     wav_out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
#     write_wav(wav_out, params_hint, frames.reshape(-1))
#     return wav_out

# def decode_ex2_to_wav_FLAClite(ex2_path, params_hint, use_mid_side=True, order=1):
#     with open(ex2_path, "rb") as f:
#         magic, K, nvals = struct.unpack(_HDR, f.read(_HSZ))
#         payload = f.read()
#     assert magic == b"EX2\0"

#     nch = params_hint.nchannels

#     if K == 255:
#         # verbatim payload of int16 samples AFTER optional mid/side
#         inter = np.frombuffer(payload, dtype=np.int16, count=nvals)
#         frames = inter.reshape(-1, nch)
#         if nch==2 and use_mid_side:
#             frames = from_mid_side(frames)
#         wav_out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
#         write_wav(wav_out, params_hint, frames.reshape(-1))
#         return wav_out

#     # Rice-coded path (unchanged)
#     bitstr = _unpack_bits(payload)
#     vals=[]; i=0
#     for _ in range(nvals):
#         v,i = rice_dec_num(bitstr, i, K)
#         vals.append(v)
#     resid = np.array(vals, dtype=np.int16)
#     inter = fixed_pred_inverse_interleaved(resid, nch, order=order)
#     frames = inter.reshape(-1, nch)
#     if nch==2 and use_mid_side:
#         frames = from_mid_side(frames)
#     wav_out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
#     write_wav(wav_out, params_hint, frames.reshape(-1))
#     return wav_out




# --------------------------
# Tiny driver to produce the table (K=4 and K=2)
# --------------------------
def run_one_FLAClite(wav_path, Ks=(4,2), block_size=4096, use_mid_side=True):
    params,_ = read_wav(wav_path)
    row = {"File": os.path.basename(wav_path), "Original size": os.path.getsize(wav_path)}
    nm = 0
    for K in Ks:
        ex2, p = encode_wav_to_ex2_FLAClite(wav_path, K, block_size, use_mid_side, order=4, blockwise_adapt=True, guarantee_nonexpanding=True)
        dec = decode_ex2_to_wav_FLAClite(ex2, p, use_mid_side, order=4)
        nm+=1
        print(nm)
        print(wav_path + " " + str(K))
        # verify lossless
        _, a = read_wav(wav_path)
        _, b = read_wav(dec)
        assert np.array_equal(a, b), f"Round‑trip mismatch for {wav_path} (K={K})"
        row[f"Rice (K={K})"] = os.path.getsize(ex2)
        row[f"% Compression (K={K})"] = 100.0*(row["Original size"]-row[f"Rice (K={K})"])/row["Original size"]
    return row

# Example batch (put Sound1.wav / Sound2.wav next to the notebook)
targets = ["Sound1.wav","Sound2.wav"]
results=[]
for t in targets:
    if os.path.exists(t):
        results.append(run_one_FLAClite(t, Ks=(4,2), block_size=4096, use_mid_side=True))
    else:
        print(f"⚠️ Missing {t} — skipping")
df = pd.DataFrame(results, columns=["File","Original size","Rice (K=4)","Rice (K=2)","% Compression (K=4)","% Compression (K=2)"])
df


Writing .ex2 (mixed blocks: Rice-or-raw per block)...
Encoding progress: 10/123 (8.1%)
Encoding progress: 20/123 (16.3%)
Encoding progress: 30/123 (24.4%)
Encoding progress: 40/123 (32.5%)
Encoding progress: 50/123 (40.7%)
Encoding progress: 60/123 (48.8%)
Encoding progress: 70/123 (56.9%)
Encoding progress: 80/123 (65.0%)
Encoding progress: 90/123 (73.2%)
Encoding progress: 100/123 (81.3%)
Encoding progress: 110/123 (89.4%)
Encoding progress: 120/123 (97.6%)
Encoding progress: 123/123 (100.0%)
1
Sound1.wav 4
Writing .ex2 (mixed blocks: Rice-or-raw per block)...
Encoding progress: 10/123 (8.1%)
Encoding progress: 20/123 (16.3%)
Encoding progress: 30/123 (24.4%)
Encoding progress: 40/123 (32.5%)
Encoding progress: 50/123 (40.7%)
Encoding progress: 60/123 (48.8%)
Encoding progress: 70/123 (56.9%)
Encoding progress: 80/123 (65.0%)
Encoding progress: 90/123 (73.2%)
Encoding progress: 100/123 (81.3%)
Encoding progress: 110/123 (89.4%)
Encoding progress: 120/123 (97.6%)
Encoding progress: 12

Unnamed: 0,File,Original size,Rice (K=4),Rice (K=2),% Compression (K=4),% Compression (K=2)
0,Sound1.wav,1002088,562910,682303,43.826291,31.911868
1,Sound2.wav,1008044,1008009,1008009,0.003472,0.003472
