In [1]:
import math
import time
import string
import random
import itertools
import requests

In [2]:
def find_combinations(param_grid):
    param_keys = list(param_grid.keys())

    params_list = [param_grid[key] for key in param_keys]
    combinations = list(itertools.product(*params_list))

    param_combinations = []
    for comb in combinations:
        d = {}
        for i in range(len(comb)):
            d[param_keys[i]] = comb[i]
        param_combinations.append(d)

    return param_combinations

In [3]:
def create_model_groups(models, workers):    
    gpu_counts = [len(i) for i in workers.values()]
    avg_gpus = math.ceil(sum(gpu_counts)/len(gpu_counts))
    
    model_groups = []
    for i in range(0, len(models), avg_gpus):
        model_groups.append(tuple(models[i:i + avg_gpus]))
    
    return model_groups

In [4]:
def get_runnable_model_group(worker, model_groups, model_group_on_worker, mgw_pairs):
        runnable_model_group = -1
        random.shuffle(model_groups)
        for mg in model_groups:
            if not (mgw_pairs[mg][worker]):
                if model_group_on_worker[mg] == -1:
                    runnable_model_group = mg
                    break
        return runnable_model_group

In [5]:
def check_finished(worker, exec_id):
    with open("check_finished_" + str(worker) + ".txt", "r") as f:
        s = f.read().split("\n")
    return exec_id in s

In [6]:
models = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
workers = {0: ["GPU1", "GPU2"],
           1: ["GPU1", "GPU2"],
           2: ["GPU1", "GPU2"],
           3: ["GPU1", "GPU2"]
            }
train_shards = ["shard1", "shard2", "shard3", "shard4"]

In [7]:
param_grid = {
        'learning_rate': [1e-2, 1e-3],
        'embed_size': [256, 512],
        'hidden_size': [256, 512],
        'batch_size': [128]
    }

In [8]:
def init_stuff():
    param_combinations = find_combinations(param_grid)
    print("Grid Search space:")
    print(param_combinations)
    
    model_id_to_mst_mapping = {}
    for i in range(len(models)):
        model_id_to_mst_mapping[models[i]] = param_combinations[i]
    print(model_id_to_mst_mapping)
        
    model_groups = create_model_groups(models, workers)
    print("Model Groups:")
    print(model_groups)
        
    model_group_on_worker = {}
    for i in range(len(model_groups)):
        model_group_on_worker[model_groups[i]] = -1
    
    worker_running_model_group = {}
    for w in workers:
        worker_running_model_group[w] = -1
        
#     mgw_pairs = [[False] * len(workers)] * len(model_groups)
    mgw_pairs = {}
    for mg in model_groups:
        mgw_pairs[mg] = [False] * len(workers)
    
    return model_groups, model_group_on_worker, worker_running_model_group, mgw_pairs

In [9]:
def launch_hydra_job(epoch, worker, mg):
    print("Scheduling epoch {} of model_group {} on worker {}".format(epoch, mg, worker))
    
    exec_id = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(32))
    data = {
        "epoch": epoch,
        "models": str(mg),
        "exec_id": str(exec_id)
    }
    worker_ip = "http://localhost:" + str(8000 + worker) + "/hydra"
    requests.post(url=worker_ip, json=data)

    return exec_id

In [10]:
# TODO: add validation stuff: is_last_worker, etc.
def scheduler(epoch, model_groups, workers, model_group_on_worker, mgw_pairs, worker_running_model_group):
    model_groups_to_build = set(model_groups)
    exec_id_on_worker = {x:None for x in range(len(workers))}
    
    
    while(len(model_groups_to_build) > 0):
        for worker in workers:
            if worker_running_model_group[worker] == -1:
                mg = get_runnable_model_group(worker, model_groups, model_group_on_worker, mgw_pairs)
                if mg != -1:
                    exec_id = launch_hydra_job(epoch, worker, mg)
                    
                    model_group_on_worker[mg] = worker
                    worker_running_model_group[worker] = mg
                    exec_id_on_worker[worker] = exec_id
                    print("Sent models {} to build on worker {} ".format(
                            str(mg), str(worker)))
            else:
                mg = worker_running_model_group[worker]
                exec_id = exec_id_on_worker[worker]
                completed = check_finished(worker, exec_id)
#                 print("In else, checking {}".format(exec_id))

                if completed:
                    print("Received Models {} built on worker {}".format(str(mg), str(worker)))
                    model_group_on_worker[mg] = -1
                    worker_running_model_group[worker] = -1
                    mgw_pairs[mg][worker] = True

                    model_group_done = True
                    for i in range(len(workers)):
                        if not mgw_pairs[mg][i]:
                            model_group_done = False
                            break
                    if model_group_done:
                        model_groups_to_build.remove(mg)
    print("Done with all epochs :)")

In [11]:
def grid_search():
    num_epochs = 2
    
    for i in range(num_epochs):
        print("EPOCH: " + str(i+1))
        model_groups, model_group_on_worker, worker_running_model_group, mgw_pairs = init_stuff()
        scheduler(i, model_groups, workers, model_group_on_worker, mgw_pairs, worker_running_model_group)

In [12]:
grid_search()

EPOCH: 1
Grid Search space:
[{'learning_rate': 0.01, 'embed_size': 256, 'hidden_size': 256, 'batch_size': 128}, {'learning_rate': 0.01, 'embed_size': 256, 'hidden_size': 512, 'batch_size': 128}, {'learning_rate': 0.01, 'embed_size': 512, 'hidden_size': 256, 'batch_size': 128}, {'learning_rate': 0.01, 'embed_size': 512, 'hidden_size': 512, 'batch_size': 128}, {'learning_rate': 0.001, 'embed_size': 256, 'hidden_size': 256, 'batch_size': 128}, {'learning_rate': 0.001, 'embed_size': 256, 'hidden_size': 512, 'batch_size': 128}, {'learning_rate': 0.001, 'embed_size': 512, 'hidden_size': 256, 'batch_size': 128}, {'learning_rate': 0.001, 'embed_size': 512, 'hidden_size': 512, 'batch_size': 128}]
{'A': {'learning_rate': 0.01, 'embed_size': 256, 'hidden_size': 256, 'batch_size': 128}, 'B': {'learning_rate': 0.01, 'embed_size': 256, 'hidden_size': 512, 'batch_size': 128}, 'C': {'learning_rate': 0.01, 'embed_size': 512, 'hidden_size': 256, 'batch_size': 128}, 'D': {'learning_rate': 0.01, 'embed_si