Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] np.packbits / np.unpackbits, general BitTensors (maybe can be just tensors with dtype torch.bits8 or have a new dtype torch.bits introduced) and bit packed tensors utilities for saving memory / accesses, support for BitTensors wherever BoolTensors are used #32867

Open
vadimkantorov opened this issue Jan 31, 2020 · 77 comments
Labels
feature A request for a proper, new feature. high priority module: boolean tensor triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jan 31, 2020

A usecase: storing a full backtracking pointer matrix can be okay for needleman/ctc alignment (4x memory saving compared to uint8 representation), if 2bit data type is used. Currently it's possible to do this with bit manipulation magic, but probably not very efficient (store and load will require masking and shifting, not fused)

Another usecase: compressed BoolTensor for binary neural networks

Another usecase: extremely low-bit quantized representations.

Is something like this already implemented for quantization? Probably a simple version of this feature could be providing some explicitly utility functions like calculating size of the holder uint8 tensor, fused store and load functions (potentially explicitly batched, e.g. actual store is delayed until some aligned number of memory lines has arrived)

In NumPy the related functionality is np.packbits and np.unpackbits, however these are designed to work only with 1-bit contained type. 2-bit/4-bit would be cool as well.

On 1-bit side, another related project is RoaringBitmap https://github.com/RoaringBitmap/RoaringBitmap (http://roaringbitmap.org/) - for compressed bitsets for set operations.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @izdeby

@ezyang ezyang added feature A request for a proper, new feature. module: boolean tensor triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 3, 2020
@ezyang
Copy link
Contributor

ezyang commented Feb 3, 2020

One difficulty is that in many points of our code we assume all tensor elements are addressable (for example, for views), and this would not be the case with bit-packed tensors.

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Feb 3, 2020

I wonder if we could design some explicit pack/unpack/load/store/index util methods that would be enough for basic usage (like numpy does with packbits/unpackbits)

Maybe we could have some unpack method that is optimal if the user themselves provided dtype-aligned indexes

This new bittensors wouldn't be first class objects, but still utility methods could be enough for first experimentation.

Maybe unpack method could be some variant of narrow that performs a single memory access if the index/length are aligned with container dtype. NestedTensor/vmap could be used to represent the returned unpacked byte tensor list

@vadimkantorov vadimkantorov changed the title [feature request] Bit packed tensors [feature request] Bit packed tensors utilities Feb 4, 2020
@vadimkantorov
Copy link
Contributor Author

A simple interface could be packbits / unpackbits like in NumPy with additional bitness argument (to support 1-bit, 2-bit and 4-bit) and dim argument. It should maybe support out argument for unpacked uint8 tensor. Unpacked dimension could always be a new zeroth dimension.

@vadimkantorov
Copy link
Contributor Author

https://www.microsoft.com/en-us/research/uploads/prod/2018/02/KoPhiliposeTashevZarar_ICASSP_2018.pdf suggests that XNOR and POPCNT functionality is useful for 1-bit networks

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Feb 14, 2020

