In [1]:
from spark_env.env import Environment
import numpy as np
import networkx as nx
from matplotlib import pyplot as plt
import dgl
import torch
from spark_env.job_dag import JobDAG
from spark_env.node import Node
import torch.nn as nn
from dgl.nn.pytorch import GraphConv

class GCN(nn.Module):

    def __init__(self, features=5, hidden_layer_size=10, embedding_size=10, layers=3):
        super(GCN, self).__init__()
        self.conv = []
        self.layers = layers
        self.conv.append(GraphConv(features, hidden_layer_size))
        for i in range(layers-2):
            self.conv.append(GraphConv(hidden_layer_size, hidden_layer_size))
        self.conv.append(GraphConv(hidden_layer_size, embedding_size))

    def forward(self, g, inputs):
        h = inputs
        for i in range(self.layers):
            h = self.conv[i](g, h)
            if i != self.layers-1:
                h = torch.relu(h)
        return h

class EnvironmentWrapper:
    
    def __init__(self) -> None:
        self.env = Environment()
        self.offset = 0
        self.reset_prob = 5e-7
        self.net = GCN()
        self.max_exec = 100
        self.range = 50
        self.reset()
        self.frontier_nodes = []
        self.action_map = {}
        self.required_indices = []

    def translate_state(self, job_dags, source_job, num_source_exec, exec_commit, moving_executors, node_input_dim=5, job_input_dim=3):
        """
        Translate the observation to matrix form
        """

        # compute total number of nodes
        total_num_nodes = int(np.sum([job_dag.num_nodes for job_dag in job_dags]))

        # job and node inputs to feed
        node_inputs = np.zeros([total_num_nodes, node_input_dim])
        job_inputs = np.zeros([len(job_dags), job_input_dim])

        # sort out the exec_map
        exec_map = {}
        for job_dag in job_dags:
            exec_map[job_dag] = len(job_dag.executors)
        # count in moving executors
        for node in moving_executors.moving_executors.values():
            exec_map[node.job_dag] += 1
        # count in executor commit
        for s in exec_commit.commit:
            if isinstance(s, JobDAG):
                j = s
            elif isinstance(s, Node):
                j = s.job_dag
            elif s is None:
                j = None
            else:
                print('source', s, 'unknown')
                exit(1)
            for n in exec_commit.commit[s]:
                if n is not None and n.job_dag != j:
                    exec_map[n.job_dag] += exec_commit.commit[s][n]

        # gather job level inputs
        job_idx = 0
        for job_dag in job_dags:
            # number of executors in the job
            job_inputs[job_idx, 0] = exec_map[job_dag] / 20.0
            # the current executor belongs to this job or not
            if job_dag is source_job:
                job_inputs[job_idx, 1] = 2
            else:
                job_inputs[job_idx, 1] = -2
            # number of source executors
            job_inputs[job_idx, 2] = num_source_exec / 20.0

            job_idx += 1

        # gather node level inputs
        node_idx = 0
        job_idx = 0
        for job_dag in job_dags:
            for node in job_dag.nodes:

                # copy the feature from job_input first
                node_inputs[node_idx, :3] = job_inputs[job_idx, :3]

                # work on the node
                node_inputs[node_idx, 3] = \
                    (node.num_tasks - node.next_task_idx) * \
                    node.tasks[-1].duration / 100000.0

                # number of tasks left
                node_inputs[node_idx, 4] = \
                    (node.num_tasks - node.next_task_idx) / 200.0

                node_idx += 1

            job_idx += 1

        return node_inputs

    def reset(self):
        self.env.seed(np.random.randint(12, 512))
        self.env.reset(max_time=np.random.geometric(self.reset_prob))
        self.offset = 0

    def observe(self):
        job_graph, frontier_nodes, executor_limits, action_map, \
            job_dags, source_job, num_source_exec, \
               exec_commit, moving_executors = self.env.new_observation()

        if len(frontier_nodes) == 0:
            return torch.zeros(501)
        
        node_inputs = self.translate_state(job_dags, source_job, num_source_exec, exec_commit, moving_executors)
        logits = self.get_graph_embedding(job_graph, node_inputs, frontier_nodes, action_map)
        logits = logits.flatten()
        executor_limits /= self.max_exec
        if len(logits) < 500:
            logits = torch.cat([logits, torch.zeros(500-len(logits))])
        return torch.cat([logits, torch.tensor([executor_limits])])

    def step(self, action):
        direction, job, limit = action 
        direction = direction.item()
        job = job.item()
        limit = limit.item()
        

    def get_graph_embedding(self, job_graph, node_inputs, frontier_nodes, action_map):
        # convert the observation to required format
        dgl_graph = dgl.from_networkx(job_graph)
        node_inputs = torch.tensor(node_inputs)
        G = dgl.add_self_loop(dgl_graph)
        inputs = node_inputs.type(torch.FloatTensor)

        # compute the logits for the graph
        logits = self.net(G, inputs)

        # extract only the required indices (only the frontier nodes)
        # and generate a new summary graph
        frontier_indices = []
        self.frontier_nodes = []

        # loop over the frontier nodes seperating the features as well as 
        # adding nodes to the new environment
        for node in frontier_nodes:
            frontier_indices.append(action_map.inverse_map[node])

        required_indices = []
        for i in range(1, self.range):
            index = (self.offset+i)%len(frontier_indices)
            self.frontier_nodes.append(action_map.map[index])
            required_indices.append(frontier_indices[index])
            if len(required_indices) == len(frontier_indices):
                break

        # calculate features for new graph and convert the input to a required format
        logits = logits[required_indices]

        return logits



