diff --git a/ydb/apps/ydb/CHANGELOG.md b/ydb/apps/ydb/CHANGELOG.md index 4ada8812bd88..ede436cb2557 100644 --- a/ydb/apps/ydb/CHANGELOG.md +++ b/ydb/apps/ydb/CHANGELOG.md @@ -1,4 +1,5 @@ * Added a simple progress bar for non-interactive stderr. +* The `ydb workload vector` now supports `import files` to populate table from CSV and parquet ## 2.27.0 ## diff --git a/ydb/library/workload/vector/vector_data_generator.cpp b/ydb/library/workload/vector/vector_data_generator.cpp new file mode 100644 index 000000000000..37ec108b03fa --- /dev/null +++ b/ydb/library/workload/vector/vector_data_generator.cpp @@ -0,0 +1,267 @@ +#include "vector_data_generator.h" + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace NYdbWorkload { + +namespace { + +class TTransformingDataGenerator final: public IBulkDataGenerator { +private: + std::shared_ptr InnerDataGenerator; + const TString EmbeddingSourceField; + +private: + static std::pair, std::shared_ptr> Deserialize(TDataPortion::TArrow* data) { + arrow::ipc::DictionaryMemo dictionary; + + arrow::io::BufferReader schemaBuffer(arrow::util::string_view(data->Schema.data(), data->Schema.size())); + const std::shared_ptr schema = arrow::ipc::ReadSchema(&schemaBuffer, &dictionary).ValueOrDie(); + + arrow::io::BufferReader recordBatchBuffer(arrow::util::string_view(data->Data.data(), data->Data.size())); + const std::shared_ptr recordBatch = arrow::ipc::ReadRecordBatch(schema, &dictionary, {}, &recordBatchBuffer).ValueOrDie(); + + return std::make_pair(schema, recordBatch); + } + + std::shared_ptr Deserialize(TDataPortion::TCsv* data) { + Ydb::Formats::CsvSettings csvSettings; + if (Y_UNLIKELY(!csvSettings.ParseFromString(data->FormatString))) { + ythrow yexception() << "Unable to parse CsvSettings"; + } + + arrow::csv::ReadOptions readOptions = arrow::csv::ReadOptions::Defaults(); + readOptions.skip_rows = csvSettings.skip_rows(); + if (data->Data.size() > NKikimr::NFormats::TArrowCSV::DEFAULT_BLOCK_SIZE) { + ui32 blockSize = NKikimr::NFormats::TArrowCSV::DEFAULT_BLOCK_SIZE; + blockSize *= data->Data.size() / blockSize + 1; + readOptions.block_size = blockSize; + } + + arrow::csv::ParseOptions parseOptions = arrow::csv::ParseOptions::Defaults(); + const auto& quoting = csvSettings.quoting(); + if (Y_UNLIKELY(quoting.quote_char().length() > 1)) { + ythrow yexception() << "Cannot read CSV: Wrong quote char '" << quoting.quote_char() << "'"; + } + const char qchar = quoting.quote_char().empty() ? '"' : quoting.quote_char().front(); + parseOptions.quoting = false; + parseOptions.quote_char = qchar; + parseOptions.double_quote = !quoting.double_quote_disabled(); + if (csvSettings.delimiter()) { + if (Y_UNLIKELY(csvSettings.delimiter().size() != 1)) { + ythrow yexception() << "Cannot read CSV: Invalid delimitr in csv: " << csvSettings.delimiter(); + } + parseOptions.delimiter = csvSettings.delimiter().front(); + } + + arrow::csv::ConvertOptions convertOptions = arrow::csv::ConvertOptions::Defaults(); + if (csvSettings.null_value()) { + convertOptions.null_values = { std::string(csvSettings.null_value().data(), csvSettings.null_value().size()) }; + convertOptions.strings_can_be_null = true; + convertOptions.quoted_strings_can_be_null = false; + } + + auto bufferReader = std::make_shared(arrow::util::string_view(data->Data.data(), data->Data.size())); + auto csvReader = arrow::csv::TableReader::Make( + arrow::io::default_io_context(), + bufferReader, + readOptions, + parseOptions, + convertOptions + ).ValueOrDie(); + + return csvReader->Read().ValueOrDie(); + } + + void TransformArrow(TDataPortion::TArrow* data) { + const auto [schema, batch] = Deserialize(data); + + // id + const auto idColumn = batch->GetColumnByName("id"); + const auto newIdColumn = arrow::compute::Cast(idColumn, arrow::uint64()).ValueOrDie().make_array(); + + // embedding + const auto embeddingColumn = std::dynamic_pointer_cast(batch->GetColumnByName(EmbeddingSourceField)); + arrow::StringBuilder newEmbeddingsBuilder; + for (int64_t row = 0; row < batch->num_rows(); ++row) { + const auto embeddingFloatList = std::static_pointer_cast(embeddingColumn->value_slice(row)); + + TStringBuilder buffer; + NKnnVectorSerialization::TSerializer serializer(&buffer.Out); + for (int64_t i = 0; i < embeddingFloatList->length(); ++i) { + serializer.HandleElement(embeddingFloatList->Value(i)); + } + serializer.Finish(); + + if (const auto status = newEmbeddingsBuilder.Append(buffer.MutRef()); !status.ok()) { + status.Abort(); + } + } + std::shared_ptr newEmbeddingColumn; + if (const auto status = newEmbeddingsBuilder.Finish(&newEmbeddingColumn); !status.ok()) { + status.Abort(); + } + + const auto newSchema = arrow::schema({ + arrow::field("id", arrow::uint64()), + arrow::field("embedding", arrow::utf8()), + }); + const auto newRecordBatch = arrow::RecordBatch::Make( + newSchema, + batch->num_rows(), + { + newIdColumn, + newEmbeddingColumn, + } + ); + data->Schema = arrow::ipc::SerializeSchema(*newSchema).ValueOrDie()->ToString(); + data->Data = arrow::ipc::SerializeRecordBatch(*newRecordBatch, arrow::ipc::IpcWriteOptions{}).ValueOrDie()->ToString(); + } + + void TransformCsv(TDataPortion::TCsv* data) { + const auto table = Deserialize(data); + + // id + const auto idColumn = table->GetColumnByName("id"); + + // embedding + const auto embeddingColumn = table->GetColumnByName(EmbeddingSourceField); + arrow::StringBuilder newEmbeddingsBuilder; + for (int64_t row = 0; row < table->num_rows(); ++row) { + const auto embeddingListString = std::static_pointer_cast(embeddingColumn->Slice(row, 1)->chunk(0))->Value(0); + + TStringBuf buffer(embeddingListString.data(), embeddingListString.size()); + buffer.SkipPrefix("["); + buffer.ChopSuffix("]"); + TMemoryInput input(buffer); + + TStringBuilder newEmbeddingBuilder; + NKnnVectorSerialization::TSerializer serializer(&newEmbeddingBuilder.Out); + while (!input.Exhausted()) { + float val; + input >> val; + input.Skip(1); + serializer.HandleElement(val); + } + serializer.Finish(); + + if (const auto status = newEmbeddingsBuilder.Append(newEmbeddingBuilder.MutRef()); !status.ok()) { + status.Abort(); + } + } + std::shared_ptr newEmbeddingColumn; + if (const auto status = newEmbeddingsBuilder.Finish(&newEmbeddingColumn); !status.ok()) { + status.Abort(); + } + + const auto newSchema = arrow::schema({ + arrow::field("id", arrow::uint64()), + arrow::field("embedding", arrow::utf8()), + }); + const auto newTable = arrow::Table::Make( + newSchema, + { + idColumn, + arrow::ChunkedArray::Make({newEmbeddingColumn}).ValueOrDie(), + } + ); + auto outputStream = arrow::io::BufferOutputStream::Create().ValueOrDie(); + if (const auto status = arrow::csv::WriteCSV(*newTable, arrow::csv::WriteOptions::Defaults(), outputStream.get()); !status.ok()) { + status.Abort(); + } + data->FormatString = ""; + data->Data = outputStream->Finish().ValueOrDie()->ToString(); + } + + void Transform(TDataPortion::TDataType& data) { + if (auto* value = std::get_if(&data)) { + TransformArrow(value); + } + if (auto* value = std::get_if(&data)) { + TransformCsv(value); + } + } + +public: + TTransformingDataGenerator(std::shared_ptr innerDataGenerator, const TString embeddingSourceField) + : IBulkDataGenerator(innerDataGenerator->GetName(), innerDataGenerator->GetSize()) + , InnerDataGenerator(innerDataGenerator) + , EmbeddingSourceField(embeddingSourceField) + {} + + virtual TDataPortions GenerateDataPortion() override { + TDataPortions portions = InnerDataGenerator->GenerateDataPortion(); + for (auto portion : portions) { + Transform(portion->MutableData()); + } + return portions; + } +}; + +} + +TWorkloadVectorFilesDataInitializer::TWorkloadVectorFilesDataInitializer(const TVectorWorkloadParams& params) + : TWorkloadDataInitializerBase("files", "Import vectors from files", params) + , Params(params) +{ } + +void TWorkloadVectorFilesDataInitializer::ConfigureOpts(NLastGetopt::TOpts& opts) { + opts.AddLongOption('i', "input", + "File or Directory with dataset. If directory is set, all its available files will be used. " + "Supports zipped and unzipped csv, tsv files and parquet ones that may be downloaded here: " + "https://huggingface.co/datasets/Cohere/wikipedia-22-12-simple-embeddings. " + "For better performance you may split it into some parts for parallel upload." + ).Required().StoreResult(&DataFiles); + opts.AddLongOption('t', "transform", + "Perform transformation of input data. " + "Parquet: leave only required fields, cast to expected types, convert list of floats into serialized representation. " + "CSV: leave only required fields, parse float list from string and serialize. " + "Reference for embedding serialization: https://ydb.tech/docs/yql/reference/udf/list/knn#functions-convert" + ).Optional().StoreTrue(&DoTransform); + opts.AddLongOption( + "transform-embedding-source-field", + "Specify field that contains list of floats to be converted into YDB embedding format." + ).DefaultValue(EmbeddingSourceField).StoreResult(&EmbeddingSourceField); +} + +TBulkDataGeneratorList TWorkloadVectorFilesDataInitializer::DoGetBulkInitialData() { + auto dataGenerator = std::make_shared( + *this, + Params.TableName, + 0, + Params.TableName, + DataFiles, + Params.GetColumns(), + TDataGenerator::EPortionSizeUnit::Line + ); + + if (DoTransform) { + return {std::make_shared(dataGenerator, EmbeddingSourceField)}; + } + return {dataGenerator}; +} + +} // namespace NYdbWorkload diff --git a/ydb/library/workload/vector/vector_data_generator.h b/ydb/library/workload/vector/vector_data_generator.h new file mode 100644 index 000000000000..3f66feb891c3 --- /dev/null +++ b/ydb/library/workload/vector/vector_data_generator.h @@ -0,0 +1,24 @@ +#pragma once + +#include "vector_workload_params.h" + +#include +#include + +namespace NYdbWorkload { + +class TWorkloadVectorFilesDataInitializer : public TWorkloadDataInitializerBase { +private: + const TVectorWorkloadParams& Params; + TString DataFiles; + bool DoTransform = false; + TString EmbeddingSourceField = "embedding"; + +public: + TWorkloadVectorFilesDataInitializer(const TVectorWorkloadParams& params); + + virtual void ConfigureOpts(NLastGetopt::TOpts& opts) override; + virtual TBulkDataGeneratorList DoGetBulkInitialData() override; +}; + +} // namespace NYdbWorkload diff --git a/ydb/library/workload/vector/vector_workload_params.cpp b/ydb/library/workload/vector/vector_workload_params.cpp index cba0776bc93e..82cef18fa903 100644 --- a/ydb/library/workload/vector/vector_workload_params.cpp +++ b/ydb/library/workload/vector/vector_workload_params.cpp @@ -1,3 +1,4 @@ +#include "vector_data_generator.h" #include "vector_enums.h" #include "vector_workload_params.h" #include "vector_workload_generator.h" @@ -55,6 +56,9 @@ void TVectorWorkloadParams::ConfigureOpts(NLastGetopt::TOpts& opts, const EComma ConfigureCommonOpts(opts); addInitParam(); break; + case TWorkloadParams::ECommandType::Import: + ConfigureCommonOpts(opts); + break; case TWorkloadParams::ECommandType::Run: ConfigureCommonOpts(opts); switch (static_cast(workloadType)) { @@ -91,6 +95,15 @@ void TVectorWorkloadParams::ConfigureIndexOpts(NLastGetopt::TOpts& opts) { .Required().StoreResult(&KmeansTreeClusters); } +TVector TVectorWorkloadParams::GetColumns() const { + TVector result(KeyColumns.begin(), KeyColumns.end()); + result.emplace_back(EmbeddingColumn); + if (PrefixColumn.has_value()) { + result.emplace_back(PrefixColumn.value()); + } + return result; +} + void TVectorWorkloadParams::Init() { const TString tablePath = GetFullTableName(TableName.c_str()); @@ -193,6 +206,12 @@ THolder TVectorWorkloadParams::CreateGenerator() const return MakeHolder(this); } +TWorkloadDataInitializer::TList TVectorWorkloadParams::CreateDataInitializers() const { + return { + std::make_shared(*this) + }; +} + TString TVectorWorkloadParams::GetWorkloadName() const { return "vector"; } diff --git a/ydb/library/workload/vector/vector_workload_params.h b/ydb/library/workload/vector/vector_workload_params.h index 958609418263..34961704787c 100644 --- a/ydb/library/workload/vector/vector_workload_params.h +++ b/ydb/library/workload/vector/vector_workload_params.h @@ -18,6 +18,7 @@ class TVectorWorkloadParams final: public TWorkloadBaseParams { public: void ConfigureOpts(NLastGetopt::TOpts& opts, const ECommandType commandType, int workloadType) override; THolder CreateGenerator() const override; + TWorkloadDataInitializer::TList CreateDataInitializers() const override; TString GetWorkloadName() const override; void Validate(const ECommandType commandType, int workloadType) override; @@ -26,6 +27,8 @@ class TVectorWorkloadParams final: public TWorkloadBaseParams { void ConfigureCommonOpts(NLastGetopt::TOpts& opts); void ConfigureIndexOpts(NLastGetopt::TOpts& opts); + TVector GetColumns() const; + TString TableName; TString QueryTableName; TString IndexName; diff --git a/ydb/library/workload/vector/ya.make b/ydb/library/workload/vector/ya.make index 5f16bc5c166e..20c79fb68cce 100644 --- a/ydb/library/workload/vector/ya.make +++ b/ydb/library/workload/vector/ya.make @@ -2,6 +2,7 @@ LIBRARY() SRCS( vector_command_index.cpp + vector_data_generator.cpp vector_recall_evaluator.cpp vector_sampler.cpp vector_sql.cpp @@ -11,7 +12,9 @@ SRCS( ) PEERDIR( + contrib/libs/apache/arrow ydb/library/workload/abstract + ydb/public/api/protos ) GENERATE_ENUM_SERIALIZATION_WITH_HEADER(vector_enums.h)