Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions ydb/core/kqp/common/result_set_format/kqp_formats_arrow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <yql/essentials/minikql/mkql_type_helper.h>
#include <yql/essentials/minikql/mkql_type_ops.h>
#include <yql/essentials/parser/pg_wrapper/interface/codec.h>
#include <yql/essentials/public/udf/arrow/block_type_helper.h>
#include <yql/essentials/types/binary_json/read.h>
#include <yql/essentials/types/binary_json/write.h>
Expand Down Expand Up @@ -567,6 +568,21 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons
}
}

void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, const NMiniKQL::TPgType* pgType) {
YQL_ENSURE(builder->type()->id() == arrow::Type::STRING, "Unexpected builder type");
auto stringBuilder = reinterpret_cast<arrow::StringBuilder*>(builder);

if (!value) {
auto status = stringBuilder->AppendNull();
YQL_ENSURE(status.ok(), "Failed to append null pg value: " << status.ToString());
return;
}

auto textValue = NYql::NCommon::PgValueToNativeText(value, pgType->GetTypeId());
auto status = stringBuilder->Append(textValue.data(), textValue.size());
YQL_ENSURE(status.ok(), "Failed to append pg value: " << status.ToString());
}

} // namespace

bool NeedWrapByExternalOptional(const NMiniKQL::TType* type) {
Expand All @@ -576,7 +592,8 @@ bool NeedWrapByExternalOptional(const NMiniKQL::TType* type) {
case NMiniKQL::TType::EKind::EmptyList:
case NMiniKQL::TType::EKind::EmptyDict:
case NMiniKQL::TType::EKind::Optional:
case NMiniKQL::TType::EKind::Variant: {
case NMiniKQL::TType::EKind::Variant:
case NMiniKQL::TType::EKind::Pg: {
return true;
}

Expand All @@ -597,7 +614,6 @@ bool NeedWrapByExternalOptional(const NMiniKQL::TType* type) {
case NMiniKQL::TType::EKind::Flow:
case NMiniKQL::TType::EKind::ReservedKind:
case NMiniKQL::TType::EKind::Block:
case NMiniKQL::TType::EKind::Pg:
case NMiniKQL::TType::EKind::Multi:
case NMiniKQL::TType::EKind::Linear: {
YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr());
Expand Down Expand Up @@ -651,6 +667,10 @@ std::shared_ptr<arrow::DataType> GetArrowType(const NMiniKQL::TType* type) {
return GetArrowType(static_cast<const NMiniKQL::TTaggedType*>(type)->GetBaseType());
}

case NMiniKQL::TType::EKind::Pg: {
return arrow::utf8();
}

case NMiniKQL::TType::EKind::Type:
case NMiniKQL::TType::EKind::Stream:
case NMiniKQL::TType::EKind::Callable:
Expand All @@ -659,7 +679,6 @@ std::shared_ptr<arrow::DataType> GetArrowType(const NMiniKQL::TType* type) {
case NMiniKQL::TType::EKind::Flow:
case NMiniKQL::TType::EKind::ReservedKind:
case NMiniKQL::TType::EKind::Block:
case NMiniKQL::TType::EKind::Pg:
case NMiniKQL::TType::EKind::Multi:
case NMiniKQL::TType::EKind::Linear: {
YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr());
Expand All @@ -674,7 +693,8 @@ bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) {
case NMiniKQL::TType::EKind::Void:
case NMiniKQL::TType::EKind::EmptyList:
case NMiniKQL::TType::EKind::EmptyDict:
case NMiniKQL::TType::EKind::Data: {
case NMiniKQL::TType::EKind::Data:
case NMiniKQL::TType::EKind::Pg: {
return true;
}

Expand Down Expand Up @@ -739,7 +759,6 @@ bool IsArrowCompatible(const NKikimr::NMiniKQL::TType* type) {
case NMiniKQL::TType::EKind::Flow:
case NMiniKQL::TType::EKind::ReservedKind:
case NMiniKQL::TType::EKind::Block:
case NMiniKQL::TType::EKind::Pg:
case NMiniKQL::TType::EKind::Multi:
case NMiniKQL::TType::EKind::Linear: {
return false;
Expand Down Expand Up @@ -807,6 +826,11 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons
break;
}

case NMiniKQL::TType::EKind::Pg: {
AppendElement(value, builder, static_cast<const NMiniKQL::TPgType*>(type));
break;
}

case NMiniKQL::TType::EKind::Type:
case NMiniKQL::TType::EKind::Stream:
case NMiniKQL::TType::EKind::Callable:
Expand All @@ -815,7 +839,6 @@ void AppendElement(NUdf::TUnboxedValue value, arrow::ArrayBuilder* builder, cons
case NMiniKQL::TType::EKind::Flow:
case NMiniKQL::TType::EKind::ReservedKind:
case NMiniKQL::TType::EKind::Block:
case NMiniKQL::TType::EKind::Pg:
case NMiniKQL::TType::EKind::Multi:
case NMiniKQL::TType::EKind::Linear: {
YQL_ENSURE(false, "Unsupported type: " << type->GetKindAsStr());
Expand Down
5 changes: 3 additions & 2 deletions ydb/core/kqp/common/result_set_format/kqp_formats_arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ bool SwitchMiniKQLDataTypeToArrowType(NUdf::EDataSlot typeId, TFunc&& callback)
* @param type The MiniKQL type to check
* @return true if the type needs external Optional wrapping, false otherwise
*
* @note Types that need wrapping: Void, Null, Variant, Optional, EmptyList, EmptyDict
* @note Types that need wrapping: Void, Null, Variant, Optional, EmptyList, EmptyDict, Pg
*/
bool NeedWrapByExternalOptional(const NMiniKQL::TType* type);

Expand All @@ -136,6 +136,7 @@ bool NeedWrapByExternalOptional(const NMiniKQL::TType* type);
* - Variant: converted to arrow::DenseUnionType
* - Optional: nested optionals are flattened and represented via struct wrapping
* - Tagged: converted to inner type
* - Pg: converted to arrow::StringType
*
* @param type The MiniKQL type to convert
* @return Shared pointer to corresponding Arrow DataType, or arrow::NullType if unsupported
Expand All @@ -152,7 +153,7 @@ std::shared_ptr<arrow::DataType> GetArrowType(const NMiniKQL::TType* type);
* @param type The MiniKQL type to validate
* @return true if the type can be converted to Arrow format, false otherwise
*
* @note Incompatible types: Type, Stream, Callable, Any, Resource, Flow, Block, Pg, Multi, Linear
* @note Incompatible types: Type, Stream, Callable, Any, Resource, Flow, Block, Multi, Linear
*/
bool IsArrowCompatible(const NMiniKQL::TType* type);

Expand Down
133 changes: 133 additions & 0 deletions ydb/core/kqp/common/result_set_format/ut/kqp_formats_arrow_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

#include <yql/essentials/minikql/computation/mkql_value_builder.h>
#include <yql/essentials/minikql/mkql_string_util.h>
#include <yql/essentials/parser/pg_wrapper/interface/codec.h>
#include <yql/essentials/parser/pg_wrapper/interface/compare.h>
#include <yql/essentials/parser/pg_wrapper/postgresql/src/backend/catalog/pg_type_d.h>
#include <yql/essentials/public/udf/arrow/defs.h>
#include <yql/essentials/types/binary_json/read.h>
#include <yql/essentials/types/binary_json/write.h>
Expand All @@ -22,6 +25,7 @@ using namespace NYql;

inline static constexpr size_t TEST_ARRAY_DATATYPE_SIZE = 1 << 16;
inline static constexpr size_t TEST_ARRAY_NESTED_SIZE = 1 << 8;
inline static constexpr size_t TEST_ARRAY_PG_SIZE = TEST_ARRAY_DATATYPE_SIZE;
inline static constexpr ui8 DECIMAL_PRECISION = 35;
inline static constexpr ui8 DECIMAL_SCALE = 10;
inline static constexpr ui32 VARIANT_NESTED_SIZE = 260;
Expand Down Expand Up @@ -800,6 +804,18 @@ struct TTestContext {
return values;
}

TType* GetOptionalPgValueType(ui32 pgTypeId) {
return TOptionalType::Create(GetPgType(pgTypeId), TypeEnv);
}

TUnboxedValueVector CreateOptionalsPgValue(ui32 quantity, ui32 pgTypeId) {
auto values = CreatePgValues(quantity, pgTypeId);
for (size_t i = 0; i < values.size(); ++i) {
values[i] = (i % 2 == 0) ? values[i].MakeOptional() : NUdf::TUnboxedValuePod();
}
return values;
}

TType* GetOptionalOptionalValueType() {
return TOptionalType::Create(GetOptionalDataValueType(), TypeEnv);
}
Expand Down Expand Up @@ -1151,6 +1167,37 @@ struct TTestContext {
}
return values;
}

TType* GetPgType(ui32 typeId) {
return TPgType::Create(typeId, TypeEnv);
}

TUnboxedValueVector CreatePgValues(ui32 quantity, ui32 typeId) {
TUnboxedValueVector values;
for (ui64 value = 0; value < quantity; ++value) {
if (value % 4 == 3) {
values.emplace_back(NUdf::TUnboxedValuePod());
continue;
}

std::string stringValue;
switch (typeId) {
case BOOLOID:
stringValue = std::to_string(value % 2 == 0);
break;
case INT8OID:
stringValue = std::to_string(value);
break;
case TEXTOID:
stringValue = "text" + std::to_string(value);
break;
default:
UNIT_ASSERT_C(false, "You need to add a new case for type " << typeId);
}
values.emplace_back(NYql::NCommon::PgValueFromNativeText(stringValue, typeId));
}
return values;
}
};

void AssertUnboxedValuesAreEqual(NUdf::TUnboxedValue& left, NUdf::TUnboxedValue& right, TType* type) {
Expand Down Expand Up @@ -1298,6 +1345,13 @@ void AssertUnboxedValuesAreEqual(NUdf::TUnboxedValue& left, NUdf::TUnboxedValue&
break;
}

case TType::EKind::Pg: {
auto pgType = static_cast<const TPgType*>(type);
auto equate = MakePgEquate(pgType);
UNIT_ASSERT(equate->Equals(left, right));
break;
}

default: {
UNIT_ASSERT_C(false, TStringBuilder() << "Unsupported type: " << type->GetKindAsStr());
}
Expand Down Expand Up @@ -1420,6 +1474,48 @@ void TestSingularTypeConversion() {
}
}

template <ui32 PgTypeId>
void TestPgTypeConversion() {
TTestContext context;

auto pgType = context.GetPgType(PgTypeId);
auto values = context.CreatePgValues(TEST_ARRAY_PG_SIZE, PgTypeId);

UNIT_ASSERT(IsArrowCompatible(pgType));

auto array = MakeArrowArray(values, pgType);
UNIT_ASSERT_C(array->ValidateFull().ok(), array->ValidateFull().ToString());
UNIT_ASSERT_VALUES_EQUAL(array->length(), values.size());

UNIT_ASSERT(array->type_id() == arrow::Type::STRING);
auto stringArray = static_pointer_cast<arrow::StringArray>(array);
UNIT_ASSERT_VALUES_EQUAL(stringArray->length(), values.size());

if (stringArray->length() > 1) {
switch (PgTypeId) {
case BOOLOID:
UNIT_ASSERT_VALUES_EQUAL(stringArray->GetString(0), "t");
UNIT_ASSERT_VALUES_EQUAL(stringArray->GetString(1), "f");
break;
case INT8OID:
UNIT_ASSERT_VALUES_EQUAL(stringArray->GetString(0), "0");
UNIT_ASSERT_VALUES_EQUAL(stringArray->GetString(1), "1");
break;
case TEXTOID:
UNIT_ASSERT_VALUES_EQUAL(stringArray->GetString(0), "text0");
UNIT_ASSERT_VALUES_EQUAL(stringArray->GetString(1), "text1");
break;
default:
UNIT_ASSERT_C(false, "You need to add a new case for type " << PgTypeId);
}
}

for (size_t i = 0; i < values.size(); ++i) {
auto arrowValue = ExtractUnboxedValue(array, i, pgType, context.HolderFactory);
AssertUnboxedValuesAreEqual(arrowValue, values[i], pgType);
}
}

} // namespace

Y_UNIT_TEST_SUITE(KqpFormats_Arrow_Conversion) {
Expand Down Expand Up @@ -2332,6 +2428,29 @@ Y_UNIT_TEST_SUITE(KqpFormats_Arrow_Conversion) {
}
}

Y_UNIT_TEST(NestedType_Optional_PgValue) {
TTestContext context;

auto optionalType = context.GetOptionalPgValueType(INT8OID);
auto values = context.CreateOptionalsPgValue(TEST_ARRAY_NESTED_SIZE, INT8OID);

UNIT_ASSERT(IsArrowCompatible(optionalType));

auto array = MakeArrowArray(values, optionalType);
UNIT_ASSERT_C(array->ValidateFull().ok(), array->ValidateFull().ToString());
UNIT_ASSERT_VALUES_EQUAL(array->length(), values.size());
UNIT_ASSERT(array->type_id() == arrow::Type::STRUCT);

auto structArray = static_pointer_cast<arrow::StructArray>(array);
UNIT_ASSERT_VALUES_EQUAL(structArray->num_fields(), 1);
UNIT_ASSERT(structArray->field(0)->type_id() == arrow::Type::STRING);

for (size_t i = 0; i < values.size(); ++i) {
auto arrowValue = ExtractUnboxedValue(array, i, optionalType, context.HolderFactory);
AssertUnboxedValuesAreEqual(arrowValue, values[i], optionalType);
}
}

Y_UNIT_TEST(NestedType_Optional_OptionalValue) {
TTestContext context;

Expand Down Expand Up @@ -2788,6 +2907,20 @@ Y_UNIT_TEST_SUITE(KqpFormats_Arrow_Conversion) {
AssertUnboxedValuesAreEqual(arrowValue, values[i], taggedType);
}
}

// Pg types
// They are converted using NYql::NCommon::PgValueToNativeText, so testing all types is not required
Y_UNIT_TEST(PgType_Bool) {
TestPgTypeConversion<BOOLOID>();
}

Y_UNIT_TEST(PgType_Int8) {
TestPgTypeConversion<INT8OID>();
}

Y_UNIT_TEST(PgType_Text) {
TestPgTypeConversion<TEXTOID>();
}
}

} // namespace NKikimr::NKqp::NFormats
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <yql/essentials/minikql/mkql_string_util.h>
#include <yql/essentials/minikql/mkql_type_helper.h>
#include <yql/essentials/parser/pg_wrapper/interface/codec.h>
#include <yql/essentials/types/dynumber/dynumber.h>
#include <yql/essentials/types/binary_json/write.h>
#include <yql/essentials/utils/yql_panic.h>
Expand Down Expand Up @@ -334,6 +335,18 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& arr
return holderFactory.CreateVariantHolder(value.Release(), variantIndex);
}

NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& array, ui64 row, const NMiniKQL::TPgType* pgType) {
YQL_ENSURE(array->type_id() == arrow::Type::STRING, "Unexpected array type");
auto stringArray = static_pointer_cast<arrow::StringArray>(array);

if (stringArray->IsNull(row)) {
return NUdf::TUnboxedValuePod();
}

auto data = stringArray->GetView(row);
return NYql::NCommon::PgValueFromNativeText(NUdf::TStringRef(data.data(), data.size()), pgType->GetTypeId());
}

} // namespace

std::unique_ptr<arrow::ArrayBuilder> MakeArrowBuilder(const NMiniKQL::TType* type) {
Expand Down Expand Up @@ -404,6 +417,10 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& arr
return ExtractUnboxedValue(array, row, static_cast<const NMiniKQL::TTaggedType*>(itemType)->GetBaseType(), holderFactory);
}

case NMiniKQL::TType::EKind::Pg: {
return ExtractUnboxedValue(array, row, static_cast<const NMiniKQL::TPgType*>(itemType));
}

case NMiniKQL::TType::EKind::Type:
case NMiniKQL::TType::EKind::Stream:
case NMiniKQL::TType::EKind::Callable:
Expand All @@ -412,7 +429,6 @@ NUdf::TUnboxedValue ExtractUnboxedValue(const std::shared_ptr<arrow::Array>& arr
case NMiniKQL::TType::EKind::Flow:
case NMiniKQL::TType::EKind::ReservedKind:
case NMiniKQL::TType::EKind::Block:
case NMiniKQL::TType::EKind::Pg:
case NMiniKQL::TType::EKind::Multi:
case NMiniKQL::TType::EKind::Linear: {
YQL_ENSURE(false, "Unsupported type: " << itemType->GetKindAsStr());
Expand Down
2 changes: 1 addition & 1 deletion ydb/core/kqp/common/result_set_format/ut/ya.make
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ YQL_LAST_ABI_VERSION()
PEERDIR(
library/cpp/testing/unittest
yql/essentials/public/udf/service/exception_policy
yql/essentials/sql/pg_dummy
yql/essentials/parser/pg_wrapper
)

END()
3 changes: 2 additions & 1 deletion ydb/core/kqp/ut/arrow/ya.make
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ SRCS(
PEERDIR(
ydb/core/kqp
ydb/core/kqp/ut/common
yql/essentials/sql/pg_dummy
ydb/public/sdk/cpp/src/client/arrow
yql/essentials/sql/pg
yql/essentials/parser/pg_wrapper
)

YQL_LAST_ABI_VERSION()
Expand Down
Loading