In [None]:
def parent(idx):
    return idx & (idx - 1)

def largestChild(idx):
    return idx | (idx - 1)

def calcStartIndices(n_summands):
    idx = 0
    result = list()
    for n in n_summands:
        result.append(idx)
        idx += n
    return result

def rankFromIndex(idx, startIndices):
    guard = [(len(startIndices), 1)]
    offset = [(rank, startIdx - idx) for rank, startIdx in enumerate(startIndices)] + guard
    # find first positive valued offset
    offset_is_positive = lambda rank_offset_tuple: rank_offset_tuple[1] > 0
    rank = next(filter(offset_is_positive, offset))[0]
    return rank - 1


"""
Given a list of summands that are assigned to each rank, how many partial sums must be sent out to another rank.
"""
def rankIntersectingNumbers(n_summands, startIndices):
    n_parts = sum(n_summands)
    is_rank_intersecting = lambda idx: rankFromIndex(idx, startIndices) != rankFromIndex(parent(idx), startIndices)
    
    rank_intersecting_summands = filter(is_rank_intersecting, range(0, n_parts))
    return list(rank_intersecting_summands)

def rankIntersectionCount(n_summands, startIndices):
    is_rank_intersecting = lambda startIndex, idx: parent(idx) < startIndex
    # calculate a boolean list of rank intersecting numbers for each rank
    rankIntersectingFlagsPerRank = map(lambda t:
           [is_rank_intersecting(t[0], i) for i in range(t[0], t[0] + t[1])], zip(startIndices, n_summands))
    
    countTrueValues = lambda x: reduce(lambda acc, flag: acc + 1 if flag else acc, x, 0)
    return sum(map(countTrueValues, rankIntersectingFlagsPerRank))

In [None]:
dataSize = {
    "multi100": 767,
    "dna_rokasD1": 1327505,
    "aa_rokasA8": 504850,
    "fusob": 1602,
    "dna_PeteD8": 3011099,
    "dna_rokasD4": 239763,
    "aa_rokasA4": 1806035,
    "354": 460,
    "prim": 898
}
ranks = 80

In [None]:
def messageBundling(ranks, totalParts):
    nPerRank = totalParts // ranks
    remainder = totalParts % ranks
    n_summands = [nPerRank + 1 for _ in range(remainder)] + [nPerRank for _ in range(ranks - remainder)]
    assert(len(n_summands) == ranks)
    assert(sum(n_summands) == totalParts)
    print(n_summands)
    startIndices = calcStartIndices(n_summands)
    
    ri_results = rankIntersectingNumbers(n_summands, startIndices)
    prev_rank = 0
    for ri in ri_results:
        rank = rankFromIndex(ri, startIndices)
        if rank != prev_rank:
            print("=" * 80)
            prev_rank = rank
        target = rankFromIndex(parent(ri), startIndices)
        computationalEffort = largestChild(ri) + 1 - ri
        print(f"Rank {rank:3}: {ri:5}, target: {target:3}, effort: {computationalEffort:5}")
        

In [None]:
messageBundling(ranks, dataSize["dna_rokasD4"])

In [None]:
from math import log2



def plot_tree(n_summands):
    ranks = len(n_summands)
    total_parts = sum(n_summands)
    
    max_digits = len(str(total_parts - 1))
    distance = 40 #px
    tree_height = log2(total_parts) + 1
    maxX = total_parts * distance
    maxY = tree_height * 50
    font_height = 15

    
    def trailing_zeros(x):
        if x == 0:
            return tree_height
        else:
            return bin(x)[::-1].find("1") # most in-efficient, as always
    
    
    svg = f"<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"-1 0 {maxX} {maxY+2*font_height}\">"
    svg += """
        <style>
            text {
                font-size: 15px;
                text-align: center;
            }
        </style>
        """
    
    # get the x-coordinate for a given index
    def idx2coordinate(idx):
        return idx * distance
    
    for idx in range(total_parts):
        text_x = idx2coordinate(idx)
        text_y = maxY+font_height
        svg += f"<text x=\"{text_x}\" y=\"{text_y}\">{idx:{max_digits}}</text>"
        
        center_x = text_x
        center_y = text_y - font_height
        if idx == 0:
            # straight line all the way up
            svg += f"<line x1=\"{center_x}\" y1=\"{center_y}\" x2=\"{center_x}\" y2=\"0\" stroke=\"black\" />"
        else:
            target_x = idx2coordinate(parent(idx))
            level = (tree_height - trailing_zeros(idx) - 1) / tree_height
            target_y = level * maxY
            #print(idx, level, tree_height, target_x, target_y)
            
            svg += f"""
                <path stroke="black" fill="none"
                d="M {center_x},{center_y} L {center_x},{target_y} L {target_x},{target_y}" /> 
            """
    
    svg += "</svg>"
    from IPython.display import HTML, display

    with open("test.svg", "w") as f:
        f.write(svg)
    display(HTML(svg))

#plot_tree([1,1,1,1,1,1,1,1,1,1,1])
#plot_tree([21, 21, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20])
#plot_tree([2998, 2998, 2998, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997, 2997])