Arguments that can be helpful for packbits/unpackbits:

  1. mask - a bit mask integer, specifying pack/unpack compress mask like in compress expand instructions -> this is slightly more flexible than a single nbits=1|2|4 arg)

  2. dim -> packing / unpacking along a given dim (during unpacking it can then only be done across an already existing dim, maybe that's fine for dense dimensions)

  3. target dim size may be needed for unpacking to undo the padding

  4. out argument

I guess on CPU packbits/unpackbits can be implemented with those compress/expand SIMD instructions, if the op is performed across contiguous dimension (actually contiguity doesn't matter after a load to a vector register is done already)

@vadimkantorov
Copy link
Contributor Author

in my code I'd do sth like: torch.packbits(something.argmax(dim = -1), mask = 0b11, dim =-1, out = my_uint8_array[k])

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Feb 14, 2020

I think one can think about this feature request as surfacing compress/expand SIMD functionality to user land and reimplementing it on GPU

@ezyang ezyang added the needs research We need to decide whether or not this merits inclusion, based on research world label Feb 19, 2020
@gchanan
Copy link
Contributor

gchanan commented Feb 25, 2020

pack and unpack seem worth doing. The other parts (i.e. compress/expand) could be useful, but I'm not sure it's worth doing -- it seems like at that point you'd be writing specialized ops in C++ anyway.

@gchanan gchanan removed the needs research We need to decide whether or not this merits inclusion, based on research world label Feb 25, 2020
@vadimkantorov
Copy link
Contributor Author

I thought that given a mask pack/unpack are precisely equivalent to SIMD compress/expand? Aren’t they?

@ezyang
Copy link
Contributor

ezyang commented Feb 27, 2020

We haven't looked! You're probably right :)

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Feb 27, 2020

General int4 support request in #33859 is also related

@vadimkantorov
Copy link
Contributor Author

Another useful functionality would be scatter/gather-like functionality for compressing index tensors.

In a practical usecase it can help to compress the hybrid sparse+dense tensor indices by a lot: https://discuss.pytorch.org/t/sparse-torch-topk/71832/4

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Mar 15, 2020

It seems that <8bit quantization starts to appear: #34783 and seems somewhat related to this discussion

@ezyang
Copy link
Contributor

ezyang commented Mar 16, 2020

cc @jspark1105

@vadimkantorov
Copy link
Contributor Author

related: #36380

@vadimkantorov vadimkantorov changed the title [feature request] Bit packed tensors utilities [feature request] BitTensors and bit packed tensors utilities Jul 7, 2020
@vadimkantorov
Copy link
Contributor Author

I renamed to enlarge the scope a little bit :) BitTensors could be very helpful for binary neural networks. Even if few operators on them are supported (such as bit packbits/unpackbits, binary operations, popcnt), they are already useful for reducing memory footprint, e.g. for storing masks instead of full inputs when sufficient for backward ops. E.g. in #41034 if a mask is stored, the silu/swish operation would become bijective if additional bit is stored to represent direction (half-space) away from function minimum.

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jul 17, 2020

I made a draft (https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a):

import math
import torch

def tensor_dim_slice(tensor, dim, s):
    return tensor[(slice(None),) * (dim if dim >= 0 else dim + tensor.dim()) + (s, )]

def packshape(shape, dim, mask, dtype):
    nbits_element = torch.iinfo(dtype).bits
    nbits = 1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111  else None
    assert nbits is not None and nbits <= nbits_element and nbits_element % nbits == 0
    packed_size = nbits_element // nbits
    shape = list(shape)
    shape[dim] = int(math.ceil(shape[dim] / packed_size))
    return shape, packed_size, nbits

def packbits(tensor, dim = -1, mask = 0b00000001, out = None, dtype = torch.uint8):
    shape, packed_size, nbits = packshape(tensor.shape, dim = dim, mask = mask, dtype = dtype)
    out = out.zero_() if out is not None else torch.zeros(shape, device = tensor.device, dtype = dtype)
    assert tuple(out.shape) == tuple(shape)
    for e in range(packed_size):
        sliced_input = tensor_dim_slice(tensor, dim, slice(e, None, packed_size))
        compress = (sliced_input << (nbits * (packed_size - e - 1)))
        sliced_output = out.narrow(dim, 0, sliced_input.shape[dim])
        sliced_output |= compress
    return out

def unpackbits(tensor, shape, dim = -1, mask = 0b00000001, out = None, dtype = torch.uint8):
    _, packed_size, nbits = packshape(shape, dim = dim, mask = mask, dtype = tensor.dtype)
    out = out.zero_() if out is not None else torch.zeros(shape, device = tensor.device, dtype = dtype)
    assert tuple(out.shape) == tuple(shape)
    for e in range(packed_size):
        sliced_output = tensor_dim_slice(out, dim, slice(e, None, packed_size))
        expand = (tensor >> (nbits * (packed_size - e - 1))) & ((1 << nbits) - 1)
        sliced_input = expand.narrow(dim, 0, sliced_output.shape[dim])
        sliced_output.copy_(sliced_input)
    return out

if __name__ == '__main__':
    shape = (10, 17)
    K = 10
    for nbits in [1, 2, 4, 8]:
        mask = (1 << nbits) - 1
        for dtype in [torch.uint8, torch.int32, torch.int64]:
            for k in range(K):
                x = torch.randint(0, 1 << nbits, shape, dtype = dtype)
                y = packbits(x, mask = mask)
                z = unpackbits(y, mask = mask, dtype = x.dtype, shape = x.shape)
                assert torch.allclose(x, z)
                                             

