In [None]:
%matplotlib inline
%reset -f
import matplotlib.pyplot as plt
import numpy as np
import random
from dataclasses import dataclass

In [None]:
@dataclass
class Patch:
    id : int
    data_count : int
    owner_node : int
    _1d_pos : float

In [None]:
def sort_patch_list_in_node(world_size,ptch_list):
    ptch_lst_node = [[] for i in range(world_size)]

    for i in ptch_list:
        ptch_lst_node[i.owner_node].append(i)
    
    return ptch_lst_node

def sort_patch_list_datacount(world_size,ptch_list):
    ptch_lst_node = [[] for i in range(world_size)]

    for i in ptch_list:
        ptch_lst_node[i.owner_node].append(i.data_count)
    
    return ptch_lst_node

def fill_with_0(world_size,ptch_lst_node):
    ptch_lst_node_filled = ptch_lst_node.copy()

    lens = [len(i) for i in ptch_lst_node_filled]
    mx = max(lens)

    for i in ptch_lst_node_filled:
        while len(i) < mx:
            i.append(0)
        
    return ptch_lst_node_filled

In [None]:
def plot_patch_list(world_size,ptch_list):

    bar_lst = [[ptch_list[j][i] for j in range(world_size)] for i in range(len(ptch_list[0]))]

    #print(bar_lst)

    cummul = np.array([0 for j in range(world_size)])

    for i in bar_lst:

        plt.bar(range(len(i)), i, bottom=cummul)

        cummul += np.array(i)

    plt.xlabel("nodes id")
    plt.ylabel("load")
    
    plt.show()
    

In [None]:
def balance_load(world_size,ptch_list):

    patch_list = ptch_list.copy()

    target_datacnt = sum([i.data_count for i in patch_list])/world_size
    #print("target datacnt :",target_datacnt)

    new_patch_list = []

    current_dtcnt = 0 
    current_node = 0

    for i in patch_list:
        a = i
        a.owner_node = current_node
        new_patch_list.append(a)

        current_dtcnt += i.data_count
        
        if(current_dtcnt > (current_node+1)*target_datacnt):
            current_node += 1
            
        
    

    return new_patch_list

In [None]:
def gen_patch_list(world_size,data_count_max,num):
    return [Patch(
        id = i,
        owner_node = random.randint(0,world_size-1),
        data_count = random.randint(1,data_count_max),
        _1d_pos = i
        ) for i in range(num)]

In [None]:
def compare_distrib(init,final):
    curve_init  = np.sum(init  ,axis=1)
    curve_final = np.sum(final ,axis=1)

    return np.mean(curve_init),np.std(curve_init),np.mean(curve_final),np.std(curve_final)

In [None]:
def metric_load_balancing(world_size,load_curve):

    total_load = np.sum(load_curve)

    t_calc_opti = total_load / world_size

    t_calc_real = np.max(load_curve)

    #print(t_calc_opti/t_calc_real)

    return t_calc_opti/t_calc_real

In [None]:
def get_stddev_sample(world_size,data_count_max,patch_cnt):
    patch_list = gen_patch_list(world_size,4,patch_cnt)
    final_bar_plot = fill_with_0(world_size,sort_patch_list_datacount(world_size,balance_load(world_size,patch_list)))
    
    load_curve = np.sum(final_bar_plot ,axis=1)

    metric_lb = metric_load_balancing(world_size,load_curve)

    return metric_lb

def get_stddev_map(data_count_max,world_size_arr,patch_cnt_arr,sample):

    map = np.array([[0. for patch_cnt in patch_cnt_arr] for world_size in world_size_arr])

    for i in range(sample):
        map += np.array([[get_stddev_sample(world_size,data_count_max,patch_cnt) for patch_cnt in patch_cnt_arr] for world_size in world_size_arr]) / sample
    
    #map = np.array([[get_stddev_sample(world_size,data_count_max,patch_cnt) for patch_cnt in patch_cnt_arr] for world_size in world_size_arr])

    return map




In [None]:
#will show the initial distribution of patchs on the nodes

#print(patch_list)
#print(sort_patch_list_datacount(patch_list))
#print(fill_with_0(sort_patch_list_datacount(patch_list)))
world_size = 100
patch_list = gen_patch_list(world_size,4,1000)

initial_bar_plot = fill_with_0(world_size,sort_patch_list_datacount(world_size,patch_list))

In [None]:
final_bar_plot = fill_with_0(world_size,sort_patch_list_datacount(world_size,balance_load(world_size,patch_list)))

In [None]:
load_curve = np.sum(final_bar_plot ,axis=1)

In [None]:
plot_patch_list(world_size,initial_bar_plot)

In [None]:
plot_patch_list(world_size,final_bar_plot)

In [None]:
print(compare_distrib(initial_bar_plot,final_bar_plot))
print(metric_load_balancing(world_size,load_curve))

In [None]:
world_size_arr = [i for i in range(2,100,1)]
patch_cnt_arr  = [i for i in range(2,1000,10)]

map = get_stddev_map(4
        ,world_size_arr,
        patch_cnt_arr,1)


In [None]:
def fmt(x):
    s = f"{x:.2f}"
    return rf"   {s}   " if plt.rcParams["text.usetex"] else f"  {s}   "



plt.imshow(map,
    extent = (min(patch_cnt_arr),max(patch_cnt_arr),min(world_size_arr),max(world_size_arr)),
    aspect = "auto",
    cmap = "gist_ncar", 
    vmin=0, 
    vmax=1)

plt.colorbar()

levels = [0.5,0.75,0.9,0.95,0.99]

CS = plt.contour(map,levels,
    extent = (min(patch_cnt_arr),max(patch_cnt_arr),min(world_size_arr),max(world_size_arr)),
    origin='upper', colors='k'
    )

plt.clabel(CS, CS.levels, inline=True, fmt=fmt, fontsize=10)


plt.xlabel("patch count")
plt.ylabel("world size")