Skip to content

Commit

Permalink
Merge pull request apache#158 from rafael-telles/flight-sql-cpp-param…
Browse files Browse the repository at this point in the history
…eter-binding

[Flight SQL C++] Implement parameter binding on execute of PreparedStatement on Flight SQL Client.
  • Loading branch information
jcralmeida authored Oct 8, 2021
2 parents a8188fd + 8f94998 commit f79e28b
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 4 deletions.
16 changes: 16 additions & 0 deletions cpp/src/arrow/flight/flight-sql/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace internal {
template <class T = arrow::flight::FlightClient>
class ARROW_EXPORT PreparedStatementT {
pb::sql::ActionCreatePreparedStatementResult prepared_statement_result;
std::shared_ptr<RecordBatch> parameter_binding;
FlightCallOptions options;
bool is_closed;
T* client;
Expand All @@ -59,6 +60,21 @@ class ARROW_EXPORT PreparedStatementT {
/// \return Status.
Status Execute(std::unique_ptr<FlightInfo>* info);

/// \brief Retrieve the parameter schema from the query.
/// \param schema The parameter schema from the query.
/// \return Status.
Status GetParameterSchema(std::shared_ptr<Schema>* schema);

/// \brief Retrieve the ResultSet schema from the query.
/// \param schema The ResultSet schema from the query.
/// \return Status.
Status GetResultSetSchema(std::shared_ptr<Schema>* schema);

/// \brief Set a RecordBatch that contains the parameters that will be bind.
/// \param parameter_binding_ The parameters that will be bind.
/// \return Status.
Status SetParameters(std::shared_ptr<RecordBatch> parameter_binding);

/// \brief Closes the prepared statement.
/// \param[in] options RPC-layer hints for this call.
/// \return Status.
Expand Down
53 changes: 51 additions & 2 deletions cpp/src/arrow/flight/flight-sql/client_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include <arrow/buffer.h>
#include <arrow/flight/flight-sql/FlightSql.pb.h>
#include <arrow/flight/types.h>
#include <arrow/io/memory.h>
#include <arrow/ipc/reader.h>
#include <arrow/testing/gtest_util.h>
#include <google/protobuf/any.pb.h>
#include <google/protobuf/message_lite.h>

Expand Down Expand Up @@ -271,7 +274,8 @@ Status FlightSqlClientT<T>::Prepare(

prepared_result.UnpackTo(&prepared_statement_result);

prepared_statement->reset(new PreparedStatementT<T>(client.get(), query, prepared_statement_result, options));
prepared_statement->reset(
new PreparedStatementT<T>(client.get(), query, prepared_statement_result, options));

return Status::OK();
}
Expand All @@ -291,16 +295,61 @@ Status PreparedStatementT<T>::Execute(std::unique_ptr<FlightInfo>* info) {
any.PackFrom(execute_query_command);

const std::string& string = any.SerializeAsString();
const FlightDescriptor& descriptor = FlightDescriptor::Command(string);
const FlightDescriptor descriptor = FlightDescriptor::Command(string);

if (parameter_binding && parameter_binding->num_rows() > 0) {
std::unique_ptr<FlightStreamWriter> writer;
std::unique_ptr<FlightMetadataReader> reader;
client->DoPut(options, descriptor, parameter_binding->schema(), &writer, &reader);

ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding));
ARROW_RETURN_NOT_OK(writer->DoneWriting());
// Wait for the server to ack the result
std::shared_ptr<Buffer> buffer;
ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer));
}

return client->GetFlightInfo(options, descriptor, info);
}

template <class T>
Status PreparedStatementT<T>::SetParameters(std::shared_ptr<RecordBatch> parameter_binding_) {
parameter_binding = std::move(parameter_binding_);

return Status::OK();
}

template <class T>
bool PreparedStatementT<T>::IsClosed() const {
return is_closed;
}

template <class T>
Status PreparedStatementT<T>::GetResultSetSchema(std::shared_ptr<Schema> *schema) {
auto &args = prepared_statement_result.dataset_schema();
std::shared_ptr<Buffer> schema_buffer = std::make_shared<Buffer>(args);

io::BufferReader reader(schema_buffer);

ipc::DictionaryMemo in_memo;
ARROW_ASSIGN_OR_RAISE(*schema, ReadSchema(&reader, &in_memo))

return Status::OK();
}

