diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 1dfc1aa115edb..dc496f1cdef3c 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -133,7 +133,7 @@ func (b *RowCountBasedBalancer) convertToNodeItemsByChannel(nodeIDs []int64) []* ret := make([]*nodeItem, 0, len(nodeIDs)) for _, nodeInfo := range b.getNodes(nodeIDs) { node := nodeInfo.ID() - channels := b.dist.ChannelDistManager.GetByNode(node) + channels := b.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(node)) // more channel num, less priority nodeItem := newNodeItem(len(channels), node) @@ -276,7 +276,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool { // if the segment are redundant, skip it's balance for now - return len(b.dist.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1 + return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1 }) if len(nodesWithLessRow) == 0 || len(segmentsToMove) == 0 { @@ -295,7 +295,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan { channelPlans := make([]ChannelAssignPlan, 0) for _, nodeID := range offlineNodes { - dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID) + dmChannels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID)) plans := b.AssignChannel(dmChannels, onlineNodes) for i := range plans { plans[i].From = nodeID @@ -310,7 +310,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode channelPlans := make([]ChannelAssignPlan, 0) if len(onlineNodes) > 1 { // start to balance channels on all available nodes - channelDist := b.dist.ChannelDistManager.GetChannelDistByReplica(replica) + channelDist := b.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica)) if len(channelDist) == 0 { return nil } @@ -320,7 +320,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode nodeWithLessChannel := make([]int64, 0) channelsToMove := make([]*meta.DmChannel, 0) for _, node := range onlineNodes { - channels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), node) + channels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node)) if len(channels) <= average { nodeWithLessChannel = append(nodeWithLessChannel, node) diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 8369728e3d842..5fba4bc2d2ba4 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -306,7 +306,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [ // if the segment are redundant, skip it's balance for now segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool { - return len(b.dist.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1 + return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1 }) if len(segmentsToMove) == 0 { diff --git a/internal/querycoordv2/balance/utils.go b/internal/querycoordv2/balance/utils.go index 8ec0104f48619..be00b7aabeee6 100644 --- a/internal/querycoordv2/balance/utils.go +++ b/internal/querycoordv2/balance/utils.go @@ -176,7 +176,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica, // 3. print stopping nodes channel distribution distInfo += "[stoppingNodesChannelDist:" for stoppingNodeID := range stoppingNodesSegments { - stoppingNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), stoppingNodeID) + stoppingNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(stoppingNodeID)) distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", stoppingNodeID, len(stoppingNodeChannels)) distInfo += "channels:[" for _, stoppingChan := range stoppingNodeChannels { @@ -189,7 +189,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica, // 4. print normal nodes channel distribution distInfo += "[normalNodesChannelDist:" for normalNodeID := range nodeSegments { - normalNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), normalNodeID) + normalNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(normalNodeID)) distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", normalNodeID, len(normalNodeChannels)) distInfo += "channels:[" for _, normalNodeChan := range normalNodeChannels { diff --git a/internal/querycoordv2/checkers/channel_checker.go b/internal/querycoordv2/checkers/channel_checker.go index 3c23416e71f49..55ce218131110 100644 --- a/internal/querycoordv2/checkers/channel_checker.go +++ b/internal/querycoordv2/checkers/channel_checker.go @@ -91,7 +91,7 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task { } } - channels := c.dist.ChannelDistManager.GetAll() + channels := c.dist.ChannelDistManager.GetByFilter(nil) released := utils.FilterReleased(channels, collectionIDs) releaseTasks := c.createChannelReduceTasks(ctx, released, -1) task.SetReason("collection released", releaseTasks...) @@ -163,7 +163,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64, func (c *ChannelChecker) getChannelDist(replica *meta.Replica) []*meta.DmChannel { dist := make([]*meta.DmChannel, 0) for _, nodeID := range replica.GetNodes() { - dist = append(dist, c.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)...) + dist = append(dist, c.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))...) } return dist } diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index 3b980cc806549..03284f434740b 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -150,7 +150,6 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64, zap.Int64("replicaID", replica.ID)) leaders := c.dist.ChannelDistManager.GetShardLeadersByReplica(replica) - // distMgr.LeaderViewManager. for channelName, node := range leaders { view := c.dist.LeaderViewManager.GetLeaderShardView(node, channelName) if view == nil { diff --git a/internal/querycoordv2/dist/dist_handler.go b/internal/querycoordv2/dist/dist_handler.go index 8d0f5578833f1..77bdf19c485a6 100644 --- a/internal/querycoordv2/dist/dist_handler.go +++ b/internal/querycoordv2/dist/dist_handler.go @@ -212,7 +212,7 @@ func (dh *distHandler) getDistribution(ctx context.Context) (*querypb.GetDataDis defer dh.mu.Unlock() channels := make(map[string]*msgpb.MsgPosition) - for _, channel := range dh.dist.ChannelDistManager.GetByNode(dh.nodeID) { + for _, channel := range dh.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(dh.nodeID)) { targetChannel := dh.target.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) if targetChannel == nil { continue diff --git a/internal/querycoordv2/job/utils.go b/internal/querycoordv2/job/utils.go index bc6b3f7b65d0b..85faa81ee768c 100644 --- a/internal/querycoordv2/job/utils.go +++ b/internal/querycoordv2/job/utils.go @@ -49,7 +49,7 @@ func waitCollectionReleased(dist *meta.DistributionManager, checkerController *c return partitionSet.Contain(segment.GetPartitionID()) }) } else { - channels = dist.ChannelDistManager.GetByCollection(collection) + channels = dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(collection)) } if len(channels)+len(segments) == 0 { diff --git a/internal/querycoordv2/meta/channel_dist_manager.go b/internal/querycoordv2/meta/channel_dist_manager.go index c46041fc2addb..0c5922a530b44 100644 --- a/internal/querycoordv2/meta/channel_dist_manager.go +++ b/internal/querycoordv2/meta/channel_dist_manager.go @@ -25,6 +25,32 @@ import ( . "github.com/milvus-io/milvus/pkg/util/typeutil" ) +type ChannelDistFilter = func(ch *DmChannel) bool + +func WithCollectionID2Channel(collectionID int64) ChannelDistFilter { + return func(ch *DmChannel) bool { + return ch.GetCollectionID() == collectionID + } +} + +func WithNodeID2Channel(nodeID int64) ChannelDistFilter { + return func(ch *DmChannel) bool { + return ch.Node == nodeID + } +} + +func WithReplica2Channel(replica *Replica) ChannelDistFilter { + return func(ch *DmChannel) bool { + return ch.GetCollectionID() == replica.GetCollectionID() && replica.Contains(ch.Node) + } +} + +func WithChannelName2Channel(channelName string) ChannelDistFilter { + return func(ch *DmChannel) bool { + return ch.GetChannelName() == channelName + } +} + type DmChannel struct { *datapb.VchannelInfo Node int64 @@ -58,33 +84,7 @@ func NewChannelDistManager() *ChannelDistManager { } } -func (m *ChannelDistManager) GetByNode(nodeID UniqueID) []*DmChannel { - m.rwmutex.RLock() - defer m.rwmutex.RUnlock() - - return m.getByNode(nodeID) -} - -func (m *ChannelDistManager) getByNode(nodeID UniqueID) []*DmChannel { - channels, ok := m.channels[nodeID] - if !ok { - return nil - } - - return channels -} - -func (m *ChannelDistManager) GetAll() []*DmChannel { - m.rwmutex.RLock() - defer m.rwmutex.RUnlock() - - result := make([]*DmChannel, 0) - for _, channels := range m.channels { - result = append(result, channels...) - } - return result -} - +// todo by liuwei: should consider the case of duplicate leader exists // GetShardLeader returns the node whthin the given replicaNodes and subscribing the given shard, // returns (0, false) if not found. func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int64, bool) { @@ -103,6 +103,7 @@ func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int return 0, false } +// todo by liuwei: should consider the case of duplicate leader exists func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[string]int64 { m.rwmutex.RLock() defer m.rwmutex.RUnlock() @@ -119,35 +120,21 @@ func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[stri return ret } -func (m *ChannelDistManager) GetChannelDistByReplica(replica *Replica) map[string][]int64 { - m.rwmutex.RLock() - defer m.rwmutex.RUnlock() - - ret := make(map[string][]int64) - for _, node := range replica.GetNodes() { - channels := m.channels[node] - for _, dmc := range channels { - if dmc.GetCollectionID() == replica.GetCollectionID() { - channelName := dmc.GetChannelName() - _, ok := ret[channelName] - if !ok { - ret[channelName] = make([]int64, 0) - } - ret[channelName] = append(ret[channelName], node) - } - } - } - return ret -} - -func (m *ChannelDistManager) GetByCollection(collectionID UniqueID) []*DmChannel { +// return all channels in list which match all given filters +func (m *ChannelDistManager) GetByFilter(filters ...ChannelDistFilter) []*DmChannel { m.rwmutex.RLock() defer m.rwmutex.RUnlock() ret := make([]*DmChannel, 0) for _, channels := range m.channels { for _, channel := range channels { - if channel.CollectionID == collectionID { + allMatch := true + for _, fn := range filters { + if fn != nil && !fn(channel) { + allMatch = false + } + } + if allMatch { ret = append(ret, channel) } } @@ -155,20 +142,6 @@ func (m *ChannelDistManager) GetByCollection(collectionID UniqueID) []*DmChannel return ret } -func (m *ChannelDistManager) GetByCollectionAndNode(collectionID, nodeID UniqueID) []*DmChannel { - m.rwmutex.RLock() - defer m.rwmutex.RUnlock() - - channels := make([]*DmChannel, 0) - for _, channel := range m.getByNode(nodeID) { - if channel.CollectionID == collectionID { - channels = append(channels, channel) - } - } - - return channels -} - func (m *ChannelDistManager) Update(nodeID UniqueID, channels ...*DmChannel) { m.rwmutex.Lock() defer m.rwmutex.Unlock() diff --git a/internal/querycoordv2/meta/channel_dist_manager_test.go b/internal/querycoordv2/meta/channel_dist_manager_test.go index fbd4afe2c3fd5..be67db31ac043 100644 --- a/internal/querycoordv2/meta/channel_dist_manager_test.go +++ b/internal/querycoordv2/meta/channel_dist_manager_test.go @@ -66,36 +66,36 @@ func (suite *ChannelDistManagerSuite) TestGetBy() { dist := suite.dist // Test GetAll - channels := dist.GetAll() + channels := dist.GetByFilter(nil) suite.Len(channels, 4) // Test GetByNode for _, node := range suite.nodes { - channels := dist.GetByNode(node) + channels := dist.GetByFilter(WithNodeID2Channel(node)) suite.AssertNode(channels, node) } // Test GetByCollection - channels = dist.GetByCollection(suite.collection) + channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection)) suite.Len(channels, 4) suite.AssertCollection(channels, suite.collection) - channels = dist.GetByCollection(-1) + channels = dist.GetByFilter(WithCollectionID2Channel(-1)) suite.Len(channels, 0) // Test GetByNodeAndCollection // 1. Valid node and valid collection for _, node := range suite.nodes { - channels := dist.GetByCollectionAndNode(suite.collection, node) + channels := dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(node)) suite.AssertNode(channels, node) suite.AssertCollection(channels, suite.collection) } // 2. Valid node and invalid collection - channels = dist.GetByCollectionAndNode(-1, suite.nodes[1]) + channels = dist.GetByFilter(WithCollectionID2Channel(-1), WithNodeID2Channel(suite.nodes[1])) suite.Len(channels, 0) // 3. Invalid node and valid collection - channels = dist.GetByCollectionAndNode(suite.collection, -1) + channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(-1)) suite.Len(channels, 0) } @@ -148,47 +148,6 @@ func (suite *ChannelDistManagerSuite) TestGetShardLeader() { suite.Equal(leaders["dmc1"], suite.nodes[1]) } -func (suite *ChannelDistManagerSuite) TestGetChannelDistByReplica() { - replica := NewReplica( - &querypb.Replica{ - CollectionID: suite.collection, - }, - typeutil.NewUniqueSet(11, 22, 33), - ) - - ch1 := &DmChannel{ - VchannelInfo: &datapb.VchannelInfo{ - CollectionID: suite.collection, - ChannelName: "test-channel1", - }, - Node: 11, - Version: 1, - } - ch2 := &DmChannel{ - VchannelInfo: &datapb.VchannelInfo{ - CollectionID: suite.collection, - ChannelName: "test-channel1", - }, - Node: 22, - Version: 1, - } - ch3 := &DmChannel{ - VchannelInfo: &datapb.VchannelInfo{ - CollectionID: suite.collection, - ChannelName: "test-channel2", - }, - Node: 33, - Version: 1, - } - suite.dist.Update(11, ch1) - suite.dist.Update(22, ch2) - suite.dist.Update(33, ch3) - - dist := suite.dist.GetChannelDistByReplica(replica) - suite.Len(dist["test-channel1"], 2) - suite.Len(dist["test-channel2"], 1) -} - func (suite *ChannelDistManagerSuite) AssertNames(channels []*DmChannel, names ...string) bool { for _, channel := range channels { hasChannel := false diff --git a/internal/querycoordv2/observers/replica_observer.go b/internal/querycoordv2/observers/replica_observer.go index dc92164645bec..6816f45c8cfdd 100644 --- a/internal/querycoordv2/observers/replica_observer.go +++ b/internal/querycoordv2/observers/replica_observer.go @@ -98,8 +98,8 @@ func (ob *ReplicaObserver) checkNodesInReplica() { ) for node := range outboundNodes { - channels := ob.distMgr.ChannelDistManager.GetByCollectionAndNode(collectionID, node) segments := ob.distMgr.SegmentDistManager.GetByFilter(meta.WithCollectionID(collectionID), meta.WithNodeID(node)) + channels := ob.distMgr.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node)) if len(channels) == 0 && len(segments) == 0 { replica.RemoveNode(node)