In [14]:
from functools import reduce
from collections import deque


def parent(idx):
    return idx & (idx - 1)

def calcStartIndices(n_summands):
    idx = 0
    result = list()
    for n in n_summands:
        result.append(idx)
        idx += n
    return result

def rankFromIndex(idx, startIndices):
    guard = [(len(startIndices), 1)]
    offset = [(rank, startIdx - idx) for rank, startIdx in enumerate(startIndices)] + guard
    # find first positive valued offset
    offset_is_positive = lambda rank_offset_tuple: rank_offset_tuple[1] > 0
    rank = next(filter(offset_is_positive, offset))[0]
    return rank - 1


"""
Given a list of summands that are assigned to each rank, how many partial sums must be sent out to another rank.
"""
def rankIntersectingNumbers(n_summands, startIndices):
    n_parts = sum(n_summands)
    is_rank_intersecting = lambda idx: rankFromIndex(idx, startIndices) != rankFromIndex(parent(idx), startIndices)
    
    rank_intersecting_summands = filter(is_rank_intersecting, range(0, n_parts))
    return list(rank_intersecting_summands)

def rankIntersectionCount(n_summands, startIndices):
    is_rank_intersecting = lambda startIndex, idx: parent(idx) < startIndex
    # calculate a boolean list of rank intersecting numbers for each rank
    rankIntersectingFlagsPerRank = map(lambda t:
           [is_rank_intersecting(t[0], i) for i in range(t[0], t[0] + t[1])], zip(startIndices, n_summands))
    
    countTrueValues = lambda x: reduce(lambda acc, flag: acc + 1 if flag else acc, x, 0)
    return sum(map(countTrueValues, rankIntersectingFlagsPerRank))


def subtree_size(i):
    largest_child_idx = i | (i - 1)
    return largest_child_idx + 1 - i

"""
Determine the number of messages in O(n) where n is the number of messages.

Arguments:
  - ns must be a list with the number of values assigned to each rank.
    Each rank must be assigned at least 1 number!
"""
def message_count(ns):
    assert(min(ns) > 0)

    # No communication needed if there is only one rank
    if len(ns) <= 1:
        return 0

    # calculate the startIndices for each rank, add a guard element at the end.
    startIndices = deque([sum(ns[:i]) for i in range(len(ns)+1)])
    guard = startIndices[-1]
    
    # We can omit the first rank, since it does not need to send its results anywhere
    startIndices.popleft()
    
    idx = startIndices[0]
    message_count = 0
    
    while idx < guard:
        # Fetch the limits for the current rank
        begin = startIndices[0]
        end = startIndices[1]
        
        # Through the subtree_iteration we can assume that each index update yields
        # another rank-intersecting index.
        assert(parent(idx) < begin)
        message_count += 1
        
        # Update the index. If the whole subtree is local we can skip over it completely,
        # since its result will be sent in only one message. If the subtree is split over
        # one or more ranks, we update the index to the start index of the next rank.
        idx = min(idx + subtree_size(idx), end)
        
        if idx == end:
            startIndices.popleft()
    
    return message_count
    
    
# Calculate an ns array which sums n parts over m ranks.
def even_distribution(n, m):
    partsPerRank = int(n // m)
    remainder = int(n % m)
    
    return [partsPerRank + 1] * remainder + [partsPerRank] * (m - remainder)

In [17]:
import numpy as np
import random
from tqdm import tqdm

for m, n in tqdm([(random.randint(128, 1024), random.randint(1024, 2**10)) for _ in (range(5_000))]):
    ns = even_distribution(n, m)
    assert(rankIntersectionCount(ns, calcStartIndices(ns))) == message_count(ns)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:21<00:00, 229.19it/s]
