Skip to content

Commit

Permalink
Add finfo properties for float8 dtypes (#109744)
Browse files Browse the repository at this point in the history
Add float8 finfo checks to `test_type_info.py`
Fixes #109737
Pull Request resolved: #109744
Approved by: https://github.com/drisspg

(cherry picked from commit cddd0db)
  • Loading branch information
malfet committed Sep 21, 2023
1 parent e534243 commit 385c12d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 11 deletions.
16 changes: 16 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,22 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))

#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)

#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, \
NAME, \
AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))

#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
Expand Down
20 changes: 20 additions & 0 deletions test/test_type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,25 @@ def test_finfo(self):
# Restore the default type to ensure that the test has no side effect
torch.set_default_dtype(initial_default_type)

# Special test case for Float8_E5M2
xinfo = torch.finfo(torch.float8_e5m2)
self.assertEqual(xinfo.bits, 8)
self.assertEqual(xinfo.max, 57344.0)
self.assertEqual(xinfo.min, -57344.0)
self.assertEqual(xinfo.eps, .25)
self.assertEqual(xinfo.tiny, 6.10352e-05)
self.assertEqual(xinfo.resolution, 1.0)
self.assertEqual(xinfo.dtype, "float8_e5m2")

# Special test case for Float8_E4M3FN
xinfo = torch.finfo(torch.float8_e4m3fn)
self.assertEqual(xinfo.bits, 8)
self.assertEqual(xinfo.max, 448.0)
self.assertEqual(xinfo.min, -448.0)
self.assertEqual(xinfo.eps, .125)
self.assertEqual(xinfo.tiny, 0.015625)
self.assertEqual(xinfo.resolution, 1.0)
self.assertEqual(xinfo.dtype, "float8_e4m3fn")

if __name__ == '__main__':
run_tests()
54 changes: 43 additions & 11 deletions torch/csrc/TypeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,43 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
}

static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "epsilon", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"epsilon",
[] {
return PyFloat_FromDouble(
std::numeric_limits<
at::scalar_value_type<scalar_t>::type>::epsilon());
});
}

static PyObject* THPFInfo_max(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"max",
[] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
});
}

static PyObject* THPFInfo_min(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"lowest",
[] {
return PyFloat_FromDouble(
std::numeric_limits<
at::scalar_value_type<scalar_t>::type>::lowest());
Expand Down Expand Up @@ -170,8 +188,14 @@ static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
}

static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"smallest",
[] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
});
Expand All @@ -183,8 +207,14 @@ static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
}

static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"digits10",
[] {
return PyFloat_FromDouble(std::pow(
10,
-std::numeric_limits<
Expand All @@ -194,9 +224,11 @@ static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {

static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
auto primary_name = torch::utils::getDtypeNames(self->type).first;
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
at::ScalarType::Float8_e4m3fn,
at::ScalarType::Float8_e5m2,
self->type,
"dtype",
[&primary_name] { return PyUnicode_FromString(primary_name.data()); });
Expand Down

0 comments on commit 385c12d

Please sign in to comment.