diff --git a/Makefile b/Makefile index 24bcf8c..a26d72d 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ test_short: go test ./... -short test_race: - go test ./... -short -race + go test ./... -race test_stress: go test -v -tags=stress -timeout=45m ./... diff --git a/go.mod b/go.mod index bd919ba..c2bab3d 100644 --- a/go.mod +++ b/go.mod @@ -9,20 +9,27 @@ require ( github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.1 github.com/vmihailenco/msgpack v4.0.4+incompatible + github.com/wk8/go-error-buffer v0.0.0-20230515211523-1bb61b128a10 ) require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/friendsofgo/errors v0.9.2 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/uuid v1.3.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/lithammer/shortuuid/v3 v3.0.7 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.7 // indirect golang.org/x/net v0.5.0 // indirect + golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 07e799d..fee7d32 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,10 @@ github.com/Rican7/retry v0.3.1 h1:scY4IbO8swckzoA/11HgBwaZRJEyY9vaNJshcdhp1Mc= github.com/Rican7/retry v0.3.1/go.mod h1:CxSDrhAyXmTMeEuRAnArMu1FHu48vtfjLREWqVl7Vw0= github.com/ThreeDotsLabs/watermill v1.2.0 h1:TU3TML1dnQ/ifK09F2+4JQk2EKhmhXe7Qv7eb5ZpTS8= github.com/ThreeDotsLabs/watermill v1.2.0/go.mod h1:IuVxGk/kgCN0cex2S94BLglUiB0PwOm8hbUhm6g2Nx4= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -9,6 +13,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/friendsofgo/errors v0.9.2 h1:X6NYxef4efCBdwI7BgS820zFaN7Cphrmb+Pljdzjtgk= +github.com/friendsofgo/errors v0.9.2/go.mod h1:yCvFW5AkDIL9qn7suHVLiI/gH228n7PC4Pn44IGoTOI= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= @@ -26,8 +32,11 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/lithammer/shortuuid/v3 v3.0.7 h1:trX0KTHy4Pbwo/6ia8fscyHoGA+mf1jWbPJVuvyJQQ8= github.com/lithammer/shortuuid/v3 v3.0.7/go.mod h1:vMk8ke37EmiewwolSO1NLW8vP4ZaKlRuDIi8tWWmAts= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= @@ -46,6 +55,10 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/vmihailenco/msgpack v4.0.4+incompatible h1:dSLoQfGFAo3F6OoNhwUmLwVgaUXK79GlxNBwueZn0xI= github.com/vmihailenco/msgpack v4.0.4+incompatible/go.mod h1:fy3FlTQTDXWkZ7Bh6AcGMlsjHatGryHQYUTf1ShIgkk= +github.com/wk8/go-error-buffer v0.0.0-20230515211523-1bb61b128a10 h1:YX7AEbO6/OLivzT2piLGjZyxv30cmMs2weQ587w/JQI= +github.com/wk8/go-error-buffer v0.0.0-20230515211523-1bb61b128a10/go.mod h1:N0jirnKcRGOtdZiyUpIq/yxrLNzZL7snWCjqTayQTaQ= +github.com/wk8/go-ordered-map/v2 v2.1.7 h1:aUZ1xBMdbvY8wnNt77qqo4nyT3y0pX4Usat48Vm+hik= +github.com/wk8/go-ordered-map/v2 v2.1.7/go.mod h1:9Xvgm2mV2kSq2SAm0Y608tBmu8akTzI7c2bz7/G7ZN4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= @@ -56,6 +69,7 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= diff --git a/pkg/redisstream/publisher.go b/pkg/redisstream/publisher.go index 1c45879..8d74208 100644 --- a/pkg/redisstream/publisher.go +++ b/pkg/redisstream/publisher.go @@ -2,21 +2,16 @@ package redisstream import ( "context" - "sync" - "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/message" - "github.com/pkg/errors" "github.com/go-redis/redis/v8" + "github.com/pkg/errors" ) type Publisher struct { config PublisherConfig client redis.UniversalClient logger watermill.LoggerAdapter - - closed bool - closeMutex sync.Mutex } // NewPublisher creates a new redis stream Publisher. @@ -35,7 +30,6 @@ func NewPublisher(config PublisherConfig, logger watermill.LoggerAdapter) (*Publ config: config, client: config.Client, logger: logger, - closed: false, }, nil } @@ -69,10 +63,6 @@ func (c *PublisherConfig) Validate() error { // Publish is blocking and wait for redis response // When one of messages delivery fails - function is interrupted. func (p *Publisher) Publish(topic string, msgs ...*message.Message) error { - if p.closed { - return errors.New("publisher closed") - } - logFields := make(watermill.LogFields, 3) logFields["topic"] = topic @@ -108,17 +98,5 @@ func (p *Publisher) Publish(topic string, msgs ...*message.Message) error { } func (p *Publisher) Close() error { - p.closeMutex.Lock() - defer p.closeMutex.Unlock() - - if p.closed { - return nil - } - p.closed = true - - if err := p.client.Close(); err != nil { - return err - } - return nil } diff --git a/pkg/redisstream/pubsub_test.go b/pkg/redisstream/pubsub_test.go index 75df910..18edabb 100644 --- a/pkg/redisstream/pubsub_test.go +++ b/pkg/redisstream/pubsub_test.go @@ -18,6 +18,9 @@ import ( "github.com/stretchr/testify/require" ) +// should be long enough to be robust even for CI boxes +const testInterval = 250 * time.Millisecond + func redisClient() (redis.UniversalClient, error) { client := redis.NewClient(&redis.Options{ Addr: "127.0.0.1:6379", @@ -64,8 +67,7 @@ func createPubSubWithConsumerGroup(t *testing.T, consumerGroup string) (message. Client: redisClientOrFail(t), Consumer: watermill.NewShortUUID(), ConsumerGroup: consumerGroup, - BlockTime: 10 * time.Millisecond, - ClaimInterval: 3 * time.Second, + ClaimInterval: 10 * time.Millisecond, MaxIdleTime: 5 * time.Second, }) } @@ -129,79 +131,7 @@ func TestSubscriber(t *testing.T) { require.NoError(t, subscriber.Close()) } -func TestFanOut(t *testing.T) { - topic := watermill.NewShortUUID() - - subscriber1, err := NewSubscriber( - SubscriberConfig{ - Client: redisClientOrFail(t), - Consumer: watermill.NewShortUUID(), - ConsumerGroup: "", - }, - watermill.NewStdLogger(true, false), - ) - require.NoError(t, err) - - subscriber2, err := NewSubscriber( - SubscriberConfig{ - Client: redisClientOrFail(t), - Consumer: watermill.NewShortUUID(), - ConsumerGroup: "", - }, - watermill.NewStdLogger(true, false), - ) - require.NoError(t, err) - - publisher, err := NewPublisher( - PublisherConfig{ - Client: redisClientOrFail(t), - }, - watermill.NewStdLogger(false, false), - ) - require.NoError(t, err) - for i := 0; i < 10; i++ { - require.NoError(t, publisher.Publish(topic, message.NewMessage(watermill.NewShortUUID(), []byte("test"+strconv.Itoa(i))))) - } - - messages1, err := subscriber1.Subscribe(context.Background(), topic) - require.NoError(t, err) - messages2, err := subscriber2.Subscribe(context.Background(), topic) - require.NoError(t, err) - - // wait for initial XREAD before publishing messages to avoid message loss - time.Sleep(2 * DefaultBlockTime) - for i := 10; i < 50; i++ { - require.NoError(t, publisher.Publish(topic, message.NewMessage(watermill.NewShortUUID(), []byte("test"+strconv.Itoa(i))))) - } - - for i := 10; i < 50; i++ { - msg := <-messages1 - if msg == nil { - t.Fatal("msg nil") - } - t.Logf("subscriber 1: %v %v %v", msg.UUID, msg.Metadata, string(msg.Payload)) - require.Equal(t, string(msg.Payload), "test"+strconv.Itoa(i)) - msg.Ack() - } - for i := 10; i < 50; i++ { - msg := <-messages2 - if msg == nil { - t.Fatal("msg nil") - } - t.Logf("subscriber 2: %v %v %v", msg.UUID, msg.Metadata, string(msg.Payload)) - require.Equal(t, string(msg.Payload), "test"+strconv.Itoa(i)) - msg.Ack() - } - - require.NoError(t, publisher.Close()) - require.NoError(t, subscriber1.Close()) - require.NoError(t, subscriber2.Close()) -} - func TestClaimIdle(t *testing.T) { - // should be long enough to be robust even for CI boxes - testInterval := 250 * time.Millisecond - topic := watermill.NewShortUUID() consumerGroup := watermill.NewShortUUID() testLogger := watermill.NewStdLogger(true, false) @@ -228,7 +158,7 @@ func TestClaimIdle(t *testing.T) { // handles loop variables in function literals subID := subscriberID - suscriber, err := NewSubscriber( + subscriber, err := NewSubscriber( SubscriberConfig{ Client: redisClientOrFail(t), Consumer: strconv.Itoa(subID), @@ -264,7 +194,7 @@ func TestClaimIdle(t *testing.T) { router.AddNoPublisherHandler( strconv.Itoa(subID), topic, - suscriber, + subscriber, func(msg *message.Message) error { msgID, err := strconv.Atoi(string(msg.Payload)) require.NoError(t, err) @@ -343,3 +273,87 @@ func TestClaimIdle(t *testing.T) { assert.GreaterOrEqual(t, nMsgsWithRetries, 3) } + +// this test checks that even workers that are idle for a while will +// try to claim messages that have been idle for too long, which is not covered by TestClaimIdle +func TestMessagesGetClaimedEvenByIdleWorkers(t *testing.T) { + topic := watermill.NewShortUUID() + consumerGroup := watermill.NewShortUUID() + testLogger := watermill.NewStdLogger(true, false) + + router, err := message.NewRouter(message.RouterConfig{ + CloseTimeout: testInterval, + }, testLogger) + require.NoError(t, err) + + receivedCh := make(chan int) + payload := message.Payload("coucou toi") + + // let's create a few subscribers, that just wait for a while each time they receive anything + nSubscribers := 8 + for subscriberID := 0; subscriberID < nSubscribers; subscriberID++ { + subID := subscriberID + + subscriber, err := NewSubscriber( + SubscriberConfig{ + Client: redisClientOrFail(t), + Consumer: strconv.Itoa(subID), + ConsumerGroup: consumerGroup, + ClaimInterval: testInterval, + MaxIdleTime: testInterval, + }, + testLogger, + ) + require.NoError(t, err) + + router.AddNoPublisherHandler( + strconv.Itoa(subID), + topic, + subscriber, + func(msg *message.Message) error { + assert.Equal(t, msg.Payload, payload) + + receivedCh <- subID + time.Sleep(time.Duration(nSubscribers+2) * testInterval) + + return nil + }, + ) + } + + runCtx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + require.NoError(t, router.Run(runCtx)) + }() + + // now let's push only one message + publisher, err := NewPublisher( + PublisherConfig{ + Client: redisClientOrFail(t), + }, + testLogger, + ) + require.NoError(t, err) + msg := message.NewMessage(watermill.NewShortUUID(), payload) + require.NoError(t, publisher.Publish(topic, msg)) + + // it should get retried by all subscribers + seenSubscribers := make([]bool, nSubscribers) + for receivedCount := 0; receivedCount != nSubscribers; receivedCount++ { + select { + case subscriberID := <-receivedCh: + assert.False(t, seenSubscribers[subscriberID], "subscriber %d seen more than once", subscriberID) + seenSubscribers[subscriberID] = true + + case <-time.After(time.Duration(nSubscribers) * 2 * testInterval): + t.Fatalf("timed out waiting for new messages, only received %d messages", receivedCount) + } + } + + // shut everything down + cancel() + wg.Wait() +} diff --git a/pkg/redisstream/subscriber.go b/pkg/redisstream/subscriber.go index dd1269a..7a5b9c8 100644 --- a/pkg/redisstream/subscriber.go +++ b/pkg/redisstream/subscriber.go @@ -2,7 +2,6 @@ package redisstream import ( "context" - "net" "sync" "time" @@ -11,59 +10,26 @@ import ( "github.com/ThreeDotsLabs/watermill/message" "github.com/go-redis/redis/v8" "github.com/pkg/errors" + errorbuffer "github.com/wk8/go-error-buffer" ) const ( - groupStartid = ">" - redisBusyGroup = "BUSYGROUP Consumer Group name already exists" -) - -const ( - // NoSleep can be set to SubscriberConfig.NackResendSleep - NoSleep time.Duration = -1 - - DefaultBlockTime = time.Millisecond * 100 + DefaultBlockTime = 100 * time.Millisecond - // How often to check for dead workers to claim pending messages from - DefaultClaimInterval = time.Second * 5 + // DefaultClaimInterval is how often to check for dead workers to claim pending messages from + DefaultClaimInterval = 5 * time.Second DefaultClaimBatchSize = int64(100) - // Default max idle time for pending message. + // DefaultMaxIdleTime is the default max idle time for pending message. // After timeout, the message will be claimed and its idle consumer will be removed from consumer group - DefaultMaxIdleTime = time.Second * 60 -) - -type Subscriber struct { - config SubscriberConfig - client redis.UniversalClient - logger watermill.LoggerAdapter - closing chan struct{} - subscribersWg sync.WaitGroup - - closed bool - closeMutex sync.Mutex -} - -// NewSubscriber creates a new redis stream Subscriber -func NewSubscriber(config SubscriberConfig, logger watermill.LoggerAdapter) (*Subscriber, error) { - config.setDefaults() + DefaultMaxIdleTime = time.Minute - if err := config.Validate(); err != nil { - return nil, err - } + DefaultRedisErrorsMaxCount = uint(3) + DefaultRedisErrorsWindow = time.Minute - if logger == nil { - logger = &watermill.NopLogger{} - } - - return &Subscriber{ - config: config, - client: config.Client, - logger: logger, - closing: make(chan struct{}), - }, nil -} + redisBusyGroup = "BUSYGROUP Consumer Group name already exists" +) type SubscriberConfig struct { Client redis.UniversalClient @@ -72,16 +38,14 @@ type SubscriberConfig struct { // Redis stream consumer id, paired with ConsumerGroup Consumer string - // When empty, fan-out mode will be used + + // Cannot be empty ConsumerGroup string // How long after Nack message should be redelivered NackResendSleep time.Duration - // Block to wait next redis stream message - BlockTime time.Duration - - // Claim idle pending message interval + // Claim idle pending message maximum interval ClaimInterval time.Duration // How many pending messages are claimed at most each claim interval @@ -90,10 +54,11 @@ type SubscriberConfig struct { // How long should we treat a consumer as offline MaxIdleTime time.Duration - // Start consumption from the specified message ID - // When using "0", the consumer group will consume from the very first message - // When using "$", the consumer group will consume from the latest message - OldestId string + // RedisErrorsMaxCount is how many redis errors in a RedisErrorsWindow duration is too many + // When there are that many errors in the given duration, the subscriber will give up and close + // its output channel. + RedisErrorsMaxCount uint + RedisErrorsWindow time.Duration // If this is set, it will be called to decide whether messages that // have been idle for longer than MaxIdleTime should actually be re-claimed, @@ -115,24 +80,25 @@ func (sc *SubscriberConfig) setDefaults() { if sc.Consumer == "" { sc.Consumer = watermill.NewShortUUID() } - if sc.NackResendSleep == 0 { - sc.NackResendSleep = NoSleep - } - if sc.BlockTime == 0 { - sc.BlockTime = DefaultBlockTime - } - if sc.ClaimInterval == 0 { + if sc.ClaimInterval <= 0 { sc.ClaimInterval = DefaultClaimInterval } - if sc.ClaimBatchSize == 0 { + if sc.ClaimBatchSize <= 0 { sc.ClaimBatchSize = DefaultClaimBatchSize } - if sc.MaxIdleTime == 0 { + if sc.MaxIdleTime <= 0 { sc.MaxIdleTime = DefaultMaxIdleTime } - // Consume from scratch by default - if sc.OldestId == "" { - sc.OldestId = "0" + if sc.RedisErrorsMaxCount == 0 { + sc.RedisErrorsMaxCount = DefaultRedisErrorsMaxCount + } + if sc.RedisErrorsWindow <= 0 { + sc.RedisErrorsWindow = DefaultRedisErrorsWindow + } + if sc.ShouldClaimPendingMessage == nil { + sc.ShouldClaimPendingMessage = func(_ redis.XPendingExt) bool { + return true + } } } @@ -140,18 +106,68 @@ func (sc *SubscriberConfig) Validate() error { if sc.Client == nil { return errors.New("redis client is empty") } + if sc.ConsumerGroup == "" { + return errors.New("consumer group is empty") + } + if sc.ClaimBatchSize < 2 { + return errors.New("claim batch size must be at least 2") + } + return nil } +type Subscriber struct { + config *SubscriberConfig + client redis.UniversalClient + baseLogger watermill.LoggerAdapter + + baseCtx context.Context + baseCtxCancel context.CancelFunc + topicSubscribers sync.WaitGroup +} + +var _ message.Subscriber = &Subscriber{} + +// NewSubscriber creates a new redis stream Subscriber +func NewSubscriber(config SubscriberConfig, logger watermill.LoggerAdapter) (*Subscriber, error) { + config.setDefaults() + + if err := config.Validate(); err != nil { + return nil, err + } + + if logger == nil { + logger = &watermill.NopLogger{} + } + + baseCtx, baseCtxCancel := context.WithCancel(context.Background()) + + return &Subscriber{ + config: &config, + client: config.Client, + baseLogger: logger, + baseCtx: baseCtx, + baseCtxCancel: baseCtxCancel, + }, nil +} + +type topicSubscriber struct { + config *SubscriberConfig + topic string + client redis.UniversalClient + logger watermill.LoggerAdapter + redisErrors *errorbuffer.ErrorBuffer +} + func (s *Subscriber) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) { - s.closeMutex.Lock() - closed := s.closed - s.closeMutex.Unlock() - if closed { + if err := s.baseCtx.Err(); err != nil { return nil, errors.New("subscriber closed") } - s.subscribersWg.Add(1) + // create consumer group + if _, err := s.client.XGroupCreateMkStream(ctx, topic, s.config.ConsumerGroup, "0").Result(); err != nil && err.Error() != redisBusyGroup { + return nil, errors.Wrap(err, "unable to create redis consumer group") + } logFields := watermill.LogFields{ "provider": "redis", @@ -159,409 +175,271 @@ func (s *Subscriber) Subscribe(ctx context.Context, topic string) (<-chan *messa "consumer_group": s.config.ConsumerGroup, "consumer_uuid": s.config.Consumer, } - s.logger.Info("Subscribing to redis stream topic", logFields) - - // we don't want to have buffered channel to not consume messsage from redis stream when consumer is not consuming - output := make(chan *message.Message) - - consumeClosed, err := s.consumeMessages(ctx, topic, output, logFields) - if err != nil { - s.subscribersWg.Done() - return nil, err + logger := s.baseLogger.With(logFields) + logger.Debug("Subscribing to redis stream topic", nil) + + subscriber := &topicSubscriber{ + config: s.config, + topic: topic, + client: s.client, + logger: logger, + redisErrors: errorbuffer.NewErrorBuffer(s.config.RedisErrorsMaxCount, s.config.RedisErrorsWindow), } - go func() { - <-consumeClosed - close(output) - s.subscribersWg.Done() - }() - - return output, nil -} - -func (s *Subscriber) consumeMessages(ctx context.Context, topic string, output chan *message.Message, logFields watermill.LogFields) (consumeMessageClosed chan struct{}, err error) { - s.logger.Info("Starting consuming", logFields) + output := make(chan *message.Message) - ctx, cancel := context.WithCancel(ctx) + runCtx, cancel := context.WithCancel(ctx) go func() { select { - case <-s.closing: - s.logger.Debug("Closing subscriber, cancelling consumeMessages", logFields) + case <-runCtx.Done(): + case <-s.baseCtx.Done(): cancel() - case <-ctx.Done(): - // avoid goroutine leak } }() - if s.config.ConsumerGroup != "" { - // create consumer group - if _, err := s.client.XGroupCreateMkStream(ctx, topic, s.config.ConsumerGroup, s.config.OldestId).Result(); err != nil && err.Error() != redisBusyGroup { - return nil, err - } - } - - consumeMessageClosed, err = s.consumeStreams(ctx, topic, output, logFields) - if err != nil { - s.logger.Debug( - "Starting consume failed, cancelling context", - logFields.Add(watermill.LogFields{"err": err}), - ) - cancel() - return nil, err - } - return consumeMessageClosed, nil -} - -func (s *Subscriber) consumeStreams(ctx context.Context, stream string, output chan *message.Message, logFields watermill.LogFields) (chan struct{}, error) { - messageHandler := s.createMessageHandler(output) - consumeMessageClosed := make(chan struct{}) + s.topicSubscribers.Add(1) go func() { - defer close(consumeMessageClosed) - - readChannel := make(chan *redis.XStream, 1) - go s.read(ctx, stream, readChannel, logFields) + defer s.topicSubscribers.Done() + defer cancel() + defer close(output) - for { - select { - case xs := <-readChannel: - if xs == nil { - s.logger.Debug("readStreamChannel is closed, stopping readStream", logFields) - return - } - if err := messageHandler.processMessage(ctx, xs.Stream, &xs.Messages[0], logFields); err != nil { - s.logger.Error("processMessage fail", err, logFields) - return - } - case <-s.closing: - s.logger.Debug("Subscriber is closing, stopping readStream", logFields) - return - case <-ctx.Done(): - s.logger.Debug("Ctx was cancelled, stopping readStream", logFields) - return - } + if err := subscriber.run(runCtx, output); !isContextDoneErr(err) { + logger.Error("subscriber exited with error", err, nil) } }() - return consumeMessageClosed, nil + return output, nil } -func (s *Subscriber) read(ctx context.Context, stream string, readChannel chan<- *redis.XStream, logFields watermill.LogFields) { - wg := &sync.WaitGroup{} - claimCtx, claimCancel := context.WithCancel(ctx) - defer func() { - claimCancel() - wg.Wait() - close(readChannel) - }() - var ( - streamsGroup = []string{stream, groupStartid} - - fanOutStartid = "$" - countFanOut int64 = 0 - blockTime time.Duration = 0 - - xss []redis.XStream - xs *redis.XStream - err error - ) - - if s.config.ConsumerGroup != "" { - // 1. get pending message from idle consumer - wg.Add(1) - s.claim(claimCtx, stream, readChannel, false, wg, logFields) - - // 2. background - wg.Add(1) - go s.claim(claimCtx, stream, readChannel, true, wg, logFields) - } +func (s *Subscriber) Close() error { + s.baseCtxCancel() + s.topicSubscribers.Wait() + s.baseLogger.Debug("Redis stream subscriber closed", nil) + return nil +} +// blocking call; never returns nil +func (s *topicSubscriber) run(ctx context.Context, output chan *message.Message) error { for { - select { - case <-s.closing: - return - case <-ctx.Done(): - return - default: - if s.config.ConsumerGroup != "" { - xss, err = s.client.XReadGroup( - ctx, - &redis.XReadGroupArgs{ - Group: s.config.ConsumerGroup, - Consumer: s.config.Consumer, - Streams: streamsGroup, - Count: 1, - Block: blockTime, - }).Result() - } else { - xss, err = s.client.XRead( - ctx, - &redis.XReadArgs{ - Streams: []string{stream, fanOutStartid}, - Count: countFanOut, - Block: blockTime, - }).Result() - } - if err == redis.Nil { - continue - } else if err != nil { - s.logger.Error("read fail", err, logFields) - } - if len(xss) < 1 || len(xss[0].Messages) < 1 { - continue - } - // update last delivered message - xs = &xss[0] - if s.config.ConsumerGroup == "" { - fanOutStartid = xs.Messages[0].ID - countFanOut = 1 - } - - blockTime = s.config.BlockTime + // first process any pending messages + redisMessage, claimedFrom, err := s.claimPendingMessage(ctx) + if err != nil { + return err + } - select { - case <-s.closing: - return - case <-ctx.Done(): - return - case readChannel <- xs: + if redisMessage == nil { + // read directly from the message queue + redisMessage, err = s.read(ctx) + if err != nil { + return err } } - } -} -func (s *Subscriber) claim(ctx context.Context, stream string, readChannel chan<- *redis.XStream, keep bool, wg *sync.WaitGroup, logFields watermill.LogFields) { - defer wg.Done() - var ( - xps []redis.XPendingExt - err error - xp redis.XPendingExt - xm []redis.XMessage - tick = time.NewTicker(s.config.ClaimInterval) - initCh = make(chan byte, 1) - ) - defer func() { - tick.Stop() - close(initCh) - }() - if !keep { // if not keep, run immediately - initCh <- 1 - } + if redisMessage == nil { + continue + } -OUTER_LOOP: - for { - select { - case <-s.closing: - return - case <-ctx.Done(): - return - case <-tick.C: - case <-initCh: + if err := s.processMessage(ctx, redisMessage, output); err != nil { + return err } - xps, err = s.client.XPendingExt(ctx, &redis.XPendingExtArgs{ - Stream: stream, + if err := s.deleteIdleConsumer(ctx, claimedFrom); err != nil { + return err + } + } +} + +// claimPendingMessage tries to claim a message from the ones that are pending, if any +// If it returns a non-nil message, it also returns which consumer it was claimed from. +func (s *topicSubscriber) claimPendingMessage(ctx context.Context) (*redis.XMessage, string, error) { + for startID := "0"; ; { + pendingMessages, err := s.client.XPendingExt(ctx, &redis.XPendingExtArgs{ + Stream: s.topic, Group: s.config.ConsumerGroup, Idle: s.config.MaxIdleTime, - Start: "0", + Start: startID, End: "+", Count: s.config.ClaimBatchSize, }).Result() if err != nil { - s.logger.Error( - "xpendingext fail", - err, - logFields, - ) - continue + if isContextDoneErr(err) { + return nil, "", err + } else { + s.logger.Error("xpendingext failed", err, nil) + return nil, "", s.redisErrors.Add(err) + } } - for _, xp = range xps { - shouldClaim := xp.Idle >= s.config.MaxIdleTime - if shouldClaim && s.config.ShouldClaimPendingMessage != nil { - shouldClaim = s.config.ShouldClaimPendingMessage(xp) + + for _, pendingMessage := range pendingMessages { + if !s.config.ShouldClaimPendingMessage(pendingMessage) { + continue } - if shouldClaim { - // assign the ownership of a pending message to the current consumer - xm, err = s.client.XClaim(ctx, &redis.XClaimArgs{ - Stream: stream, - Group: s.config.ConsumerGroup, - Consumer: s.config.Consumer, - // this is important: it ensures that 2 concurrent subscribers - // won't claim the same pending message at the same time - MinIdle: s.config.MaxIdleTime, - Messages: []string{xp.ID}, - }).Result() - if err != nil { - s.logger.Error( - "xclaim fail", - err, - logFields.Add(watermill.LogFields{"xp": xp}), - ) - continue OUTER_LOOP + // try to claim this message + claimed, err := s.client.XClaim(ctx, &redis.XClaimArgs{ + Stream: s.topic, + Group: s.config.ConsumerGroup, + Consumer: s.config.Consumer, + // this is important: it ensures that 2 concurrent subscribers + // won't claim the same pending message at the same time + MinIdle: s.config.MaxIdleTime, + Messages: []string{pendingMessage.ID}, + }).Result() + if err != nil { + if isContextDoneErr(err) { + return nil, "", err + } else { + s.logger.Error("xclaim failed", err, watermill.LogFields{"pending": pendingMessage}) + return nil, "", s.redisErrors.Add(err) } + } - // delete idle consumer - if err = s.client.XGroupDelConsumer(ctx, stream, s.config.ConsumerGroup, xp.Consumer).Err(); err != nil { - s.logger.Error( - "xgroupdelconsumer fail", - err, - logFields.Add(watermill.LogFields{"xp": xp}), - ) - continue OUTER_LOOP - } - if len(xm) > 0 { - select { - case <-s.closing: - return - case <-ctx.Done(): - return - case readChannel <- &redis.XStream{Stream: stream, Messages: xm}: - } + if len(claimed) != 0 { + if len(claimed) != 1 { + // shouldn't happen, we only tried to claim one message + err := errors.Errorf("claimed %d messages", len(claimed)) + s.logger.Error("claimed more than 1 messages", err, watermill.LogFields{"claimed": claimed}) } + + return &claimed[0], pendingMessage.Consumer, nil } } - if len(xps) == 0 || int64(len(xps)) < s.config.ClaimBatchSize { // done - if !keep { - return - } - continue + + if int64(len(pendingMessages)) < s.config.ClaimBatchSize { + return nil, "", nil } - } -} -func (s *Subscriber) createMessageHandler(output chan *message.Message) messageHandler { - return messageHandler{ - outputChannel: output, - rc: s.client, - consumerGroup: s.config.ConsumerGroup, - unmarshaller: s.config.Unmarshaller, - nackResendSleep: s.config.NackResendSleep, - logger: s.logger, - closing: s.closing, + startID = pendingMessages[len(pendingMessages)-1].ID } } -func (s *Subscriber) Close() error { - s.closeMutex.Lock() - defer s.closeMutex.Unlock() - - if s.closed { - return nil +// read reads directly from the queue +func (s *topicSubscriber) read(ctx context.Context) (*redis.XMessage, error) { + streams, err := s.client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: s.config.ConsumerGroup, + Consumer: s.config.Consumer, + Streams: []string{s.topic, ">"}, + Count: 1, + Block: s.config.ClaimInterval, + }).Result() + if err != nil && err != redis.Nil { + if isContextDoneErr(err) { + return nil, err + } else { + s.logger.Error("read failed", err, nil) + return nil, s.redisErrors.Add(err) + } } - s.closed = true - close(s.closing) - s.subscribersWg.Wait() - - // the errors.Is(err, net.ErrClosed) bit is because there is a race condition that's hard to - // fix here when closing a subscriber: it makes read return, which in turn cancels the claim context, - // which also tries to connection (see - // https://github.com/redis/go-redis/blob/v8.11.5/redis.go#L295) - if err := s.client.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - return err + if len(streams) == 0 { + return nil, nil + } + if len(streams) != 1 { + // shouldn't happen, we only queried one stream + err := errors.Errorf("read from %d streams", len(streams)) + names := make([]string, len(streams)) + for i, stream := range streams { + names[i] = stream.Stream + } + s.logger.Error("read from more than 1 stream", err, watermill.LogFields{"stream_names": names}) } - s.logger.Debug("Redis stream subscriber closed", nil) - - return nil -} - -type messageHandler struct { - outputChannel chan<- *message.Message - rc redis.UniversalClient - consumerGroup string - unmarshaller Unmarshaller - - nackResendSleep time.Duration + if len(streams[0].Messages) == 0 { + return nil, nil + } + if len(streams[0].Messages) != 1 { + // same, we only asked for 1 message + err := errors.Errorf("read %d messages", len(streams[0].Messages)) + s.logger.Error("read more than 1 message", err, nil) + } - logger watermill.LoggerAdapter - closing chan struct{} + return &streams[0].Messages[0], nil } -func (h *messageHandler) processMessage(ctx context.Context, stream string, xm *redis.XMessage, messageLogFields watermill.LogFields) error { - receivedMsgLogFields := messageLogFields.Add(watermill.LogFields{ - "xid": xm.ID, - }) - - h.logger.Trace("Received message from redis stream", receivedMsgLogFields) +func (s *topicSubscriber) processMessage(ctx context.Context, redisMessage *redis.XMessage, output chan *message.Message) error { + logger := s.logger.With(watermill.LogFields{"message_id": redisMessage.ID}) + logger.Trace("Processing redis message", nil) - msg, err := h.unmarshaller.Unmarshal(xm.Values) + msg, err := s.config.Unmarshaller.Unmarshal(redisMessage.Values) if err != nil { return errors.Wrapf(err, "message unmarshal failed") } - ctx, cancelCtx := context.WithCancel(ctx) - msg.SetContext(ctx) - defer cancelCtx() - - receivedMsgLogFields = receivedMsgLogFields.Add(watermill.LogFields{ - "message_uuid": msg.UUID, - "stream": stream, - "xid": xm.ID, - }) + msgCtx, cancel := context.WithCancel(ctx) + msg.SetContext(msgCtx) + defer cancel() -ResendLoop: for { select { - case h.outputChannel <- msg: - h.logger.Trace("Message sent to consumer", receivedMsgLogFields) - case <-h.closing: - h.logger.Trace("Closing, message discarded", receivedMsgLogFields) - return nil + case output <- msg: + logger.Trace("message sent to consumer", nil) case <-ctx.Done(): - h.logger.Trace("Closing, ctx cancelled before sent to consumer", receivedMsgLogFields) - return nil + return ctx.Err() } select { case <-msg.Acked(): - // deadly retry ack - p := h.rc.Pipeline() - if h.consumerGroup != "" { - p.XAck(ctx, stream, h.consumerGroup, xm.ID) - } - err := retry.Retry(func(attempt uint) error { - _, err := p.Exec(ctx) - return err - }, func(attempt uint) bool { - if attempt != 0 { - time.Sleep(time.Millisecond * 100) - } - return true - }, func(attempt uint) bool { + var redisError error + + return retry.Retry( + func(attempt uint) error { + xAckErr := s.client.XAck(ctx, s.topic, s.config.ConsumerGroup, redisMessage.ID).Err() + + if xAckErr == nil { + logger.Trace("message successfully acked", nil) + } else if !isContextDoneErr(xAckErr) { + redisError = s.redisErrors.Add(xAckErr) + logger.Error("message ack failed", xAckErr, watermill.LogFields{"attempt": attempt, "retrying": redisError == nil}) + if redisError == nil { + time.Sleep(100 * time.Millisecond) + } else { + xAckErr = redisError + } + } + + return xAckErr + }, + func(_ uint) bool { + return ctx.Err() == nil && redisError == nil + }, + ) + + case <-msg.Nacked(): + logger.Trace("message nacked", nil) + + if s.config.NackResendSleep > 0 { select { - case <-h.closing: + case <-time.After(s.config.NackResendSleep): case <-ctx.Done(): - default: - return true + return ctx.Err() } - return false - }) - if err != nil { - h.logger.Error("Message Acked fail", err, receivedMsgLogFields) - } else { - h.logger.Trace("Message Acked", receivedMsgLogFields) } - break ResendLoop - case <-msg.Nacked(): - h.logger.Trace("Message Nacked", receivedMsgLogFields) // reset acks, etc. msg = msg.Copy() - if h.nackResendSleep != NoSleep { - time.Sleep(h.nackResendSleep) - } - continue ResendLoop - case <-h.closing: - h.logger.Trace("Closing, message discarded before ack", receivedMsgLogFields) - return nil case <-ctx.Done(): - h.logger.Trace("Closing, ctx cancelled before ack", receivedMsgLogFields) - return nil + return ctx.Err() } } +} - return nil +func (s *topicSubscriber) deleteIdleConsumer(ctx context.Context, consumerName string) error { + if consumerName == "" { + return nil + } + + err := s.client.XGroupDelConsumer(ctx, s.topic, s.config.ConsumerGroup, consumerName).Err() + + if err != nil && !isContextDoneErr(err) { + s.logger.Error("xgroupdelconsumer failed", err, watermill.LogFields{"consumerName": consumerName}) + err = s.redisErrors.Add(err) + } + + return err +} + +func isContextDoneErr(err error) bool { + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) }