In [1]:
%config Completer.use_jedi = False

In [2]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp

In [3]:
import gym
from gym.spaces import MultiDiscrete

In [4]:
from copy import copy, deepcopy

In [5]:
def get_random_queues(n: int, beta: float = 0.8):
    free = [True]*n
    out = []
    for i in np.random.permutation(n):
        if not free[i]:
            continue
        free[i] = False
        
        while np.random.rand() < beta:
            try:
                j = np.random.choice(np.where(free)[0])
            except:
                return out
            free[j] = False
            out.append([i, j])
            i = j
    return out

def count_q_length(_from, _to, n):
    counts, prev_counts = torch.zeros(n), torch.zeros(n)
    counts[_from] = 1
    while not all(counts == prev_counts):
        prev_counts = deepcopy(counts)
        counts[_from] = counts[_to]+1
    return counts

In [6]:
_from, _to = torch.tensor(get_random_queues(10)).T        
graph_data = {('job', 'precede', 'job'): (_from, _to), # A ---before---> B
              ('job', 'next', 'job'): (torch.tensor([0]), torch.tensor([0])), # jobshop queue
              ('worker', 'processing', 'job'): (torch.tensor([0]), torch.tensor([0])) # nothing is scheduled
             }

In [7]:
g = dgl.heterograph(graph_data, num_nodes_dict={'worker':2 , 'job':10})
g.remove_edges(0, 'processing')
g.remove_edges(0, 'next')

g.nodes['job'].data['hv'] = torch.rand(10, 7) #TODO
g.nodes['worker'].data['he'] = torch.rand(2, 3) #torch.eye(2) -- one hot encoding

In [8]:
def construct_readout_graph(g, etype):
    """ etype = use .canonical_etypes() """
    utype, _, vtype = etype
    nu, nv = g.num_nodes(utype), g.num_nodes(vtype)
    src, dst = g.nodes(utype).repeat_interleave(nv), g.nodes(vtype).repeat(nu)
    
    return dgl.heterograph({etype: (src, dst)},
                           num_nodes_dict={utype: nu, vtype: nv})

In [9]:
import dgl.function as fn

class dotProductPredictor(nn.Module):
    """ returns scores for each job (row) per worker (col)"""
    def forward(self, graph, hv, he, _etype):
        # hv contains the node representations computed from the GNN
        utype, etype, vtype = _etype
        with graph.local_scope():
            graph.nodes[vtype].data['hv'] = hv
            graph.nodes[utype].data['he'] = he
            graph.apply_edges(fn.u_dot_v('he', 'hv', 'score'), etype=etype)
            return graph.edges[etype].data['score'].view(-1, hv.shape[0]).T

```python
g.nodes('job')
g.ntypes
g.etypes
g.canonical_etypes
```

