Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable cat for cuda bits types #115044

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 35 additions & 12 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.mutable_data_ptr<scalar_t>();
scalar_t *data = (scalar_t *) out.mutable_data_ptr();
ngimel marked this conversation as resolved.
Show resolved Hide resolved
CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;

Expand Down Expand Up @@ -289,7 +289,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
dimSize = inputs[i+batchCounter].get().size(dimension);
}

catMetaData.input[batchCounter] = inputs[i+batchCounter].get().const_data_ptr<scalar_t>();
catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr());
ngimel marked this conversation as resolved.
Show resolved Hide resolved
catMetaData.offset[batchCounter] = offset;
catMetaData.dimSize[batchCounter] = dimSize;
catMetaData.nElements[batchCounter] = inputs[i+batchCounter].get().numel();
Expand Down Expand Up @@ -375,6 +375,10 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
#undef HANDLE_CASE
}
}
// The kernels are templated on an opaque, self-aligned type of the correct
// size to avoid redundant kernels for different types of the same size.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
ngimel marked this conversation as resolved.
Show resolved Hide resolved

} // namespace

TORCH_IMPL_FUNC(cat_out_cuda)
Expand Down Expand Up @@ -412,29 +416,48 @@ TORCH_IMPL_FUNC(cat_out_cuda)
// memory. Therefore, we could pass more inputs to cuda threads.
// For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation
// of constant memory.



if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(result) &&
all_contiguous &&
all32BitIndexable &&
all_same_dtype) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
if (isBitsType(result.scalar_type())) {
AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
}
} else if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(result) &&
nDims <= CAT_ARRAY_MAX_INPUT_DIMS &&
all32BitIndexable &&
all_same_dtype &&
memory_format == c10::MemoryFormat::Contiguous) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
if (isBitsType(result.scalar_type())) {
AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
}
} else {
int64_t offset = 0;
for (const Tensor& t : materialized) {
Expand Down
30 changes: 26 additions & 4 deletions test/quantization/core/experimental/test_bits.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Owner(s): ["oncall: quantization"]

import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests

from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import tree_map

import itertools

class Int16Tensor(torch.Tensor):
def __new__(cls, elem):
assert elem.dtype == torch.bits16
Expand Down Expand Up @@ -41,24 +45,42 @@ def __repr__(self) -> str:


class TestBits(TestCase):
def test_types(self):
def test_types(self, device):
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
for bits_type in bits_types:
_ = torch.zeros(20, dtype=torch.int32).view(bits_type)
_ = torch.empty(20, dtype=bits_type)
x = torch.randint(100, (20, 20), dtype=torch.int8).view(bits_type)
_ = torch.zeros(20, dtype=torch.int32, device=device).view(bits_type)
_ = torch.empty(20, dtype=bits_type, device=device)
x = torch.randint(100, (20, 20), dtype=torch.int8, device=device).view(bits_type)
y = x.t().contiguous()
view_type = torch.int8 if x.element_size() == 1 else torch.int16
self.assertEqual(x.t().view(view_type), y.view(view_type))
y = x.t().clone()
self.assertEqual(x.t().view(view_type), y.view(view_type))

def test_cat(self, device):
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
for bits_type in bits_types:
view_type = torch.int8 if bits_type.itemsize == 1 else torch.int16
x_int = torch.randint(100, (512, 512), dtype=view_type, device=device)
x = x_int.view(bits_type)
y_int = torch.randint(100, (512, 512), dtype=view_type, device=device)
y = y_int.view(bits_type)
for dim, transpose in itertools.product(range(x_int.ndim), (True, False)):
y_ref = y_int.t() if transpose else y_int
y_b = y.t() if transpose else y
z_ref = torch.cat([x_int, y_ref], dim=dim)
z = torch.cat([x, y_b], dim=dim)
self.assertEqual(z_ref, z.view(view_type))


def test_subclass(self):
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16)
s = Int16Tensor(t)
s = s + 1 - 1
self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16)))

instantiate_device_type_tests(TestBits, globals())


if __name__ == '__main__':
run_tests()
9 changes: 8 additions & 1 deletion test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,14 @@
logging.warning(e)

# Experimental functionality
from quantization.core.experimental.test_bits import TestBits # noqa: F401
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes Steve sad...

from quantization.core.experimental.test_bits import TestBitsCPU # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_bits import TestBitsCUDA # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU # noqa: F401
except ImportError as e:
Expand Down