Skip to content

Commit

Permalink
enhance: Refactor channel dist manager interface
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Liu <wei.liu@zilliz.com>
  • Loading branch information
weiliu1031 committed Mar 14, 2024
1 parent b1ff8e2 commit eaa7298
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 126 deletions.
10 changes: 5 additions & 5 deletions internal/querycoordv2/balance/rowcount_based_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (b *RowCountBasedBalancer) convertToNodeItemsByChannel(nodeIDs []int64) []*
ret := make([]*nodeItem, 0, len(nodeIDs))
for _, nodeInfo := range b.getNodes(nodeIDs) {
node := nodeInfo.ID()
channels := b.dist.ChannelDistManager.GetByNode(node)
channels := b.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(node))

// more channel num, less priority
nodeItem := newNodeItem(len(channels), node)
Expand Down Expand Up @@ -276,7 +276,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode

segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool {
// if the segment are redundant, skip it's balance for now
return len(b.dist.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
})

if len(nodesWithLessRow) == 0 || len(segmentsToMove) == 0 {
Expand All @@ -295,7 +295,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
for _, nodeID := range offlineNodes {
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
dmChannels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))
plans := b.AssignChannel(dmChannels, onlineNodes)
for i := range plans {
plans[i].From = nodeID
Expand All @@ -310,7 +310,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
channelPlans := make([]ChannelAssignPlan, 0)
if len(onlineNodes) > 1 {
// start to balance channels on all available nodes
channelDist := b.dist.ChannelDistManager.GetChannelDistByReplica(replica)
channelDist := b.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica))
if len(channelDist) == 0 {
return nil
}
Expand All @@ -320,7 +320,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
nodeWithLessChannel := make([]int64, 0)
channelsToMove := make([]*meta.DmChannel, 0)
for _, node := range onlineNodes {
channels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), node)
channels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))

if len(channels) <= average {
nodeWithLessChannel = append(nodeWithLessChannel, node)
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/balance/score_based_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [

// if the segment are redundant, skip it's balance for now
segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool {
return len(b.dist.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
})

if len(segmentsToMove) == 0 {
Expand Down
4 changes: 2 additions & 2 deletions internal/querycoordv2/balance/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
// 3. print stopping nodes channel distribution
distInfo += "[stoppingNodesChannelDist:"
for stoppingNodeID := range stoppingNodesSegments {
stoppingNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), stoppingNodeID)
stoppingNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(stoppingNodeID))
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", stoppingNodeID, len(stoppingNodeChannels))
distInfo += "channels:["
for _, stoppingChan := range stoppingNodeChannels {
Expand All @@ -189,7 +189,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
// 4. print normal nodes channel distribution
distInfo += "[normalNodesChannelDist:"
for normalNodeID := range nodeSegments {
normalNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), normalNodeID)
normalNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(normalNodeID))
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", normalNodeID, len(normalNodeChannels))
distInfo += "channels:["
for _, normalNodeChan := range normalNodeChannels {
Expand Down
4 changes: 2 additions & 2 deletions internal/querycoordv2/checkers/channel_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task {
}
}

channels := c.dist.ChannelDistManager.GetAll()
channels := c.dist.ChannelDistManager.GetByFilter(nil)
released := utils.FilterReleased(channels, collectionIDs)
releaseTasks := c.createChannelReduceTasks(ctx, released, -1)
task.SetReason("collection released", releaseTasks...)
Expand Down Expand Up @@ -163,7 +163,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64,
func (c *ChannelChecker) getChannelDist(replica *meta.Replica) []*meta.DmChannel {
dist := make([]*meta.DmChannel, 0)
for _, nodeID := range replica.GetNodes() {
dist = append(dist, c.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)...)
dist = append(dist, c.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))...)
}
return dist
}
Expand Down
1 change: 0 additions & 1 deletion internal/querycoordv2/checkers/segment_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64,
zap.Int64("replicaID", replica.ID))