In [10]:
class agent(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = dotProductPredictor()
    
    def forward(self, g, x, etype):        
        hv = self.sage(g, x)
        he = g.nodes['worker'].data['he']
        rg = construct_readout_graph(g, ('worker', 'processing', 'job'))
        return self.pred(rg, hv, he, ('worker', 'processing', 'job'))

In [11]:
class jobShopScheduling(gym.Env):
    """
    ### Description
    Learning to schedule a successful sequence of “job” to multiple workers respecting given constraints. 
    
    ### Action Space
    By adding an edge from a worker to an unscheduled job, the job gets queued to that thread.
    The resulting sequence can not be chnaged in hindsight.
    
    ### State Space    
    A disjunctive heterogeneous graph g = (V, C U D). Each node represents a “job” or a “worker”. 
    Edges in C denote succession requirements for jobs, edges in D denotes which jobs were assigned to 
    which worker. 
    
    ### Rewards
    The system recieves a positive unit reward for each executed job. And a penalty per time step.
    
    ### Starting State
    A random set of n jobs, including time requirements and succession constraints, e.g., task i requires 
    completion of task j.
    
    ### Episode Termination
    The episode terminates when all jobs have been scheduled. Then the action space has schunken to size 0.
    The final reward tallies up the remaining rewards to be versed (w/o time discounting).
    
    ### Arguments
    No additional arguments are currently supported.
    """

    def __init__(self, njobs: int, nworkers: int):
        self._njobs = njobs
        self._nworkers = nworkers
        self._jfeat = 7
        self._wfeat = 3
        self._dt = 0.1
        self._time_penalty = -.1
        
        self._state = None
        
    def reward(self, a):
        assert False, "Not implemented. Do not call."
    
    def terminal(self):
        # Terminal state is reached when all the jobs have been scheduled. |A| is zero.
        return all(self._state.nodes['job'].data['hv'][:, 3] == 1)
    
    def worker_features(self):
        return ('n queued', 'expected run time', 'efficiency rate')
    
    def job_features(self):
        return ('time req', 
                'completion%', #1
                'nr of child nodes', #2 
                'status (one hot: scheduled, processing, finished)', #3-4-5
                'remaining time') #6
    
    def valid_action(self, a):
        _, j = a
        return self._state.nodes['job'].data['hv'][j, 3] == 0
    
    def check_job_requirements(self, j):
        # Return True if no incoming edges from preceding job requirements.
        _, dst = self._state.edges(etype='precede')
        return all(dst != j)
    
    def rollout(self, verbose=False):
        # Return number of jobs complete if we just waited until all workers exit (done of gridlock)
        # Does not take into account discount factor!
        state_hv = deepcopy(self._state.nodes['job'].data['hv'])
        state_he = deepcopy(self._state.nodes['worker'].data['he'])
        
        jdone = state_hv[:, 5] == 1
        
        reward = 0
        src, dst = deepcopy(self._state.edges(etype='processing'))
        sreq, dreq = deepcopy(self._state.edges(etype='precede'))
        
        while True:
            idx = [dst[src==w][0].item() for w in src.unique().tolist()]
            idx = [j for j in idx if all(jdone[sreq[dreq==j]])]
            if len(idx) == 0:
                break # gridlock
            
            # get smallest remaining time for idx. -(.dt)
            j = idx[state_hv[idx, 6].argmin().item()]
            if verbose:
                print("executing job", j, "on worker", src[dst==j].item())
            jdone[j] = True
            reward += 1 - state_hv[j, 6]
            state_hv[idx, 6] -= state_hv[j, 6] #mark that job as done
            
            # remove job from queue
            src = src[dst!=j]
            dst = dst[dst!=j]
                
        return reward, all(jdone)
        
    
    def step(self, a):
        assert self.valid_action(a), "Invalid action taken."
        
        src, dst, cnts = self._state.edges('all', etype='processing')
        
        """ 
        1) Schedule job j for worker w: 
            a) Find last job scheduled for worker w, add edge from end of queue to new job j. 
            b) Add edge from w to j. 
            c) Update worker info (queue length, run time estimate).
            d) Mark job as scheduled.
        """
        w, j = a
        if w in src:
            _i = dst[src==w][-1].item() # add to end of q -- last edge added
            self._state.add_egde(_i, j, etype='next')        
        self._state.add_egde(w, j, etype='processing')
        
        state_hv = deepcopy(self._state.nodes['job'].data['hv'])
        state_he = deepcopy(self._state.nodes['worker'].data['he'])
        
        state_he[w, 0] += 1. # add job to work queue length
        state_he[w, 1] += state_hv[j, 0] # update worker' run time estimate
        state_hv[j, 3] = 1. # mark as scheduled
        
        """ 2) Assure the first job in queue is being processed at this time step. """
        _, req = self._state.edges(etype='precede')
        newidx = [dst[src==w][0].item() for w in src.unique().tolist()]
        newidx = [j for j in newidx if j in req]
        state_hv[newidx, 4] = 1 # set to processing (but completion % remain 0)

        """ 
        3) Update feature vectors:
            a) Progress time for node features: remaining time, completion % for jobs and workers
            b) Update info around terminal jobs, and remove processing edge if job has terminated.
            c) Remove next and precede edges for terminated jobs. 
        """
        # a
        processing_mask = prev_state_hv[:, 4] == 1
        state_hv[processing_mask, 6] = torch.maximum(state_hv[processing_mask, 6]-self._dt, torch.zeros(nj)) # update remaining time
        state_hv[processing_mask, 1] = torch.clamp(1-torch.div(state_hv[processing_mask, 0],
                                                               state_hv[processing_mask, 6]), 
                                                   min=0, max=1) # update completion %
        state_he[:, 1] = torch.maximum(state_he[:, 1]-self._dt, torch.zeros(ne)) # update remaining time
        
        # b
        state_hv[processing_mask, 5] = state_hv[processing_mask, 1] == 1 # mark terminal
        state_hv[processing_mask, 4] = ~state_hv[processing_mask, 5] # if terminal, job no longer processing        
        idx = torch.where(processing_mask)[0][torch.where(state_hv[processing_mask, 5])[0]].tolist() # job ids just terminated
        if len(idx):
            widx = [(j in idx) for j in dst]
            state_he[src[widx], 0] -= 1 # remove job from job count
            self._state.remove_egdes(cnts[widx].tolist(), 'processing') # delete those edges?
        
        # c
        src, dst, cnts = self._state.edges('all', etype='next')
        ptridx = torch.cat([cnts[src == j].item() for j in idx if j in src]) # this works because it is a queue: unique next node
        self._state.remove_egdes(ptridx.tolist(), 'next')
        
        src, dst, cnts = self._state.edges('all', etype='precede')
        jidx = [(j in idx) for j in src]
        self._state.remove_egdes(cnts[jidx].tolist(), 'precede') # delete those edges?
                
        """ 5) Update feature vectors. """
        self._state.nodes['job'].data['hv'] = state_hv
        self._state.nodes['worker'].data['he'] = state_he
                
        """ 6) Compute reward and terminal state. """
        done = self.terminal()
        if done:
            n_terminal = len(idx)
            reward = self._time_penalty + n_terminal
        else:
            reward, success = self.rollout()
        
        return deepcopy(self._state), reward, done, {}

    def reset(self, seed: int = None, topology: str = 'random'):
        if not seed == None:
            super().reset(seed=seed)
        
        nw, nj = self._nworkers, self._njobs
        _from, _to = torch.tensor(get_random_queues(nj)).T
        graph_data = {
           ('job', 'pecede', 'job'): (_from, _to), # A ---before---> B
           ('job', 'next', 'job'): (torch.tensor([0]), torch.tensor([0])), # jobshop queue
           ('worker', 'processing', 'job'): (torch.tensor([0]), torch.tensor([0])) # nothing is scheduled
        }
        
        self._state = dgl.heterograph(graph_data, num_nodes_dict={'worker': nw, 'job': nj})
        # hack: can not init null vector for edges
        self._state.remove_edges(0, 'processing')
        self._state.remove_edges(0, 'next')
        
        times = 0.1*torch.randint(10, (nj,1)) # torch.rand(nj,1)
        counts = count_q_length(_from, _to, nj)
        self._state.nodes['job'].data['hv'] = torch.cat((times, torch.zeros(nj, 1), counts, torch.zeros(nj, 3), times), 1)
        self._state.nodes['worker'].data['he'] = torch.cat((torch.zeros(nw,2), torch.ones(nw,1)), 1)
        
        return deepcopy(self._state)

    def render(self):
        """
        import networkx as nx
        import matplotlib.pyplot as plt
        G = dgl.to_homogeneous(g).to_networkx()
        options = { 'node_color': 'black', 'node_size': 20, 'width': 1,  }
        nx.draw(G, **options)
        """
        pass
        
    def seed(self, n: int):
        super().reset(seed=seed)