diff --git a/options.go b/options.go index dff52ae8b..39cb1714d 100644 --- a/options.go +++ b/options.go @@ -147,6 +147,11 @@ type Options struct { // Add suffix to client name. Default is empty. IdentitySuffix string + + // Use connections from pool instead of creating new ones. Note that after use these connections will not be + // returned to the pool. Useful for managing the total Redis connection limit for a mix of Pubsub & other commands. + // Applies only to non-cluster client. Default is false. + PubsubFromPool bool } func (opt *Options) init() { diff --git a/pubsub.go b/pubsub.go index 5df537c42..764cb23a7 100644 --- a/pubsub.go +++ b/pubsub.go @@ -12,6 +12,9 @@ import ( "github.com/redis/go-redis/v9/internal/proto" ) +type PubsubNewConnFunc func(ctx context.Context, channels []string) (*pool.Conn, error) +type PubsubCloseConnFunc func(*pool.Conn) error + // PubSub implements Pub/Sub commands as described in // http://redis.io/topics/pubsub. Message receiving is NOT safe // for concurrent use by multiple goroutines. @@ -21,8 +24,8 @@ import ( type PubSub struct { opt *Options - newConn func(ctx context.Context, channels []string) (*pool.Conn, error) - closeConn func(*pool.Conn) error + newConn PubsubNewConnFunc + closeConn PubsubCloseConnFunc mu sync.Mutex cn *pool.Conn diff --git a/pubsub_test.go b/pubsub_test.go index a76100659..553f65b0a 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -567,4 +567,104 @@ var _ = Describe("PubSub", func() { Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal(text)) }) + + It("should not use connections from pool", func() { + statsBefore := client.PoolStats() + + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + stats := client.PoolStats() + // A connection has been created + Expect(stats.TotalConns - statsBefore.TotalConns).To(Equal(uint32(1))) + // But it's not taken from the pool + poolFetchesBefore := statsBefore.Hits + statsBefore.Misses + poolFetchesAfter := stats.Hits + stats.Misses + Expect(poolFetchesAfter - poolFetchesBefore).To(Equal(uint32(0))) + + pubsub.Close() + + stats = client.PoolStats() + // The connection no longer exists + Expect(stats.TotalConns - statsBefore.TotalConns).To(Equal(uint32(0))) + Expect(stats.IdleConns - statsBefore.IdleConns).To(Equal(uint32(0))) + }) +}) + +var _ = Describe("PubSub with PubsubFromPool set", func() { + var client *redis.Client + + BeforeEach(func() { + opt := redisOptions() + opt.MinIdleConns = 0 + opt.ConnMaxLifetime = 0 + opt.PubsubFromPool = true + // zero value ends up using default so set small instead + opt.PoolTimeout = time.Microsecond + client = redis.NewClient(opt) + Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(client.Close()).NotTo(HaveOccurred()) + }) + + It("should use connection from pool", func() { + statsBefore := client.PoolStats() + + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + stats := client.PoolStats() + // A connection has been taken from the pool + Expect(stats.Hits - statsBefore.Hits).To(Equal(uint32(1))) + statsDuring := client.PoolStats() + + pubsub.Close() + + stats = client.PoolStats() + // It's not returned to the idle pool .. + Expect(statsDuring.IdleConns - stats.IdleConns).To(Equal(uint32(0))) + // .. and has been terminated + Expect(statsDuring.TotalConns - stats.TotalConns).To(Equal(uint32(1))) + }) + + It("should respect pool size limit", func() { + poolSize := client.Options().PoolSize + statsBefore := client.PoolStats() + + var pubsubs []*redis.PubSub + for i := 0; i < poolSize; i++ { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + pubsubs = append(pubsubs, pubsub) + } + + statsDuring := client.PoolStats() + poolFetchesBefore := statsBefore.Hits + statsBefore.Misses + poolFetchesAfter := statsDuring.Hits + statsDuring.Misses + + // A total of poolSize connections should have been taken from the pool (new or existing) + Expect(poolFetchesAfter - poolFetchesBefore).To(Equal(uint32(poolSize))) + + // The next pubsub connection should fail to connect (waiting for pool) + extraPubsub := client.Subscribe(ctx, "mychannel") + defer extraPubsub.Close() + Expect(client.PoolStats().Timeouts - statsDuring.Timeouts).To(Equal(uint32(1))) + + // As should retries + err := extraPubsub.Ping(ctx) + Expect(err).To(MatchError(ContainSubstring("connection pool timeout"))) + Expect(client.PoolStats().Timeouts - statsDuring.Timeouts).To(Equal(uint32(2))) + + for _, pubsub := range pubsubs { + pubsub.Close() + } + + stats := client.PoolStats() + // Connections are not returned to the idle pool .. + Expect(statsDuring.IdleConns - stats.IdleConns).To(Equal(uint32(0))) + // .. and have been terminated + Expect(statsDuring.TotalConns - stats.TotalConns).To(Equal(uint32(poolSize))) + }) }) diff --git a/redis.go b/redis.go index 4dd862b84..185082332 100644 --- a/redis.go +++ b/redis.go @@ -199,6 +199,9 @@ type baseClient struct { opt *Options connPool pool.Pooler + pubsubNewConn PubsubNewConnFunc + pubsubCloseConn PubsubCloseConnFunc + onClose func() error // hook called when client is closed } @@ -368,6 +371,13 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) } } +func (c *baseClient) removeConn(ctx context.Context, cn *pool.Conn, err error) { + if c.opt.Limiter != nil { + c.opt.Limiter.ReportResult(err) + } + c.connPool.Remove(ctx, cn, err) +} + func (c *baseClient) withConn( ctx context.Context, fn func(context.Context, *pool.Conn) error, ) error { @@ -649,6 +659,28 @@ func (c *Client) init() { pipeline: c.baseClient.processPipeline, txPipeline: c.baseClient.processTxPipeline, }) + + if c.opt.PubsubFromPool { + // Take connections from pool and remove them from pool afterwards. (Pubsub & other connections are managed + // together.) + c.pubsubNewConn = func(ctx context.Context, channels []string) (*pool.Conn, error) { + return c.getConn(ctx) + } + c.pubsubCloseConn = func(conn *pool.Conn) error { + c.removeConn(context.TODO(), conn, nil) + return nil + } + } else { + // Make brand new connection from pool and close it afterwards. (Pubsub & other connections are managed + // independently other than that pubsub connection can no longer be created once the pool is full.) + c.pubsubNewConn = func(ctx context.Context, channels []string) (*pool.Conn, error) { + return c.newConn(ctx) + } + // wrapping in closure since pool has not been initialised yet + c.pubsubCloseConn = func(conn *pool.Conn) error { + return c.connPool.CloseConn(conn) + } + } } func (c *Client) WithTimeout(timeout time.Duration) *Client { @@ -720,10 +752,8 @@ func (c *Client) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) - }, - closeConn: c.connPool.CloseConn, + newConn: c.pubsubNewConn, + closeConn: c.pubsubCloseConn, } pubsub.init() return pubsub