Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unsigned integer dtypes to PyTorch #116594

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ DLDataType getDLDataType(const Tensor& t) {
dtype.bits = t.element_size() * 8;
switch (t.scalar_type()) {
case ScalarType::Byte:
case ScalarType::UInt16:
case ScalarType::UInt32:
case ScalarType::UInt64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case ScalarType::Char:
Expand Down
4 changes: 4 additions & 0 deletions c10/core/ScalarType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) {
toString(b));
}

if (isBarebonesUnsignedType(a) || isBarebonesUnsignedType(b)) {
return ScalarType::Undefined;
}

auto ix_a = dtype2index[static_cast<int64_t>(a)];
TORCH_INTERNAL_ASSERT(ix_a != -1);
auto ix_b = dtype2index[static_cast<int64_t>(b)];
Expand Down
50 changes: 43 additions & 7 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,24 @@
namespace c10 {

// For the macros below:
// NB: If you want to macro some code for all non-QInt scalar types (i.e. types
// with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
// macros below, which are designed to behave similarly to the Dispatch macros
// with the same name.
//
// For users: If you want to macro some code for all non-QInt scalar types
// (i.e. types with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND macros below, which are
// designed to behave similarly to the Dispatch macros with the same name.
//
// For adding a new dtype: In the beginning, we had an idea that there was a
// list of all scalar types, and you could use AT_FORALL_SCALAR_TYPES to
// iterate over them. But over the years we added weird types which couldn't
// be handled uniformly everywhere and so in the end we ended up with some
// mish-mosh of some helper macros, but mostly use sites making a call about
// what dtypes they can or can't support. So if you want to add a new dtype,
// the preferred resolution is to find a dtype similar to what you want,
// grep for it and edit all the sites you find this way. If you need to add
// a completely new kind of dtype, you're going to have to laboriously audit
// all of the sites everywhere to figure out how it should work. Consulting
// some old PRs where we added new dtypes (check history of this file) can
// help give you an idea where to start.

// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
Expand Down Expand Up @@ -61,11 +74,18 @@ namespace c10 {
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */

// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
//
// TODO: To add unsigned int types here, we must define accumulate type.
// But uint8 currently accumulates into int64, so we would have to make
// an inconsistent choice for the larger types. Difficult.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
Expand All @@ -82,6 +102,8 @@ namespace c10 {
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)

// This macro controls many of our C++ APIs, including constructors
// for Scalar as well as the data() and item() accessors on Tensor
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
Expand Down Expand Up @@ -157,6 +179,8 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)

#undef SPECIALIZE_CppTypeToScalarType

// NB: despite its generic sounding name, the macros that don't take _AND
// are mostly only used by tensorexpr
#define AT_FORALL_INT_TYPES(_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
Expand All @@ -173,6 +197,11 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
_(float, Float) \
_(double, Double)

// These macros are often controlling how many template instantiations we
// create for kernels. It is typically inappropriate to add new dtypes here,
// instead, new types should be added to use sites on a case-by-case basis.
// We generally are not accepting new dtypes due to binary size concerns.

#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
Expand Down Expand Up @@ -384,7 +413,9 @@ static inline size_t elementSize(ScalarType t) {
static inline bool isIntegralType(ScalarType t, bool includeBool) {
bool isIntegral =
(t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
t == ScalarType::Long || t == ScalarType::Short);
t == ScalarType::Long || t == ScalarType::Short ||
t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
t == ScalarType::UInt64);

return isIntegral || (includeBool && t == ScalarType::Bool);
}
Expand Down Expand Up @@ -428,6 +459,11 @@ static inline bool isBitsType(ScalarType t) {
t == ScalarType::Bits16;
}

static inline bool isBarebonesUnsignedType(ScalarType t) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should UInt8 be in here as well?
I guess not given how you use it for the type promotion checks. But this might be a confusing API for our c++ devs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bare bones here is defined to be "we have only minimal kernel support in traditional C++ eager mode". Since uint8 is grandfathered from Lua days to have lots of kernels, it shouldn't be included.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the plan to document that to make sure user expectations for these are appropriate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most logical place to document that these are bare bones is in torch.Tensor where we list supported dtypes. Note that the f8 dtypes are not documented right now, so we could also keep it under the radar undocumented until enough stuff is working. I am also half expecting to end up having a select few eager kernels sprout arbitrary unsigned support, especially gather-like operations I think is what people are most likely to want.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One could argue that fp8 dtypes are not very well documented and shouldn't be the example here. Also there is a strong alignment there that we will not implement any arithmetic op for these dtypes (no scale means no compute with the value).
These dtypes could contain the full thing so we most likely want to have some details about that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll post a proposed doc PR and we can discuss it there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return t == ScalarType::UInt16 || t == ScalarType::UInt32 ||
t == ScalarType::UInt64;
}

static inline ScalarType toQIntType(ScalarType t) {
switch (t) {
case ScalarType::Byte:
Expand Down
1 change: 0 additions & 1 deletion c10/util/typeid.h
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,6 @@ inline std::ostream& operator<<(
}

CAFFE_DECLARE_KNOWN_TYPE(std::string, std_string)
CAFFE_DECLARE_KNOWN_TYPE(uint16_t, uint16_t)
CAFFE_DECLARE_KNOWN_TYPE(char, char)
CAFFE_DECLARE_KNOWN_TYPE(std::unique_ptr<std::mutex>, std_unique_ptr_std_mutex)
CAFFE_DECLARE_KNOWN_TYPE(
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/utils/tensor_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
// no "byte" because byte is signed in numpy and we overload
// byte to mean bool often
return std::make_pair("uint8", "");
case at::ScalarType::UInt16:
return std::make_pair("uint16", "");
case at::ScalarType::UInt32:
return std::make_pair("uint32", "");
case at::ScalarType::UInt64:
return std::make_pair("uint64", "");
case at::ScalarType::Char:
// no "char" because it is not consistently signed or unsigned; we want
// to move to int8
Expand Down