In [259]:
from ACLScheduleTest.Simulate import Simulator
from ACLScheduleTest.Analysis import Analyst
from ACLScheduleTest.data import RandomDsk
import numpy as np
from time import time

rd = RandomDsk(
    mute=True, node_num=100,
    st_node_num=10,
    min_time=1, time_lv=0, time_rv=0, min_time_step=0,
    avg_mem_peek=1, mem_pv=0, mem_sv=1
)
dsk = rd.build()
sim = Simulator()

In [260]:
num_workers, num_nodes = 10, 100

information = np.ones((num_nodes, num_workers))
def topoOrder(workers, cache, av_mem, factor_node_dict, calculated_node_dict, dsk, global_dict):
    topoQueue = []
    matrix = np.zeros((num_nodes, num_nodes))
    nodeList = list(dsk.keys())
    for idx, node in enumerate(nodeList):
        for n in dsk[node][1:]:
            if n in nodeList:
                id = nodeList.index(n)
                matrix[idx, id] = 1 # idx depends on id
    while len(topoQueue) < len(nodeList):
        for idx, node in enumerate(nodeList):
            if node in topoQueue:
                continue
            if matrix[idx].sum() == 0:
                topoQueue.append(node)
                matrix[:, idx] = 0
    return topoQueue

import math
def reward(x):
    baseline = 5
    return math.exp(x - baseline)

## Ant as Agent

In [261]:
current_worker_id = 0
iter = 0
from schedule_func_backup import schedule_out_func
def search(dsk):
    global current_worker_id, iter
    current_worker_id, iter = 0, 0
    worker_schedules = []
    worker_traces = np.zeros((num_nodes, num_workers))
    for idx in range(num_workers):
        worker_schedule = [idx]
        worker_schedules.append(worker_schedule)
        worker_traces[idx, idx] = 1
    for node_id in range(num_workers, num_nodes):
        prob = information[node_id] / information[node_id].sum()
        select_worker = np.random.choice(range(num_workers), p=prob)
        worker_schedules[select_worker].append(node_id)
        worker_traces[node_id, select_worker] = 1
    #print(worker_schedules)
    def search_n_func(workers, cache, av_mem, factor_node_dict, calculated_node_dict, dsk, global_dict):
        global current_worker_id, iter
        ori_id = current_worker_id

        current_worker_id += 1
        if current_worker_id >= num_workers:
            current_worker_id = 0
            iter += 1
            #print(f'\niter {iter}')
        #print(current_worker_id)
        if av_mem == 0:
            return None
        if len(worker_schedules[ori_id])==0:
            return None
        node_id = worker_schedules[ori_id][0]
        node = f'node{node_id}'
        nodeList = set(dsk.keys())
        #print(node)

        pre_nodes = [arg for arg in dsk[node][1:] if arg in nodeList]
        #print(pre_nodes)
        for arg in pre_nodes:
            if arg in nodeList and arg not in cache.keys():
                #print('None\n', '-'*100)
                return None
        worker_schedules[ori_id].remove(node_id)
        #print(node, '\n', '-'*100)
        return node

    result = sim.simulate(search_n_func, schedule_out_func, mem_bound=50, worker_bound=10,dsk=dsk,rd=rd, compare_mode=True)
    analyst = Analyst()
    sucess_rate, speed_up = analyst.run(result=[[result]])
    speed_up = speed_up.values.item()
    rewarded = np.array([reward(speed_up)]) * sucess_rate
    return speed_up, rewarded * worker_traces


## Train

In [262]:
rd = RandomDsk(mute=True)
def train(num_epochs, num_agent, alpha):
    global  information
    for epoch in range(num_epochs):
        start_time = time()
        new_information = np.zeros((num_nodes, num_workers))
        speed_up_total = 0
        for agent_idx in range(num_agent):
            speed_up , rewarded = search(dsk)
            new_information += rewarded
            speed_up_total += speed_up
            information = alpha * information + new_information
        print(f'epoch {epoch+1}: speed_up = {speed_up_total / num_agent:.3f}. Use time: {time()-start_time}')
train(20, 50, 0.2)

epoch 1: speed_up = 5.274. Use time: 11.021054983139038
epoch 2: speed_up = 5.445. Use time: 10.648623943328857
epoch 3: speed_up = 5.626. Use time: 10.647620916366577
epoch 4: speed_up = 5.876. Use time: 9.904424905776978
epoch 5: speed_up = 5.882. Use time: 9.708353042602539
epoch 6: speed_up = 5.882. Use time: 10.11555004119873
epoch 7: speed_up = 5.882. Use time: 9.555318832397461
epoch 8: speed_up = 5.882. Use time: 9.5466628074646
epoch 9: speed_up = 5.882. Use time: 9.926537990570068
epoch 10: speed_up = 5.882. Use time: 9.721480369567871
epoch 11: speed_up = 5.882. Use time: 9.839972019195557
epoch 12: speed_up = 5.882. Use time: 9.602741956710815
epoch 13: speed_up = 5.882. Use time: 9.579833030700684
epoch 14: speed_up = 5.882. Use time: 9.581772089004517
epoch 15: speed_up = 5.882. Use time: 9.568605184555054
epoch 16: speed_up = 5.882. Use time: 9.672024965286255
epoch 17: speed_up = 5.882. Use time: 9.660619020462036
epoch 18: speed_up = 5.882. Use time: 9.7033531665802
ep

## schedule_in_func 

## schedule_out_func