diff --git a/osscluster.go b/osscluster.go index 17f98d9dc..ac40a5038 100644 --- a/osscluster.go +++ b/osscluster.go @@ -48,6 +48,9 @@ type ClusterOptions struct { // Allows routing read-only commands to the random master or slave node. // It automatically enables ReadOnly. RouteRandomly bool + // Allows routing read-only commands to the replica nodes in ronund-robin. + // It automatically enables ReadOnly + RouteRoundRobinReplicas bool // Optional function that returns cluster slots information. // It is useful to manually create cluster of standalone Redis servers @@ -98,7 +101,7 @@ func (opt *ClusterOptions) init() { opt.MaxRedirects = 3 } - if opt.RouteByLatency || opt.RouteRandomly { + if opt.RouteByLatency || opt.RouteRandomly || opt.RouteRoundRobinReplicas { opt.ReadOnly = true } @@ -584,6 +587,9 @@ func (c *clusterNodes) Random() (*clusterNode, error) { type clusterSlot struct { start, end int nodes []*clusterNode + + // Allows node selection to use round-robin selection strategy. + next uint32 } type clusterSlotSlice []*clusterSlot @@ -767,7 +773,44 @@ func (c *clusterState) slotRandomNode(slot int) (*clusterNode, error) { return nodes[randomNodes[0]], nil } +// slotRoundRobinReplicaNode tries to select a node from the list of replica nodes. +// if no replica nodes are available, returns the primary node. +func (c *clusterState) slotRoundRobinReplicaNode(slot int) (*clusterNode, error) { + cs := c.slotCluster(slot) + if cs == nil { + return c.nodes.Random() + } + + switch len(cs.nodes) { + case 0: + return c.nodes.Random() + case 1: + return cs.nodes[0], nil + case 2: + if replica := cs.nodes[1]; !replica.Failing() { + return replica, nil + } + return cs.nodes[0], nil + default: + var replica *clusterNode + for i := 0; i < 10; i++ { + next := atomic.AddUint32(&cs.next, 1) + n := (int(next))%(len(cs.nodes)-1) + 1 + replica = cs.nodes[n] + if !replica.Failing() { + return replica, nil + } + } + // All slaves are loading - use master. + return cs.nodes[0], nil + } +} + func (c *clusterState) slotNodes(slot int) []*clusterNode { + return c.slotCluster(slot).nodes +} + +func (c *clusterState) slotCluster(slot int) *clusterSlot { i := sort.Search(len(c.slots), func(i int) bool { return c.slots[i].end >= slot }) @@ -776,8 +819,9 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode { } x := c.slots[i] if slot >= x.start && slot <= x.end { - return x.nodes + return x } + return nil } @@ -1824,6 +1868,9 @@ func (c *ClusterClient) slotReadOnlyNode(state *clusterState, slot int) (*cluste if c.opt.RouteRandomly { return state.slotRandomNode(slot) } + if c.opt.RouteRoundRobinReplicas { + return state.slotRoundRobinReplicaNode(slot) + } return state.slotSlaveNode(slot) } diff --git a/osscluster_test.go b/osscluster_test.go index 3d2f80711..40cfcc762 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -1282,6 +1282,66 @@ var _ = Describe("ClusterClient", func() { assertClusterClient() }) + + Describe("ClusterClient with RouteRoundRobinReplicas and ClusterSlots with multiple nodes per slot", func() { + BeforeEach(func() { + failover = true + + opt = redisClusterOptions() + opt.RouteRoundRobinReplicas = true + opt.ClusterSlots = func(ctx context.Context) ([]redis.ClusterSlot, error) { + slots := []redis.ClusterSlot{{ + Start: 0, + End: 4999, + Nodes: []redis.ClusterNode{{ + Addr: ":8220", + }, { + Addr: ":8223", + }}, + }, { + Start: 5000, + End: 9999, + Nodes: []redis.ClusterNode{{ + Addr: ":8221", + }, { + Addr: ":8224", + }}, + }, { + Start: 10000, + End: 16383, + Nodes: []redis.ClusterNode{{ + Addr: ":8222", + }, { + Addr: ":8225", + }}, + }} + return slots, nil + } + client = cluster.newClusterClient(ctx, opt) + + err := client.ForEachMaster(ctx, func(ctx context.Context, master *redis.Client) error { + return master.FlushDB(ctx).Err() + }) + Expect(err).NotTo(HaveOccurred()) + + err = client.ForEachSlave(ctx, func(ctx context.Context, slave *redis.Client) error { + Eventually(func() int64 { + return client.DBSize(ctx).Val() + }, 30*time.Second).Should(Equal(int64(0))) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + failover = false + + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + + assertClusterClient() + }) }) var _ = Describe("ClusterClient without nodes", func() {