leaders := c.dist.ChannelDistManager.GetShardLeadersByReplica(replica)
// distMgr.LeaderViewManager.
for channelName, node := range leaders {
view := c.dist.LeaderViewManager.GetLeaderShardView(node, channelName)
if view == nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/dist/dist_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (dh *distHandler) getDistribution(ctx context.Context) (*querypb.GetDataDis
defer dh.mu.Unlock()

channels := make(map[string]*msgpb.MsgPosition)
for _, channel := range dh.dist.ChannelDistManager.GetByNode(dh.nodeID) {
for _, channel := range dh.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(dh.nodeID)) {
targetChannel := dh.target.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget)
if targetChannel == nil {
continue
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/job/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func waitCollectionReleased(dist *meta.DistributionManager, checkerController *c
return partitionSet.Contain(segment.GetPartitionID())
})
} else {
channels = dist.ChannelDistManager.GetByCollection(collection)
channels = dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(collection))
}

if len(channels)+len(segments) == 0 {
Expand Down
101 changes: 37 additions & 64 deletions internal/querycoordv2/meta/channel_dist_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,32 @@ import (
. "github.com/milvus-io/milvus/pkg/util/typeutil"
)

type ChannelDistFilter = func(ch *DmChannel) bool

func WithCollectionID2Channel(collectionID int64) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetCollectionID() == collectionID
}
}

func WithNodeID2Channel(nodeID int64) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.Node == nodeID
}
}

func WithReplica2Channel(replica *Replica) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetCollectionID() == replica.GetCollectionID() && replica.Contains(ch.Node)
}
}

func WithChannelName2Channel(channelName string) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetChannelName() == channelName
}
}

type DmChannel struct {
*datapb.VchannelInfo
Node int64
Expand Down Expand Up @@ -58,33 +84,7 @@ func NewChannelDistManager() *ChannelDistManager {
}
}

func (m *ChannelDistManager) GetByNode(nodeID UniqueID) []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

return m.getByNode(nodeID)
}

func (m *ChannelDistManager) getByNode(nodeID UniqueID) []*DmChannel {
channels, ok := m.channels[nodeID]
if !ok {
return nil
}

return channels
}

func (m *ChannelDistManager) GetAll() []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

result := make([]*DmChannel, 0)
for _, channels := range m.channels {
result = append(result, channels...)
}
return result
}

// todo by liuwei: should consider the case of duplicate leader exists
// GetShardLeader returns the node whthin the given replicaNodes and subscribing the given shard,
// returns (0, false) if not found.
func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int64, bool) {
Expand All @@ -103,6 +103,7 @@ func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int
return 0, false
}

// todo by liuwei: should consider the case of duplicate leader exists
func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[string]int64 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
Expand All @@ -119,56 +120,28 @@ func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[stri
return ret
}

func (m *ChannelDistManager) GetChannelDistByReplica(replica *Replica) map[string][]int64 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

ret := make(map[string][]int64)
for _, node := range replica.GetNodes() {
channels := m.channels[node]
for _, dmc := range channels {
if dmc.GetCollectionID() == replica.GetCollectionID() {
channelName := dmc.GetChannelName()
_, ok := ret[channelName]
if !ok {
ret[channelName] = make([]int64, 0)
}
ret[channelName] = append(ret[channelName], node)
}
}
}
return ret
}

func (m *ChannelDistManager) GetByCollection(collectionID UniqueID) []*DmChannel {
// return all channels in list which match all given filters
func (m *ChannelDistManager) GetByFilter(filters ...ChannelDistFilter) []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

ret := make([]*DmChannel, 0)
for _, channels := range m.channels {
for _, channel := range channels {
if channel.CollectionID == collectionID {
allMatch := true
for _, fn := range filters {
if fn != nil && !fn(channel) {
allMatch = false
}
}
if allMatch {
ret = append(ret, channel)
}
}
}
return ret
}

func (m *ChannelDistManager) GetByCollectionAndNode(collectionID, nodeID UniqueID) []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()

channels := make([]*DmChannel, 0)
for _, channel := range m.getByNode(nodeID) {
if channel.CollectionID == collectionID {
channels = append(channels, channel)
}
}

return channels
}

