In [1]:
from dataclasses import dataclass
import os

import torch 
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn import RGCNConv

from custom_modules import Training

2024-09-02 14:15:50.481714: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-02 14:15:50.623316: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-02 14:15:50.623367: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-02 14:15:50.623402: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-02 14:15:50.641007: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-02 14:15:50.643049: I tensorflow/core/platform/cpu_feature_guard.cc:182] This Tens

In [2]:
@dataclass
class TrainingConfig:
    output_dir = "RL_PersSched"
    num_epoch = 220000 #10'000 needs 1 hour with 3x CPU, 3GB RAM and 1x 1080ti
    max_steps = 20
    batch_size = 128
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    replay_size = 10000
    eval_every_n_epochs = 20
    lr = 1e-4
    gamma = 0.99
    hidden_dim = 3
    num_message_passing = 4
    clip_norm = 100

config = TrainingConfig()

## Training

In [3]:
tb_summary = SummaryWriter(config.output_dir, purge_step=0)
os.makedirs(config.output_dir, exist_ok=True)

In [7]:
gnn = RGCNConv(in_channels = (config.hidden_dim, config.hidden_dim), out_channels=config.hidden_dim, num_relations=1).to(config.device)
dim_employee = 1
projection_employees = torch.nn.Linear(dim_employee, config.hidden_dim).to(config.device)
dim_shift = 7
projection_shifts = torch.nn.Linear(dim_shift, config.hidden_dim).to(config.device)

In [24]:
from custom_modules import ReplayMemory, ActorMemoryWrapper, DataGenerator, PersonnelScheduleEnv, RLagent, Actor

env = PersonnelScheduleEnv(
    employees=DataGenerator.get_random_employees(), 
    shifts=DataGenerator.get_week_shifts(), 
    assignments=DataGenerator.get_empty_assignments(), 
    device=config.device
    )
env.reset()
agent = RLagent(gnn, projection_employees, projection_shifts, 4) 
actor = Actor(agent, env=env, max_steps=20)

In [25]:
agent.get_policy(env.state).probs

tensor([0.0710, 0.0712, 0.0707, 0.0709, 0.0713, 0.0716, 0.0710, 0.0713, 0.0719,
        0.0721, 0.0713, 0.0716, 0.0719, 0.0722], grad_fn=<SoftmaxBackward0>)

In [26]:
state, reward, terminated, truncated, info = env.step(1)

In [27]:
agent.get_policy(env.state).probs

tensor([0.0289, 0.0000, 0.0282, 0.0705, 0.0263, 0.0657, 0.0210, 0.0524, 0.0683,
        0.1709, 0.0357, 0.0893, 0.0979, 0.2449], grad_fn=<SoftmaxBackward0>)

In [28]:
state, reward, terminated, truncated, info = env.step(9)

In [29]:
agent.get_policy(env.state).probs

tensor([0.0361, 0.0000, 0.0353, 0.0861, 0.0328, 0.0799, 0.0264, 0.0643, 0.0828,
        0.0000, 0.0442, 0.1078, 0.1176, 0.2867], grad_fn=<SoftmaxBackward0>)

In [30]:
state, reward, terminated, truncated, info = env.step(13)

In [31]:
agent.get_policy(env.state).probs

tensor([0.0509, 0.0000, 0.0497, 0.1201, 0.0463, 0.1120, 0.0373, 0.0902, 0.1163,
        0.0000, 0.0623, 0.1505, 0.1645, 0.0000], grad_fn=<SoftmaxBackward0>)

In [5]:
optimizer = torch.optim.AdamW(
    list(gnn.parameters()) + list(projection_employees.parameters()) + list(projection_shifts.parameters()), 
    lr=config.lr, 
    amsgrad=True
    ) 

In [6]:
training = Training(
    gnn, 
    optimizer, 
    projection_employees,
    projection_shifts,
    tb_summary, 
    device=config.device,
    num_message_passing=config.num_message_passing, 
    gamma=config.gamma,
    max_steps=config.max_steps, 
    num_epoch=config.num_epoch, 
    batch_size=config.batch_size, 
    replay_size=config.replay_size, 
    eval_every_n_epochs=config.eval_every_n_epochs,
    output_dir=config.output_dir,
    clip_norm=config.clip_norm
)
training.start_training()

  0%|                                                                                        | 0/220000 [00:00<?, ?it/s]

False
52
0
False
55
0
False
55
0
False
53
0
False
43
0
False
4
0
False
55
0
False
10
0
False
55
0
False
46
0
False
4
0
False
2
0
False
4
0
False
13
0
False
43
0
False
11
0
False
50
0
False
55
0
False
55
0
False
13
0
False
3
0
False
7
0
False
3
0
False
5
0
False
5
0
False
5
0
False
5
0
False
3
0
False
5
0
False
7
0
False
12
0
False
3
0
False
5
0
False
12
0
False
3
0
False
13
0
False
12
0
False
12
0
False
3
0
False
12
0
False
10
0
False
13
0
False
3
0
False
10
0
False
18
0
False
18
0
False
27
0
False
17
0
False
24
0
False
3
0
False
13
0
False
19
0
False
17
0
False
25
0
False
18
0
False
18
0
False
13
0
False
9
0
False
1
0
False
13
0


  0%|                                                                             | 2/220000 [00:01<33:12:48,  1.84it/s]