Discussion about including in core tensor_slice is in https://discuss.pytorch.org/t/use-python-like-slice-indexing-across-a-given-dimension/89606/7

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jul 21, 2020

Any advice on how to fuse this properly and for cuda? Should just torch.jit.script work?

@vadimkantorov
Copy link
Contributor Author

Very interesting! I think core would be happy to have the basic kernels for packbits/unpackbits at least

@Felix-Petersen
Copy link

My current (unreleased) version of packing and unpacking is:

CUDA

#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x)                                                                                                 \
    CHECK_CUDA(x);                                                                                                     \
    CHECK_CONTIGUOUS(x)

// adapted from https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
#define gpuErrchk(ans)                                                                                                 \
    { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(const cudaError_t code, const char *const file, const int line, const bool abort = true) {
    if (code != cudaSuccess) {
        fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
        if (abort)
            exit(code);
    }
}

template <typename T> T ceil_div(const T x, const T y) { return x / y + !!(x % y); }

/**********************************************************************************************************************/

template <typename scalar_t>
__global__ void tensor_packbits_cuda_kernel(
    torch::PackedTensorAccessor64<bool, 2, torch::RestrictPtrTraits> t,
    torch::PackedTensorAccessor64<scalar_t, 2, torch::RestrictPtrTraits> b
) {

    for (
        auto row = blockIdx.y * blockDim.y + threadIdx.y;
        row < b.size(0);
        row += blockDim.y * gridDim.y
    ) {
        for (
            auto col = blockIdx.x * blockDim.x + threadIdx.x;
            col < b.size(1);
            col += blockDim.x * gridDim.x
        ) {

            typedef typename std::make_unsigned<scalar_t>::type unsigned_scalar_t;
            union {
                unsigned_scalar_t unsigned_scalar;
                scalar_t signed_scalar;
            } val;
            constexpr int bit_count = std::numeric_limits<unsigned_scalar_t>::digits;
            val.signed_scalar = b[row][col];
            for (unsigned int i = 0; i < bit_count; ++i) {
                const unsigned_scalar_t bit_mask = static_cast<unsigned_scalar_t>(t[row][bit_count * col + i]) << i;
                val.unsigned_scalar = val.unsigned_scalar | bit_mask;
            }
            b[row][col] = val.signed_scalar;
        }
    }
}

std::tuple<torch::Tensor, int> tensor_packbits_cuda(
    torch::Tensor t,
    const int bit_count
) {
    CHECK_INPUT(t);

    const auto batch_in_size = t.size(1);
    const auto batch_out_size = ceil_div(batch_in_size, static_cast<int64_t>(bit_count));
    const auto out_size = t.size(0);
    const auto pad_len = (bit_count - batch_in_size % bit_count) % bit_count;

    dim3 threads_per_block(32, 32);

    const dim3 blocks_per_grid(
        min(static_cast<int64_t>(65535), ceil_div(batch_out_size, static_cast<int64_t>(threads_per_block.x))),
        min(static_cast<int64_t>(65535), ceil_div(out_size, static_cast<int64_t>(threads_per_block.y)))
    );

    auto dispatch_type = [bit_count]() {
        switch (bit_count) {
        case 8:
            return torch::kInt8;
        case 16:
            return torch::kInt16;
        case 32:
            return torch::kInt32;
        case 64:
            return torch::kInt64;
        default:
            throw std::invalid_argument("`bit_count` has to be in { 8, 16, 32, 64 }");
        }
    }();
    auto b = torch::zeros({out_size, batch_out_size}, torch::dtype(dispatch_type).device(t.device()));

    AT_DISPATCH_INTEGRAL_TYPES(b.type(), "tensor_packbits_cuda_kernel", ([&] {
                                   tensor_packbits_cuda_kernel<scalar_t><<<blocks_per_grid, threads_per_block>>>(t.packed_accessor64<bool, 2, torch::RestrictPtrTraits>(),
                                                                                                                 b.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>());
                               }));
    gpuErrchk(cudaPeekAtLastError());
    gpuErrchk(cudaDeviceSynchronize());

    return {b, pad_len};
}


/**********************************************************************************************************************/


template <typename scalar_t>
__global__ void tensor_unpackbits_cuda_kernel(
    torch::PackedTensorAccessor64<scalar_t, 2, torch::RestrictPtrTraits> t,
    torch::PackedTensorAccessor64<bool, 2, torch::RestrictPtrTraits> b,
    const int bit_count
) {

    for (
        auto row = blockIdx.y * blockDim.y + threadIdx.y;
        row < b.size(0);
        row += blockDim.y * gridDim.y
    ) {
        for (
            auto col = blockIdx.x * blockDim.x + threadIdx.x;
            col < b.size(1);
            col += blockDim.x * gridDim.x
        ) {

            const auto bit = (t[row][col / bit_count] >> (col % bit_count)) & 1;
            b[row][col] = static_cast<bool>(bit);

        }
    }
}

torch::Tensor tensor_unpackbits_cuda(
    torch::Tensor t,
    const int bit_count,
    const int pad_len
) {
    CHECK_INPUT(t);

    const auto batch_in_size = t.size(1);
    const auto batch_out_size = batch_in_size * bit_count - pad_len;
    const auto out_size = t.size(0);

    dim3 threads_per_block(32, 32);

    const dim3 blocks_per_grid(
        min(static_cast<int64_t>(65535), ceil_div(batch_out_size, static_cast<int64_t>(threads_per_block.x))),
        min(static_cast<int64_t>(65535), ceil_div(out_size, static_cast<int64_t>(threads_per_block.y)))
    );

    auto b = torch::zeros({out_size, batch_out_size}, torch::dtype(torch::kBool).device(t.device()));

    AT_DISPATCH_INTEGRAL_TYPES(t.type(), "tensor_unpackbits_cuda_kernel", ([&] {
                                   tensor_unpackbits_cuda_kernel<scalar_t><<<blocks_per_grid, threads_per_block>>>(t.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
                                                                                                                   b.packed_accessor64<bool, 2, torch::RestrictPtrTraits>(),
                                                                                                                   bit_count
                                                                                                                   );
                               }));
    gpuErrchk(cudaPeekAtLastError());
    gpuErrchk(cudaDeviceSynchronize());

    return b;
}

CPP

std::tuple<torch::Tensor, int> tensor_packbits_cuda(
    torch::Tensor t,
    const int bit_count
);
torch::Tensor tensor_unpackbits_cuda(
    torch::Tensor t,
    const int bit_count,
    const int pad_len
);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def(
        "tensor_packbits_cuda",
        [](torch::Tensor t, const int bit_count) {
            return tensor_packbits_cuda(t, bit_count);
        },
        "tensor_packbits_cuda (CUDA)");
    m.def(
        "tensor_unpackbits_cuda",
        [](torch::Tensor t, const int bit_count, const int pad_len) {
            return tensor_unpackbits_cuda(t, bit_count, pad_len);
        },
        "tensor_unpackbits_cuda (CUDA)");
}