func (m *ChannelDistManager) Update(nodeID UniqueID, channels ...*DmChannel) {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
Expand Down
55 changes: 7 additions & 48 deletions internal/querycoordv2/meta/channel_dist_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,36 +66,36 @@ func (suite *ChannelDistManagerSuite) TestGetBy() {
dist := suite.dist

// Test GetAll
channels := dist.GetAll()
channels := dist.GetByFilter(nil)
suite.Len(channels, 4)

// Test GetByNode
for _, node := range suite.nodes {
channels := dist.GetByNode(node)
channels := dist.GetByFilter(WithNodeID2Channel(node))
suite.AssertNode(channels, node)
}

// Test GetByCollection
channels = dist.GetByCollection(suite.collection)
channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection))
suite.Len(channels, 4)
suite.AssertCollection(channels, suite.collection)
channels = dist.GetByCollection(-1)
channels = dist.GetByFilter(WithCollectionID2Channel(-1))
suite.Len(channels, 0)

// Test GetByNodeAndCollection
// 1. Valid node and valid collection
for _, node := range suite.nodes {
channels := dist.GetByCollectionAndNode(suite.collection, node)
channels := dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(node))
suite.AssertNode(channels, node)
suite.AssertCollection(channels, suite.collection)
}

// 2. Valid node and invalid collection
channels = dist.GetByCollectionAndNode(-1, suite.nodes[1])
channels = dist.GetByFilter(WithCollectionID2Channel(-1), WithNodeID2Channel(suite.nodes[1]))
suite.Len(channels, 0)

// 3. Invalid node and valid collection
channels = dist.GetByCollectionAndNode(suite.collection, -1)
channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(-1))
suite.Len(channels, 0)
}

Expand Down Expand Up @@ -148,47 +148,6 @@ func (suite *ChannelDistManagerSuite) TestGetShardLeader() {
suite.Equal(leaders["dmc1"], suite.nodes[1])
}

func (suite *ChannelDistManagerSuite) TestGetChannelDistByReplica() {
replica := NewReplica(
&querypb.Replica{
CollectionID: suite.collection,
},
typeutil.NewUniqueSet(11, 22, 33),
)

ch1 := &DmChannel{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: "test-channel1",
},
Node: 11,
Version: 1,
}
ch2 := &DmChannel{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: "test-channel1",
},
Node: 22,
Version: 1,
}
ch3 := &DmChannel{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: "test-channel2",
},
Node: 33,
Version: 1,
}
suite.dist.Update(11, ch1)
suite.dist.Update(22, ch2)
suite.dist.Update(33, ch3)

dist := suite.dist.GetChannelDistByReplica(replica)
suite.Len(dist["test-channel1"], 2)
suite.Len(dist["test-channel2"], 1)
}

func (suite *ChannelDistManagerSuite) AssertNames(channels []*DmChannel, names ...string) bool {
for _, channel := range channels {
hasChannel := false
Expand Down
2 changes: 1 addition & 1 deletion internal/querycoordv2/observers/replica_observer.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ func (ob *ReplicaObserver) checkNodesInReplica() {
)

for node := range outboundNodes {
channels := ob.distMgr.ChannelDistManager.GetByCollectionAndNode(collectionID, node)
segments := ob.distMgr.SegmentDistManager.GetByFilter(meta.WithCollectionID(collectionID), meta.WithNodeID(node))
channels := ob.distMgr.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))

if len(channels) == 0 && len(segments) == 0 {
replica.RemoveNode(node)
Expand Down

0 comments on commit eaa7298

Please sign in to comment.