False
20
0
False
20
0
False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
59
0
False
69
0
False
62
0
False
69
0
False
59
0
False
27
0
False
27
0
False
61
0
False
69
0
False
17
0
False
47
0
False
47
0
False
47
0
False
47
0
False
47
0
False
47
0
False
61
0
False
61
0
False
47
0
False
61
0
False
47
0
False
47
0
False
47
0
False
61
0
False
54
0
False
54
0
False
61
0
False
47
0
False
68
0
False
48
0


  0%|                                                                             | 3/220000 [00:01<25:07:10,  2.43it/s]

False
32
0
False
54
0
False
46
0
False
31
0
False
32
0
False
39
0
False
31
0
False
46
0
False
45
0
False
39
0
False
41
0
False
54
0
False
41
0
False
24
0
False
38
0
False
52
0
False
22
0
False
36
0
False
36
0
False
50
0


  0%|                                                                             | 4/220000 [00:01<23:33:31,  2.59it/s]

False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
27
0
False
56
0
False
14
0
False
14
0
False
68
0
False
61
0
False
27
0
False
27
0
False
61
0
False
56
0
False
69
0
False
2
0


  0%|                                                                             | 5/220000 [00:02<20:01:04,  3.05it/s]

False
32
0
False
67
0
False
39
0
False
58
0
False
41
0
False
30
0
False
30
0
False
68
0
False
40
0
False
68
0
False
69
0
False
41
0
False
40
0
False
41
0
False
17
0
False
27
0
False
8
0
False
22
0
False
64
0
False
31
0


  0%|                                                                             | 6/220000 [00:02<19:00:32,  3.21it/s]

False
47
0
False
57
0
False
43
0
False
55
0
False
58
0
False
44
0
False
44
0
False
44
0
False
57
0
False
69
0
False
57
0
False
68
0
False
47
0
False
48
0
False
47
0
False
55
0
False
47
0
False
32
0
False
44
0
False
57
0
False
1
0
False
3
0
False
1
0
False
51
0
False
9
0
False
9
0
False
46
0
False
7
0
False
46
0
False
45
0
False
51
0
False
1
0
False
43
0
False
49
0
False
45
0
False
51
0
False
46
0
False
13
0
False
46
0
False
3
0


  0%|                                                                             | 7/220000 [00:04<51:27:20,  1.19it/s]

False
8
0
False
5
0
False
5
0
False
8
0
False
5
0
False
5
0
False
13
0
False
5
0
False
5
0
False
5
0
False
13
0
False
8
0
False
13
0
False
5
0
False
5
0
False
12
0
False
12
0
False
12
0
False
34
0
False
6
0


  0%|                                                                            | 8/220000 [00:09<130:22:36,  2.13s/it]

False
49
0
False
55
0
False
55
0
False
49
0
False
55
0
False
33
0
False
35
0
False
23
0
False
51
0
False
33
0
False
55
0
False
55
0
False
41
0
False
37
0
False
33
0
False
37
0
False
48
0
False
48
0
False
55
0
False
55
0


  0%|                                                                            | 9/220000 [00:13<163:18:27,  2.67s/it]

False
17
0
False
40
0
False
40
0
False
62
0
False
34
0
False
20
0
False
34
0
False
61
0
False
17
0
False
19
0
False
40
0
False
26
0
False
68
0
False
40
0
False
64
0
False
40
0
False
40
0
False
40
0
False
40
0
False
36
0


  0%|                                                                           | 10/220000 [00:17<203:45:16,  3.33s/it]

False
29
0
False
29
0
False
29
0
False
34
0
False
22
0
False
15
0
False
22
0
False
30
0
False
22
0
False
16
0
False
30
0
False
29
0
False
29
0
False
1
0
False
27
0
False
34
0
False
40
0
False
26
0
False
33
0
False
34
0


  0%|                                                                           | 11/220000 [00:19<174:39:25,  2.86s/it]

False
25
0
False
19
0
False
22
0
False
19
0
False
17
0
False
61
0
False
59
0
False
64
0
False
64
0
False
19
0
False
25
0
False
64
0
False
19
0
False
64
0
False
67
0
False
25
0
False
61
0
False
61
0
False
19
0
False
59
0


  0%|                                                                           | 11/220000 [00:20<113:09:17,  1.85s/it]

KeyboardInterrupt



## Save as python script

In [None]:
import os

!jupyter nbconvert --to script "RL_env.ipynb"
filename = "RL_env.py"

# delete this cell from python file
lines = []
with open(filename, 'r') as fp:
    lines = fp.readlines()
with open(filename, 'w') as fp:
    for number, line in enumerate(lines):
        if number < len(lines)-17: 
            fp.write(line)