@@ -387,14 +387,19 @@ inline bool isComplexType(exec_aten::ScalarType t) {
387387 t == exec_aten::ScalarType::ComplexDouble);
388388}
389389
390- inline bool isBitsType (exec_aten::ScalarType t) {
390+ constexpr bool isBitsType (exec_aten::ScalarType t) {
391391 return t == exec_aten::ScalarType::Bits1x8 ||
392392 t == exec_aten::ScalarType::Bits2x4 ||
393393 t == exec_aten::ScalarType::Bits4x2 ||
394394 t == exec_aten::ScalarType::Bits8 || t == exec_aten::ScalarType::Bits16;
395395}
396396
397- inline bool isQIntType (exec_aten::ScalarType t) {
397+ template <typename T>
398+ struct is_bits_type
399+ : std::integral_constant<bool , isBitsType(CppTypeToScalarType<T>::value)> {
400+ };
401+
402+ constexpr bool isQIntType (exec_aten::ScalarType t) {
398403 // Don't forget to extend this when adding new QInt types
399404 return t == exec_aten::ScalarType::QInt8 ||
400405 t == exec_aten::ScalarType::QUInt8 ||
@@ -403,6 +408,11 @@ inline bool isQIntType(exec_aten::ScalarType t) {
403408 t == exec_aten::ScalarType::QUInt2x4;
404409}
405410
411+ template <typename T>
412+ struct is_qint_type
413+ : std::integral_constant<bool , isQIntType(CppTypeToScalarType<T>::value)> {
414+ };
415+
406416inline exec_aten::ScalarType toQIntType (exec_aten::ScalarType t) {
407417 switch (t) {
408418 case exec_aten::ScalarType::Byte:
@@ -550,6 +560,225 @@ To convert(From val) {
550560 return static_cast <To>(val);
551561}
552562
563+ namespace internal {
564+ template <typename T1, typename T2>
565+ struct promote_types_lookup ;
566+
567+ template <typename T1>
568+ struct promote_types_lookup <T1, T1> {
569+ using type = T1;
570+ };
571+
572+ using U1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Byte>::type;
573+ using I1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Char>::type;
574+ using I2 = typename ScalarTypeToCppType<exec_aten::ScalarType::Short>::type;
575+ using I4 = typename ScalarTypeToCppType<exec_aten::ScalarType::Int>::type;
576+ using I8 = typename ScalarTypeToCppType<exec_aten::ScalarType::Long>::type;
577+ using F2 = typename ScalarTypeToCppType<exec_aten::ScalarType::Half>::type;
578+ using F4 = typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type;
579+ using F8 = typename ScalarTypeToCppType<exec_aten::ScalarType::Double>::type;
580+ using C2 =
581+ typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexHalf>::type;
582+ using C4 =
583+ typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexFloat>::type;
584+ using C8 =
585+ typename ScalarTypeToCppType<exec_aten::ScalarType::ComplexDouble>::type;
586+ using B1 = typename ScalarTypeToCppType<exec_aten::ScalarType::Bool>::type;
587+
588+ #define TABLE_ENTRY (key1, key2, value ) \
589+ template <> \
590+ struct promote_types_lookup <key1, key2> { \
591+ using type = value; \
592+ }
593+
594+ /* promote_types_lookup is a compile-time-accessible version of the
595+ * table in promoteTypes below; we cannot make promoteTypes constexpr
596+ * and use it directly because we are on C++11 and thus don't have
597+ * C++17 relaxed constexpr. The below series of entries is generated
598+ * by genScalarTypeTable.py. */
599+ TABLE_ENTRY (U1, U1, U1);
600+ TABLE_ENTRY (U1, I1, I2);
601+ TABLE_ENTRY (U1, I2, I2);
602+ TABLE_ENTRY (U1, I4, I4);
603+ TABLE_ENTRY (U1, I8, I8);
604+ TABLE_ENTRY (U1, F2, F2);
605+ TABLE_ENTRY (U1, F4, F4);
606+ TABLE_ENTRY (U1, F8, F8);
607+ TABLE_ENTRY (U1, C2, C2);
608+ TABLE_ENTRY (U1, C4, C4);
609+ TABLE_ENTRY (U1, C8, C8);
610+ TABLE_ENTRY (U1, B1, U1);
611+ TABLE_ENTRY (I1, U1, I2);
612+ TABLE_ENTRY (I1, I1, I1);
613+ TABLE_ENTRY (I1, I2, I2);
614+ TABLE_ENTRY (I1, I4, I4);
615+ TABLE_ENTRY (I1, I8, I8);
616+ TABLE_ENTRY (I1, F2, F2);
617+ TABLE_ENTRY (I1, F4, F4);
618+ TABLE_ENTRY (I1, F8, F8);
619+ TABLE_ENTRY (I1, C2, C2);
620+ TABLE_ENTRY (I1, C4, C4);
621+ TABLE_ENTRY (I1, C8, C8);
622+ TABLE_ENTRY (I1, B1, I1);
623+ TABLE_ENTRY (I2, U1, I2);
624+ TABLE_ENTRY (I2, I1, I2);
625+ TABLE_ENTRY (I2, I2, I2);
626+ TABLE_ENTRY (I2, I4, I4);
627+ TABLE_ENTRY (I2, I8, I8);
628+ TABLE_ENTRY (I2, F2, F2);
629+ TABLE_ENTRY (I2, F4, F4);
630+ TABLE_ENTRY (I2, F8, F8);
631+ TABLE_ENTRY (I2, C2, C2);
632+ TABLE_ENTRY (I2, C4, C4);
633+ TABLE_ENTRY (I2, C8, C8);
634+ TABLE_ENTRY (I2, B1, I2);
635+ TABLE_ENTRY (I4, U1, I4);
636+ TABLE_ENTRY (I4, I1, I4);
637+ TABLE_ENTRY (I4, I2, I4);
638+ TABLE_ENTRY (I4, I4, I4);
639+ TABLE_ENTRY (I4, I8, I8);
640+ TABLE_ENTRY (I4, F2, F2);
641+ TABLE_ENTRY (I4, F4, F4);
642+ TABLE_ENTRY (I4, F8, F8);
643+ TABLE_ENTRY (I4, C2, C2);
644+ TABLE_ENTRY (I4, C4, C4);
645+ TABLE_ENTRY (I4, C8, C8);
646+ TABLE_ENTRY (I4, B1, I4);
647+ TABLE_ENTRY (I8, U1, I8);
648+ TABLE_ENTRY (I8, I1, I8);
649+ TABLE_ENTRY (I8, I2, I8);
650+ TABLE_ENTRY (I8, I4, I8);
651+ TABLE_ENTRY (I8, I8, I8);
652+ TABLE_ENTRY (I8, F2, F2);
653+ TABLE_ENTRY (I8, F4, F4);
654+ TABLE_ENTRY (I8, F8, F8);
655+ TABLE_ENTRY (I8, C2, C2);
656+ TABLE_ENTRY (I8, C4, C4);
657+ TABLE_ENTRY (I8, C8, C8);
658+ TABLE_ENTRY (I8, B1, I8);
659+ TABLE_ENTRY (F2, U1, F2);
660+ TABLE_ENTRY (F2, I1, F2);
661+ TABLE_ENTRY (F2, I2, F2);
662+ TABLE_ENTRY (F2, I4, F2);
663+ TABLE_ENTRY (F2, I8, F2);
664+ TABLE_ENTRY (F2, F2, F2);
665+ TABLE_ENTRY (F2, F4, F4);
666+ TABLE_ENTRY (F2, F8, F8);
667+ TABLE_ENTRY (F2, C2, C2);
668+ TABLE_ENTRY (F2, C4, C4);
669+ TABLE_ENTRY (F2, C8, C8);
670+ TABLE_ENTRY (F2, B1, F2);
671+ TABLE_ENTRY (F4, U1, F4);
672+ TABLE_ENTRY (F4, I1, F4);
673+ TABLE_ENTRY (F4, I2, F4);
674+ TABLE_ENTRY (F4, I4, F4);
675+ TABLE_ENTRY (F4, I8, F4);
676+ TABLE_ENTRY (F4, F2, F4);
677+ TABLE_ENTRY (F4, F4, F4);
678+ TABLE_ENTRY (F4, F8, F8);
679+ TABLE_ENTRY (F4, C2, C4);
680+ TABLE_ENTRY (F4, C4, C4);
681+ TABLE_ENTRY (F4, C8, C8);
682+ TABLE_ENTRY (F4, B1, F4);
683+ TABLE_ENTRY (F8, U1, F8);
684+ TABLE_ENTRY (F8, I1, F8);
685+ TABLE_ENTRY (F8, I2, F8);
686+ TABLE_ENTRY (F8, I4, F8);
687+ TABLE_ENTRY (F8, I8, F8);
688+ TABLE_ENTRY (F8, F2, F8);
689+ TABLE_ENTRY (F8, F4, F8);
690+ TABLE_ENTRY (F8, F8, F8);
691+ TABLE_ENTRY (F8, C2, C8);
692+ TABLE_ENTRY (F8, C4, C8);
693+ TABLE_ENTRY (F8, C8, C8);
694+ TABLE_ENTRY (F8, B1, F8);
695+ TABLE_ENTRY (C2, U1, C2);
696+ TABLE_ENTRY (C2, I1, C2);
697+ TABLE_ENTRY (C2, I2, C2);
698+ TABLE_ENTRY (C2, I4, C2);
699+ TABLE_ENTRY (C2, I8, C2);
700+ TABLE_ENTRY (C2, F2, C2);
701+ TABLE_ENTRY (C2, F4, C4);
702+ TABLE_ENTRY (C2, F8, C8);
703+ TABLE_ENTRY (C2, C2, C2);
704+ TABLE_ENTRY (C2, C4, C4);
705+ TABLE_ENTRY (C2, C8, C8);
706+ TABLE_ENTRY (C2, B1, C2);
707+ TABLE_ENTRY (C4, U1, C4);
708+ TABLE_ENTRY (C4, I1, C4);
709+ TABLE_ENTRY (C4, I2, C4);
710+ TABLE_ENTRY (C4, I4, C4);
711+ TABLE_ENTRY (C4, I8, C4);
712+ TABLE_ENTRY (C4, F2, C4);
713+ TABLE_ENTRY (C4, F4, C4);
714+ TABLE_ENTRY (C4, F8, C8);
715+ TABLE_ENTRY (C4, C2, C4);
716+ TABLE_ENTRY (C4, C4, C4);
717+ TABLE_ENTRY (C4, C8, C8);
718+ TABLE_ENTRY (C4, B1, C4);
719+ TABLE_ENTRY (C8, U1, C8);
720+ TABLE_ENTRY (C8, I1, C8);
721+ TABLE_ENTRY (C8, I2, C8);
722+ TABLE_ENTRY (C8, I4, C8);
723+ TABLE_ENTRY (C8, I8, C8);
724+ TABLE_ENTRY (C8, F2, C8);
725+ TABLE_ENTRY (C8, F4, C8);
726+ TABLE_ENTRY (C8, F8, C8);
727+ TABLE_ENTRY (C8, C2, C8);
728+ TABLE_ENTRY (C8, C4, C8);
729+ TABLE_ENTRY (C8, C8, C8);
730+ TABLE_ENTRY (C8, B1, C8);
731+ TABLE_ENTRY (B1, U1, U1);
732+ TABLE_ENTRY (B1, I1, I1);
733+ TABLE_ENTRY (B1, I2, I2);
734+ TABLE_ENTRY (B1, I4, I4);
735+ TABLE_ENTRY (B1, I8, I8);
736+ TABLE_ENTRY (B1, F2, F2);
737+ TABLE_ENTRY (B1, F4, F4);
738+ TABLE_ENTRY (B1, F8, F8);
739+ TABLE_ENTRY (B1, C2, C2);
740+ TABLE_ENTRY (B1, C4, C4);
741+ TABLE_ENTRY (B1, C8, C8);
742+ TABLE_ENTRY (B1, B1, B1);
743+
744+ } // namespace internal
745+
746+ template <typename T1, typename T2, bool half_to_float = false >
747+ struct promote_types {
748+ private:
749+ static_assert (
750+ std::is_same<T1, T2>::value ||
751+ (!is_qint_type<T1>::value && !is_qint_type<T2>::value),
752+ " promote_types not valid for quantized dtypes" );
753+ static_assert (
754+ std::is_same<T1, T2>::value ||
755+ (!is_bits_type<T1>::value && !is_bits_type<T2>::value),
756+ " promote_types not valid for bits dtypes" );
757+
758+ static_assert (
759+ !std::is_same<
760+ T1,
761+ typename ScalarTypeToCppType<exec_aten::ScalarType::BFloat16>::type>::
762+ value &&
763+ !std::is_same<
764+ T2,
765+ typename ScalarTypeToCppType<
766+ exec_aten::ScalarType::BFloat16>::type>::value,
767+ " promote_types not valid for BFloat16" );
768+ using promoted_type_not_respecting_half_to_float =
769+ typename internal::promote_types_lookup<T1, T2>::type;
770+
771+ public:
772+ using type = typename std::conditional<
773+ half_to_float &&
774+ std::is_same<
775+ promoted_type_not_respecting_half_to_float,
776+ typename ScalarTypeToCppType<exec_aten::ScalarType::Half>::type>::
777+ value,
778+ typename ScalarTypeToCppType<exec_aten::ScalarType::Float>::type,
779+ promoted_type_not_respecting_half_to_float>::type;
780+ };
781+
553782/* *
554783 * Implements type promotion rules that are consistent with ATen behaviour,
555784 * which in turn is consistent with NumPy's promote_types.
@@ -589,6 +818,10 @@ inline exec_aten::ScalarType promoteTypes(
589818 ET_CHECK_MSG (false , " promoteTypes not valid for bits dtypes" );
590819 }
591820
821+ ET_CHECK_MSG (
822+ a != exec_aten::ScalarType::BFloat16 &&
823+ b != exec_aten::ScalarType::BFloat16,
824+ " promoteTypes not valid for BFloat16" );
592825 // 12 types are handled by this function, see the constexpr definitions above
593826 const int NUM_PROMOTE_TYPES = 12 ;
594827
0 commit comments