Python

def pack_along_dim_1(t, bit_count=64, device='cuda'):
    assert len(t.shape) == 2, t.shape
    assert device == 'cuda', device

    t = t.contiguous().to(device)

    return difflogic_cuda.tensor_packbits_cuda(t, bit_count)

def unpack_along_dim_1(t, pad_len, device='cuda'):
    assert len(t.shape) == 2, t.shape
    assert device == 'cuda', device

    t = t.contiguous().to(device)

    return difflogic_cuda.tensor_unpackbits_cuda(t, torch.iinfo(t.dtype).bits, pad_len)

if __name__ == '__main__':
    # Test pack_along_dim_1 / unpack_along_dim_1

    # t = torch.rand(123, 3910).round().bool()
    t0 = torch.rand(128, 2097152).round().bool().cuda()
    print(t0.shape, t0.dtype)
    t, pad = pack_along_dim_1(t0, 64)
    print(t.shape, t.dtype, pad)
    t2 = unpack_along_dim_1(t[::20], pad)
    print(t2.shape, t2.dtype)
    t = unpack_along_dim_1(t, pad)
    print(t.shape, t.dtype)

    print((t0 == t).float().mean())
    assert (t0 == t).all()

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Sep 13, 2023

Matlab's functions for packing binary images into uint32: https://www.mathworks.com/help/images/ref/bwpack.html and https://www.mathworks.com/help/images/ref/bwunpack.html

