From 0ee3906b0fa41e1cb7e1e96c44818581bf942406 Mon Sep 17 00:00:00 2001 From: Tim Chepeleff Date: Thu, 20 Nov 2025 11:02:26 -0500 Subject: [PATCH 1/4] Add auth timeout --- README.md | 10 ++++++ lib/redis_client.rb | 19 +++++++++-- lib/redis_client/config.rb | 6 ++++ test/redis_client_test.rb | 64 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1fa96db..afc161e 100644 --- a/README.md +++ b/README.md @@ -503,6 +503,16 @@ RedisClient.config( All timeout values are specified in seconds. +You can also configure a specific timeout to apply only to authentication during the connection handshake: + +```ruby +RedisClient.config( + username: "app", + password: "secret", + auth_timeout: 0.2, # applies to AUTH (or HELLO ... AUTH) during connect +).new +``` + ### Reconnection `redis-client` support automatic reconnection after network errors via the `reconnect_attempts:` configuration option. diff --git a/lib/redis_client.rb b/lib/redis_client.rb index d7b813b..ef8b35f 100644 --- a/lib/redis_client.rb +++ b/lib/redis_client.rb @@ -823,22 +823,37 @@ def connect @raw_connection.retry_attempt = @retry_attempt prelude = config.connection_prelude.dup + timeouts = nil + if (auth_timeout = config.respond_to?(:auth_timeout) ? config.auth_timeout : nil) + unless auth_timeout.nil? + timeouts = Array.new(prelude.size) + prelude.each_with_index do |cmd, idx| + if cmd && !cmd.empty? + if cmd.first == "AUTH" || (cmd.first == "HELLO" && cmd.size >= 3 && cmd[2] == "AUTH") + timeouts[idx] = auth_timeout + end + end + end + end + end if id prelude << ["CLIENT", "SETNAME", id] + timeouts << nil if timeouts end # The connection prelude is deliberately not sent to Middlewares if config.sentinel? prelude << ["ROLE"] + timeouts << nil if timeouts role, = @middlewares.call_pipelined(prelude, config) do - @raw_connection.call_pipelined(prelude, nil).last + @raw_connection.call_pipelined(prelude, timeouts).last end config.check_role!(role) else unless prelude.empty? @middlewares.call_pipelined(prelude, config) do - @raw_connection.call_pipelined(prelude, nil) + @raw_connection.call_pipelined(prelude, timeouts) end end end diff --git a/lib/redis_client/config.rb b/lib/redis_client/config.rb index 94252ab..7b0be2a 100644 --- a/lib/redis_client/config.rb +++ b/lib/redis_client/config.rb @@ -27,6 +27,7 @@ def initialize( read_timeout: timeout, write_timeout: timeout, connect_timeout: timeout, + auth_timeout: nil, ssl: nil, custom: {}, ssl_params: nil, @@ -54,6 +55,7 @@ def initialize( @connect_timeout = connect_timeout @read_timeout = read_timeout @write_timeout = write_timeout + @auth_timeout = auth_timeout @driver = driver ? RedisClient.driver(driver) : RedisClient.default_driver @@ -123,6 +125,10 @@ def username @username || DEFAULT_USERNAME end + def auth_timeout + @auth_timeout + end + def resolved? true end diff --git a/test/redis_client_test.rb b/test/redis_client_test.rb index 8b1e361..edee862 100644 --- a/test/redis_client_test.rb +++ b/test/redis_client_test.rb @@ -15,6 +15,70 @@ def test_preselect_database assert_includes client.call("CLIENT", "INFO"), " db=5 " end + def test_auth_timeout_applied_resp3 + capturing_driver = Class.new(RedisClient::RubyConnection) do + class << self + attr_accessor :last_timeouts + end + def call_pipelined(commands, timeouts, exception: true) + self.class.last_timeouts = timeouts + Array.new(commands.size, "OK") + end + end + client = new_client( + driver: capturing_driver, + username: "user", + password: "pass", + auth_timeout: 0.123, + protocol: 3 + ) + client.call("PING") + assert_equal [0.123], capturing_driver.last_timeouts + end + + def test_auth_timeout_applied_resp2 + capturing_driver = Class.new(RedisClient::RubyConnection) do + class << self + attr_accessor :last_timeouts + end + def call_pipelined(commands, timeouts, exception: true) + self.class.last_timeouts = timeouts + Array.new(commands.size, "OK") + end + end + client = new_client( + driver: capturing_driver, + username: "user", + password: "pass", + auth_timeout: 0.456, + protocol: 2 + ) + client.call("PING") + assert_equal [0.456], capturing_driver.last_timeouts + end + + def test_auth_timeout_only_applies_to_auth_commands + capturing_driver = Class.new(RedisClient::RubyConnection) do + class << self + attr_accessor :last_timeouts + end + def call_pipelined(commands, timeouts, exception: true) + self.class.last_timeouts = timeouts + Array.new(commands.size, "OK") + end + end + client = new_client( + driver: capturing_driver, + username: "user", + password: "pass", + db: 5, + auth_timeout: 0.789, + protocol: 2 + ) + client.call("PING") + assert_equal [0.789, nil], capturing_driver.last_timeouts + end + def test_set_client_id client = new_client(id: "peter") assert_includes client.call("CLIENT", "INFO"), " name=peter " From 359e447e9de7bda40d8faaa5647fcfb39613712c Mon Sep 17 00:00:00 2001 From: Tim Chepeleff Date: Thu, 20 Nov 2025 11:21:32 -0500 Subject: [PATCH 2/4] Add auth timeout --- lib/redis_client.rb | 31 ++++++++++++++++++------------- lib/redis_client/config.rb | 1 + 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/lib/redis_client.rb b/lib/redis_client.rb index ef8b35f..1880d33 100644 --- a/lib/redis_client.rb +++ b/lib/redis_client.rb @@ -823,19 +823,7 @@ def connect @raw_connection.retry_attempt = @retry_attempt prelude = config.connection_prelude.dup - timeouts = nil - if (auth_timeout = config.respond_to?(:auth_timeout) ? config.auth_timeout : nil) - unless auth_timeout.nil? - timeouts = Array.new(prelude.size) - prelude.each_with_index do |cmd, idx| - if cmd && !cmd.empty? - if cmd.first == "AUTH" || (cmd.first == "HELLO" && cmd.size >= 3 && cmd[2] == "AUTH") - timeouts[idx] = auth_timeout - end - end - end - end - end + timeouts = build_prelude_timeouts(prelude, config.auth_timeout) if id prelude << ["CLIENT", "SETNAME", id] @@ -872,6 +860,23 @@ def connect raise end end + + # Build the per-command timeouts for the connection prelude. + # Only AUTH-related steps should be bounded by auth_timeout. + # Returns nil if no timeout applies so downstream can skip passing it. + def build_prelude_timeouts(prelude, auth_timeout) + return nil unless auth_timeout + + timeouts = Array.new(prelude.size) + prelude.each_with_index do |command, index| + next if !command || command.empty? + name = command.first + if name == "AUTH" || (name == "HELLO" && command.include?("AUTH")) + timeouts[index] = auth_timeout + end + end + timeouts + end end require "redis_client/pooled" diff --git a/lib/redis_client/config.rb b/lib/redis_client/config.rb index 7b0be2a..cf3a04d 100644 --- a/lib/redis_client/config.rb +++ b/lib/redis_client/config.rb @@ -56,6 +56,7 @@ def initialize( @read_timeout = read_timeout @write_timeout = write_timeout @auth_timeout = auth_timeout + @auth_timeout = nil if @auth_timeout && @auth_timeout <= 0 @driver = driver ? RedisClient.driver(driver) : RedisClient.default_driver From 416117519431fad766c46e4326518b396c29fe33 Mon Sep 17 00:00:00 2001 From: Tim Chepeleff Date: Thu, 20 Nov 2025 11:29:41 -0500 Subject: [PATCH 3/4] Add auth timeout --- lib/redis_client.rb | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/lib/redis_client.rb b/lib/redis_client.rb index 1880d33..6ad25d9 100644 --- a/lib/redis_client.rb +++ b/lib/redis_client.rb @@ -867,15 +867,30 @@ def connect def build_prelude_timeouts(prelude, auth_timeout) return nil unless auth_timeout - timeouts = Array.new(prelude.size) - prelude.each_with_index do |command, index| - next if !command || command.empty? - name = command.first - if name == "AUTH" || (name == "HELLO" && command.include?("AUTH")) - timeouts[index] = auth_timeout + auth_seen = false + timeouts = prelude.map do |command| + if auth_command?(command) + auth_seen = true + auth_timeout + else + nil end end - timeouts + + auth_seen ? timeouts : nil + end + + def auth_command?(command) + return false unless command&.any? + + case command.first + when "AUTH" + true + when "HELLO" + command.size >= 3 && command[2] == "AUTH" + else + false + end end end From 83bbcf577b806b6300a4f4364649049a5752f2a1 Mon Sep 17 00:00:00 2001 From: Tim Chepeleff Date: Fri, 21 Nov 2025 10:52:32 -0500 Subject: [PATCH 4/4] Address comments --- README.md | 14 +++++---- lib/redis_client.rb | 43 ++++++++++++++------------- lib/redis_client/middlewares.rb | 33 ++++++++++++++++++++ test/redis_client/middlewares_test.rb | 30 +++++++++++++++++++ test/redis_client_test.rb | 12 ++++---- 5 files changed, 100 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index afc161e..422a809 100644 --- a/README.md +++ b/README.md @@ -419,17 +419,21 @@ module MyGlobalRedisInstrumentation MyMonitoringService.instrument("redis.connect") { super } end - def call(command, redis_config) + def call(command, redis_config, context = nil) MyMonitoringService.instrument("redis.query") { super } end - def call_pipelined(commands, redis_config) + def call_pipelined(commands, redis_config, context = nil) MyMonitoringService.instrument("redis.pipeline") { super } end end RedisClient.register(MyGlobalRedisInstrumentation) ``` +Middleware callbacks can optionally accept a third `context` Hash. When present it carries extra information about the current stage of the client. +For instance, during the connection prelude the context includes `stage: :connection_prelude` and the underlying `connection`, which allows a middleware +to temporarily tweak socket timeouts around the initial `AUTH/HELLO` handshake without affecting other commands. + Note that `RedisClient.register` is global and apply to all `RedisClient` instances. To add middlewares to only a single client, you can provide them when creating the config: @@ -447,11 +451,11 @@ module MyGlobalRedisInstrumentation MyMonitoringService.instrument("redis.connect", tags: redis_config.custom[:tags]) { super } end - def call(command, redis_config) + def call(command, redis_config, context = nil) MyMonitoringService.instrument("redis.query", tags: redis_config.custom[:tags]) { super } end - def call_pipelined(commands, redis_config) + def call_pipelined(commands, redis_config, context = nil) MyMonitoringService.instrument("redis.pipeline", tags: redis_config.custom[:tags]) { super } end end @@ -469,7 +473,7 @@ In many cases you may want to ignore retriable errors, or report them differentl ```ruby module MyGlobalRedisInstrumentation - def call(command, redis_config) + def call(command, redis_config, context = nil) super rescue RedisClient::Error => error if error.final? diff --git a/lib/redis_client.rb b/lib/redis_client.rb index 6ad25d9..bf7238a 100644 --- a/lib/redis_client.rb +++ b/lib/redis_client.rb @@ -326,7 +326,7 @@ def pubsub def measure_round_trip_delay ensure_connected do |connection| - @middlewares.call(["PING"], config) do + @middlewares.call_with_context(["PING"], config) do connection.measure_round_trip_delay end end @@ -335,7 +335,7 @@ def measure_round_trip_delay def call(*command, **kwargs) command = @command_builder.generate(command, kwargs) result = ensure_connected do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, nil) end end @@ -350,7 +350,7 @@ def call(*command, **kwargs) def call_v(command) command = @command_builder.generate(command) result = ensure_connected do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, nil) end end @@ -365,7 +365,7 @@ def call_v(command) def call_once(*command, **kwargs) command = @command_builder.generate(command, kwargs) result = ensure_connected(retryable: false) do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, nil) end end @@ -380,7 +380,7 @@ def call_once(*command, **kwargs) def call_once_v(command) command = @command_builder.generate(command) result = ensure_connected(retryable: false) do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, nil) end end @@ -396,7 +396,7 @@ def blocking_call(timeout, *command, **kwargs) command = @command_builder.generate(command, kwargs) error = nil result = ensure_connected do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, timeout) end rescue ReadTimeoutError => error @@ -416,7 +416,7 @@ def blocking_call_v(timeout, command) command = @command_builder.generate(command) error = nil result = ensure_connected do |connection| - @middlewares.call(command, config) do + @middlewares.call_with_context(command, config) do connection.call(command, timeout) end rescue ReadTimeoutError => error @@ -490,7 +490,7 @@ def pipelined(exception: true) else results = ensure_connected(retryable: pipeline._retryable?) do |connection| commands = pipeline._commands - @middlewares.call_pipelined(commands, config) do + @middlewares.call_pipelined_with_context(commands, config) do connection.call_pipelined(commands, pipeline._timeouts, exception: exception) end end @@ -510,7 +510,7 @@ def multi(watch: nil, &block) begin if transaction = build_transaction(&block) commands = transaction._commands - results = @middlewares.call_pipelined(commands, config) do + results = @middlewares.call_pipelined_with_context(commands, config) do connection.call_pipelined(commands, nil) end.last else @@ -529,7 +529,7 @@ def multi(watch: nil, &block) else ensure_connected(retryable: transaction._retryable?) do |connection| commands = transaction._commands - @middlewares.call_pipelined(commands, config) do + @middlewares.call_pipelined_with_context(commands, config) do connection.call_pipelined(commands, nil) end.last end @@ -805,13 +805,14 @@ def raw_connection def connect @pid = PIDCache.pid + connect_context = { stage: :connect } if @raw_connection&.revalidate - @middlewares.connect(config) do + @middlewares.connect_with_context(config, connect_context) do @raw_connection.reconnect end else - @raw_connection = @middlewares.connect(config) do + @raw_connection = @middlewares.connect_with_context(config, connect_context) do config.driver.new( config, connect_timeout: connect_timeout, @@ -830,17 +831,19 @@ def connect timeouts << nil if timeouts end - # The connection prelude is deliberately not sent to Middlewares + prelude_context = { stage: :connection_prelude, connection: @raw_connection } + + # The connection prelude goes through middlewares with a dedicated context. if config.sentinel? prelude << ["ROLE"] timeouts << nil if timeouts - role, = @middlewares.call_pipelined(prelude, config) do + role, = @middlewares.call_pipelined_with_context(prelude, config, prelude_context) do @raw_connection.call_pipelined(prelude, timeouts).last end config.check_role!(role) else unless prelude.empty? - @middlewares.call_pipelined(prelude, config) do + @middlewares.call_pipelined_with_context(prelude, config, prelude_context) do @raw_connection.call_pipelined(prelude, timeouts) end end @@ -869,12 +872,10 @@ def build_prelude_timeouts(prelude, auth_timeout) auth_seen = false timeouts = prelude.map do |command| - if auth_command?(command) - auth_seen = true - auth_timeout - else - nil - end + next unless auth_command?(command) + + auth_seen = true + auth_timeout end auth_seen ? timeouts : nil diff --git a/lib/redis_client/middlewares.rb b/lib/redis_client/middlewares.rb index f090bd2..618e1cd 100644 --- a/lib/redis_client/middlewares.rb +++ b/lib/redis_client/middlewares.rb @@ -16,6 +16,39 @@ def call(command, _config) yield command end alias_method :call_pipelined, :call + + # These helpers keep backward compatibility with two-argument middlewares + # while allowing newer ones to accept a third `context` parameter. + def connect_with_context(config, context = nil, &block) + invoke_with_optional_context(:connect, [config], context, &block) + end + + def call_with_context(command, config, context = nil, &block) + invoke_with_optional_context(:call, [command, config], context, &block) + end + + def call_pipelined_with_context(commands, config, context = nil, &block) + invoke_with_optional_context(:call_pipelined, [commands, config], context, &block) + end + + private + + def invoke_with_optional_context(method_name, args, context, &block) + method_obj = method(method_name) + if context && accepts_extra_positional_arg?(method_obj, args.length) + method_obj.call(*args, context, &block) + else + method_obj.call(*args, &block) + end + end + + def accepts_extra_positional_arg?(method_obj, required_args) + parameters = method_obj.parameters + return true if parameters.any? { |type, _| type == :rest } + + positional_count = parameters.count { |type, _| type == :req || type == :opt } + positional_count >= (required_args + 1) + end end class Middlewares < BasicMiddleware diff --git a/test/redis_client/middlewares_test.rb b/test/redis_client/middlewares_test.rb index f3f0374..3b0ba14 100644 --- a/test/redis_client/middlewares_test.rb +++ b/test/redis_client/middlewares_test.rb @@ -223,12 +223,42 @@ def call_pipelined(commands, _config, &_) end end + module PreludeContextMiddleware + class << self + attr_accessor :contexts, :client + end + @contexts = [] + + def initialize(client) + super + PreludeContextMiddleware.client = client + end + + def call_pipelined(commands, config, context = nil, &block) + PreludeContextMiddleware.contexts << context if context + super + end + end + def test_instance_middleware second_client = new_client(middlewares: [DummyMiddleware]) assert_equal ["GET", "2"], second_client.call("GET", 2) assert_equal([["GET", "2"]], second_client.pipelined { |p| p.call("GET", 2) }) end + def test_prelude_context_is_exposed + client = new_client(middlewares: [PreludeContextMiddleware]) + client.call("PING") + + context = PreludeContextMiddleware.contexts.find { |ctx| ctx && ctx[:stage] == :connection_prelude } + refute_nil context + assert_equal :connection_prelude, context[:stage] + refute_nil context[:connection] + assert_kind_of RedisClient, PreludeContextMiddleware.client + ensure + PreludeContextMiddleware.contexts.clear + end + private def assert_call(call) diff --git a/test/redis_client_test.rb b/test/redis_client_test.rb index edee862..0bdcf79 100644 --- a/test/redis_client_test.rb +++ b/test/redis_client_test.rb @@ -20,7 +20,7 @@ def test_auth_timeout_applied_resp3 class << self attr_accessor :last_timeouts end - def call_pipelined(commands, timeouts, exception: true) + def call_pipelined(commands, timeouts, exception: true) # rubocop:disable Lint/UnusedMethodArgument self.class.last_timeouts = timeouts Array.new(commands.size, "OK") end @@ -30,7 +30,7 @@ def call_pipelined(commands, timeouts, exception: true) username: "user", password: "pass", auth_timeout: 0.123, - protocol: 3 + protocol: 3, ) client.call("PING") assert_equal [0.123], capturing_driver.last_timeouts @@ -41,7 +41,7 @@ def test_auth_timeout_applied_resp2 class << self attr_accessor :last_timeouts end - def call_pipelined(commands, timeouts, exception: true) + def call_pipelined(commands, timeouts, exception: true) # rubocop:disable Lint/UnusedMethodArgument self.class.last_timeouts = timeouts Array.new(commands.size, "OK") end @@ -51,7 +51,7 @@ def call_pipelined(commands, timeouts, exception: true) username: "user", password: "pass", auth_timeout: 0.456, - protocol: 2 + protocol: 2, ) client.call("PING") assert_equal [0.456], capturing_driver.last_timeouts @@ -62,7 +62,7 @@ def test_auth_timeout_only_applies_to_auth_commands class << self attr_accessor :last_timeouts end - def call_pipelined(commands, timeouts, exception: true) + def call_pipelined(commands, timeouts, exception: true) # rubocop:disable Lint/UnusedMethodArgument self.class.last_timeouts = timeouts Array.new(commands.size, "OK") end @@ -73,7 +73,7 @@ def call_pipelined(commands, timeouts, exception: true) password: "pass", db: 5, auth_timeout: 0.789, - protocol: 2 + protocol: 2, ) client.call("PING") assert_equal [0.789, nil], capturing_driver.last_timeouts