# Exercise 2 — Lossless Audio Compression with Rice Coding

**Goal.** Build a bit-exact WAV to EX2 encoder/decoder that:
1) predicts samples (fixed linear prediction),  
2) ZigZag maps signed residuals to unsigned,  
3) compresses with Rice coding,  
4) decodes back to a WAV identical to the original.

A report is created of the compression using K = 2 and K = 4 and the application guarantees that no expansion occurs on hard-to-compress audio.


## Environment, I/O & Constraints

- **Language/Libraries:** Python, `numpy`, `wave`, `struct`, `pandas`.
- **Input:** 16-bit PCM WAV only (mono or stereo).  
  This is asserted in `read_wav(...)`.
- **Output container (.ex2):** tiny header + payload (bitstream or verbatim PCM, depending on mode).
- **Interleaving:** Samples are processed frame-wise (e.g., stereo L0,R0,L1,R1,…). This is reshaped to `[frames, nch]` when predicting.


### BitWriter and BitReader Classes

#### BitWriter Class
The BitWriter class allows writing individual bits to a binary file, which normally only accepts full bytes. It buffers bits until it collects 8 of them, then writes a full byte to the file.

- `self.buf`: An 8-bit buffer that stores bits temporarily.
- `self.nbits`: Tracks how many bits are currently stored.
- `_putbit(b)`: Adds a single bit (0 or 1) to the buffer. Once the buffer has 8 bits, it writes them as a byte.
- `write_unary(q)`: Writes `q` number of 1s followed by a 0 for unary encoding.
- `write_kbits(val, k)`: Writes the `k`-bit binary representation of `val`, from most to least significant bit.
- `flush_to_byte()`: Pads any remaining bits with zeros and writes them as a final byte.
- `close()`: Calls `flush_to_byte()` to finalize writing.

#### BitReader Class
The BitReader class reads bits one at a time from a binary file. It loads bytes as needed and allows bit-level access.

- `self.cur`: The current byte being read.
- `self.left`: Number of unread bits left in `self.cur`.
- `_fill()`: Loads the next byte from the file and resets the counter.
- `getbit()`: Returns the next bit by shifting and masking `self.cur`.
- `getbits(k)`: Returns `k` bits as an integer, bit by bit.
- `align_to_byte()`: Skips any unread leftover bits to realign to the next byte.
- `read_bytes(n)`: Reads `n` bytes directly after byte alignment.

The classes are used for efficient bit-level operations in streaming data compression, such as Rice coding, and used when encoding and decoding the block-wise audio data.


In [1]:
import os, wave, struct
import numpy as np
import pandas as pd

# ---- Bit I/O helpers for streaming mixed blocks ----
class BitWriter:
    def __init__(self, f):
        self.f, self.buf, self.nbits = f, 0, 0
    def _putbit(self, b):
        self.buf = (self.buf << 1) | (1 if b else 0); self.nbits += 1
        if self.nbits == 8:
            self.f.write(bytes([self.buf])); self.buf = 0; self.nbits = 0
    def write_unary(self, q):
        for _ in range(q): self._putbit(1)
        self._putbit(0)
    def write_kbits(self, val, k):
        for i in range(k-1, -1, -1): self._putbit((val>>i) & 1)
    def flush_to_byte(self):
        if self.nbits:
            self.f.write(bytes([self.buf << (8 - self.nbits)]))
            self.buf = 0; self.nbits = 0
    def close(self): self.flush_to_byte()

class BitReader:
    def __init__(self, f):
        self.f, self.cur, self.left = f, 0, 0
    def _fill(self):
        b = self.f.read(1)
        if not b: raise EOFError("BitReader exhausted")
        self.cur = b[0]; self.left = 8
    def getbit(self):
        if self.left == 0: self._fill()
        self.left -= 1
        return (self.cur >> self.left) & 1
    def getbits(self, k):
        v = 0
        for _ in range(k): v = (v<<1) | self.getbit()
        return v
    def align_to_byte(self):
        if self.left and self.left != 8:
            self.left = 0  # drop remaining bits in this byte
    def read_bytes(self, n):
        self.align_to_byte()
        return self.f.read(n)

ModuleNotFoundError: No module named 'pandas'

### Rice Coding Helper Functions

The helper functions implement Rice coding using `BitWriter` and `BitReader`.

---

#### rice_write_num(bw, s, K)
This function encodes a signed 16-bit integer `s` using Rice coding with bit parameter `K` and writes the result to the `BitWriter`:

