In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
if '../src' not in sys.path:
  sys.path.append('../src')
import os
os.environ["BNB_CUDA_VERSION"] = "115"
import numpy as np
from llm_compressor import AECompressorLLM
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("gpt2")


BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64
Loading CUDA version: BNB_CUDA_VERSION=115


  warn((f'\n\n{"="*80}\n'


In [4]:

# prompt_ids = tokenizer("This", return_tensors="pt").input_ids
# gentext = tokenizer.batch_decode(
#             model.generate(input_ids=prompt_ids, 
#             max_new_tokens=20, pad_token_id=tokenizer.eos_token_id))[0]
# gentext

In [5]:
import torch
test_text = "The hypernym of cat is animal"
input_ids = tokenizer(test_text, return_tensors="pt").input_ids
# input_ids = input_ids[:, :-1]
with torch.no_grad():
  logits = model(input_ids).logits.squeeze()
probs = torch.softmax(logits, dim=1)
uniform_prob = torch.ones(probs.shape[1]) / probs.shape[1]
next_token_probs = torch.concat([uniform_prob.unsqueeze(0), probs[:-1, :]], dim=0)
uniform_nt_probs = torch.ones_like(probs) / probs.shape[1]

In [6]:
# next_token_probs: (seq_len, vocab_size)
# input_ids: (batch_size, seq_len)
print(input_ids.squeeze().unsqueeze(1).shape)
print(next_token_probs.gather(dim=1, index=input_ids.squeeze().unsqueeze(1)).squeeze())
print(tokenizer.convert_ids_to_tokens(input_ids.squeeze()))

torch.Size([8, 1])
tensor([1.9898e-05, 4.2627e-05, 1.1066e-05, 1.2249e-01, 3.8000e-02, 1.4243e-04,
        1.6088e-03, 1.0009e-04])
['The', 'Ġhyper', 'ny', 'm', 'Ġof', 'Ġcat', 'Ġis', 'Ġanimal']


In [7]:
compressor = AECompressorLLM()
data_ids = input_ids.squeeze().tolist()

msg = compressor.compress(data_ids, next_token_probs)
recon = compressor.decompress(msg, len(data_ids), next_token_probs)
## uniform probs baseline
# msg = compressor.compress(data_ids, uniform_nt_probs)
# recon = compressor.decompress(msg, len(data_ids), uniform_nt_probs)

assert all(a==b for a, b in zip(recon, data_ids))
msg_len = len(msg)
data_len = len(data_ids) * 16
print(f"message length: {msg_len} bits")
print(f"data length: {data_len} bits")
print(f"compress ratio: {msg_len/data_len:.4f}")


message length: 90 bits
data length: 128 bits
compress ratio: 0.7031


In [11]:
dbg = compressor.encoder.dbg

In [12]:
dbg = compressor.encoder.dbg
dbg
# prob = dbg["prob"]
# sym = dbg["symbol"]
# total = dbg["total"]
# r = dbg["range"]
# prob[sym], total, r

{'prob': [[0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-2147483648, 0),
  [0, -2147483648),
  [-21474836

In [None]:
import zlib
zmsg = zlib.compress(test_text.encode())

In [None]:
len(zmsg)/len(test_text.encode())

1.2758620689655173

## Appendix

In [None]:
## clear float2repr implmentation by Copilot
import struct

def float_repr(num):
    # pack the float into a bytes object
    packed = struct.pack('f', num)
    
    # unpack the bytes object to get the exponent and fractional part
    bits = struct.unpack('I', packed)[0]    
    sign = bits >> 31
    exp = (bits >> 23) & 0xff
    frac = bits & 0x7fffff
    
    # convert the exponent to a signed integer
    if exp == 0:
        exp = -126
    else:
        exp -= 127
    
    # convert the fractional part to a float
    frac = float(frac) / (1 << 23)
    
    # apply the sign, exponent, and fractional part to get the final representation
    print("sign: ", sign)
    print("exp: ", exp)
    print("frac: ", frac)
    assert (-1)**sign * (1 + frac) * 2**exp == num
