Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reland fast TypeMeta/ScalarType conversion #45544

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/native/DispatchStub.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#include <c10/core/Backend.h>
#include <c10/core/ScalarType.h>
#include <c10/util/Exception.h>

#include <type_traits>
#include <atomic>

// Implements instruction set specific function dispatch.
//
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/templates/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <c10/core/Stream.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeToTypeMeta.h>
#include <c10/core/Storage.h>
#include <ATen/core/TensorAccessor.h>
#include <c10/core/TensorImpl.h>
Expand Down
1 change: 1 addition & 0 deletions aten/src/TH/THStorageFunctions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <TH/THStorageFunctions.h>

#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeToTypeMeta.h>

// Note [Weak references for intrusive refcounting]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions c10/core/DefaultDtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

namespace c10 {
static auto default_dtype = caffe2::TypeMeta::Make<float>();
static auto default_dtype_as_scalartype = typeMetaToScalarType(default_dtype);
static auto default_dtype_as_scalartype = default_dtype.toScalarType();
static auto default_complex_dtype = caffe2::TypeMeta::Make<c10::complex<float>>();

void set_default_dtype(caffe2::TypeMeta dtype) {
default_dtype = std::move(dtype);
default_dtype_as_scalartype = typeMetaToScalarType(default_dtype);
default_dtype_as_scalartype = default_dtype.toScalarType();
if(default_dtype_as_scalartype == ScalarType::Double) {
default_complex_dtype = std::move(caffe2::TypeMeta::Make<c10::complex<double>>());
} else {
Expand Down
67 changes: 7 additions & 60 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
#include <c10/util/ArrayRef.h>
#include <c10/util/complex.h>
#include <c10/util/Half.h>
#include <c10/util/qint32.h>
#include <c10/util/qint8.h>
#include <c10/util/quint8.h>
#include <c10/util/BFloat16.h>
#include <c10/util/quint4x2.h>
#include <c10/util/Optional.h>
#include <c10/util/typeid.h>

#include <complex>
#include <cstdint>
Expand Down Expand Up @@ -68,6 +71,8 @@ enum class ScalarType : int8_t {
NumOptions
};

constexpr uint16_t NumScalarTypes = static_cast<uint16_t>(ScalarType::NumOptions);
jeffdaily marked this conversation as resolved.
Show resolved Hide resolved

namespace impl {

// These are used to map ScalarTypes to C++ types.
Expand All @@ -94,7 +99,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType)

#undef SPECIALIZE_ScalarTypeToCPPType

}
} // namespace impl

template <typename T>
struct CppTypeToScalarType;
Expand Down Expand Up @@ -162,64 +167,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble)

static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
#define DEFINE_CASE(ctype, name) \
case ScalarType::name: \
return caffe2::TypeMeta::Make<ctype>();

switch (scalar_type) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
case ScalarType::Undefined:
return caffe2::TypeMeta();
default:
AT_ERROR(
"Unrecognized Scalartype ",
scalar_type,
" (please report this error)");
}
#undef DEFINE_CASE
}

static inline c10::optional<ScalarType> tryTypeMetaToScalarType(
caffe2::TypeMeta dtype) {
#define DEFINE_IF(ctype, name) \
if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
return {ScalarType::name}; \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_IF)
#undef DEFINE_IF
if (dtype == caffe2::TypeMeta()) {
return {ScalarType::Undefined};
}
return c10::nullopt;
}

static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
if (auto scalar_type = tryTypeMetaToScalarType(dtype)) {
return *scalar_type;
}
AT_ERROR(
"Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
}

inline optional<at::ScalarType> optTypeMetaToScalarType(optional<caffe2::TypeMeta> type_meta) {
if (!type_meta.has_value()) {
return c10::nullopt;
}
return typeMetaToScalarType(*type_meta);
}

static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
if (auto mt = tryTypeMetaToScalarType(m)) {
return (*mt) == t;
}
return false;
}

static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
return t == m;
}

#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;

Expand Down
47 changes: 47 additions & 0 deletions c10/core/ScalarTypeToTypeMeta.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#pragma once

#include <c10/core/ScalarType.h>
#include <c10/util/typeid.h>

// these just expose TypeMeta/ScalarType bridge functions in c10
// TODO move to typeid.h (or codemod away) when TypeMeta et al
// are moved from caffe2 to c10 (see note at top of typeid.h)

namespace c10 {

/**
* convert ScalarType enum values to TypeMeta handles
*/
static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
return caffe2::TypeMeta::fromScalarType(scalar_type);
}

/**
* convert TypeMeta handles to ScalarType enum values
*/
static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
return dtype.toScalarType();
}

/**
* typeMetaToScalarType(), lifted to optional
*/
static inline optional<at::ScalarType> optTypeMetaToScalarType(optional<caffe2::TypeMeta> type_meta) {
if (!type_meta.has_value()) {
return c10::nullopt;
}
return type_meta->toScalarType();
}

/**
* convenience: equality across TypeMeta/ScalarType conversion
*/
static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
return m.isScalarType(t);
}

