In [1]:
import torch
import torch._prims_common as utils
import torch.utils._pytree as pytree
from torch.library import impl, Library
import lovely_tensors as lt
lt.monkey_patch()

def down_size(size):
    assert size[-1] % 4 == 0, f"{size} last dim not divisible by four"
    return (*size[:-1], size[-1] // 4)

def up_size(size):
    return (*size[:-1], size[-1] * 4)

def unpack_uint2(uint8_data) -> torch.Tensor:
    """Get the original weight from the normalized float weight format"""
    # since we are using uint8 we will decode 4 entries per byte
    shape = uint8_data.shape
    first_elements = ((uint8_data >> 6) & 0b11).to(torch.uint8)
    second_elements = ((uint8_data >> 4) & 0b11).to(torch.uint8)
    third_elements = ((uint8_data >> 2) & 0b11).to(torch.uint8) 
    fourth_elements = (uint8_data & 0b11).to(torch.uint8)
    return torch.stack([first_elements, second_elements, third_elements, fourth_elements], dim=-1).view(up_size(shape))

def pack_uint2(uint8_data) -> torch.Tensor:
    # converting to uint8 for operations
    shape = uint8_data.shape
    assert shape[-1] % 4 == 0
    uint8_data = uint8_data.contiguous().view(-1)
    packed_data = (uint8_data[::4] << 6 | uint8_data[1::4] << 4 | uint8_data[2::4] << 2 | uint8_data[3::4]).view(down_size(shape))
    return packed_data

In [2]:
### Here for reference:
def pack_uint4(uint8_data) -> torch.Tensor:
    # converting to uint8 for operations
    shape = uint8_data.shape
    assert shape[-1] % 2 == 0
    uint8_data = uint8_data.contiguous().view(-1)
    return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape))

In [3]:
### Here for reference:
def unpack_uint4(uint8_data) -> torch.Tensor:
    """Get the original weight from the normalized float weight format"""
    # since we are using uint8 we will decode 2 entries per byte
    # Shift elements down 4 and select out the bottom 4 bits
    shape = uint8_data.shape
    first_elements = (uint8_data >> 4).to(torch.uint8)
    second_elements = (uint8_data & 0b1111).to(torch.uint8)
    return torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape))

In [4]:
131072 / 4

32768.0

In [5]:
test_tensor = torch.randint(0, 3, (1024, 16, 8), dtype=torch.uint8)
print(test_tensor)
packed = pack_uint2(test_tensor)
unpacked = unpack_uint2(packed)
print(unpacked.allclose(test_tensor))
assert(unpacked.allclose(test_tensor))

tensor[1024, 16, 8] u8 n=131072 (0.1Mb) x∈[0, 2] μ=0.999 σ=0.815
True


In [6]:
def roundclip(x, a, b):
    return torch.max(torch.tensor(a), torch.min(torch.tensor(b), torch.round(x)))

def quantize_per_tensor_uint2_trinary(weights):
    # Compute the average absolute value of the weight tensor
    gamma = torch.mean(torch.abs(weights))
    
    # Scale the weight tensor by the average absolute value
    scaled_weights = weights / (gamma + 1e-8)
    
    # Round each scaled weight to the nearest integer in {-1, 0, +1}
    quantized_weights = roundclip(scaled_weights, -1, 1)

    #Shift the distribution over by 1 so we can pack into a uint and not deal with signs
    quantized_weights += 1.0
    return quantized_weights.to(torch.uint8)

In [7]:
test_layer = torch.rand(1024, 16, 8) * 500.0 - 250.0
test_layer

tensor[1024, 16, 8] n=131072 (0.5Mb) x∈[-250.000, 249.996] μ=-0.710 σ=144.426

In [8]:
quantized_fake_layer = quantize_per_tensor_uint2_trinary(test_layer)
print(quantized_fake_layer)

tensor[1024, 16, 8] u8 n=131072 (0.1Mb) x∈[0, 2] μ=0.999 σ=0.867


In [9]:
packed = pack_uint2(quantized_fake_layer)
unpacked = unpack_uint2(packed)
print(unpacked.allclose(quantized_fake_layer))
assert(unpacked.allclose(quantized_fake_layer))

True
