In [4]:
import torch
import torch._prims_common as utils
import torch.utils._pytree as pytree
from torch.library import impl, Library
from functools import reduce
import lovely_tensors as lt
lt.monkey_patch()

In [74]:
def unpack_uint4(uint_data) -> torch.Tensor:
    """Get the original weight from the normalized float weight format"""
    # since we are using uint8 we will decode 2 entries per byte
    # Shift elements down 4 and select out the bottom 4 bits
    shape = uint_data.shape
    scale = uint_data.element_size() * 8 // 4 # how many uint4s can fit in a dtype_size
    unpacked_data = torch.empty((*shape, scale), dtype=uint_data.dtype)
    for i in range(scale):
        unpacked_data[..., i] = (uint_data >> int(uint_data.element_size()*8- 4*(i+1))) & 0b1111
    return unpacked_data.view(up_size(shape, scale))


def pack_uint4(uint_data, dtype_size=8) -> torch.Tensor:
    # converting to uint8 for operations
    scale = dtype_size // 4 # how many uint4s can fit in a dtype_size
    padding = torch.zeros((*uint_data.shape[:-1], (scale - uint_data.shape[-1] % scale)%scale), dtype=uint_data.dtype)
    uint_data = torch.cat([uint_data, padding], dim=-1)
    shape = uint_data.shape
    uint_data = uint_data.contiguous().view(-1)
    return reduce(lambda x,y: x|y,[uint_data[i::scale] << dtype_size-4*(i+1) for i in range(scale)]).view(down_size(shape, scale))

def down_size(size, amt):
    assert size[-1] % amt == 0, f"{size} last dim not divisible by {amt}"
    return (*size[:-1], size[-1] // amt)


def up_size(size, amt):
    return (*size[:-1], size[-1] * amt)

In [81]:
test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8)
# print('og', test_tensor)
packed = pack_uint4(test_tensor)
# print('packed', packed)
unpacked = unpack_uint4(packed)
# print('unpacked', unpacked)
unpadded = unpacked[..., :test_tensor.shape[-1]]
assert(unpadded.allclose(test_tensor))

test_tensor = torch.randint(0, 3, (5,1, 4), dtype=torch.int16)
# print('og', test_tensor)
packed = pack_uint4(test_tensor,16)
# print('packed', packed)
unpacked = unpack_uint4(packed)
# print('unpacked', unpacked)
unpadded = unpacked[..., :test_tensor.shape[-1]]
assert(unpadded.allclose(test_tensor))

test_tensor = torch.randint(0, 3, (3,1, 9), dtype=torch.int32)
# print('og', test_tensor)
packed = pack_uint4(test_tensor,32)
# print('packed', packed)
unpacked = unpack_uint4(packed)
# print('unpacked', unpacked)
unpadded = unpacked[..., :test_tensor.shape[-1]]
assert(unpadded.allclose(test_tensor))