In [1]:
%config Completer.use_jedi = False

In [2]:
import sys
sys.path.append('/home/victorialena/rlkit')

import pdb

In [3]:
import dgl
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import copy, deepcopy
from torch.optim import Adam

In [4]:
from env.job_shop import jobShopScheduling
from actor.job_shop import *
from utils.job_shop import * # custom replay buffer and episode sampler
from utils.jsp_makespan import *

from path_collector import MdpPathCollector

  from collections import MutableMapping
  from collections import Sized


In [5]:
class hgnn(nn.Module):
    def __init__(self, embedding_dim=16, k=2):
        super().__init__()
        
        self.embedding = dglnn.HeteroLinear({'job': 7, 'worker':3}, embedding_dim)
        self.conv = dglnn.HeteroGraphConv({
            'precede' : dglnn.GraphConv(embedding_dim, embedding_dim),
            'next' : dglnn.GraphConv(embedding_dim, embedding_dim),
            'processing' : dglnn.SAGEConv((embedding_dim, embedding_dim), embedding_dim, 'mean')},
            aggregate='sum')
               
        self.pred = dotProductPredictor()
        self.num_loops = k
        
    def forward(self, g):
        h0 = {**g.ndata['hv'], **g.ndata['he']}
        hv = self.embedding(h0)
        hw = hv['worker']
        for _ in range(self.num_loops):
            hv = {'job': self.conv(g, hv)['job'],
                  'worker': hw}
            
        rg = construct_readout_graph(g, ('worker', 'processing', 'job'))
        return self.pred(rg, hv['job'], hw, ('worker', 'processing', 'job'))    

#### Helpers

In [6]:
scientific_notation =  lambda x:"{:.2e}".format(x)

def get_scores(g, scores):
    n = scores.shape[0]
    idx = (g.ndata['hv']['job'][:, 3] == 0).view(n, -1)
    
    values, workers = scores.max(-1, keepdims=False)
    return torch.stack([values[i][idx[i]].max() if sum(idx[i]).item()>0 else torch.tensor(0.) for i in range(n)])

def mean_reward(paths):
    return torch.tensor([p['rewards'] for p in paths]).sum(1).mean().item()

def mean_makespan(paths):
    "Returns the average makespan successful paths from given list. Returns *nan* if no path was successful."
    return torch.tensor([p['makespan'] for p in paths if p['success']]).mean().item()

def relative_makespan_error(paths):
    """ From initial conditions of each path, evaluate optimal makespan for each path and compare against
    sampled trajectory. """
    err = []
    for p in paths:
        if not p['success']:
            continue
        jdata = g2jobdata(p['observations'][0], p['actions'])
        makespan, status = get_makespan(jdata)
        if makespan > -1: # feasible
            relative_error = p['makespan']/makespan - 1
            err.append(relative_error)
        else: 
            print("This should not be possible, check for bugs!!")
            
    return torch.tensor(err).mean().item()

#### Q Learning

In [7]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

In [8]:
njobs, nworkers = 25, 10 # 10, 2 #5, 2
env = jobShopScheduling(njobs, nworkers)
g0 = env.reset()

In [9]:
layersz = 16

qf = hgnn(layersz)
expl_policy = epsilonGreedyPolicy(qf, .1)

target_qf = hgnn(layersz)
eval_policy = epsilonGreedyPolicy(target_qf, 0.)

expl_path_collector = MdpPathCollector(env, expl_policy, rollout_fn=sample_episode, parallelize=False)
eval_path_collector = MdpPathCollector(env, eval_policy, rollout_fn=sample_episode, parallelize=False)

replay_buffer_cap = 5000 #10000
replay_buffer = replayBuffer(replay_buffer_cap, prioritized=True)

```python
path = rollout(env, expl_policy, 2500)
path['terminals'][-1]
env.render()

replay_buffer.add_path(path, env.g)

replay_buffer.random_batch(50)
```

In [10]:
learning_rate = 8e-5 #1e-4

In [11]:
optimizer = Adam(qf.parameters(), lr=learning_rate, weight_decay=0.01)
qf_criterion = nn.MSELoss()

max_len = njobs+1
n_samples = 128 
n_epoch = 400 #300 #425 #300 #200
n_iter = 64
batch_size = 64 #32
gamma = 1.0

loss = []
avg_r_train = []
avg_r_eval = []
success_rates = []
relative_errors = []