template <class T>
Status PreparedStatementT<T>::GetParameterSchema(std::shared_ptr<Schema>* schema) {
auto &args = prepared_statement_result.parameter_schema();
std::shared_ptr<Buffer> schema_buffer = std::make_shared<Buffer>(args);

io::BufferReader reader(schema_buffer);

ipc::DictionaryMemo in_memo;
ARROW_ASSIGN_OR_RAISE(*schema, ReadSchema(&reader, &in_memo))

return Status::OK();
}

template <class T>
Status PreparedStatementT<T>::Close() {
if (is_closed) {
Expand Down
114 changes: 114 additions & 0 deletions cpp/src/arrow/flight/flight-sql/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <google/protobuf/any.pb.h>
#include <gtest/gtest.h>

#include <utility>

namespace pb = arrow::flight::protocol;
using ::testing::_;
using ::testing::Ref;
Expand Down Expand Up @@ -62,6 +64,45 @@ class FlightMetadataReaderMock : public FlightMetadataReader {
}
};

class FlightStreamWriterMock : public FlightStreamWriter {
public:
explicit FlightStreamWriterMock() = default;

Status DoneWriting() override {
return Status::OK();
}

Status WriteMetadata(std::shared_ptr<Buffer> app_metadata) override {
return Status::OK();
}

Status Begin(const std::shared_ptr<Schema> &schema,
const ipc::IpcWriteOptions &options) override {
return Status::OK();
}

Status Begin(const std::shared_ptr<Schema> &schema) override {
return MetadataRecordBatchWriter::Begin(schema);
}

ipc::WriteStats stats() const override {
return ipc::WriteStats();
}

Status WriteWithMetadata(const RecordBatch &batch,
std::shared_ptr<Buffer> app_metadata) override {
return Status::OK();
}

Status Close() override {
return Status::OK();
}

Status WriteRecordBatch(const RecordBatch &batch) override {
return Status::OK();
}
};

FlightDescriptor getDescriptor(google::protobuf::Message& command) {
google::protobuf::Any any;
any.PackFrom(command);
Expand Down Expand Up @@ -271,6 +312,79 @@ TEST(TestFlightSqlClient, TestPreparedStatementExecute) {
(void)preparedStatement->Execute(&flight_info);
}

TEST(TestFlightSqlClient, TestPreparedStatementExecuteParameterBinding) {
auto* client_mock = new FlightClientMock();
std::unique_ptr<FlightClientMock> client_mock_ptr(client_mock);
FlightSqlClientT<FlightClientMock> sqlClient(client_mock_ptr);
FlightCallOptions call_options;

const std::string query = "query";

ON_CALL(*client_mock, DoAction)
.WillByDefault([](const FlightCallOptions& options, const Action& action,
std::unique_ptr<ResultStream>* results) {
google::protobuf::Any command;

pb::sql::ActionCreatePreparedStatementResult prepared_statement_result;

prepared_statement_result.set_prepared_statement_handle("query");

auto schema = arrow::schema({arrow::field("id", int64())});

std::shared_ptr<Buffer> schema_buffer;
const arrow::Result<std::shared_ptr<Buffer>> &result = arrow::ipc::SerializeSchema(
*schema);

ARROW_ASSIGN_OR_RAISE(schema_buffer, result);

prepared_statement_result.set_parameter_schema(schema_buffer->ToString());

command.PackFrom(prepared_statement_result);

*results = std::unique_ptr<ResultStream>(new SimpleResultStream(
{Result{Buffer::FromString(command.SerializeAsString())}}));

return Status::OK();
});

std::shared_ptr<Buffer> buffer_ptr;
ON_CALL(*client_mock, DoPut)
.WillByDefault([&buffer_ptr](const FlightCallOptions& options,
const FlightDescriptor& descriptor1,
const std::shared_ptr<Schema>& schema,
std::unique_ptr<FlightStreamWriter>* writer,
std::unique_ptr<FlightMetadataReader>* reader) {

writer->reset(new FlightStreamWriterMock());
reader->reset(new FlightMetadataReaderMock(&buffer_ptr));

return Status::OK();
});


std::unique_ptr<FlightInfo> flight_info;
EXPECT_CALL(*client_mock, DoAction(_, _, _)).Times(2);
EXPECT_CALL(*client_mock, DoPut(_, _, _, _, _));

std::shared_ptr<internal::PreparedStatementT<FlightClientMock>> prepared_statement;
ASSERT_OK(sqlClient.Prepare(call_options, query, &prepared_statement));

std::shared_ptr<Schema> parameter_schema;
ASSERT_OK(prepared_statement->GetParameterSchema(&parameter_schema));

arrow::Int64Builder int_builder;
ASSERT_OK(int_builder.Append(1));
std::shared_ptr<arrow::Array> int_array;
ASSERT_OK(int_builder.Finish(&int_array));
std::shared_ptr<arrow::RecordBatch> result;
result = arrow::RecordBatch::Make(parameter_schema, 1, {int_array});
ASSERT_OK(prepared_statement->SetParameters(result));

EXPECT_CALL(*client_mock, GetFlightInfo(_, _, &flight_info));

ASSERT_OK(prepared_statement->Execute(&flight_info));
}

TEST(TestFlightSqlClient, TestExecuteUpdate) {
auto* client_mock = new FlightClientMock();
std::unique_ptr<FlightClientMock> client_mock_ptr(client_mock);
Expand Down
26 changes: 24 additions & 2 deletions cpp/src/arrow/flight/flight-sql/test_app.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

#include <arrow/array/builder_binary.h>
#include <arrow/flight/api.h>
#include <arrow/flight/flight-sql/api.h>
#include <arrow/io/memory.h>
Expand All @@ -23,6 +24,7 @@
#include <arrow/table.h>
#include <gflags/gflags.h>

#include <boost/algorithm/string.hpp>
#include <iostream>
#include <memory>

Expand Down Expand Up @@ -135,7 +137,26 @@ Status RunMain() {
ARROW_RETURN_NOT_OK(sqlClient.Prepare(call_options, fLS::FLAGS_query, &prepared_statement));
ARROW_RETURN_NOT_OK(prepared_statement->Execute(&info));
ARROW_RETURN_NOT_OK(PrintResults(sqlClient, call_options, info));
} else if (fLS::FLAGS_command == "GetSchemas") {
} else if (fLS::FLAGS_command == "PreparedStatementExecuteParameterBinding") {
std::shared_ptr<arrow::flight::sql::PreparedStatement> prepared_statement;
ARROW_RETURN_NOT_OK(sqlClient.Prepare({}, fLS::FLAGS_query, &prepared_statement));
std::shared_ptr<Schema> parameter_schema;
std::shared_ptr<Schema> result_set_schema;
ARROW_RETURN_NOT_OK(prepared_statement->GetParameterSchema(&parameter_schema));
ARROW_RETURN_NOT_OK(prepared_statement->GetResultSetSchema(&result_set_schema));

std::cout << result_set_schema->ToString(false) << std::endl;
arrow::Int64Builder int_builder;
ARROW_RETURN_NOT_OK(int_builder.Append(1));
std::shared_ptr<arrow::Array> int_array;
ARROW_RETURN_NOT_OK(int_builder.Finish(&int_array));
std::shared_ptr<arrow::RecordBatch> result;
result = arrow::RecordBatch::Make(parameter_schema, 1, {int_array});

ARROW_RETURN_NOT_OK(prepared_statement->SetParameters(result));
ARROW_RETURN_NOT_OK(prepared_statement->Execute(&info));
ARROW_RETURN_NOT_OK(PrintResults(sqlClient, call_options, info));
}else if (fLS::FLAGS_command == "GetSchemas") {
ARROW_RETURN_NOT_OK(sqlClient.GetSchemas(call_options, &fLS::FLAGS_catalog,
&fLS::FLAGS_schema, &info));
} else if (fLS::FLAGS_command == "GetTableTypes") {
Expand All @@ -158,7 +179,8 @@ Status RunMain() {
call_options, &fLS::FLAGS_catalog, &fLS::FLAGS_schema, fLS::FLAGS_table, &info));
}

if (info != NULLPTR && fLS::FLAGS_command != "PreparedStatementExecute") {
if (info != NULLPTR &&
!boost::istarts_with(fLS::FLAGS_command, "PreparedStatementExecute")) {
return PrintResults(sqlClient, call_options, info);
}

Expand Down

0 comments on commit f79e28b

Please sign in to comment.