In [1]:
from dataclasses import dataclass
import os

import torch 
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.nn import RGCNConv

from custom_modules import Training

2024-08-28 14:01:29.865480: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-28 14:01:29.916098: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-28 14:01:29.916183: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-28 14:01:29.916227: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-28 14:01:29.925586: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-08-28 14:01:29.926487: I tensorflow/core/platform/cpu_feature_guard.cc:182] This Tens

In [2]:
@dataclass
class TrainingConfig:
    output_dir = "RL_PersSched"
    num_epoch = 220000 #10'000 needs 1 hour with 3x CPU, 3GB RAM and 1x 1080ti
    max_steps = 20
    batch_size = 128
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    replay_size = 10000
    eval_every_n_epochs = 20
    lr = 1e-4
    gamma = 0.9

config = TrainingConfig()

## Training

In [3]:
tb_summary = SummaryWriter(config.output_dir, purge_step=0)
os.makedirs(config.output_dir, exist_ok=True)

In [4]:
gnn_employees = RGCNConv(in_channels = (1,7), out_channels=2, num_relations=1).to(config.device)
gnn_shifts = RGCNConv(in_channels = (7,1), out_channels=2, num_relations=1).to(config.device)

In [5]:
optimizer_employees = torch.optim.AdamW(gnn_employees.parameters(), lr=config.lr, amsgrad=True)
optimizer_shifts = torch.optim.AdamW(gnn_shifts.parameters(), lr=config.lr, amsgrad=True)

In [6]:
training = Training(
    gnn_employees, 
    gnn_shifts, 
    optimizer_employees,
    optimizer_shifts,
    tb_summary, 
    device=config.device, 
    gamma=config.gamma,
    max_steps=config.max_steps, 
    num_epoch=config.num_epoch, 
    batch_size=config.batch_size, 
    replay_size=config.replay_size, 
    eval_every_n_epochs=config.eval_every_n_epochs,
    output_dir=config.output_dir
)
training.start_training()

  0%|                                                                           | 219/220000 [02:50<47:39:59,  1.28it/s]

KeyboardInterrupt



## Save as python script

In [None]:
import os

!jupyter nbconvert --to script "RL_env.ipynb"
filename = "RL_env.py"

# delete this cell from python file
lines = []
with open(filename, 'r') as fp:
    lines = fp.readlines()
with open(filename, 'w') as fp:
    for number, line in enumerate(lines):
        if number < len(lines)-17: 
            fp.write(line)