In [70]:
from spark_env.env import Environment
import numpy as np
import networkx as nx
from matplotlib import pyplot as plt
import dgl
import torch
from spark_env.job_dag import JobDAG
from spark_env.node import Node
import torch.nn as nn
from dgl.nn.pytorch import GraphConv

class GCN(nn.Module):

    def __init__(self, features=5, hidden_layer_size=10, embedding_size=10, layers=3):
        super(GCN, self).__init__()
        self.conv = []
        self.layers = layers
        self.conv.append(GraphConv(features, hidden_layer_size))
        for i in range(layers-2):
            self.conv.append(GraphConv(hidden_layer_size, hidden_layer_size))
        self.conv.append(GraphConv(hidden_layer_size, embedding_size))

    def forward(self, g, inputs):
        h = inputs
        for i in range(self.layers):
            h = self.conv[i](g, h)
            if i != self.layers-1:
                h = torch.relu(h)
        return h

def translate_state(obs, node_input_dim=5, job_input_dim=3):
        """
        Translate the observation to matrix form
        """
        job_dags, source_job, num_source_exec, frontier_nodes, executor_limits, \
        exec_commit, moving_executors, action_map = obs

        # compute total number of nodes
        total_num_nodes = int(np.sum([job_dag.num_nodes for job_dag in job_dags]))

        # job and node inputs to feed
        node_inputs = np.zeros([total_num_nodes, node_input_dim])
        job_inputs = np.zeros([len(job_dags), job_input_dim])

        # sort out the exec_map
        exec_map = {}
        for job_dag in job_dags:
            exec_map[job_dag] = len(job_dag.executors)
        # count in moving executors
        for node in moving_executors.moving_executors.values():
            exec_map[node.job_dag] += 1
        # count in executor commit
        for s in exec_commit.commit:
            if isinstance(s, JobDAG):
                j = s
            elif isinstance(s, Node):
                j = s.job_dag
            elif s is None:
                j = None
            else:
                print('source', s, 'unknown')
                exit(1)
            for n in exec_commit.commit[s]:
                if n is not None and n.job_dag != j:
                    exec_map[n.job_dag] += exec_commit.commit[s][n]

        # gather job level inputs
        job_idx = 0
        for job_dag in job_dags:
            # number of executors in the job
            job_inputs[job_idx, 0] = exec_map[job_dag] / 20.0
            # the current executor belongs to this job or not
            if job_dag is source_job:
                job_inputs[job_idx, 1] = 2
            else:
                job_inputs[job_idx, 1] = -2
            # number of source executors
            job_inputs[job_idx, 2] = num_source_exec / 20.0

            job_idx += 1

        # gather node level inputs
        node_idx = 0
        job_idx = 0
        for job_dag in job_dags:
            for node in job_dag.nodes:

                # copy the feature from job_input first
                node_inputs[node_idx, :3] = job_inputs[job_idx, :3]

                # work on the node
                node_inputs[node_idx, 3] = \
                    (node.num_tasks - node.next_task_idx) * \
                    node.tasks[-1].duration / 100000.0

                # number of tasks left
                node_inputs[node_idx, 4] = \
                    (node.num_tasks - node.next_task_idx) / 200.0

                node_idx += 1

            job_idx += 1

        return node_inputs


net = GCN()
env_net = GCN(features=10, hidden_layer_size=100, embedding_size=100)

In [71]:
env = Environment()
done = False
reset_prob = 5e-7
env.seed(234)
env.reset(max_time=np.random.geometric(reset_prob))
# nx.draw(job_graph, with_labels=True)
# plt.show()

In [67]:
# get the first obervation
job_graph, frontier_nodes, executor_limits, exec_graph, action_map = env.new_observation()
node_inputs = translate_state(env.observe())

# convert the observation to required format
dgl_graph = dgl.from_networkx(job_graph)
node_inputs = torch.tensor(node_inputs)
G = dgl.add_self_loop(dgl_graph)
inputs = node_inputs.type(torch.FloatTensor)

# compute the logits for the graph
logits = net(G, inputs)

# extract only the required indices (only the frontier nodes)
# and generate a new summary graph
required_indices = []
summary_graph = nx.DiGraph()
schedulable_nodes = len(frontier_nodes)
summary_graph.add_node(schedulable_nodes)

# loop over the frontier nodes seperating the features as well as 
# adding nodes to the new graph
for node in frontier_nodes:
    vertex = action_map.inverse_map[node]
    summary_graph.add_node(vertex)
    summary_graph.add_edge(vertex, schedulable_nodes)
    required_indices.append(action_map.inverse_map[node])

# calculate features for new graph and convert the input to a required format
e_input = logits[required_indices]
e_input = torch.cat((e_input, torch.ones(1, 10)), 0)
G_s = dgl.from_networkx(summary_graph)
G_s = dgl.add_self_loop(G_s)

# convert the summary graph to logits
logits_s = env_net(G_s, e_input)

In [None]:
class Net():

    def __init__(self) -> None:
        pass

In [None]:
from collections import deque

class Agent():

    def __init__(self, state_dim=100, action_dim=2, save_dir="."):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir

        self.use_cuda = torch.cuda.is_available()

        # Mario's DNN to predict the most optimal action - we implement this in the Learn section
        self.net = Net(self.state_dim, self.action_dim).float()
        if self.use_cuda:
            self.net = self.net.to(device="cuda")

        self.exploration_rate = 1
        self.exploration_rate_decay = 0.99999975
        self.exploration_rate_min = 0.1
        self.curr_step = 0

        self.save_every = 5e5  # no. of experiences between saving Mario Net

In [58]:
from collections import namedtuple, deque
import random

Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [89]:
# obs = env.observe()
# job_dags, source_job, num_source_exec, frontier_nodes, executor_limits, exec_commit, moving_executors, action_map = obs
# total_reward = 0
# loops = 0

# dag_list = []
# dag_map = {}

# while not done:
    
#     eq = np.random.randint(1, 5)
#     eq = min(eq, num_source_exec)

#     if(len(frontier_nodes)) == 0:
#         break

#     for node in frontier_nodes:

#         if len(dag_list) == 0:
#             dag_map[node.job_dag] = True
#             dag_list.append((node, action_map.inverse_map[node]))
#         elif node.job_dag in dag_map:
#             dag_list.append((node, action_map.inverse_map[node]))

#         obs, reward, done = env.step(node, eq)
#         break

#     total_reward += reward
#     if not done:
#         job_dags, source_job, num_source_exec, frontier_nodes, executor_limits, exec_commit, moving_executors, action_map = obs
#         print(moving_executors)
#     input()

# print(total_reward)   

In [88]:
# from matplotlib import pylab

# def save_graph(graph, file_name):
#     plt.figure(num=None, figsize=(20, 20), dpi=80)
#     plt.axis('off')
#     fig = plt.figure(1)
#     pos = nx.spring_layout(graph)
#     nx.draw_networkx_nodes(graph,pos)
#     nx.draw_networkx_edges(graph,pos)
#     nx.draw_networkx_labels(graph,pos)

#     plt.savefig(file_name,bbox_inches="tight")
#     pylab.close()
#     del fig

# save_graph(job_graph, "job_dag")