Using backend: pytorch


In [79]:
import copy
from collections import deque
import random
import time, datetime

class Net(nn.Module):

    def __init__(self, input_dim=500, output_dim=3):
        super().__init__()
        self.static = nn.Sequential(
            nn.Linear(input_dim, 500),
            nn.ReLU(),
            nn.Linear(500, 250),
            nn.ReLU(),
            nn.Linear(250, 200),
            nn.ReLU(),
            nn.Linear(200, 150),
            nn.ReLU(),
            nn.Linear(150, output_dim)
        )

        self.dynamic = copy.deepcopy(self.online)

    def forward(self, input, model="static"):
        if model == "static":
            return self.static(input)
            
        return self.dynamic(input)

class Agent():

    def __init__(self, state_dim=500, action_dim=(3, 50, 100), save_dir="."):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir

        self.use_cuda = torch.cuda.is_available()

        # DNN to predict the most optimal action
        self.net = Net(self.state_dim, len(self.action_dim)).float()
        self.net = self.net.to(device="cuda")

        self.exploration_rate = 1
        self.exploration_rate_decay = 0.99999975
        self.exploration_rate_min = 0.1
        self.curr_step = 0

        self.memory = deque(maxlen=100000)
        self.batch_size = 32

        self.save_every = 5e3  # no. of experiences

        self.gamma = 0.9

        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
        self.loss_fn = torch.nn.CrossEntropyLoss()

        self.burnin = 1e4  # min. experiences before training
        self.learn_every = 3  # no. of experiences between updates to Q_online
        self.sync_every = 1e4  # no. of experiences between Q_target & Q_online sync

    def act(self, state):
        # EXPLORE
        if np.random.rand() < self.exploration_rate:
            direction, job, executor = self.action_dim
            action_idx = (np.random.randint(direction), np.random.random(job), np.random.random(executor))

        # EXPLOIT
        else:
            state = state.__array__()
            if self.use_cuda:
                state = torch.tensor(state).cuda()
            else:
                state = torch.tensor(state)
            state = state.unsqueeze(0)
            action_values = self.net(state, model="dynamic")
            action_idx = (action_values * 10**3).round() / (10**3)

        # decrease exploration_rate
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

        # increment step
        self.curr_step += 1
        return action_idx

    def cache(self, state, next_state, action, reward, done):
        state = torch.tensor(state).cuda()
        next_state = torch.tensor(next_state).cuda()
        action = torch.tensor([action]).cuda()
        reward = torch.tensor([reward]).cuda()
        done = torch.tensor([done]).cuda()
        self.memory.append((state, next_state, action, reward, done,))

    def recall(self):
        batch = random.sample(self.memory, self.batch_size)
        state, next_state, action, reward, done = map(torch.stack, zip(*batch))
        return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()

    def td_estimate(self, state, action):
        current_Q = self.net(state, model="static")[
            np.arange(0, self.batch_size), action
        ]
        return current_Q

    @torch.no_grad()
    def td_target(self, reward, next_state, done):
        next_state_Q = self.net(next_state, model="static")
        best_action = torch.argmax(next_state_Q, axis=1)
        next_Q = self.net(next_state, model="dynamic")[
            np.arange(0, self.batch_size), best_action
        ]
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()

    def update_Q_online(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sync_Q_target(self):
        self.net.target.load_state_dict(self.net.online.state_dict())

    def save(self):
        save_path = (
            self.save_dir / f"sched_net_{int(self.curr_step // self.save_every)}.chkpt"
        )
        torch.save(
            dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),
            save_path,
        )
        print(f"Sched_net saved to {save_path} at step {self.curr_step}")

    def learn(self):
        if self.curr_step % self.sync_every == 0:
            self.sync_Q_target()

        if self.curr_step % self.save_every == 0:
            self.save()

        if self.curr_step < self.burnin:
            return None, None

        if self.curr_step % self.learn_every != 0:
            return None, None

        # Sample from memory
        state, next_state, action, reward, done = self.recall()

        # Get TD Estimate
        td_est = self.td_estimate(state, action)

        # Get TD Target
        td_tgt = self.td_target(reward, next_state, done)

        # Backpropagate loss through Q_online
        loss = self.update_Q_online(td_est, td_tgt)

        return (td_est.mean().item(), loss)

