# dType Data Balancing on Shards

In [21]:
import random

In [22]:
shard_count = 20
dtype_count = 100
max_shard_load = 400
max_dtype_load = 2000

# Increase average_coef if there are not enough shards for all dtypes 
average_coef = 1.2

In [23]:
# Initialize random load values for shards and dtypes
shard_loads_initial = list(enumerate([random.randrange(i, max_shard_load) for i in range(shard_count)]))
dtype_loads_initial = list(enumerate([random.randrange(i, max_dtype_load) for i in range(dtype_count)]))
shards = [[] for i in range(shard_count)]

print('shard_loads_initial', shard_loads_initial)
print('dtype_loads_initial', dtype_loads_initial)

shard_loads_initial [(0, 53), (1, 137), (2, 217), (3, 308), (4, 325), (5, 175), (6, 11), (7, 94), (8, 90), (9, 372), (10, 335), (11, 100), (12, 330), (13, 286), (14, 142), (15, 79), (16, 201), (17, 352), (18, 19), (19, 43)]
dtype_loads_initial [(0, 610), (1, 615), (2, 1642), (3, 134), (4, 1758), (5, 859), (6, 1788), (7, 1776), (8, 1181), (9, 702), (10, 83), (11, 828), (12, 1573), (13, 833), (14, 448), (15, 101), (16, 1699), (17, 1242), (18, 97), (19, 709), (20, 1404), (21, 1051), (22, 475), (23, 928), (24, 1975), (25, 719), (26, 643), (27, 1611), (28, 1323), (29, 1333), (30, 335), (31, 124), (32, 587), (33, 1775), (34, 690), (35, 213), (36, 1933), (37, 80), (38, 1897), (39, 318), (40, 949), (41, 860), (42, 1280), (43, 1289), (44, 698), (45, 1792), (46, 1182), (47, 649), (48, 252), (49, 1997), (50, 916), (51, 297), (52, 89), (53, 220), (54, 214), (55, 97), (56, 587), (57, 1989), (58, 426), (59, 536), (60, 1524), (61, 993), (62, 1236), (63, 1617), (64, 113), (65, 306), (66, 969), (67, 13

In [24]:
next_index_s = 0
next_index_dt = 0
last_index_dt = len(dtype_loads_initial) - 1
last_index_s = len(shard_loads_initial) - 1

# Sort loads: ascending for shards, descending for dtypes
shard_loads = sorted(shard_loads_initial, key=lambda tup: tup[1])
dtype_loads = sorted(dtype_loads_initial, key=lambda tup: tup[1], reverse=True)

# Calculate average count per shard
average_load_shard = (sum(i[1] for i in dtype_loads) + sum(i[1] for i in shard_loads)) / shard_count
average_load_shard *= average_coef
print('average_load_shard', average_load_shard)

# Move heavier than average dtypes on the least heaviest shards
for i, dload in dtype_loads:
    if dload >= average_load_shard:
        shards[next_index_s].append(i)
        next_index_s += 1
        next_index_dt += 1


# Pair heaviest dtypes with lightest shards
# and add as many light dtypes on top, as possible
for i, dload in dtype_loads[next_index_dt:]:
    if last_index_s < next_index_s:
        print('Needs more shards. Increase average_coef');
        break

    # Add the next heaviest dtype to the next lightest shard
    shards[next_index_s].append(i)

    # Add as many light dtypes as the average_load_shard permits
    load = shard_loads[next_index_s][1] + dload + dtype_loads[last_index_dt][1]
    while last_index_dt > next_index_dt and load <= average_load_shard:
        shards[next_index_s].append(dtype_loads[last_index_dt][0])
        last_index_dt -= 1
        load += dtype_loads[last_index_dt][1]

    next_index_s += 1
    next_index_dt += 1
    if next_index_dt > last_index_dt:
        break

print('(shard_index, shard_load, dtype_indexes)')

final_shards = [(shard_loads[x][0], sum([dtype_loads_initial[dtype_index][1] for dtype_index in shards[x]]), shards[x]) for x, _ in enumerate(shards)]

print('final_shards', final_shards)

average_load_shard 5655.839999999999
(shard_index, shard_load, dtype_indexes)
final_shards [(6, 5619, [49, 37, 10, 52, 55, 18, 15, 64, 31, 3, 98, 35, 54, 53, 73, 87, 48, 70, 51, 84, 65]), (18, 5137, [57, 39, 86, 30, 72, 58, 97, 14, 22]), (19, 5419, [24, 59, 78, 75, 56, 32, 0]), (0, 5140, [36, 1, 88, 26, 47, 90]), (15, 5415, [38, 34, 44, 9, 19, 25]), (8, 4964, [45, 74, 99, 96, 11]), (7, 5232, [6, 13, 5, 41, 89]), (11, 5512, [7, 50, 23, 93, 40]), (1, 4695, [33, 66, 77, 83]), (14, 4807, [4, 61, 71, 85]), (5, 5092, [82, 21, 79, 8]), (16, 5272, [16, 69, 46, 81]), (2, 5430, [80, 62, 17, 92]), (13, 4230, [91, 42, 43]), (3, 4246, [2, 76, 95]), (4, 4249, [63, 67, 68]), (12, 4267, [27, 28, 29]), (10, 4345, [12, 94, 20]), (17, 1524, [60]), (9, 0, [])]