1. `M = 1 << K`: Compute Rice divisor `M = 2^K`.
2. `u = _zz16_enc(s)`: Convert signed int to unsigned using ZigZag encoding.
3. `q, r = divmod(u, M)`: Compute quotient and remainder of `u / M`.
4. `bw.write_unary(q)`: Write `q` using unary coding.
5. `bw.write_kbits(r, K)`: Write the remainder `r` using `K` bits.

---

#### rice_read_num(br, K)
This function reads a Rice-coded number from the `BitReader` using parameter `K`, and decodes it back into a signed int:

1. Initialize `q = 0`.
2. While reading bits, count the number of `1`s until the first `0`. This is the unary-encoded `q`.
3. Read the next `K` bits as the remainder `r`.
4. Reconstruct the encoded value uaing `u = q * 2^K + r`.
5. Apply the ZigZag decoding to return the original signed value.

---

#### _estimate_rice_bits_block(block_int16, K)
This function estimates the total number of bits needed to Rice-encode an array of int16 samples with parameter `K`:

1. Convert `block_int16` to 32-bit integers for safe processing.
2. Apply ZigZag encoding to make all numbers non-negative.
3. For each value `u`, estimate required bits as: `bits = q + 1 + K = (u >> K) + 1 + K`
- `q = u // (1 << K)` estimated as `u >> K`
- `+1` is for the terminating `0` in unary code
- `+K` is for the fixed-length remainder

This function is useful for deciding whether Rice coding is more efficient than storing the block in raw 16-bit form.

---

The functions are used to implement efficient, block-wise and lossless audio compression using Rice coding.

In [None]:
# ---- Rice helpers ----
def rice_write_num(bw: BitWriter, s: int, K: int):
    M = 1 << K
    u = _zz16_enc(s)
    q, r = divmod(u, M)
    bw.write_unary(q); bw.write_kbits(r, K)

def rice_read_num(br: BitReader, K: int):
    # unary q
    q = 0
    while br.getbit() == 1: q += 1
    # the '0' has just been consumed
    r = br.getbits(K)
    return _zz16_dec(q*(1<<K) + r)

def _estimate_rice_bits_block(block_int16: np.ndarray, K: int) -> int:
    x = block_int16.astype(np.int32)
    u = ((x << 1) ^ (x >> 15)) & 0xFFFF
    return int(np.sum((u >> K) + 1 + K))  # unary(q)=q+1, plus K remainder bits

# ---- Mixed per-block writer (Rice or verbatim residuals) ----
# File header stays <4sBI> but we set K=250 to signal "mixed blocks" and
# immediately store one extra byte: K_global.
def _write_ex2_mixed_blocks(path, K_global, nvals, resid, blk_vals):
    with open(path, "wb") as f:
        # header: magic, K=250 (mixed), total nvals
        f.write(struct.pack(_HDR, b"EX2\0", 250, nvals))
        f.write(bytes([K_global]))  # extra byte after header

        bw = BitWriter(f)
        total_blocks = (nvals + blk_vals - 1) // blk_vals
        done_blocks = 0

        for start in range(0, nvals, blk_vals):
            blk = resid[start:start+blk_vals]
            # decide Rice vs verbatim
            rice_bits = _estimate_rice_bits_block(blk, K_global)
            raw_bits  = blk.size * 16
            mode = 1 if rice_bits < raw_bits else 0  # 1=Rice, 0=Verbatim

            # block header (byte-aligned): mode (uint8), nvals_block (uint32)
            bw.flush_to_byte()
            f.write(struct.pack("<BI", mode, blk.size))

            if mode == 1:
                # Rice-code this block
                for s in blk:
                    rice_write_num(bw, int(s), K_global)
            else:
                # Verbatim residuals (int16), byte-aligned
                bw.flush_to_byte()
                f.write(np.asarray(blk, np.int16).tobytes())

            done_blocks += 1
            if (done_blocks % 10 == 0) or (done_blocks == total_blocks):
                print(f"Encoding progress: {done_blocks}/{total_blocks} "
                      f"({done_blocks/total_blocks*100:.1f}%)", flush=True)
        bw.close()
        
