Skip to content

Commit

Permalink
Handle C++ exceptions raised during finfo/iinfo calls
Browse files Browse the repository at this point in the history
ghstack-source-id: 6be25539981a8e8292b6d54d760309cac5c9f6dc
Pull Request resolved: #109743
  • Loading branch information
malfet committed Sep 20, 2023
1 parent 1ab9b61 commit 9947a47
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions torch/csrc/TypeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
}

static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
HANDLE_TH_ERRORS
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
Expand All @@ -125,9 +126,11 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
std::numeric_limits<
at::scalar_value_type<scalar_t>::type>::epsilon());
});
END_HANDLE_TH_ERRORS
}

static PyObject* THPFInfo_max(THPFInfo* self, void*) {
HANDLE_TH_ERRORS
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
Expand All @@ -139,9 +142,11 @@ static PyObject* THPFInfo_max(THPFInfo* self, void*) {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
});
END_HANDLE_TH_ERRORS
}

static PyObject* THPFInfo_min(THPFInfo* self, void*) {
HANDLE_TH_ERRORS
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
Expand All @@ -154,9 +159,11 @@ static PyObject* THPFInfo_min(THPFInfo* self, void*) {
std::numeric_limits<
at::scalar_value_type<scalar_t>::type>::lowest());
});
END_HANDLE_TH_ERRORS
}

static PyObject* THPIInfo_max(THPIInfo* self, void*) {
HANDLE_TH_ERRORS
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max", [] {
return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
Expand All @@ -166,9 +173,11 @@ static PyObject* THPIInfo_max(THPIInfo* self, void*) {
return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "max", [] {
return THPUtils_packInt64(std::numeric_limits<underlying_t>::max());
});
END_HANDLE_TH_ERRORS
}

static PyObject* THPIInfo_min(THPIInfo* self, void*) {
HANDLE_TH_ERRORS
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min", [] {
return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
Expand All @@ -178,16 +187,20 @@ static PyObject* THPIInfo_min(THPIInfo* self, void*) {
return AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self->type, "min", [] {
return THPUtils_packInt64(std::numeric_limits<underlying_t>::lowest());
});
END_HANDLE_TH_ERRORS
}

static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
HANDLE_TH_ERRORS
auto primary_name = torch::utils::getDtypeNames(self->type).first;
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "dtype", [&primary_name] {
return PyUnicode_FromString(primary_name.data());
});
END_HANDLE_TH_ERRORS
}

static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
HANDLE_TH_ERRORS
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
Expand All @@ -199,6 +212,7 @@ static PyObject* THPFInfo_smallest_normal(THPFInfo* self, void*) {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
});
END_HANDLE_TH_ERRORS
}

static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
Expand All @@ -207,6 +221,7 @@ static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
}

static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
HANDLE_TH_ERRORS
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
at::ScalarType::BFloat16,
Expand All @@ -220,9 +235,11 @@ static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
-std::numeric_limits<
at::scalar_value_type<scalar_t>::type>::digits10));
});
END_HANDLE_TH_ERRORS
}

static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
HANDLE_TH_ERRORS
auto primary_name = torch::utils::getDtypeNames(self->type).first;
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4(
at::kHalf,
Expand All @@ -232,10 +249,12 @@ static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
self->type,
"dtype",
[&primary_name] { return PyUnicode_FromString(primary_name.data()); });
END_HANDLE_TH_ERRORS
}

PyObject* THPFInfo_str(THPFInfo* self) {
std::ostringstream oss;
const auto dtypeStr = THPFInfo_dtype(self, nullptr);
oss << "finfo(resolution="
<< PyFloat_AsDouble(THPFInfo_resolution(self, nullptr));
oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr));
Expand All @@ -244,19 +263,23 @@ PyObject* THPFInfo_str(THPFInfo* self) {
oss << ", smallest_normal="
<< PyFloat_AsDouble(THPFInfo_smallest_normal(self, nullptr));
oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr));
oss << ", dtype=" << PyUnicode_AsUTF8(THPFInfo_dtype(self, nullptr)) << ")";

return THPUtils_packString(oss.str().c_str());
if (dtypeStr != nullptr) {
oss << ", dtype=" << PyUnicode_AsUTF8(dtypeStr) << ")";
}
return !PyErr_Occurred() ? THPUtils_packString(oss.str().c_str()) : nullptr;
}

PyObject* THPIInfo_str(THPIInfo* self) {
std::ostringstream oss;

const auto dtypeStr = THPIInfo_dtype(self, nullptr);
oss << "iinfo(min=" << PyLong_AsDouble(THPIInfo_min(self, nullptr));
oss << ", max=" << PyLong_AsDouble(THPIInfo_max(self, nullptr));
oss << ", dtype=" << PyUnicode_AsUTF8(THPIInfo_dtype(self, nullptr)) << ")";
if (dtypeStr) {
oss << ", dtype=" << PyUnicode_AsUTF8(dtypeStr) << ")";
}

return THPUtils_packString(oss.str().c_str());
return !PyErr_Occurred() ? THPUtils_packString(oss.str().c_str()) : nullptr;
}

// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays)
Expand Down

0 comments on commit 9947a47

Please sign in to comment.