Skip to content

Commit

Permalink
[feature-wip](arrow-flight)(step4) Support other DML and DDL statemen…
Browse files Browse the repository at this point in the history
…ts, besides `Select` (apache#25919)

Design Documentation Linked to apache#25514
  • Loading branch information
xinyiZzz authored and seawinde committed Nov 12, 2023
1 parent 5efe10d commit 41931f7
Show file tree
Hide file tree
Showing 30 changed files with 1,137 additions and 600 deletions.
26 changes: 13 additions & 13 deletions be/src/runtime/result_buffer_mgr.cpp
Expand Up @@ -109,21 +109,21 @@ std::shared_ptr<BufferControlBlock> ResultBufferMgr::find_control_block(const TU
return std::shared_ptr<BufferControlBlock>();
}

void ResultBufferMgr::register_row_descriptor(const TUniqueId& query_id,
const RowDescriptor& row_desc) {
std::unique_lock<std::shared_mutex> wlock(_row_descriptor_map_lock);
_row_descriptor_map.insert(std::make_pair(query_id, row_desc));
void ResultBufferMgr::register_arrow_schema(const TUniqueId& query_id,
const std::shared_ptr<arrow::Schema>& arrow_schema) {
std::unique_lock<std::shared_mutex> wlock(_arrow_schema_map_lock);
_arrow_schema_map.insert(std::make_pair(query_id, arrow_schema));
}

RowDescriptor ResultBufferMgr::find_row_descriptor(const TUniqueId& query_id) {
std::shared_lock<std::shared_mutex> rlock(_row_descriptor_map_lock);
RowDescriptorMap::iterator iter = _row_descriptor_map.find(query_id);
std::shared_ptr<arrow::Schema> ResultBufferMgr::find_arrow_schema(const TUniqueId& query_id) {
std::shared_lock<std::shared_mutex> rlock(_arrow_schema_map_lock);
auto iter = _arrow_schema_map.find(query_id);

if (_row_descriptor_map.end() != iter) {
if (_arrow_schema_map.end() != iter) {
return iter->second;
}

return RowDescriptor();
return nullptr;
}

void ResultBufferMgr::fetch_data(const PUniqueId& finst_id, GetResultBatchCtx* ctx) {
Expand Down Expand Up @@ -162,11 +162,11 @@ Status ResultBufferMgr::cancel(const TUniqueId& query_id) {
}

{
std::unique_lock<std::shared_mutex> wlock(_row_descriptor_map_lock);
RowDescriptorMap::iterator row_desc_iter = _row_descriptor_map.find(query_id);
std::unique_lock<std::shared_mutex> wlock(_arrow_schema_map_lock);
auto arrow_schema_iter = _arrow_schema_map.find(query_id);

if (_row_descriptor_map.end() != row_desc_iter) {
_row_descriptor_map.erase(row_desc_iter);
if (_arrow_schema_map.end() != arrow_schema_iter) {
_arrow_schema_map.erase(arrow_schema_iter);
}
}

Expand Down
15 changes: 8 additions & 7 deletions be/src/runtime/result_buffer_mgr.h
Expand Up @@ -29,12 +29,12 @@

#include "common/status.h"
#include "gutil/ref_counted.h"
#include "runtime/descriptors.h"
#include "util/countdown_latch.h"
#include "util/hash_util.hpp"

namespace arrow {
class RecordBatch;
class Schema;
} // namespace arrow

namespace doris {
Expand Down Expand Up @@ -66,8 +66,9 @@ class ResultBufferMgr {
// fetch data result to Arrow Flight Server
Status fetch_arrow_data(const TUniqueId& finst_id, std::shared_ptr<arrow::RecordBatch>* result);

void register_row_descriptor(const TUniqueId& query_id, const RowDescriptor& row_desc);
RowDescriptor find_row_descriptor(const TUniqueId& query_id);
void register_arrow_schema(const TUniqueId& query_id,
const std::shared_ptr<arrow::Schema>& arrow_schema);
std::shared_ptr<arrow::Schema> find_arrow_schema(const TUniqueId& query_id);

// cancel
Status cancel(const TUniqueId& fragment_id);
Expand All @@ -78,7 +79,7 @@ class ResultBufferMgr {
private:
using BufferMap = std::unordered_map<TUniqueId, std::shared_ptr<BufferControlBlock>>;
using TimeoutMap = std::map<time_t, std::vector<TUniqueId>>;
using RowDescriptorMap = std::unordered_map<TUniqueId, RowDescriptor>;
using ArrowSchemaMap = std::unordered_map<TUniqueId, std::shared_ptr<arrow::Schema>>;

std::shared_ptr<BufferControlBlock> find_control_block(const TUniqueId& query_id);

Expand All @@ -90,10 +91,10 @@ class ResultBufferMgr {
std::shared_mutex _buffer_map_lock;
// buffer block map
BufferMap _buffer_map;
// lock for descriptor map
std::shared_mutex _row_descriptor_map_lock;
// lock for arrow schema map
std::shared_mutex _arrow_schema_map_lock;
// for arrow flight
RowDescriptorMap _row_descriptor_map;
ArrowSchemaMap _arrow_schema_map;

// lock for timeout map
std::mutex _timeout_lock;
Expand Down
16 changes: 5 additions & 11 deletions be/src/service/arrow_flight/arrow_flight_batch_reader.cpp
Expand Up @@ -39,18 +39,12 @@ ArrowFlightBatchReader::ArrowFlightBatchReader(std::shared_ptr<QueryStatement> s
arrow::Result<std::shared_ptr<ArrowFlightBatchReader>> ArrowFlightBatchReader::Create(
const std::shared_ptr<QueryStatement>& statement_) {
// Make sure that FE send the fragment to BE and creates the BufferControlBlock before returning ticket
// to the ADBC client, so that the row_descriptor and control block can be found.
RowDescriptor row_desc =
ExecEnv::GetInstance()->result_mgr()->find_row_descriptor(statement_->query_id);
if (row_desc.equals(RowDescriptor())) {
// to the ADBC client, so that the schema and control block can be found.
auto schema = ExecEnv::GetInstance()->result_mgr()->find_arrow_schema(statement_->query_id);
if (schema == nullptr) {
ARROW_RETURN_NOT_OK(arrow::Status::Invalid(fmt::format(
"Schema RowDescriptor Not Found, queryid: {}", print_id(statement_->query_id))));
}
std::shared_ptr<arrow::Schema> schema;
auto st = convert_to_arrow_schema(row_desc, &schema);
if (UNLIKELY(!st.ok())) {
LOG(WARNING) << st.to_string();
ARROW_RETURN_NOT_OK(to_arrow_status(st));
"not found arrow flight schema, maybe query has been canceled, queryid: {}",
print_id(statement_->query_id))));
}
std::shared_ptr<ArrowFlightBatchReader> result(new ArrowFlightBatchReader(statement_, schema));
return result;
Expand Down
20 changes: 8 additions & 12 deletions be/src/service/internal_service.cpp
Expand Up @@ -718,23 +718,19 @@ void PInternalServiceImpl::fetch_arrow_flight_schema(google::protobuf::RpcContro
google::protobuf::Closure* done) {
bool ret = _light_work_pool.try_offer([request, result, done]() {
brpc::ClosureGuard closure_guard(done);
RowDescriptor row_desc = ExecEnv::GetInstance()->result_mgr()->find_row_descriptor(
UniqueId(request->finst_id()).to_thrift());
if (row_desc.equals(RowDescriptor())) {
auto st = Status::NotFound("not found row descriptor");
st.to_protobuf(result->mutable_status());
return;
}

std::shared_ptr<arrow::Schema> schema;
auto st = convert_to_arrow_schema(row_desc, &schema);
if (UNLIKELY(!st.ok())) {
std::shared_ptr<arrow::Schema> schema =
ExecEnv::GetInstance()->result_mgr()->find_arrow_schema(
UniqueId(request->finst_id()).to_thrift());
if (schema == nullptr) {
LOG(INFO) << "not found arrow flight schema, maybe query has been canceled";
auto st = Status::NotFound(
"not found arrow flight schema, maybe query has been canceled");
st.to_protobuf(result->mutable_status());
return;
}

std::string schema_str;
st = serialize_arrow_schema(row_desc, &schema, &schema_str);
auto st = serialize_arrow_schema(&schema, &schema_str);
if (st.ok()) {
result->set_schema(std::move(schema_str));
}
Expand Down
32 changes: 24 additions & 8 deletions be/src/util/arrow/row_batch.cpp
Expand Up @@ -39,6 +39,8 @@
#include "runtime/types.h"
#include "util/arrow/block_convertor.h"
#include "vec/core/block.h"
#include "vec/exprs/vexpr.h"
#include "vec/exprs/vexpr_context.h"

namespace doris {

Expand Down Expand Up @@ -163,6 +165,22 @@ Status convert_to_arrow_schema(const RowDescriptor& row_desc,
return Status::OK();
}

Status convert_expr_ctxs_arrow_schema(const vectorized::VExprContextSPtrs& output_vexpr_ctxs,
std::shared_ptr<arrow::Schema>* result) {
std::vector<std::shared_ptr<arrow::Field>> fields;
for (auto expr_ctx : output_vexpr_ctxs) {
std::shared_ptr<arrow::DataType> arrow_type;
auto root_expr = expr_ctx->root();
RETURN_IF_ERROR(convert_to_arrow_type(root_expr->type(), &arrow_type));
auto field_name = root_expr->is_slot_ref() ? root_expr->expr_name()
: root_expr->data_type()->get_name();
fields.push_back(
std::make_shared<arrow::Field>(field_name, arrow_type, root_expr->is_nullable()));
}
*result = arrow::schema(std::move(fields));
return Status::OK();
}

Status serialize_record_batch(const arrow::RecordBatch& record_batch, std::string* result) {
// create sink memory buffer outputstream with the computed capacity
int64_t capacity;
Expand Down Expand Up @@ -206,15 +224,13 @@ Status serialize_record_batch(const arrow::RecordBatch& record_batch, std::strin
return Status::OK();
}

Status serialize_arrow_schema(RowDescriptor row_desc, std::shared_ptr<arrow::Schema>* schema,
std::string* result) {
std::vector<SlotDescriptor*> slots;
for (auto tuple_desc : row_desc.tuple_descriptors()) {
slots.insert(slots.end(), tuple_desc->slots().begin(), tuple_desc->slots().end());
Status serialize_arrow_schema(std::shared_ptr<arrow::Schema>* schema, std::string* result) {
auto make_empty_result = arrow::RecordBatch::MakeEmpty(*schema);
if (!make_empty_result.ok()) {
return Status::InternalError("serialize_arrow_schema failed, reason: {}",
make_empty_result.status().ToString());
}
auto block = vectorized::Block(slots, 0);
std::shared_ptr<arrow::RecordBatch> batch;
RETURN_IF_ERROR(convert_to_arrow_batch(block, *schema, arrow::default_memory_pool(), &batch));
auto batch = make_empty_result.ValueOrDie();
return serialize_record_batch(*batch, result);
}

Expand Down
7 changes: 5 additions & 2 deletions be/src/util/arrow/row_batch.h
Expand Up @@ -23,6 +23,7 @@
#include "common/status.h"
#include "runtime/types.h"
#include "vec/core/block.h"
#include "vec/exprs/vexpr_fwd.h"

// This file will convert Doris RowBatch to/from Arrow's RecordBatch
// RowBatch is used by Doris query engine to exchange data between
Expand All @@ -49,9 +50,11 @@ Status convert_to_arrow_schema(const RowDescriptor& row_desc,
Status convert_block_arrow_schema(const vectorized::Block& block,
std::shared_ptr<arrow::Schema>* result);

Status convert_expr_ctxs_arrow_schema(const vectorized::VExprContextSPtrs& output_vexpr_ctxs,
std::shared_ptr<arrow::Schema>* result);

Status serialize_record_batch(const arrow::RecordBatch& record_batch, std::string* result);

Status serialize_arrow_schema(RowDescriptor row_desc, std::shared_ptr<arrow::Schema>* schema,
std::string* result);
Status serialize_arrow_schema(std::shared_ptr<arrow::Schema>* schema, std::string* result);

} // namespace doris
13 changes: 5 additions & 8 deletions be/src/vec/sink/varrow_flight_result_writer.cpp
Expand Up @@ -27,23 +27,20 @@
namespace doris {
namespace vectorized {

VArrowFlightResultWriter::VArrowFlightResultWriter(BufferControlBlock* sinker,
const VExprContextSPtrs& output_vexpr_ctxs,
RuntimeProfile* parent_profile,
const RowDescriptor& row_desc)
VArrowFlightResultWriter::VArrowFlightResultWriter(
BufferControlBlock* sinker, const VExprContextSPtrs& output_vexpr_ctxs,
RuntimeProfile* parent_profile, const std::shared_ptr<arrow::Schema>& arrow_schema)
: _sinker(sinker),
_output_vexpr_ctxs(output_vexpr_ctxs),
_parent_profile(parent_profile),
_row_desc(row_desc) {}
_arrow_schema(arrow_schema) {}

Status VArrowFlightResultWriter::init(RuntimeState* state) {
_init_profile();
if (nullptr == _sinker) {
return Status::InternalError("sinker is NULL pointer.");
}
_is_dry_run = state->query_options().dry_run_query;
// generate the arrow schema
RETURN_IF_ERROR(convert_to_arrow_schema(_row_desc, &_arrow_schema));
return Status::OK();
}

Expand Down Expand Up @@ -100,7 +97,7 @@ bool VArrowFlightResultWriter::can_sink() {
return _sinker->can_sink();
}

Status VArrowFlightResultWriter::close(Status) {
Status VArrowFlightResultWriter::close(Status st) {
COUNTER_SET(_sent_rows_counter, _written_rows);
COUNTER_UPDATE(_bytes_sent_counter, _bytes_sent);
return Status::OK();
Expand Down
5 changes: 2 additions & 3 deletions be/src/vec/sink/varrow_flight_result_writer.h
Expand Up @@ -31,15 +31,15 @@
namespace doris {
class BufferControlBlock;
class RuntimeState;
class RowDescriptor;

namespace vectorized {
class Block;

class VArrowFlightResultWriter final : public ResultWriter {
public:
VArrowFlightResultWriter(BufferControlBlock* sinker, const VExprContextSPtrs& output_vexpr_ctxs,
RuntimeProfile* parent_profile, const RowDescriptor& row_desc);
RuntimeProfile* parent_profile,
const std::shared_ptr<arrow::Schema>& arrow_schema);

Status init(RuntimeState* state) override;

Expand Down Expand Up @@ -72,7 +72,6 @@ class VArrowFlightResultWriter final : public ResultWriter {

uint64_t _bytes_sent = 0;

const RowDescriptor& _row_desc;
std::shared_ptr<arrow::Schema> _arrow_schema;
};
} // namespace vectorized
Expand Down
2 changes: 0 additions & 2 deletions be/src/vec/sink/vmemory_scratch_sink.cpp
Expand Up @@ -56,8 +56,6 @@ Status MemoryScratchSink::_prepare_vexpr(RuntimeState* state) {
RETURN_IF_ERROR(VExpr::create_expr_trees(_t_output_expr, _output_vexpr_ctxs));
// Prepare the exprs to run.
RETURN_IF_ERROR(VExpr::prepare(_output_vexpr_ctxs, state, _row_desc));
// generate the arrow schema
RETURN_IF_ERROR(convert_to_arrow_schema(_row_desc, &_arrow_schema));
return Status::OK();
}

Expand Down
2 changes: 0 additions & 2 deletions be/src/vec/sink/vmemory_scratch_sink.h
Expand Up @@ -65,8 +65,6 @@ class MemoryScratchSink final : public DataSink {
private:
Status _prepare_vexpr(RuntimeState* state);

std::shared_ptr<arrow::Schema> _arrow_schema;

BlockQueueSharedPtr _queue;

// Owned by the RuntimeState.
Expand Down
12 changes: 8 additions & 4 deletions be/src/vec/sink/vresult_sink.cpp
Expand Up @@ -33,6 +33,7 @@
#include "runtime/exec_env.h"
#include "runtime/result_buffer_mgr.h"
#include "runtime/runtime_state.h"
#include "util/arrow/row_batch.h"
#include "util/runtime_profile.h"
#include "util/telemetry/telemetry.h"
#include "vec/exprs/vexpr.h"
Expand Down Expand Up @@ -98,12 +99,15 @@ Status VResultSink::prepare(RuntimeState* state) {
_writer.reset(new (std::nothrow)
VMysqlResultWriter(_sender.get(), _output_vexpr_ctxs, _profile));
break;
case TResultSinkType::ARROW_FLIGHT_PROTOCAL:
state->exec_env()->result_mgr()->register_row_descriptor(state->fragment_instance_id(),
_row_desc);
case TResultSinkType::ARROW_FLIGHT_PROTOCAL: {
std::shared_ptr<arrow::Schema> arrow_schema;
RETURN_IF_ERROR(convert_expr_ctxs_arrow_schema(_output_vexpr_ctxs, &arrow_schema));
state->exec_env()->result_mgr()->register_arrow_schema(state->fragment_instance_id(),
arrow_schema);
_writer.reset(new (std::nothrow) VArrowFlightResultWriter(_sender.get(), _output_vexpr_ctxs,
_profile, _row_desc));
_profile, arrow_schema));
break;
}
default:
return Status::InternalError("Unknown result sink type");
}
Expand Down
Expand Up @@ -1204,7 +1204,10 @@ public enum ErrorCode {
"the auto increment must be BIGINT type."),

ERR_AUTO_INCREMENT_COLUMN_IN_AGGREGATE_TABLE(5096, new byte[]{'4', '2', '0', '0', '0'},
"the auto increment is only supported in duplicate table and unique table.");
"the auto increment is only supported in duplicate table and unique table."),

ERR_ARROW_FLIGHT_SQL_MUST_ONLY_RESULT_STMT(5097, new byte[]{'4', '2', '0', '0', '0'},
"There can only be one stmt that returns the result and it is at the end.");

// This is error code
private final int code;
Expand Down
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.ConnectProcessor;
import org.apache.doris.qe.ConnectScheduler;
import org.apache.doris.qe.MysqlConnectProcessor;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand Down Expand Up @@ -81,7 +82,7 @@ public void handleEvent(AcceptingChannel<StreamConnection> channel) {
context.getEnv().getAuth().getQueryTimeout(context.getQualifiedUser()));
context.setUserInsertTimeout(
context.getEnv().getAuth().getInsertTimeout(context.getQualifiedUser()));
ConnectProcessor processor = new ConnectProcessor(context);
ConnectProcessor processor = new MysqlConnectProcessor(context);
context.startAcceptQuery(processor);
} catch (AfterConnectedException e) {
// do not need to print log for this kind of exception.
Expand Down
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;

// MySQL protocol text command
// Reused by arrow flight protocol
public enum MysqlCommand {
COM_SLEEP("Sleep", 0),
COM_QUIT("Quit", 1),
Expand Down

0 comments on commit 41931f7

Please sign in to comment.