In [21]:
import librosa
import numpy as np
import math

audio_path = "lib/audio/61-70968-0000.flac" 
chunks = []

CHUNK_DURATION_SECONDS = 1.0
TARGET_SAMPLE_RATE = 8000

# resample audio to mono
data_float, sr = librosa.load(audio_path, sr=TARGET_SAMPLE_RATE, mono=True)

if 'data_float' in locals():
    print(f"File loaded and resampled to {sr} Hz, Mono.")

    # convert from 32-bit float [-1.0, 1.0] to 8-bit unsigned [0, 255]
    data_8bit = ((data_float + 1.0) * 0.5 * 255).astype(np.uint8)
    
    print(f"Converted to 8-bit. Total bytes/samples: {len(data_8bit)}")

    samples_per_chunk = int(TARGET_SAMPLE_RATE * CHUNK_DURATION_SECONDS) 
    total_samples = len(data_8bit)
    num_chunks = math.ceil(total_samples / samples_per_chunk)

    print(f"Splitting into {num_chunks} chunks")

    for i in range(num_chunks):
        start_sample = i * samples_per_chunk
        end_sample = start_sample + samples_per_chunk
        audio_slice = data_8bit[start_sample:end_sample]
        frames_as_bytes = audio_slice.tobytes()
        chunks.append(frames_as_bytes)

print(f"Created {len(chunks)} chunks")
print(f"Bytes per chunk (approx): {len(chunks[0]) if chunks else 0}")

File loaded and resampled to 8000 Hz, Mono.
Converted to 8-bit. Total bytes/samples: 39240
Splitting into 5 chunks
Created 5 chunks
Bytes per chunk (approx): 8000


In [15]:
for chunk in chunks:
    print(chunk)

b'\xcf\xfe\xc5\xfe\xd7\xfe\xf3\xfe\xeb\xfe\xeb\xfe\xc6\xfe\xbd\xfe\x8f\xfet\xfer\xfe@\xfe:\xfe\x12\xfe\xe6\xfd\x03\xfe\xea\xfd\xda\xfd\xa3\xfd\xa7\xfd\xb4\xfd\xfb\xfd\xf4\xfd\xcb\xfd\xd5\xfd\xfa\xfd\xfc\xfd\x05\xfe\x1a\xfe\xf1\xfd\x0b\xfe\xd5\xfd\xac\xfd\xc1\xfd\xc1\xfd\xc5\xfd\xb8\xfd\x80\xfdr\xfd\x8c\xfd\x9a\xfd\x80\xfd\x9c\xfd\xaa\xfd\xbf\xfd\xbe\xfd\xb3\xfd\xa1\xfd\xb5\xfd\xc6\xfd\xbc\xfd\xa3\xfd\xe7\xfd\xe4\xfd\xce\xfd\xde\xfd\xb7\xfd\xb4\xfd\xdc\xfd)\xfeN\xfeI\xfe4\xfe&\xfe>\xfe1\xfe\x0b\xfe\'\xfeH\xfe5\xfe-\xfeM\xfe\x05\xfe\xf6\xfd\xe0\xfd\xff\xfd\x02\xfe,\xfe\x17\xfe\xf3\xfd\xf1\xfd\xe3\xfd\x07\xfe1\xfe:\xfeJ\xfeo\xfeS\xfeE\xfeC\xfe=\xfeZ\xfe~\xfer\xfe\x82\xfe{\xfe~\xfe\\\xfe_\xfeR\xfeM\xfeb\xfee\xfe_\xfe\x82\xfe\x93\xfed\xfe\\\xfeD\xfe4\xfe6\xfe4\xfe_\xfe\x83\xfe\x9e\xfe\xac\xfe\xbf\xfe\xb5\xfe\xc4\xfe\xd5\xfe\xbb\xfe\xb6\xfe\xfb\xfe)\xff7\xff"\xff#\xff4\xff)\xff(\xffD\xffv\xff\xb7\xff\xd7\xff\xab\xff\x7f\xff\x83\xff\xb0\xff\xda\xff\xc5\xff\xe8\xff\xf4\xff\xe1\xff\xdf\xff\xf3\

