diff --git a/balancer.go b/balancer.go index 7a50cc1ce..7076078d4 100644 --- a/balancer.go +++ b/balancer.go @@ -288,3 +288,82 @@ func murmur2(data []byte) uint32 { return h } + +// StickyPartitioner is an Balancer implementation that the Message key is NOT used as part of the balancing strategy +// in this balancer. Messages with the same key are not guaranteed to be sent to the same partition. +// If a partition is specified in the Message, use it +// Otherwise choose the sticky partition that changes when the batch is full. +type StickyPartitioner struct { + spc stickyPartitionCache +} + +// Balance satisfies the Balancer interface. +func (sp *StickyPartitioner) Balance(msg Message, partitions ...int) int { + return sp.balance(msg, partitions) +} + +func (sp *StickyPartitioner) balance(msg Message, partitions []int) int { + return sp.spc.partition(msg, partitions) +} + +// OnNewBatch changes the sticky partition If a batch completed for the current sticky partition. +// Alternately, if no sticky partition has been determined, set one. +func (sp *StickyPartitioner) OnNewBatch(msg Message, partitions []int, prevPartition int) int { + return sp.spc.nextPartition(msg, partitions, prevPartition) +} + +// stickyPartitionCache implements a cache used for sticky partitioning behavior. The cache tracks the current sticky +// partition for any given topic. +type stickyPartitionCache struct { + lock sync.Mutex + indexCache map[string]int +} + +func (spc *stickyPartitionCache) nextPartition(msg Message, partitions []int, prevPartition int) int { + oldPartition, prs := spc.getIndex(msg) + newPartition := oldPartition + if !prs { + newPartition = -1 + } + + if prs && oldPartition != prevPartition { + finalPartition, _ := spc.getIndex(msg) + return finalPartition + } + + if len(partitions) == 1 { + newPartition = partitions[0] + } else { + for newPartition == -1 || newPartition == oldPartition { + newPartition = rand.Intn(len(partitions)) + } + } + spc.setIndex(msg, newPartition) + + finalPartition, _ := spc.getIndex(msg) + return finalPartition +} + +func (spc *stickyPartitionCache) partition(msg Message, partitions []int) int { + if spc.indexCache == nil { + spc.indexCache = make(map[string]int) + } + partition, prs := spc.getIndex(msg) + if prs { + return partition + } + return spc.nextPartition(msg, partitions, -1) +} + +func (spc *stickyPartitionCache) getIndex(msg Message) (int, bool) { + spc.lock.Lock() + defer spc.lock.Unlock() + index, prs := spc.indexCache[msg.Topic] + return index, prs +} + +func (spc *stickyPartitionCache) setIndex(msg Message, index int) { + spc.lock.Lock() + defer spc.lock.Unlock() + spc.indexCache[msg.Topic] = index +} diff --git a/balancer_test.go b/balancer_test.go index acdfe54eb..8d021b261 100644 --- a/balancer_test.go +++ b/balancer_test.go @@ -352,3 +352,105 @@ func TestLeastBytes(t *testing.T) { }) } } + +func TestStickyPartitionerWithTwoPartitions(t *testing.T) { + testCases := map[string]struct { + messages []Message + Partitions []int + }{ + "first test": { + messages: []Message{ + { + Topic: "test", + }, + }, + Partitions: []int{ + 0, 1, + }, + }, + } + + for label, test := range testCases { + t.Run(label, func(t *testing.T) { + sp := &StickyPartitioner{} + partitionCounter := make(map[int]int) + + part := 0 + for i := 0; i < 50; i++ { + part = sp.Balance(test.messages[0], test.Partitions...) + partitionCounter[part]++ + } + sp.OnNewBatch(test.messages[0], test.Partitions, part) + for i := 0; i < 50; i++ { + part = sp.Balance(test.messages[0], test.Partitions...) + partitionCounter[part]++ + } + if partitionCounter[0] != partitionCounter[1] || partitionCounter[0] != 50 { + t.Errorf("The distribution between two available partitions should be even") + } + }) + } +} + +func TestStickyPartitionerWithThreePartitions(t *testing.T) { + testCases := map[string]struct { + messages []Message + Partitions []int + }{ + "first test": { + messages: []Message{ + { + Topic: "A", + }, + { + Topic: "B", + }, + }, + Partitions: []int{ + 0, 1, 2, + }, + }, + } + + for label, test := range testCases { + t.Run(label, func(t *testing.T) { + sp := &StickyPartitioner{} + partitionCounter := make(map[int]int) + + part := 0 + for i := 0; i < 30; i++ { + part = sp.Balance(test.messages[0], test.Partitions...) + partitionCounter[part]++ + if i%5 == 0 { + sp.Balance(test.messages[1], test.Partitions...) + } + } + sp.OnNewBatch(test.messages[0], test.Partitions, part) + oldPartition := part + for i := 0; i < 30; i++ { + part = sp.Balance(test.messages[0], test.Partitions...) + partitionCounter[part]++ + if i%5 == 0 { + sp.Balance(test.messages[1], test.Partitions...) + } + } + newPartition := part + + sp.OnNewBatch(test.messages[0], test.Partitions, oldPartition) + for i := 0; i < 30; i++ { + part = sp.Balance(test.messages[0], test.Partitions...) + partitionCounter[part]++ + if i%5 == 0 { + sp.Balance(test.messages[1], test.Partitions...) + } + } + + if partitionCounter[oldPartition] != 30 { + t.Errorf("Old partition batch must contains 30 messages") + } + if partitionCounter[newPartition] != 60 { + t.Errorf("New partition batch must contains 60 messages") + } + }) + } +}