diff --git a/spec/pg/connection_spec.cr b/spec/pg/connection_spec.cr index 3dedb40d..ff8967bc 100644 --- a/spec/pg/connection_spec.cr +++ b/spec/pg/connection_spec.cr @@ -104,3 +104,35 @@ describe PG, "#read_next_row_start" do end end end + +record PG::ConnectionSpec::TestUser, id : Int32, name : String do + include DB::Serializable +end + +describe PG, "#pipeline" do + it "allows pipelined queries" do + with_connection do |db| + result_sets = db.pipeline do |pipe| + pipe.query "SELECT 42" + pipe.query "SELECT $1::int4 AS exchange, $2::int8 AS suffix", 867, 5309 + pipe.query "SELECT * FROM generate_series(1, 3)" + pipe.query <<-SQL + SELECT + generate_series AS id, + 'Person #' || generate_series AS name + FROM generate_series(1, 3) + SQL + 50.times { |i| pipe.query "SELECT $1::int4 AS index", i } + end + result_sets.scalar(Int32).should eq 42 + result_sets.read_one({Int32, Int64}).should eq({867, 5309}) + result_sets.read_all(Int32).should eq [1, 2, 3] + result_sets.read_all(PG::ConnectionSpec::TestUser).should eq [ + PG::ConnectionSpec::TestUser.new(1, "Person #1"), + PG::ConnectionSpec::TestUser.new(2, "Person #2"), + PG::ConnectionSpec::TestUser.new(3, "Person #3"), + ] + 50.times { |i| result_sets.scalar(Int32).should eq i } + end + end +end diff --git a/src/pg/connection.cr b/src/pg/connection.cr index c84cad92..118b503c 100644 --- a/src/pg/connection.cr +++ b/src/pg/connection.cr @@ -1,4 +1,6 @@ require "../pq/*" +require "./statement" +require "./result_set" module PG class Connection < ::DB::Connection @@ -25,9 +27,15 @@ module PG Statement.new(self, query) end + def pipeline + pipeline = Pipeline.new(self) + yield pipeline + pipeline.results + end + # Execute several statements. No results are returned. def exec_all(query : String) : Nil - PQ::SimpleQuery.new(@connection, query) + PQ::SimpleQuery.new(@connection, query).exec nil end @@ -68,4 +76,73 @@ module PG end end end + + struct Pipeline + def initialize(@connection : Connection) + @queries = [] of PQ::ExtendedQuery + end + + def query(query, *args_, args : Array? = nil) : self + ext_query = PQ::ExtendedQuery.new(@connection.connection, query, DB::EnumerableConcat.build(args_, args)) + @queries << ext_query.tap(&.send) + self + end + + def results + @iterator ||= Results.new(@connection, @queries.each) + end + + struct Results + def initialize(@connection : Connection, @result_sets : Iterator(PQ::ExtendedQuery)) + end + + def scalar(type : T.class) forall T + each type do |value| + return value + end + end + + def read_one(type : T.class) forall T + each(type) { |value| return value } + end + + def read_one(types : Tuple) + each(*types) { |value| return value } + end + + def read_all(type : T.class) forall T + results = Array(T).new + + each(type) do |row| + results << row + end + results + end + + def each(*type) forall T + rs = self.next + + begin + rs.each do + yield rs.read(*type) + end + ensure + rs.close + end + end + + def next + case result = @result_sets.next + when PQ::ExtendedQuery + Statement::Pipelined.new(@connection, result.query).perform_query(result.params) + else + raise "Vespene geyser exhausted" + end + end + end + + def close + each + end + end end diff --git a/src/pg/statement.cr b/src/pg/statement.cr index 97a496b6..8c4afd15 100644 --- a/src/pg/statement.cr +++ b/src/pg/statement.cr @@ -8,13 +8,34 @@ class PG::Statement < ::DB::Statement end protected def perform_query(args : Enumerable) : ResultSet + send_query args + ResultSet.new(self, receive_fields) + rescue IO::Error + raise DB::ConnectionLost.new(connection) + end + + protected def perform_exec(args : Enumerable) : ::DB::ExecResult + result = perform_query(args) + result.each { } + ::DB::ExecResult.new( + rows_affected: result.rows_affected, + last_insert_id: 0_i64 # postgres doesn't support this + ) + rescue IO::Error + raise DB::ConnectionLost.new(connection) + end + + protected def send_query(args) params = args.map { |arg| PQ::Param.encode(arg) } - conn = self.conn conn.send_parse_message(command) conn.send_bind_message params conn.send_describe_portal_message conn.send_execute_message conn.send_sync_message + end + + protected def receive_fields + conn.flush conn.expect_frame PQ::Frame::ParseComplete conn.expect_frame PQ::Frame::BindComplete frame = conn.read @@ -26,19 +47,45 @@ class PG::Statement < ::DB::Statement else raise "expected RowDescription or NoData, got #{frame}" end - ResultSet.new(self, fields) - rescue IO::Error - raise DB::ConnectionLost.new(connection) end - protected def perform_exec(args : Enumerable) : ::DB::ExecResult - result = perform_query(args) - result.each { } - ::DB::ExecResult.new( - rows_affected: result.rows_affected, - last_insert_id: 0_i64 # postgres doesn't support this - ) - rescue IO::Error - raise DB::ConnectionLost.new(connection) + class Pipelined < self + protected def perform_query(args : Enumerable) : ResultSet + conn.flush + case frame = conn.expect_frame(PQ::Frame::ParseComplete | PQ::Frame::CommandComplete) + when PQ::Frame::ParseComplete + conn.expect_frame PQ::Frame::BindComplete + when PQ::Frame::CommandComplete + conn.expect_frame PQ::Frame::ReadyForQuery + conn.expect_frame PQ::Frame::ParseComplete + conn.expect_frame PQ::Frame::BindComplete + else + raise "Unexpected frame: #{frame.inspect}" + end + + frame = conn.read + case frame + when PQ::Frame::RowDescription + fields = frame.fields + when PQ::Frame::NoData + fields = nil + else + raise "expected RowDescription or NoData, got #{frame}" + end + ResultSet.new(self, fields) + rescue IO::Error + raise DB::ConnectionLost.new(connection) + end + + protected def perform_exec(args : Enumerable) : ::DB::ExecResult + result = perform_query(args) + result.each { } + ::DB::ExecResult.new( + rows_affected: result.rows_affected, + last_insert_id: 0_i64 # postgres doesn't support this + ) + rescue IO::Error + raise DB::ConnectionLost.new(connection) + end end end diff --git a/src/pq/connection.cr b/src/pq/connection.cr index ee6b13c3..505b40a4 100644 --- a/src/pq/connection.cr +++ b/src/pq/connection.cr @@ -162,7 +162,7 @@ module PQ end end - yield row + row end def read @@ -435,10 +435,8 @@ module PQ def read_all_data_rows type = soc.read_char - loop do - break unless type == 'D' - read_data_row { |row| yield row } - type = soc.read_char + while read_next_row_start + yield read_data_row end expect_frame Frame::CommandComplete, type end @@ -522,12 +520,15 @@ module PQ def send_sync_message write_chr 'S' write_i32 4 - soc.flush end def send_terminate_message write_chr 'X' write_i32 4 end + + def flush + soc.flush + end end end diff --git a/src/pq/query.cr b/src/pq/query.cr index 7e18a2eb..f0163e22 100644 --- a/src/pq/query.cr +++ b/src/pq/query.cr @@ -1,57 +1,44 @@ module PQ # :nodoc: - class ExtendedQuery - getter conn, query, params, fields + struct ExtendedQuery + getter conn, query, params - def initialize(conn, query, params) - encoded_params = params.map { |v| Param.encode(v) } - initialize(conn, query, encoded_params) + def self.new(conn, query, params) + encoded_params = params.map { |v| Param.encode(v).as(Param) } + new(conn, query, encoded_params.to_a) end def initialize(@conn : Connection, @query : String, @params : Array(Param)) + end + + def exec + send + # TODO: How should we process the result? SHOULD we process it here? + end + + def send conn.send_parse_message query conn.send_bind_message params conn.send_describe_portal_message conn.send_execute_message conn.send_sync_message - conn.expect_frame Frame::ParseComplete - conn.expect_frame Frame::BindComplete - - frame = conn.read - if frame.is_a?(Frame::RowDescription) - @fields = frame.fields - @has_data = true - elsif frame.is_a?(Frame::NoData) - @fields = [] of PQ::Field - conn.expect_frame Frame::CommandComplete | Frame::EmptyQueryResponse - conn.expect_frame Frame::ReadyForQuery - @has_data = false - else - raise "expected RowDescription or NoData, got #{frame}" - end - @got_data = false - end - - def get_data - raise "already read data" if @got_data - if @has_data - conn.read_all_data_rows { |row| yield row } - conn.expect_frame Frame::ReadyForQuery - end - @got_data = true end end # :nodoc: - class SimpleQuery + struct SimpleQuery getter conn, query def initialize(@conn : Connection, @query : String) + end + + def exec conn.send_query_message(query) - # read_all_data_rows { |row| yield row } while !conn.read.is_a?(Frame::ReadyForQuery) end + + nil end end end