def encode_wav_to_ex2(wav_path, K=4, block_size=4096, order=1):
    params, data = read_wav(wav_path)
    nch, sw, fr, nframes, ct, cn = params

    frames = data.reshape(-1, nch)
    if nch == 2:
        frames = to_mid_side(frames)
    inter = frames.reshape(-1).astype(np.int16, copy=False)

    # residuals
    resid = fixed_pred_residual_interleaved(inter, nch, order=order)

    # ✅ define blk_vals before using it
    blk_vals = block_size * nch

    root, _ = os.path.splitext(wav_path)
    outp = f"{root}_K{K}_ord{order}_Enc.ex2"
    tmp  = outp + ".tmp"

    print("Writing .ex2...", flush=True)
    _write_ex2_mixed_blocks(...)


    # final non-expanding safeguard vs original WAV size
    if os.path.getsize(tmp) > os.path.getsize(wav_path):
        print("Mixed blocks expanded vs WAV; rewriting as whole-file verbatim.", flush=True)
        _write_ex2_verbatim(outp, inter.size, inter)  # K=255 path
        os.remove(tmp)
        return outp, params

    os.replace(tmp, outp)
    return outp, params

# ---- Decoder: handle K==250 mixed blocks ----
def decode_ex2_to_wav(ex2_path, params_hint, order=1):
    with open(ex2_path, "rb") as f:
        magic, K, nvals = struct.unpack(_HDR, f.read(_HSZ))
        assert magic == b"EX2\0"
        if K == 255:
            # whole-file verbatim (original path)
            payload = f.read()
            inter = np.frombuffer(payload, dtype=np.int16, count=nvals)
            frames = inter.reshape(-1, params_hint.nchannels)
            if params_hint.nchannels==2:
                frames = from_mid_side(frames)
            out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
            write_wav(out, params_hint, frames.reshape(-1))
            return out

        if K == 250:
            # mixed-blocks (Rice-or-raw residuals)
            K_global = f.read(1)[0]
            br = BitReader(f)
            vals = np.empty(nvals, dtype=np.int16)
            j = 0
            while j < nvals:
                # read block header
                br.align_to_byte()
                mode = br.read_bytes(1)[0]
                blk_n = struct.unpack("<I", br.read_bytes(4))[0]

                if mode == 1:  # Rice
                    for _ in range(blk_n):
                        vals[j] = rice_read_num(br, K_global); j += 1
                else:          # Verbatim residuals
                    raw = br.read_bytes(blk_n * 2)
                    vals[j:j+blk_n] = np.frombuffer(raw, dtype=np.int16, count=blk_n); j += blk_n

            resid = vals
            nch = params_hint.nchannels
            inter = fixed_pred_inverse_interleaved(resid, nch, order=order)
            frames = inter.reshape(-1, nch)
            if nch==2:
                frames = from_mid_side(frames)
            out = os.path.splitext(ex2_path)[0] + "_Dec.wav"
            write_wav(out, params_hint, frames.reshape(-1))
            return out

# Signed<->unsigned (ZigZag) for int16 residuals
def _zz16_enc(x):
    s = int(np.int16(x))
    return ((s<<1) ^ (s>>15)) & 0xFFFF
def _zz16_dec(u):
    s = (int(u)>>1) ^ -(int(u)&1)
    return np.int16(s)

# --------------------------
# WAV I/O (16‑bit PCM only)
# --------------------------
def read_wav(path):
    with wave.open(path, "rb") as wf:
        params = wf.getparams()
        nch, sw, fr, nframes, ct, cn = params
        assert sw==2 and ct=="NONE", "Expect 16‑bit PCM WAV"
        data = np.frombuffer(wf.readframes(nframes), dtype=np.int16)
    return params, data

def write_wav(path, params, data_int16):
    nch, sw, fr, _, ct, cn = params
    with wave.open(path, "wb") as wf:
        wf.setnchannels(nch); wf.setsampwidth(sw); wf.setframerate(fr)
        wf.setcomptype(ct, cn); wf.writeframes(np.asarray(data_int16, np.int16).tobytes())

# --------------------------
# building blocks
# --------------------------

# 1) Blocking (works on interleaved samples). We’ll just stream blocks; no per‑block header needed here.
def _iter_blocks(arr, block_size):
    for i in range(0, len(arr), block_size):
        yield arr[i:i+block_size]

# 2) Optional stereo mid/side (lossless integer transform)
def to_mid_side(frames_int16):
    # frames shape: [num_frames, 2]
    L = frames_int16[:,0].astype(np.int32)
    R = frames_int16[:,1].astype(np.int32)
    M = ((L + R) >> 1).astype(np.int16)     # floor((L+R)/2)
    S = (L - R).astype(np.int16)            # difference
    return np.column_stack([M,S])

def from_mid_side(ms_int16):
    # inverse of the above (FLAC-style)
    M = ms_int16[:,0].astype(np.int32)
    S = ms_int16[:,1].astype(np.int32)
    L = (M + (S >> 1)).astype(np.int32)
    R = (L - S).astype(np.int32)
    return np.column_stack([np.int16(L), np.int16(R)])