In [None]:
for i in range(n_epoch):
    qf.train(False)
    paths = expl_path_collector.collect_new_paths(n_samples, max_len, False)
    train_r = mean_reward(paths)
    avg_r_train.append(train_r)
    replay_buffer.add_paths(paths)
    
    paths = eval_path_collector.collect_new_paths(n_samples//4, max_len, False)
    eval_r = mean_reward(paths)
    avg_r_eval.append(eval_r)
    
    success_rate = np.mean([p['success'] for p in paths])
    success_rates.append(success_rate)
    
    avg_makespan = mean_makespan(paths)
    relative_err = relative_makespan_error(paths)
    relative_errors.append(relative_err)

    qf.train(True)
    if i==200:
        expl_policy.eps = 0.05
        
    for _ in range(n_iter):
        batch = replay_buffer.random_batch(batch_size)

        rewards = torch.tensor([b.r for b in batch])
        terminals = torch.tensor([b.d for b in batch]).float()
        actions = torch.tensor([b.a for b in batch])
        
        states = batch_graphs([b.s for b in batch])
        next_s = batch_graphs([b.sp for b in batch])        

        out = target_qf(next_s) # shape = (|G|, |J|, |W|)
        target_q_values = get_scores(next_s, out)
        y_target = rewards + (1. - terminals) * gamma * target_q_values 
        
        out = qf(states)
        y_pred = out[torch.arange(batch_size), actions.T[1], actions.T[0]]
        qf_loss = qf_criterion(y_pred, y_target).to(torch.float)

        loss.append(qf_loss.item())

        optimizer.zero_grad()
        qf_loss.backward()
        optimizer.step()

    target_qf.load_state_dict(deepcopy(qf.state_dict()))
    err = 3
    print("Epoch", i+1, #"| lr:", scientific_notation(optimizer.param_groups[0]["lr"]) ,
          " -> Loss:", round(np.mean(loss[-n_iter:]), err),
          "| Rewards: (train)", round(train_r, err), "(test)", round(eval_r, err),
          "| Success rate:", round(success_rate, err), 
          "| Makespan:", round(avg_makespan, err), 
          "| Rel. error:", round(relative_err, err), )



Epoch 1  -> Loss: 719.501 | Rewards: (train) 11.893 (test) -0.44 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 2  -> Loss: 186.599 | Rewards: (train) 3.998 (test) 4.06 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 3  -> Loss: 212.482 | Rewards: (train) 4.606 (test) 4.101 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 4  -> Loss: 229.917 | Rewards: (train) 4.692 (test) 3.848 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 5  -> Loss: 108.61 | Rewards: (train) 4.518 (test) 3.76 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 6  -> Loss: 98.464 | Rewards: (train) 4.08 (test) 4.561 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 7  -> Loss: 128.805 | Rewards: (train) 3.74 (test) 2.693 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 8  -> Loss: 87.335 | Rewards: (train) 3.84 (test) 3.118 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 9  -> Loss: 80.239 | Rewards: (train) 3.634 (test) 3.483 

Epoch 71  -> Loss: 7.238 | Rewards: (train) 2.088 (test) 1.58 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 72  -> Loss: 4.598 | Rewards: (train) 2.568 (test) 2.309 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 73  -> Loss: 4.137 | Rewards: (train) 3.218 (test) 3.231 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 74  -> Loss: 5.183 | Rewards: (train) 2.274 (test) 2.425 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 75  -> Loss: 5.511 | Rewards: (train) 3.096 (test) 2.523 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 76  -> Loss: 3.054 | Rewards: (train) 2.978 (test) 3.316 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 77  -> Loss: 14.732 | Rewards: (train) 2.768 (test) 2.301 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 78  -> Loss: 3.952 | Rewards: (train) 2.185 (test) 1.198 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 79  -> Loss: 6.352 | Rewards: (train) 2.633 (test) 2.145 |

Epoch 141  -> Loss: 27.328 | Rewards: (train) 5.111 (test) 6.23 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 142  -> Loss: 0.633 | Rewards: (train) 6.3 (test) 7.172 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 143  -> Loss: 20.609 | Rewards: (train) 6.228 (test) 6.777 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 144  -> Loss: 2.054 | Rewards: (train) 6.506 (test) 7.835 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 145  -> Loss: 31.864 | Rewards: (train) 6.632 (test) 9.119 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 146  -> Loss: 0.528 | Rewards: (train) 5.177 (test) 7.627 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 147  -> Loss: 59.439 | Rewards: (train) 4.634 (test) 6.657 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 148  -> Loss: 41.436 | Rewards: (train) 11.634 (test) 14.349 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 149  -> Loss: 119.937 | Rewards: (train) 11.43

Epoch 210  -> Loss: 4.293 | Rewards: (train) 4.063 (test) 4.368 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 211  -> Loss: 2.091 | Rewards: (train) 4.785 (test) 3.998 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 212  -> Loss: 1.219 | Rewards: (train) 4.406 (test) 4.202 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 213  -> Loss: 25.549 | Rewards: (train) 4.093 (test) 4.552 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 214  -> Loss: 18.006 | Rewards: (train) 4.536 (test) 4.571 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 215  -> Loss: 54.431 | Rewards: (train) 4.771 (test) 5.013 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 216  -> Loss: 3.433 | Rewards: (train) 4.52 (test) 4.683 | Success rate: 0.0 | Makespan: nan | Rel. error: nan
Epoch 217  -> Loss: 6.269 | Rewards: (train) 3.961 (test) 4.487 | Success rate: 0.0 | Makespan: nan | Rel. error: nan


In [None]:
target_qf.eval()
torch.save(target_qf.state_dict(), "jobshop_qf_j%d-w%d_4x4_multilayer" % (njobs, nworkers))

#### Plot

In [None]:
import matplotlib.pyplot as plt

In [None]:
losses = [np.mean(loss[i*n_iter:(i+1)*n_iter]) for i in range(n_epoch)]

In [None]:
x = np.arange(n_epoch)

plt.figure(figsize=(15, 10))

plt.subplot(221)
plt.plot(x, losses)QQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.subplot(222)
plt.plot(x, avg_r_train, label="test")
plt.plot(x, avg_r_eval, 'r--', label="eval")
plt.legend()
plt.ylabel('Train/Test Rewards [path, avg]')
plt.xlabel('Epoch')
plt.subplot(223)
plt.plot(x, success_rates)
plt.ylabel('Success Rate')
plt.xlabel('Epoch')
plt.subplot(224)
plt.plot(x, [0.0]*len(x), 'lightgray', linestyle='--')
plt.plot(x, relative_errors)
plt.ylabel('Relative Error')
plt.xlabel('Epoch')
plt.suptitle('Training Performance Summary', y=.95)
plt.savefig('figs/job_shop/j%d-w%d_4x4_multilayer' % (njobs, nworkers), dpi=300)