@vadimkantorov
Copy link
Contributor Author

Maybe torch.bits8 or torch.bits16 can be used as target for such pack/unpack functions. It would also be nice to support such bits tensor creation with passed Python array of bools (torch.tensor([True]*8, dtype = torch.bits8) gives currently RuntimeError: invalid type) and maybe .tolist() to such a python list of bools, only endianness needs to be decided or e.g. only big endian could be supported for such packing)

@vadimkantorov vadimkantorov changed the title [feature request] np.packbits / np.unpackbits, general BitTensors and bit packed tensors utilities for saving memory / accesses, support for BitTensors wherever BoolTensors are used [feature request] np.packbits / np.unpackbits, general BitTensors (maybe can be just tensors with dtype torch.bits8 or have a new dtype torch.bits introduced) and bit packed tensors utilities for saving memory / accesses, support for BitTensors wherever BoolTensors are used Oct 8, 2023
@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 10, 2024

Do you think it's better to have a separate tensor dtype for the semantics "compressed bool tensor"/bitmap or just ab/use torch.bits8 for this? I guess the important question is how to treat indexing bool_tensor[4] versus bit_tensor[4] (maybe some special function can be implemented on bittensor/bits8 for this?) and how to treat cases if original BoolTensor shapes do not divide 8 (one option might be to simply not support this or to always pad and let the user save somewhere else the original BoolTensor's shape). Then it would be nice to use it for save_for_backward'ing the compressed mask in old eager's dropout impl and allow fused impls like torch.gt(float_tensor, 4, dtype = torch.bits)

@ezyang
Copy link
Contributor

ezyang commented Jan 10, 2024

So we just agreed internally that we are going to implement a uint1 type. So you can use that. Note that sub byte dtypes need a lot of extra infra which we plan to implement using Python tensor subclasses

@vadimkantorov
Copy link
Contributor Author

I think, for being useful packbits/unpackbits + bit ops + some fused kernels that currently produce only BoolTensor are okay for already being useful. Going forward, it might be nice to have all things that currently support torch.bool to also support torch.uint1. Btw, if we have some generic tiled dtypes, dispatch system for sub-byte dtypes can also become simpler as they are naturally tiled (e.g. think usecase for torch.gt(float_tensor, 4, dtype = torch.bitmap) where a tile of 8 floats could ideally be read at once and be processed into a single byte (tile of 8 bits) on the output). These are probably already dispatched in the similar way for the vectorized kernels on CPU? (as tiled vec types exist)

Also, I must bikeshed-confess that uint1 is super unintuitive name for the high-level bitmap/bitmask/bittensor concept. Maybe worth naming it bitmap? or bit/bits?

For subclass, I think it would be good to have auto-upcast to torch.bool, so that existing bool kernels can consume bitmaps right away, hiding a reallocation for the eager mode (and ideally optimized away by inductor for the compiled case? )

@vadimkantorov
Copy link
Contributor Author

torch.bit is in thanks to @jerryzh168 #117208 :)

@Felix-Petersen
Copy link

I might be missing the actual changes, but as I understand it #117208 only high-level declares the existence of uint1 etc. Am I misunderstanding this?

@Felix-Petersen
Copy link

@ezyang
Copy link
Contributor

ezyang commented Jan 13, 2024

Yeah none of the actual implementation exists yet. Also, the implementation is going to all be in Python, so to get low overhead performance you are going to have to use torch.compile eventually

@Felix-Petersen
Copy link

Felix-Petersen commented Jan 13, 2024

Would we really expect that torch.compile will be able to perform efficient bitpacking? Single bit accesses seem something that ideally should be optimized on a very low level (eventually potentially instruction level). At the end of the day, a lot depends on the specifics of memory accesses, but, for any low bit cases, we need to access at least 8 bits for reading a single bit, and memory address-wise, the data type should probably not exist outside of 1+ dimensional tensors. I believe it requires significant consideration to support it properly. (And hardware considerations should be taken into account if it becomes part of a stable release.)

Just to clarify, the plan is still that we would still be storing 8 bits in 1 byte (contrasting the current 1 bit in 1 byte of bool)?

@ezyang
Copy link
Contributor

ezyang commented Jan 13, 2024