# ---- Fixed predictor coefficients (FLAC-style) ----
# pred[n] = c1*x[n-1] + c2*x[n-2] + ... + cN*x[n-N]
# N=1: [1]
# N=2: [2,-1]
# N=3: [3,-3,1]
# N=4: [4,-6,4,-1]
from math import comb
def flac_fixed_coeffs(order:int):
    if order < 1: return [0]   # order 0 would be "no prediction"; not used here
    return [ ((-1)**(k+1)) * comb(order, k) for k in range(1, order+1) ]

def fixed_pred_residual_interleaved(data_int16, nch, order=1):
    frames = data_int16.reshape(-1, nch).astype(np.int16, copy=False)
    coeffs = flac_fixed_coeffs(order)
    p = order
    resid = np.empty_like(frames, dtype=np.int16)

    # warm-up: first p samples per channel stored raw
    resid[:p, :] = frames[:p, :]

    x = frames.astype(np.int32)  # work in 32-bit to avoid overflow while predicting
    for n in range(p, len(frames)):
        pred = np.zeros((nch,), dtype=np.int32)
        for k, c in enumerate(coeffs, start=1):
            pred += c * x[n-k, :]
        # residual = x - pred (wrapped to int16 domain)
        resid[n, :] = np.int16(x[n, :] - pred)
    return resid.reshape(-1)

def fixed_pred_inverse_interleaved(resid_int16, nch, order=1):
    frames = resid_int16.reshape(-1, nch).astype(np.int16, copy=False)
    coeffs = flac_fixed_coeffs(order)
    p = order
    recon = np.empty_like(frames, dtype=np.int16)

    # warm-up: copy first p samples per channel
    recon[:p, :] = frames[:p, :]

    x = recon.astype(np.int32)
    for n in range(p, len(frames)):
        pred = np.zeros((nch,), dtype=np.int32)
        for k, c in enumerate(coeffs, start=1):
            pred += c * x[n-k, :]
        # x = pred + residual (wrapped to int16, then lifted back to int32 history)
        x[n, :] = np.int32(np.int16(pred + frames[n, :]))
        recon[n, :] = np.int16(x[n, :])
    return recon.reshape(-1)


# --------------------------
# Encode/decode (K fixed, block_size, mid/side option)
# --------------------------
        
def _write_ex2_verbatim(path, nvals, payload_int16):
    with open(path, "wb") as f:
        f.write(struct.pack(_HDR, b"EX2\0", 255, nvals))  # K=255 means verbatim
        f.write(np.asarray(payload_int16, np.int16).tobytes())

def _estimate_rice_bits(resid_int16, K):
    x = resid_int16.astype(np.int32)
    u = ((x << 1) ^ (x >> 15)) & 0xFFFF  # zigzag (vectorized for int16)
    return int(np.sum((u >> K) + 1 + K))  # unary(q)=q+1, plus K bits for remainder


# --------------------------
# Tiny driver to produce the table (K=4 and K=2)
# --------------------------
def run_one(wav_path, Ks=(4,2), block_size=4096):
    params,_ = read_wav(wav_path)
    row = {"File": os.path.basename(wav_path), "Original size": os.path.getsize(wav_path)}
    nm = 0
    for K in Ks:
        ex2, p = encode_wav_to_ex2(wav_path, K, block_size, order=4)
        dec = decode_ex2_to_wav(ex2, p, order=4)
        nm+=1
        print(nm)
        print(wav_path + " " + str(K))
        # verify lossless
        _, a = read_wav(wav_path)
        _, b = read_wav(dec)
        assert np.array_equal(a, b), f"Round‑trip mismatch for {wav_path} (K={K})"
        row[f"Rice (K={K})"] = os.path.getsize(ex2)
        row[f"% Compression (K={K})"] = 100.0*(row["Original size"]-row[f"Rice (K={K})"])/row["Original size"]
    return row

# Example batch (put Sound1.wav / Sound2.wav next to the notebook)
targets = ["Sound1.wav","Sound2.wav"]
results=[]
for t in targets:
    if os.path.exists(t):
        results.append(run_one(t, Ks=(4,2), block_size=4096)
    else:
        print(f"⚠️ Missing {t} — skipping")
df = pd.DataFrame(results, columns=["File","Original size","Rice (K=4)","Rice (K=2)","% Compression (K=4)","% Compression (K=2)"])
df