static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
return t == m;
}

} // namespace c10
3 changes: 1 addition & 2 deletions c10/core/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::
data_type_(data_type),
device_opt_(device_opt) {
if (!key_set.empty()) {
AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() ||
device_opt_.has_value());
TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value());
// UndefinedTensorImpl is a singleton, so we skip logging it
C10_LOG_API_USAGE_ONCE("tensor.create");
}
Expand Down
4 changes: 2 additions & 2 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1781,13 +1781,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
// strides SmallVector (pre-allocated 4)
// storage offset
// numel
// data type pointer
// data type
// (optional) device
// tensor type id
// miscellaneous bitfield
//
static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit...
sizeof(TensorImpl) == sizeof(int64_t) * 31,
sizeof(TensorImpl) == sizeof(int64_t) * 30,
"You changed the size of TensorImpl on 64-bit arch."
"See Note [TensorImpl size constraints] on how to proceed.");
} // namespace c10
1 change: 1 addition & 0 deletions c10/core/TensorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <c10/core/Backend.h>
#include <c10/core/Layout.h>
#include <c10/core/ScalarType.h>
#include <c10/core/ScalarTypeToTypeMeta.h>
#include <c10/core/Device.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/DispatchKeySet.h>
Expand Down
2 changes: 0 additions & 2 deletions c10/core/UndefinedTensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl {
private:
UndefinedTensorImpl();
static UndefinedTensorImpl _singleton;
public:
friend struct UndefinedType;
};

} // namespace c10
63 changes: 29 additions & 34 deletions c10/util/typeid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,41 @@ namespace detail {
C10_EXPORT void _ThrowRuntimeTypeLogicError(const string& msg) {
// In earlier versions it used to be std::abort() but it's a bit hard-core
// for a library
AT_ERROR(msg);
TORCH_CHECK(false, msg);
}
} // namespace detail

[[noreturn]] void TypeMeta::error_unsupported_typemeta(caffe2::TypeMeta dtype) {
TORCH_CHECK(false, "Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
}

} // namespace detail
// see TypeMeta::addTypeMetaData
std::atomic<uint16_t> TypeMeta::nextTypeIndex(NumScalarTypes);

template <>
EXPORT_IF_NOT_GCC const detail::TypeMetaData* TypeMeta::_typeMetaDataInstance<
detail::_Uninitialized>() noexcept {
static constexpr detail::TypeMetaData singleton = detail::TypeMetaData(
0,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
TypeIdentifier::uninitialized(),
"nullptr (uninitialized)");
return &singleton;
// fixed length array of TypeMetaData instances
detail::TypeMetaData* TypeMeta::typeMetaDatas() {
static detail::TypeMetaData instances[MaxTypeIndex + 1] = {
#define SCALAR_TYPE_META(T, name) \
/* ScalarType::name */ \
detail::TypeMetaData( \
sizeof(T), \
detail::_PickNew<T>(), \
detail::_PickPlacementNew<T>(), \
detail::_PickCopy<T>(), \
detail::_PickPlacementDelete<T>(), \
detail::_PickDelete<T>(), \
TypeIdentifier::Get<T>(), \
c10::util::get_fully_qualified_type_name<T>()),
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SCALAR_TYPE_META)
#undef SCALAR_TYPE_META
// The remainder of the array is padded with TypeMetaData blanks.
// The first of these is the entry for ScalarType::Undefined.
// The rest are consumed by CAFFE_KNOWN_TYPE entries.
};
return instances;
}

CAFFE_KNOWN_TYPE(uint8_t)
CAFFE_KNOWN_TYPE(int8_t)
CAFFE_KNOWN_TYPE(int16_t)
CAFFE_KNOWN_TYPE(int)
CAFFE_KNOWN_TYPE(int64_t)
CAFFE_KNOWN_TYPE(at::Half)
CAFFE_KNOWN_TYPE(float)
CAFFE_KNOWN_TYPE(double)
CAFFE_KNOWN_TYPE(c10::complex<c10::Half>)
CAFFE_KNOWN_TYPE(c10::complex<float>)
CAFFE_KNOWN_TYPE(c10::complex<double>)
// 11 = undefined type id
// 12 = Tensor (defined in tensor.cc)
CAFFE_KNOWN_TYPE(std::string)
CAFFE_KNOWN_TYPE(bool)
CAFFE_KNOWN_TYPE(uint16_t)
CAFFE_KNOWN_TYPE(char)
CAFFE_KNOWN_TYPE(std::unique_ptr<std::mutex>)
Expand Down Expand Up @@ -79,15 +78,11 @@ using _guard_long_unique = std::conditional_t<
_guard_long_unique_dummy<T>,
T>;
} // namespace detail

CAFFE_KNOWN_TYPE(detail::_guard_long_unique<long>);
CAFFE_KNOWN_TYPE(detail::_guard_long_unique<std::vector<long>>)

CAFFE_KNOWN_TYPE(float*)
CAFFE_KNOWN_TYPE(at::Half*)
CAFFE_KNOWN_TYPE(c10::qint8)
CAFFE_KNOWN_TYPE(c10::quint8)
CAFFE_KNOWN_TYPE(c10::qint32)
CAFFE_KNOWN_TYPE(at::BFloat16)
CAFFE_KNOWN_TYPE(c10::quint4x2)

} // namespace caffe2