diff --git a/lib/net/ldap/connection.rb b/lib/net/ldap/connection.rb index f0e5519d..ac47e6e0 100644 --- a/lib/net/ldap/connection.rb +++ b/lib/net/ldap/connection.rb @@ -111,6 +111,45 @@ def close @conn = nil end + # Internal: Reads messages by ID from a queue, falling back to reading from + # the connected socket until a message matching the ID is read. Any messages + # with mismatched IDs gets queued for subsequent reads by the origin of that + # message ID. + # + # Returns a Net::LDAP::PDU object or nil. + def queued_read(message_id) + if pdu = message_queue[message_id].shift + return pdu + end + + # read messages until we have a match for the given message_id + while pdu = read + if pdu.message_id == message_id + return pdu + else + message_queue[pdu.message_id].push pdu + next + end + end + + pdu + end + + # Internal: The internal queue of messages, read from the socket, grouped by + # message ID. + # + # Used by `queued_read` to return messages sent by the server with the given + # ID. If no messages are queued for that ID, `queued_read` will `read` from + # the socket and queue messages that don't match the given ID for other + # readers. + # + # Returns the message queue Hash. + def message_queue + @message_queue ||= Hash.new do |hash, key| + hash[key] = [] + end + end + # Internal: Reads and parses data from the configured connection. # # - syntax: the BER syntax to use to parse the read data with @@ -146,9 +185,9 @@ def read(syntax = Net::LDAP::AsnSyntax) # # Returns the return value from writing to the connection, which in some # cases is the Integer number of bytes written to the socket. - def write(request, controls = nil) + def write(request, controls = nil, message_id = next_msgid) instrument "write.net_ldap_connection" do |payload| - packet = [next_msgid.to_ber, request, controls].compact.to_ber_sequence + packet = [message_id.to_ber, request, controls].compact.to_ber_sequence payload[:content_length] = @conn.write(packet) end end @@ -356,7 +395,10 @@ def search(args = {}) result_pdu = nil n_results = 0 + message_id = next_msgid + instrument "search.net_ldap_connection", + :message_id => message_id, :filter => search_filter, :base => search_base, :scope => scope, @@ -403,12 +445,12 @@ def search(args = {}) controls << sort_control if sort_control controls = controls.empty? ? nil : controls.to_ber_contextspecific(0) - write(request, controls) + write(request, controls, message_id) result_pdu = nil controls = [] - while pdu = read + while pdu = queued_read(message_id) case pdu.app_tag when Net::LDAP::PDU::SearchReturnedData n_results += 1 @@ -476,6 +518,14 @@ def search(args = {}) result_pdu || OpenStruct.new(:status => :failure, :result_code => 1, :message => "Invalid search") end # instrument + ensure + # clean up message queue for this search + messages = message_queue.delete(message_id) + + unless messages.empty? + instrument "search_messages_unread.net_ldap_connection", + message_id: message_id, messages: messages + end end MODIFY_OPERATIONS = { #:nodoc: diff --git a/test/test_ldap_connection.rb b/test/test_ldap_connection.rb index 0c3c5f34..7ed75113 100644 --- a/test/test_ldap_connection.rb +++ b/test/test_ldap_connection.rb @@ -185,20 +185,21 @@ def test_bind_net_ldap_connection_event def test_search_net_ldap_connection_event # search data - search_data_ber = Net::BER::BerIdentifiedArray.new([2, [ + search_data_ber = Net::BER::BerIdentifiedArray.new([1, [ "uid=user1,ou=OrgUnit2,ou=OrgUnitTop,dc=openldap,dc=ghe,dc=local", [ ["uid", ["user1"]] ] ]]) search_data_ber.ber_identifier = Net::LDAP::PDU::SearchReturnedData - search_data = [2, search_data_ber] + search_data = [1, search_data_ber] # search result (end of results) search_result_ber = Net::BER::BerIdentifiedArray.new([0, "", ""]) search_result_ber.ber_identifier = Net::LDAP::PDU::SearchResult - search_result = [2, search_result_ber] + search_result = [1, search_result_ber] @tcp_socket.should_receive(:read_ber).and_return(search_data). and_return(search_result) events = @service.subscribe "search.net_ldap_connection" + unread = @service.subscribe "search_messages_unread.net_ldap_connection" result = @connection.search(filter: "(uid=user1)") assert result.success?, "should be success" @@ -209,5 +210,8 @@ def test_search_net_ldap_connection_event assert payload.has_key?(:filter) assert_equal "(uid=user1)", payload[:filter].to_s assert result + + # ensure no unread + assert unread.empty?, "should not have any leftover unread messages" end end