In [None]:
%config Completer.use_jedi = False

In [None]:
import sys
sys.path.append('/home/victorialena/rlkit')

import pdb

In [None]:
import dgl
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from copy import copy, deepcopy
from torch.optim import Adam

In [None]:
from env.job_shop import jobShopScheduling
from actor.job_shop import *
from utils.job_shop import * # custom replay buffer and episode sampler

from path_collector import MdpPathCollector

#### elpers

In [None]:
scientific_notation =  lambda x:"{:.2e}".format(x)

def get_scores(g, scores):
    n = scores.shape[0]
    idx = (g.ndata['hv']['job'][:, 3] == 0).view(n, -1)
    
    values, workers = scores.max(-1, keepdims=False)
    return torch.stack([values[i][idx[i]].max() if sum(idx[i]).item()>0 else torch.tensor(0.) for i in range(n)])

def mean_reward(paths):
    return torch.tensor([p['rewards'] for p in paths]).sum(1).mean().item()

#### Q learning

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
njobs, nworkers = 5, 2
env = jobShopScheduling(njobs, nworkers)
g0 = env.reset()
# env.render()

In [None]:
qf = hgnn()
expl_policy = epsilonGreedyPolicy(qf, .1)

target_qf = hgnn()
eval_policy = epsilonGreedyPolicy(target_qf, 0.)

expl_path_collector = MdpPathCollector(env, expl_policy, rollout_fn=sample_episode, parallelize=False)
eval_path_collector = MdpPathCollector(env, eval_policy, rollout_fn=sample_episode, parallelize=False)

replay_buffer_cap = 2000 #10000
replay_buffer = replayBuffer(replay_buffer_cap, prioritized=True)

```python
path = rollout(env, expl_policy, 2500)
path['terminals'][-1]
env.render()

replay_buffer.add_path(path, env.g)

replay_buffer.random_batch(50)
```

In [None]:
optimizer = Adam(qf.parameters(), lr=1e-4, weight_decay=0.01)
qf_criterion = nn.MSELoss()

max_len = njobs+1
n_samples = 128 
n_epoch = 200
n_iter = 64
batch_size = 32
gamma = 1.0

loss = []
avg_r_train = []
avg_r_eval = []
success_rates = []

In [None]:
for i in range(n_epoch):
    qf.train(False)
    paths = expl_path_collector.collect_new_paths(n_samples, max_len, False)
    train_r = mean_reward(paths)
    avg_r_train.append(train_r)
    replay_buffer.add_paths(paths)
    
    paths = eval_path_collector.collect_new_paths(n_samples//4, max_len, False)
    eval_r = mean_reward(paths)
    avg_r_eval.append(eval_r)
    
    success_rate = np.mean([p['success'] for p in paths])
    success_rates.append(success_rate)

    qf.train(True)
    for _ in range(n_iter):
        batch = replay_buffer.random_batch(batch_size)

        rewards = torch.tensor([b.r for b in batch])
        terminals =torch.tensor([b.d for b in batch]).float()
        actions = torch.tensor([b.a for b in batch])
        
        states = batch_graphs([b.s for b in batch])
        next_s = batch_graphs([b.sp for b in batch])        

        out = target_qf(next_s) # shape = (|G|, |J|, |W|)
        target_q_values = get_scores(next_s, out)
        y_target = rewards + (1. - terminals) * gamma * target_q_values 
        
        out = qf(states)
        y_pred = out[torch.arange(batch_size), actions.T[1], actions.T[0]]
        qf_loss = qf_criterion(y_pred, y_target).to(torch.float)

        loss.append(qf_loss.item())

        optimizer.zero_grad()
        qf_loss.backward()
        optimizer.step()

    target_qf.load_state_dict(deepcopy(qf.state_dict()))
    err = 8
    print("epoch", i+1, #"| lr:", scientific_notation(optimizer.param_groups[0]["lr"]) ,
          " -> loss:", round(np.mean(loss[-n_iter:]), err),
          "| rewards: (train)", round(train_r, err), "(test)", round(eval_r, err),
          "| success rate:", round(success_rate, err))

```python
g1 = env.reset()
qf.eval()
print(g1.edges(etype='precede'))
print(qf(g1))

g2, r, d, info = env2.step((0, 1))

# paths = eval_path_collector.collect_new_paths(10, 3)
```

#### Plot

In [None]:
import matplotlib.pyplot as plt

In [None]:
losses = [np.mean(loss[i*n_iter:(i+1)*n_iter]) for i in range(n_epoch)]

In [None]:
x = np.arange(n_epoch)

plt.figure(figsize=(15, 4))

plt.subplot(131)
plt.plot(x, losses)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.subplot(132)
plt.plot(x, avg_r_train, label="test")
plt.plot(x, avg_r_eval, 'r--', label="eval")
plt.legend()
plt.ylabel('Train/Test Rewards [path, avg]')
plt.xlabel('Epoch')
plt.subplot(133)
plt.plot(x, success_rates)
plt.ylabel('Success Rate')
plt.xlabel('Epoch')
plt.suptitle('Training Performance Summary')
plt.savefig('figs/job_shop/j%d-w%d_hor' % (njobs, nworkers), dpi=300)

In [None]:
ylabels = ['Loss', 'Train/Test Rewards', 'Success Rate']

fig, axs = plt.subplots(3, 1, figsize=(6, 12))
plt.suptitle('Training Performance Summary', y=.92)
axs[0].plot(x, losses)
axs[1].plot(x, avg_r_train, label="train")
axs[1].plot(x, avg_r_eval, 'r--', label="eval")
axs[1].legend()
axs[2].plot(x, success_rates)

for label, ax in zip(ylabels, axs.flat):
    ax.set(xlabel='Epoch', ylabel=label)

# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axs.flat:
    ax.label_outer()
plt.savefig('figs/job_shop/j%d-w%d_ver' % (njobs, nworkers), dpi=300)