class MetricLogger:
    def __init__(self, save_dir):
        self.save_log = save_dir / "log"
        with open(self.save_log, "w") as f:
            f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )
        self.ep_rewards_plot = save_dir / "reward_plot.jpg"
        self.ep_lengths_plot = save_dir / "length_plot.jpg"
        self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
        self.ep_avg_qs_plot = save_dir / "q_plot.jpg"

        # History metrics
        self.ep_rewards = []
        self.ep_lengths = []
        self.ep_avg_losses = []
        self.ep_avg_qs = []

        # Moving averages, added for every call to record()
        self.moving_avg_ep_rewards = []
        self.moving_avg_ep_lengths = []
        self.moving_avg_ep_avg_losses = []
        self.moving_avg_ep_avg_qs = []

        # Current episode metric
        self.init_episode()

        # Timing
        self.record_time = time.time()

    def log_step(self, reward, loss, q):
        self.curr_ep_reward += reward
        self.curr_ep_length += 1
        if loss:
            self.curr_ep_loss += loss
            self.curr_ep_q += q
            self.curr_ep_loss_length += 1

    def log_episode(self):
        "Mark end of episode"
        self.ep_rewards.append(self.curr_ep_reward)
        self.ep_lengths.append(self.curr_ep_length)
        if self.curr_ep_loss_length == 0:
            ep_avg_loss = 0
            ep_avg_q = 0
        else:
            ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
            ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
        self.ep_avg_losses.append(ep_avg_loss)
        self.ep_avg_qs.append(ep_avg_q)

        self.init_episode()

    def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0

    def record(self, episode, epsilon, step):
        mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
        mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
        mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
        mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
        self.moving_avg_ep_rewards.append(mean_ep_reward)
        self.moving_avg_ep_lengths.append(mean_ep_length)
        self.moving_avg_ep_avg_losses.append(mean_ep_loss)
        self.moving_avg_ep_avg_qs.append(mean_ep_q)

        last_record_time = self.record_time
        self.record_time = time.time()
        time_since_last_record = np.round(self.record_time - last_record_time, 3)

        print(
            f"Episode {episode} - "
            f"Step {step} - "
            f"Epsilon {epsilon} - "
            f"Mean Reward {mean_ep_reward} - "
            f"Mean Length {mean_ep_length} - "
            f"Mean Loss {mean_ep_loss} - "
            f"Mean Q Value {mean_ep_q} - "
            f"Time Delta {time_since_last_record} - "
            f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
        )

        with open(self.save_log, "a") as f:
            f.write(
                f"{episode:8d}{step:8d}{epsilon:10.3f}"
                f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
                f"{time_since_last_record:15.3f}"
                f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
            )

        for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
            plt.plot(getattr(self, f"moving_avg_{metric}"))
            plt.savefig(getattr(self, f"{metric}_plot"))
            plt.clf()

In [81]:
env = Environment()   

In [None]:
agent = Agent(state_dim=500, action_dim=3, save_dir=".")

logger = MetricLogger(".")

episodes = 10
for e in range(episodes):

    

    # Play the game!
    while True:

        # Run agent on the state
        action = agent.act(state)

        # Agent performs action
        next_state, reward, done = env.step(action)

        # Remember
        agent.cache(state, next_state, action, reward, done)

        # Learn
        q, loss = agent.learn()

        # Logging
        logger.log_step(reward, loss, q)

        # Update state
        state = next_state

        # Check if end of game
        if done:
            break

    logger.log_episode()

    if e % 20 == 0:
        logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)

2

In [56]:
# obs = env.observe()
# job_dags, source_job, num_source_exec, frontier_nodes, executor_limits, exec_commit, moving_executors, action_map = obs
# total_reward = 0
# loops = 0

# dag_list = []
# dag_map = {}

# while not done:
    
#     eq = np.random.randint(1, 5)
#     eq = min(eq, num_source_exec)

#     if(len(frontier_nodes)) == 0:
#         break

#     for node in frontier_nodes:

#         if len(dag_list) == 0:
#             dag_map[node.job_dag] = True
#             dag_list.append((node, action_map.inverse_map[node]))
#         elif node.job_dag in dag_map:
#             dag_list.append((node, action_map.inverse_map[node]))

#         obs, reward, done = env.step(node, eq)
#         break

#     total_reward += reward
#     if not done:
#         job_dags, source_job, num_source_exec, frontier_nodes, executor_limits, exec_commit, moving_executors, action_map = obs
#         print(moving_executors)
#     input()

# print(total_reward)   

In [88]:
# from matplotlib import pylab

# def save_graph(graph, file_name):
#     plt.figure(num=None, figsize=(20, 20), dpi=80)
#     plt.axis('off')
#     fig = plt.figure(1)
#     pos = nx.spring_layout(graph)
#     nx.draw_networkx_nodes(graph,pos)
#     nx.draw_networkx_edges(graph,pos)
#     nx.draw_networkx_labels(graph,pos)

#     plt.savefig(file_name,bbox_inches="tight")
#     pylab.close()
#     del fig

# save_graph(job_graph, "job_dag")