diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift index 5873787b423..a3873794f9d 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift @@ -1095,3 +1095,10 @@ public extension Tensor { )) } } + +@available(*, deprecated, message: "This API is experimental.") +extension Tensor: CustomStringConvertible { + public var description: String { + self.anyTensor.description + } +} diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index d38fb277bff..3cf06207b45 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -17,6 +17,86 @@ using namespace executorch::extension; using namespace executorch::runtime; +static inline NSString *dataTypeDescription(ExecuTorchDataType dataType) { + switch (dataType) { + case ExecuTorchDataTypeByte: + return @"byte"; + case ExecuTorchDataTypeChar: + return @"char"; + case ExecuTorchDataTypeShort: + return @"short"; + case ExecuTorchDataTypeInt: + return @"int"; + case ExecuTorchDataTypeLong: + return @"long"; + case ExecuTorchDataTypeHalf: + return @"half"; + case ExecuTorchDataTypeFloat: + return @"float"; + case ExecuTorchDataTypeDouble: + return @"double"; + case ExecuTorchDataTypeComplexHalf: + return @"complexHalf"; + case ExecuTorchDataTypeComplexFloat: + return @"complexFloat"; + case ExecuTorchDataTypeComplexDouble: + return @"complexDouble"; + case ExecuTorchDataTypeBool: + return @"bool"; + case ExecuTorchDataTypeQInt8: + return @"qint8"; + case ExecuTorchDataTypeQUInt8: + return @"quint8"; + case ExecuTorchDataTypeQInt32: + return @"qint32"; + case ExecuTorchDataTypeBFloat16: + return @"bfloat16"; + case ExecuTorchDataTypeQUInt4x2: + return @"quint4x2"; + case ExecuTorchDataTypeQUInt2x4: + return @"quint2x4"; + case ExecuTorchDataTypeBits1x8: + return @"bits1x8"; + case ExecuTorchDataTypeBits2x4: + return @"bits2x4"; + case ExecuTorchDataTypeBits4x2: + return @"bits4x2"; + case ExecuTorchDataTypeBits8: + return @"bits8"; + case ExecuTorchDataTypeBits16: + return @"bits16"; + case ExecuTorchDataTypeFloat8_e5m2: + return @"float8_e5m2"; + case ExecuTorchDataTypeFloat8_e4m3fn: + return @"float8_e4m3fn"; + case ExecuTorchDataTypeFloat8_e5m2fnuz: + return @"float8_e5m2fnuz"; + case ExecuTorchDataTypeFloat8_e4m3fnuz: + return @"float8_e4m3fnuz"; + case ExecuTorchDataTypeUInt16: + return @"uint16"; + case ExecuTorchDataTypeUInt32: + return @"uint32"; + case ExecuTorchDataTypeUInt64: + return @"uint64"; + default: + return @"undefined"; + } +} + +static inline NSString *shapeDynamismDescription(ExecuTorchShapeDynamism dynamism) { + switch (dynamism) { + case ExecuTorchShapeDynamismStatic: + return @"static"; + case ExecuTorchShapeDynamismDynamicBound: + return @"dynamicBound"; + case ExecuTorchShapeDynamismDynamicUnbound: + return @"dynamicUnbound"; + default: + return @"undefined"; + } +} + NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType) { return elementSize(static_cast(dataType)); } @@ -150,6 +230,70 @@ - (BOOL)isEqual:(nullable id)other { return [self isEqualToTensor:(ExecuTorchTensor *)other]; } +- (NSString *)description { + std::ostringstream os; + os << "Tensor {"; + os << "\n dataType: " << dataTypeDescription(static_cast(_tensor->scalar_type())).UTF8String << ","; + os << "\n shape: ["; + const auto& sizes = _tensor->sizes(); + for (size_t index = 0; index < sizes.size(); ++index) { + if (index > 0) { + os << ","; + } + os << sizes[index]; + } + os << "],"; + os << "\n strides: ["; + const auto& strides = _tensor->strides(); + for (size_t index = 0; index < strides.size(); ++index) { + if (index > 0) { + os << ","; + } + os << strides[index]; + } + os << "],"; + os << "\n dimensionOrder: ["; + const auto& dim_order = _tensor->dim_order(); + for (size_t index = 0; index < dim_order.size(); ++index) { + if (index > 0) { + os << ","; + } + os << static_cast(dim_order[index]); + } + os << "],"; + os << "\n shapeDynamism: " << shapeDynamismDescription(static_cast(_tensor->shape_dynamism())).UTF8String << ","; + auto const count = _tensor->numel(); + os << "\n count: " << count << ","; + os << "\n scalars: ["; + ET_SWITCH_REALHBBF16_TYPES( + static_cast(_tensor->scalar_type()), + nullptr, + "description", + CTYPE, + [&] { + auto const *pointer = reinterpret_cast(_tensor->unsafeGetTensorImpl()->data()); + auto const countToPrint = std::min(count, (ssize_t)100); + for (size_t index = 0; index < countToPrint; ++index) { + if (index > 0) { + os << ","; + } + if constexpr (std::is_same_v || + std::is_same_v) { + os << static_cast(pointer[index]); + } else { + os << pointer[index]; + } + } + if (count > countToPrint) { + os << ",..."; + } + } + ); + os << "]"; + os << "\n}"; + return @(os.str().c_str()); +} + @end @implementation ExecuTorchTensor (BytesNoCopy)