In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
if '..' not in sys.path:
  sys.path.append('..')

import numpy as np
from llm_compressor import AECompressorLLM

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [4]:
tokenizer = AutoTokenizer.from_pretrained("ckiplab/gpt2-base-chinese")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("ckiplab/gpt2-base-chinese")

In [5]:
test_text = "這是一個測試"
prompt_ids = tokenizer("這是", return_tensors="pt").input_ids[:, :-1]
gentext = tokenizer.batch_decode(
            model.generate(input_ids=prompt_ids, 
            max_new_tokens=20, pad_token_id=tokenizer.eos_token_id))[0]
gentext

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:102 for open-end generation.


'[CLS] 這 是 一 個 很 好 的 例 子 。 」 他 說 : 「 我 們 不 能 說, 我'

In [6]:
import torch
input_ids = tokenizer(test_text, return_tensors="pt").input_ids
with torch.no_grad():
  logits = model(input_ids).logits.squeeze()
probs = torch.softmax(logits, dim=1)
uniform_prob = torch.ones(probs.shape[1]) / probs.shape[1]
next_token_probs = torch.concat([uniform_prob.unsqueeze(0), probs[:-1, :]], dim=0)
                               

In [7]:
compressor = AECompressorLLM()
data_ids = input_ids.squeeze().tolist()

msg = compressor.compress(data_ids, next_token_probs)
recon = compressor.decompress(msg, len(data_ids), next_token_probs)
assert all(a==b for a, b in zip(recon, data_ids))
msg_len = len(msg)
data_len = len(data_ids) * 16
print(f"message length: {msg_len} bits")
print(f"data length: {data_len} bits")
print(f"compress ratio: {msg_len/data_len:.4f}")


message length: 83 bits
data length: 128 bits
compress ratio: 0.6484


In [13]:
import zlib
zmsg = zlib.compress(gentext.encode())

In [14]:
len(gentext.encode())*8

600

In [92]:
len(zmsg)/len(gentext.encode())

0.5277777777777778

In [19]:
data_ids

[15496, 11, 616, 3290, 318, 13779]

In [23]:
16*len(data_ids)

96

## Appendix

In [None]:
## clear float2repr implmentation by Copilot
import struct

def float_repr(num):
    # pack the float into a bytes object
    packed = struct.pack('f', num)
    
    # unpack the bytes object to get the exponent and fractional part
    bits = struct.unpack('I', packed)[0]    
    sign = bits >> 31
    exp = (bits >> 23) & 0xff
    frac = bits & 0x7fffff
    
    # convert the exponent to a signed integer
    if exp == 0:
        exp = -126
    else:
        exp -= 127
    
    # convert the fractional part to a float
    frac = float(frac) / (1 << 23)
    
    # apply the sign, exponent, and fractional part to get the final representation
    print("sign: ", sign)
    print("exp: ", exp)
    print("frac: ", frac)
    assert (-1)**sign * (1 + frac) * 2**exp == num