In [4]:
# Making the model

import torch

from lib.bgpt.config import *
from lib.bgpt.utils import bGPTLMHeadModel

from transformers import  GPT2Config


if torch.cuda.is_available():    
   device = torch.device("cuda")
else:
   device = torch.device("cpu")


patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS, 
                  max_length=PATCH_LENGTH, 
                  max_position_embeddings=PATCH_LENGTH,
                  hidden_size=HIDDEN_SIZE,
                  n_head=HIDDEN_SIZE//64,
                  vocab_size=1)

byte_config = GPT2Config(num_hidden_layers=BYTE_NUM_LAYERS, 
                  max_length=PATCH_SIZE+1, 
                  max_position_embeddings=PATCH_SIZE+1,
                  hidden_size=HIDDEN_SIZE,
                  n_head=HIDDEN_SIZE//64,
                  vocab_size=257)

model = bGPTLMHeadModel(patch_config, byte_config)


model_weights_path = "pretrained/weights-audio.pth"

checkpoint = torch.load(model_weights_path, map_location=torch.device(device))
model.load_state_dict(checkpoint['model'], strict=False)
model = model.to(device)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


bGPTLMHeadModel(
  (patch_level_decoder): PatchLevelDecoder(
    (patch_embedding): Linear(in_features=4112, out_features=768, bias=True)
    (base): GPT2Model(
      (wte): Embedding(1, 768)
      (wpe): Embedding(512, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05

In [5]:
import logging
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Iterator
from arithmetic_coder import arithmetic_coder, ac_utils

In [6]:

logger = logging.getLogger()
logger.setLevel(logging.INFO)

PATCH_SIZE = 16


class Metric:
    def __init__(self):
        self.total_length = 0
        self.compressed_length = 0

    def compute_ratio(self):
        if self.total_length != 0 and self.compressed_length != 0:
            return (
                self.total_length / self.compressed_length,
                self.compressed_length / self.total_length,
            )
        else:
            return 0, 0

    def accumulate(self, compressed, original):
        if isinstance(compressed, list):
            self.compressed_length += len(compressed)
        elif isinstance(compressed, int):
            self.compressed_length += compressed
        else:
            raise ValueError(f"Unsupported compressed length type: {type(compressed)}")

        if isinstance(original, list):
            self.total_length += len(original)
        elif isinstance(original, int):
            self.total_length += original
        else:
            raise ValueError(f"Unsupported original length type: {type(original)}")


def compress(compress_input, logits, metric):
    """
    :param compress_input: symbols to be compressed
    :param logits: generation probabilities from the model
    :param metric: compression metrics
    :return: compressed result, a floating number
    """
    output = []
    # Initialize a Encoder Object
    # Precision is for the encoder, not the model
    # You must have the same precision for encoder and decoder
    # Tricky things here: Though theoratically prefill == decode, but in practice there are numerical problems
    encoder = arithmetic_coder.Encoder(
        base=2,
        precision=64,
        output_fn=output.append,
    )
    # the first symbol should be saved for generation in decoding
    start_symbol = compress_input[:, :1]
    probs = logits.softmax(dim=-1).to(torch.float32)
    pd = torch.gather(probs, dim=-1, index=compress_input[:, 1:].unsqueeze(-1)).squeeze(
        -1
    )

    probs = np.vstack(probs.detach().cpu().numpy().squeeze())

    sequence_array = compress_input[:, 1:].detach().cpu().numpy().reshape(-1)

    pd = pd.squeeze()

    # compress the sequence
    for symbol, prob, pd_prob in zip(sequence_array, probs, pd):
        encoder.encode(
            ac_utils.normalize_pdf_for_arithmetic_coding(prob, np.float32), symbol
        )
    encoder.terminate()

    # to visualize and compute metrics, map to str
    compressed_bits = "".join(map(str, output))
    # you can only save in bytes, so need to pad some bits
    compressed_bytes, num_padded_bits = ac_utils.bits_to_bytes(compressed_bits)
    metric.accumulate(len(compressed_bytes) + num_padded_bits, len(sequence_array))

    compress_rate, compress_ratio = metric.compute_ratio()
    logger.info(f"compressed length: {metric.compressed_length}")
    logger.info(f"original length: {metric.total_length}")
    logger.info(f"compression ratio: {compress_ratio:.6f}")
    logger.info(f"compression rate: {compress_rate:.6f}")

    return compressed_bytes, num_padded_bits, start_symbol, sequence_array, pd, probs


def decode(
    compressed_bytes,
    num_padded_bits,
    model,
    start_symbol,
    device,
    original_seq_len,
    original_sequence=None,
    pd=None,
    probs=None,
    do_test=True,
):
    """

    :param compressed_bytes:  compressed data
    :param num_padded_bits:  padded bits
    :param model: same model as encoder
    :param start_symbol: first symbol to generate
    :param original_sequence: original symbol sequence, for testing purpose
    :param pd: actually not needed, used for testing
    :param probs:
    :param device:
    :return:
    """
    # convert bytes back to bit stream
    data_iter = iter(
        ac_utils.bytes_to_bits(compressed_bytes, num_padded_bits=num_padded_bits)
    )

    # utils function to read bits
    def _input_fn(bit_sequence: Iterator[str] = data_iter) -> int | None:
        try:
            return int(next(bit_sequence))
        except StopIteration:
            return None

    # initialize a Decoder Object
    decoder = arithmetic_coder.Decoder(
        base=2,
        precision=64,
        input_fn=_input_fn,
    )

    sequence_array_de = start_symbol.squeeze(0).detach().cpu().numpy()
    
    current_patch_num = 0

    start_patch_list = np.pad(sequence_array_de, (0, PATCH_SIZE - 1), 'constant', constant_values=0)
    start_patch_tensor = torch.tensor([start_patch_list], dtype=torch.long, device=device).unsqueeze(0) # [1, 1, S]

    with torch.no_grad():
        encoded_patches_history = model.patch_level_decoder(start_patch_tensor)["last_hidden_state"]

    tokens_in_current_patch = torch.tensor([model.special_token_id], device=device)

    for i in range(original_seq_len):
        with torch.no_grad():

            current_patch_feature = encoded_patches_history[0, -1, :]
            
            # 2. Call generation function
            prob_de_next = model.byte_level_decoder.generate(
                current_patch_feature, 
                tokens_in_current_patch
            ).cpu().numpy()
        

        # decode the next token using the arithmetic decoder
        de_token = decoder.decode(
            ac_utils.normalize_pdf_for_arithmetic_coding(prob_de_next, np.float32)
        )
        
        # add the decoded token to our sequences
        sequence_array_de = np.append(sequence_array_de, de_token)
        tokens_in_current_patch = torch.cat((
            tokens_in_current_patch, 
            torch.tensor([de_token], device=device)
        ), dim=0)

        # check if we just finished a patch
        # The byte_level_decoder input has a special token, so S+1
        if len(tokens_in_current_patch) == PATCH_SIZE + 1:
            print(f"--- Finished patch {current_patch_num}, starting next ---")

            # Get all bytes decoded so far
            current_len = len(sequence_array_de)
            num_to_pad = (PATCH_SIZE - (current_len % PATCH_SIZE)) % PATCH_SIZE
            padded_byte_list = np.pad(sequence_array_de, (0, num_to_pad), 'constant', constant_values=0)
            
            # Create 'patches' tensor
            patches_so_far = torch.tensor([padded_byte_list], dtype=torch.long, device=device).reshape(1, -1, PATCH_SIZE)
            
            with torch.no_grad():
                encoded_patches_history = model.patch_level_decoder(patches_so_far)["last_hidden_state"]
            
            # Reset the byte-level token buffer
            tokens_in_current_patch = torch.tensor([model.special_token_id], device=device)
            current_patch_num += 1

        if do_test:
            prob_orig_i = probs[i] 
            top_indices_de = prob_de_next.argsort()[-5:][::-1]
            top_indices = prob_orig_i.argsort()[-5:][::-1]
            original_token = original_sequence[i]
            
            if original_token != de_token:
                print(f"!!! MISMATCH AT INDEX {i} !!!")
                print(f"Should be: {original_token}, Got: {de_token}")
                import pdb
                pdb.set_trace()
    
    # Return decoded sequence
    final_decoded_tensor = torch.tensor(sequence_array_de, dtype=torch.long, device=device).unsqueeze(0)
    return final_decoded_tensor

In [7]:
def write_padded_bytes(filename: str, data: bytes, num_padded_bits: int, original_length: int):
    """
    file format:
    - 1 byte: number of padded bits (0-7)
    - 2 bytes: original length (max 65535)
    - 1 byte: start_symbol (0-255)
    - subsequent bytes: actual bytes data

    :param filename: output file name
    :param data: bytes data to write
    :param padding_bits: number of padded bits (must be between 0 and 7)
    :param original_length: original length of the uncompressed data (in tokens)
    """

    if not 0 <= num_padded_bits <= 7:
        raise ValueError("num_padded_bits must be between 0 and 7.")

    if not 0 <= original_length <= 65535:
        raise ValueError("original_length must be between 0 and 65535.")

    if not isinstance(data, bytes):
        raise TypeError("data must be of bytes type.")

    with open(filename, 'wb') as f:
        padding_byte = num_padded_bits.to_bytes(1, 'big')
        f.write(padding_byte)
        f.write(original_length.to_bytes(2, 'big'))
        f.write(data)

def read_padded_bytes(filename: str) -> tuple[bytes, int]:
    """
    Read data and padding bits from a file.

    :param filename: The name of the file to read.
    :return: A tuple containing (bytes data, number of padded bits).
             May raise an error if the file is empty or improperly formatted.
    """

    with open(filename, 'rb') as f:
        # the first byte indicates the number of padded bits
        padding_byte = f.read(1)

        # If the file is empty, f.read(1) will return an empty bytes object b''
        if not padding_byte:
            raise EOFError("File is empty or improperly formatted: unable to read padding bits byte.")

        original_length_bytes = f.read(2)
        if not original_length_bytes:
            raise EOFError("File is empty or improperly formatted: unable to read original length bytes.")
    
        padding_bits = int.from_bytes(padding_byte, 'big')
        original_length = int.from_bytes(original_length_bytes, 'big')

        data = f.read()
        
        return data, padding_bits, original_length

In [24]:
PATCH_SIZE = 16 

logger = logging.getLogger()
logger.setLevel(logging.INFO)

compression_metric = Metric()
print(f"\n=== Starting In-Memory Compression Test ===")

# Loop through each audio chunk
for i, chunk_bytes in enumerate(chunks):
    if not chunk_bytes:
        print(f"Skipping empty chunk {i}")
        continue

    byte_list = list(chunk_bytes)
    original_len = len(byte_list)
    
    # This tensor holds all the bytes for checking
    input_tensor = torch.tensor([byte_list], dtype=torch.long, device=device)

    print(f"--- Processing Chunk {i} ({original_len} bytes) ---")

    # --- Start of new autoregressive compression ---
    output = [] # holds compressed bits 
    encoder = arithmetic_coder.Encoder(
        base=2,
        precision=64,
        output_fn=output.append,
    )
    
    start_symbol_int = byte_list[0] # save first byte 
    start_symbol_tensor = torch.tensor([[start_symbol_int]], dtype=torch.long, device=device)
    
    # 'sequence_array' is the sequence we are compressing (all bytes after first)
    sequence_array = np.array(byte_list[1:])
    original_seq_len_to_compress = len(sequence_array)

    sequence_so_far = np.array([start_symbol_int])
    
    # Pad the start symbol to a full patch
    start_patch_list = np.pad(sequence_so_far, (0, PATCH_SIZE - 1), 'constant', constant_values=0)
    start_patch_tensor = torch.tensor([start_patch_list], dtype=torch.long, device=device).unsqueeze(0) # [1, 1, S]
    
    with torch.no_grad():
        encoded_patches_history = model.patch_level_decoder(start_patch_tensor)["last_hidden_state"]
        
    tokens_in_current_patch = torch.tensor([model.special_token_id], device=device)

    # Loop for each byte we need to compress
    for j in range(original_seq_len_to_compress):
        with torch.no_grad():
            current_patch_feature = encoded_patches_history[0, -1, :]
            
            # get the probability for the next byte
            prob_de_next = model.byte_level_decoder.generate(
                current_patch_feature, 
                tokens_in_current_patch
            ).cpu().numpy()

        # Get the (actual) next byte we are trying to compress
        symbol_to_compress = sequence_array[j]
        
        # encode that byte using the probabilities we just generated
        encoder.encode(
            ac_utils.normalize_pdf_for_arithmetic_coding(prob_de_next, np.float32), 
            symbol_to_compress
        )
        
        sequence_so_far = np.append(sequence_so_far, symbol_to_compress)
        tokens_in_current_patch = torch.cat((
            tokens_in_current_patch, 
            torch.tensor([symbol_to_compress], device=device)
        ), dim=0)

        # check if we just finished a patch
        if len(tokens_in_current_patch) == PATCH_SIZE + 1:
            current_len = len(sequence_so_far)
            num_to_pad = (PATCH_SIZE - (current_len % PATCH_SIZE)) % PATCH_SIZE
            padded_byte_list = np.pad(sequence_so_far, (0, num_to_pad), 'constant', constant_values=0)
            
            patches_so_far = torch.tensor([padded_byte_list], dtype=torch.long, device=device).reshape(1, -1, PATCH_SIZE)
            
            with torch.no_grad():
                encoded_patches_history = model.patch_level_decoder(patches_so_far)["last_hidden_state"]
            
            tokens_in_current_patch = torch.tensor([model.special_token_id], device=device)
    
    encoder.terminate()
    
    compressed_bits = "".join(map(str, output))
    compressed_bytes, num_padded_bits = ac_utils.bits_to_bytes(compressed_bits)
    
    compression_metric.accumulate(len(compressed_bytes), len(sequence_array))

    print(f"Chunk {i} compressed.")
    
    print(f"Decompressing and verifying chunk {i}...")
    decoded_tensor = decode(
        compressed_bytes=compressed_bytes,
        num_padded_bits=num_padded_bits,
        model=model,
        start_symbol=start_symbol_tensor,
        device=device,
        original_seq_len=len(sequence_array),
        original_sequence=sequence_array, 
        probs=None,                    
        do_test=False                     
    )
    
    print(f"Chunk {i} decompressed and verified.")


=== Starting In-Memory Compression Test ===
--- Processing Chunk 0 (8000 bytes) ---
Chunk 0 compressed.
Decompressing and verifying chunk 0...
--- Finished patch 0, starting next ---
--- Finished patch 1, starting next ---
--- Finished patch 2, starting next ---
--- Finished patch 3, starting next ---
--- Finished patch 4, starting next ---
--- Finished patch 5, starting next ---
--- Finished patch 6, starting next ---
--- Finished patch 7, starting next ---
--- Finished patch 8, starting next ---
--- Finished patch 9, starting next ---
--- Finished patch 10, starting next ---
--- Finished patch 11, starting next ---
--- Finished patch 12, starting next ---
--- Finished patch 13, starting next ---
--- Finished patch 14, starting next ---
--- Finished patch 15, starting next ---
--- Finished patch 16, starting next ---
--- Finished patch 17, starting next ---
--- Finished patch 18, starting next ---
--- Finished patch 19, starting next ---
--- Finished patch 20, starting next ---
--- F

In [25]:
print(decoded_tensor)

tensor([[127, 127, 127,  ..., 127, 128, 127]])


In [26]:
final_rate, final_ratio = compression_metric.compute_ratio()

print("=== FINAL COMPRESSION RESULTS ===")
print(f"Total Original Bytes:   {compression_metric.total_length}")
print(f"Total Compressed Bytes: {compression_metric.compressed_length}")
print(f"Final Compression Ratio:  {final_ratio:.6f}")
print(f"Final Compression Rate:   {final_rate:.6f}x")

=== FINAL COMPRESSION RESULTS ===
Total Original Bytes:   39235
Total Compressed Bytes: 21308
Final Compression Ratio:  0.543087
Final Compression Rate:   1.841327x