Just to clarify, the plan is still that we would still be storing 8 bits in 1 byte (contrasting the current 1 bit in 1 byte of bool)?

Yes.

Would we really expect that torch.compile will be able to perform efficient bitpacking? Single bit accesses seem something that ideally should be optimized on a very low level (eventually potentially instruction level). At the end of the day, a lot depends on the specifics of memory accesses, but, for any low bit cases, we need to access at least 8 bits for reading a single bit, and memory address-wise, the data type should probably not exist outside of 1+ dimensional tensors.

So, at least in the short term the way I would expect to implement these operators in terms of regular bitwise operations on the bit tensor reinterpreted as a uint8_t (or larger) tensor. The planned implementation for these sub-byte tensors doesn't really allow for non-aligned accesses anyway (see https://github.com/albanD/subclass_zoo/blob/main/uint4_tensor.py for PoC).

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 13, 2024

Also, it could be good to support many ops relevant for torch.bit tensors (bit ops / pack ops etc) on all tensors as well using reinterpret, but encourage/advertise their use only on torch.bit examples (e.g. in #105465 it might make sense to support bit flips on float tensors as well as it could be used for flipping sign or on setting LSB to some external value).

At the end of the day, IMHO, what matters is existence of these ops with nice, clear, specific naming and description. So maybe forcing a subclass or reinterpret could even be avoided, although the op examples could promote use of torch.bit for bitset/bitmap/compressedbooltensor usecases

I think sub-byte indexing helpers per se is not the most pressing op anyway :)

@vadimkantorov
Copy link
Contributor Author

Also, the implementation is going to all be in Python

I guess it would be good to have benchmarks in test comparing at least initially these Python-compiled bit ops with some manual CUDA loops to be confident in them, as unrolling multiple loop iterations related to bit to have a single memory store per byte (or even several bytes) might be a non-trivial optimization. Maybe for this it would be good to have these wide-dtypes in Python as well (e.g. like float32x8, then it would be sth like: read 8-tuple of floats from a float32tensor.view(torch.float32x8) and then store a single byte as produced output as a single store. or maybe even wider dtypes as outputs (and inputs) would be a good abstraction to help torch.compile produce effective code for these kinds of patterns (read, compute/pack/reduce, write)

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 25, 2024

Regarding sub-byte access, I would propose to treat this by introducing some variants of getbit/setbit/togglebit ops to be available for any dtype: #105465, at least this would unblock and make access more elegant and error-free.

@LemonPi
Copy link

LemonPi commented Jan 25, 2024

To chime in, my use case for this is to efficiently represent occupancy voxel grids where each voxel is either occupied or not (don't need 8 bits info). I'm running into memory issues when batching these voxel grids, and so better memory efficiency would be great.

However, these voxel grids are also frequently read from/written to, so performance is also a concern when considering packing/unpacking.

@skywolf829
Copy link

Chiming in to show support tor the feature. Similar use case to LemonPi above me - mask tensors using 8 bits per element does chew up more memory than needed. In addition, I'm passing the mask to a CUDA kernel thru Pybind and its a little weird having to cast it as mask_tensor.contiguous().data<uint8_t>() instead of as bool. Makes the code slightly less readable to a new dev looking at my code ("why is this boolean tensor read as uint8??").

@vadimkantorov
Copy link
Contributor Author

Somewhat related (although used ternary) - BitNet: https://arxiv.org/abs/2402.17764

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Apr 7, 2024

Maybe given the revival of interest in extremely low-bit quantization methods, maybe would be nice to also include some eager methods in core (mainly for correct tested semantics, not for actualy speed - speed / Linear op support may come later). I personally am more in favor of adding them first as eager methods working with dense/bits tensors without adding tensor subclasses. Now there is a lot of experimentation with quantization methods, it's developing very fast and unclear if any particular method wins. A lof of this experimentation is working with very simple Linear module swaps, and for this having some eager method under the hood and passing around a few tensors manually representing qparams state is completely fine.

@vadimkantorov
Copy link
Contributor Author

Related on pack/unpack tuples of floats: pytorch/ao#208. It would be nice to have various performant pack/unpack bit utils in core PyTorch (both in eager for semantics experiments + for fusion with triton codegen)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. high priority module: boolean tensor triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants