In [1]:
!pip install gymnasium torch_geometric



In [2]:
import torch

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Reinforcement Learning Environment

In [27]:
# RL env bauen
from collections import Counter
import copy
import itertools
from typing import Tuple

import gymnasium as gym
from torch_geometric.data import HeteroData

class PersonnelScheduleEnv(gym.Env):
    def __init__(self, employees: torch.tensor, shifts: torch.tensor, assignments: torch.tensor, device) -> None:        
        # setup graph
        self.device = device
        self.initial_state=HeteroData()
        self.initial_state["employee"].x = employees
        self.initial_state["shift"].x = shifts
        self.initial_state["employee", "assigned", "shift"].edge_index = assignments
        self.initial_state = self.initial_state.to(self.device)
        if self.initial_state.validate(raise_on_error=True):
            print("Data validation successful") 
        else: 
            raise ValueError("Data validation not successful")
        print("The graph has ", self.initial_state.num_nodes, " nodes")
        print("The graph has ", self.initial_state.num_edges, " edges")

        self.edge_space = list(itertools.product(range(employees.shape[0]),range(shifts.shape[0])))
        self.action_space = gym.spaces.Discrete(len(self.edge_space))
        self.num_employees = self.initial_state["employee"].x.shape[0]
        self.num_shifts = self.initial_state["shift"].x.shape[0]

    def _get_num_consecutive_violations(self) -> int:
        planning = self.get_current_planning()
        num_consecutive_violations = 0
        for i in range(self.num_shifts-1): 
            num_consecutive_violations = num_consecutive_violations + (Counter(planning[i])-(Counter(planning[i])-Counter(planning[i+1]))).total()  
        return num_consecutive_violations

    def idx2edge(self, idx: int) -> Tuple[int,int]:
        return self.edge_space[idx]

    def edge2idx(self, edge: Tuple[int,int]) -> int:
        return self.edge_space.index(edge)

    def get_num_planned_shifts(self) -> int:
        return self.state["assigned"].edge_index.shape[1]
            
    def get_current_planning(self) -> dict:
        planning = dict() 
        for i in range(self.num_shifts):
            planning[i] = list()
        for i in torch.arange(self.state["assigned"].edge_index.shape[1]): 
            planning[self.state["assigned"].edge_index[:, i][1].item()].append(self.state["assigned"].edge_index[:, i][0].item())
        return planning

    def info(self) -> dict:
        return {}
        
    def step(self, action: int) -> tuple[HeteroData, float, bool, bool, dict]:
        if not self.action_space.contains(action): 
            return (self.state, self.reward(), self.terminated(), self.truncated(), self.info())

        action_edge = torch.tensor(self.idx2edge(action))
        mask = torch.ones(self.state["assigned"].edge_index.shape[1], dtype=torch.bool)
        # check if shift exists 
        for i in torch.arange(self.state["assigned"].edge_index.shape[1]): 
            if torch.equal(self.state["assigned"].edge_index[:,i], action_edge):
                mask[i] = False
        # remove # if removing is allowed
        #self.state["assigned"].edge_index = self.state["assigned"].edge_index[:,mask] 
        
        # if no shift exists, create one
        if not (~mask).any():
            self.state["assigned"].edge_index = torch.hstack((self.state["assigned"].edge_index, action_edge[:,None]))
            
        return (copy.deepcopy(self.state), self.reward(), self.terminated(), self.truncated(), self.info())

    def reset(self, seed: int = None) -> tuple[HeteroData, dict]:
        super().reset(seed=seed)
        self.state = self.initial_state
        return (self.state, self.info())

    def reward(self) -> float:
        reward = 0
        reward = reward + self.get_num_planned_shifts()
        reward = reward - 10*self._get_num_consecutive_violations()
        if self.terminated():
            reward = reward + 1000
        return reward

    def terminated(self) -> bool:
        terminated = True
        planning = self.get_current_planning()
        for i in range(self.num_shifts-1): 
            if len(planning[i])<2:
                terminated = False
                break
        return terminated

    def truncated(self) -> bool:
        return False

In [4]:
import heapq

from torch import distributions
from torch_geometric.nn import RGCNConv 

class RL_agent():
    def __init__(self, gnn_employees, gnn_shifts): 
        self.gnn_employees = gnn_employees
        self.gnn_shifts = gnn_shifts

    def decode(self, emb_employees, emb_shifts):
        num_employees = emb_employees.shape[0]
        num_shifts = emb_shifts.shape[0]

        expanded_shifts = emb_shifts.repeat((num_employees,1))
        expanded_employees = emb_employees.repeat_interleave(num_shifts, dim=0)
        dot_products = expanded_shifts.mul(expanded_employees).squeeze().sum(dim=1)

        #employees_idx = torch.arange(num_employees).repeat_interleave(num_shifts, dim=0)
        #shifts_idx = torch.arange(num_shifts).repeat(num_employees)
        #edge_list = torch.stack((employees_idx, shifts_idx), dim=1)  
        
        return dot_products

    def encode(self, state):
        emb_shifts = gnn_employees(x=(state.x_dict["employee"], state.x_dict["shift"]), edge_index=state["assigned"]["edge_index"], edge_type=torch.zeros(state["assigned"]["edge_index"].shape[1], dtype=torch.int64))

        assignments_flipped=torch.vstack((state["assigned"]["edge_index"][1], state["assigned"]["edge_index"][0])) 
        emb_employees = gnn_shifts(x=(state.x_dict["shift"], state.x_dict["employee"]), edge_index=assignments_flipped, edge_type=torch.zeros(state["assigned"]["edge_index"].shape[1], dtype=torch.int64))

        return emb_employees, emb_shifts
        
    def get_policy(self, state):
        logits = self.decode(*self.encode(state)) 
        policy_distribution = distributions.categorical.Categorical(logits=logits)
        return policy_distribution

In [5]:
"""
import networkx as nx 
from torch_geometric.utils import to_networkx

g=to_networkx(env.state)
nx.draw(g)"""

'\nimport networkx as nx \nfrom torch_geometric.utils import to_networkx\n\ng=to_networkx(env.state)\nnx.draw(g)'

## Training environment

In [6]:
import random

import torch
def get_random_employees(max_n=5):
    n = random.randint(1,max_n)
    return torch.tensor([[1]], dtype=torch.float).expand((n,1)) 
def get_week_shifts():
    shifts = list()
    shifts.append(torch.tensor([0,0,0,0,0,0,0])) # Monday, day
    shifts.append(torch.tensor([0,0,0,0,0,0,1])) # Monday, night
    shifts.append(torch.tensor([1,0,0,0,0,0,0])) # Tuesday, day
    shifts.append(torch.tensor([1,0,0,0,0,0,1])) # Tuesday, night
    shifts.append(torch.tensor([0,1,0,0,0,0,0])) # Wednesday, day
    shifts.append(torch.tensor([0,1,0,0,0,0,1])) # Wednesday, night
    shifts.append(torch.tensor([0,0,1,0,0,0,0])) # Thursday, day
    shifts.append(torch.tensor([0,0,1,0,0,0,1])) # Thursday, night
    shifts.append(torch.tensor([0,0,0,1,0,0,0])) # Friday, day
    shifts.append(torch.tensor([0,0,0,1,0,0,1])) # Friday, night
    shifts.append(torch.tensor([0,0,0,0,1,0,0])) # Saturday, day
    shifts.append(torch.tensor([0,0,0,0,1,0,1])) # Saturday, night
    shifts.append(torch.tensor([0,0,0,0,0,1,0])) # Sunday, day
    shifts.append(torch.tensor([0,0,0,0,0,1,1])) # Sunday, night
    shifts = torch.vstack(shifts)
    shifts = shifts.to(torch.float)
    return shifts
def get_empty_assignments():
    return torch.tensor([[],[]], dtype=torch.int64)

In [12]:
class Actor:
    def __init__(self, agent, max_steps = 20, env = None):
        self.agent = agent
        self.max_steps = max_steps
        if env != None:
            self.env = env 
        else:
            self.env = PersonnelScheduleEnv(employees=get_random_employees(), shifts=get_week_shifts(), assignments=get_empty_assignments()) 

    def sample_episode(self):
        state, _ = self.env.reset()
        terminated = False
        i=0
        while not terminated and i < self.max_steps: 
            _, state, _, _ = self._step(state)
            i = i + 1

    def _step(self, current_state):  
        action = self.agent.get_policy(current_state).sample().item()
        
        state, reward, terminated, _, _ = self.env.step(action)
        return action, state, reward, terminated
        

In [8]:
from collections import namedtuple

# Experience Replay for RL
Transition = namedtuple("Transition", ("state", "action", "reward", "future_returns"))

class ReplayMemory():
    def __init__(self, capacity):
        self.memory=deque([], maxlen=capacity)
    def push(self, transition):
        self.memory.append(transition)
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)

In [9]:
from collections import deque

class MemoryWrapper():
    def __init__(self, actor, memory):
        self.actor = actor
        self.memory = memory

    def sample_episode(self):
        state, _ = self.actor.env.reset()
        terminated = False
        steps = []
        i=0
        while not terminated and i < self.actor.max_steps:
            action, new_state, reward, terminated = self.actor._step(state)
            steps.append((state, action, reward))
            state = new_state
            i = i + 1
        transitions = []
        future_returns=0
        for step in reversed(steps):
            future_returns = future_returns + step[2]
            transitions.append(Transition(*step, future_returns)) 
        
        for transition in transitions: 
            self.memory.push(transition) 

In [29]:
from tqdm import tqdm
from torch import distributions

class Training():
    def __init__(self, gnn_employees, gnn_shifts, tensorboard, device, max_steps=20, num_epoch=100, batch_size=128):
        self.gnn_employees = gnn_employees
        self.gnn_shifts = gnn_shifts 
        self.tensorboard = tensorboard
        self.device = device
        self.num_epoch = num_epoch
        self.batch_size = batch_size
        self.memory = ReplayMemory(10000)
        env = PersonnelScheduleEnv(employees=get_random_employees(), shifts=get_week_shifts(), assignments=get_empty_assignments(), device=device)
        self.agent = RL_agent(gnn_employees, gnn_shifts) 
        self.actor = MemoryWrapper(Actor(self.agent, max_steps=max_steps, env=env), self.memory)
    
    def _gradient_update(self):
        if len(self.memory) < self.batch_size:
            #print("len(self.memory): ", len(self.memory))
            return 

        transitions = self.memory.sample(self.batch_size)
        #converts batch_array of Transitions to Transition of batch_arrays
        batch = Transition(*zip(*transitions))
        
        state_batch = batch.state
        action_batch = torch.tensor(batch.action, dtype=torch.int).to(self.device)
        reward_batch = torch.tensor(batch.reward, dtype=torch.float).to(self.device)
        future_returns_batch = torch.tensor(batch.future_returns, dtype=torch.float).to(self.device)

        logits_list = list()
        for i in torch.arange(self.batch_size):
            logits = self.agent.decode(*self.agent.encode(state_batch[i]))
            logits_list.append(logits)
        logits_batch = torch.vstack(logits_list).to(self.device)

        policy_distribution = distributions.categorical.Categorical(logits=logits_batch) 
        log_probs = policy_distribution.log_prob(action_batch)

        loss = -(log_probs * future_returns_batch).sum() / self.batch_size
        loss.backward()

        self.tensorboard.add_scalar("loss", loss.detach().cpu(), self.epoch)
        self.tensorboard.add_scalar("avg_reward", reward_batch.mean().detach().cpu(), self.epoch)
        self.tensorboard.add_scalar("avg_future_returns", future_returns_batch.mean().detach().cpu(), self.epoch)
        
    def start_training(self):
        for epoch in tqdm(range(self.num_epoch)):
            self.epoch = epoch
            self.actor.sample_episode() 
            self._gradient_update()

## Test suite

In [30]:
from torch.utils.tensorboard import SummaryWriter
tb_summary = SummaryWriter("./", purge_step=0)

In [31]:
gnn_employees = RGCNConv(in_channels = (1,7), out_channels=2, num_relations=1).to(device)
gnn_shifts = RGCNConv(in_channels = (7,1), out_channels=2, num_relations=1).to(device)
training = Training(gnn_employees, gnn_shifts, tb_summary, device, num_epoch=10000)
training.start_training()

Data validation successful
The graph has  16  nodes
The graph has  0  edges


  0%|▏                                                                             | 31/10000 [00:21<1:52:41,  1.47it/s]


KeyboardInterrupt: 

In [None]:
sorted(training.actor.actor.env.edge_space)

In [None]:
test1,test2=agent.encode(state)
print(test1.shape)
print(test2.shape)

In [None]:
agent = RL_agent(gnn_employees, gnn_shifts)
env = PersonnelScheduleEnv(employees=get_random_employees(), shifts=get_week_shifts(), assignments=get_empty_assignments()) 
actor = Actor(agent, max_steps=20, env=env) 

state, _ = env.reset()
print("state: ", state) 
print("encode state: ", agent.encode(state))

scores = agent.decode_sequential(*agent.encode(state))
print("scores sequential: ", scores)
print("best score: ", heapq.heappop(scores)[1])
scores, actions = agent.decode_matrix(*agent.encode(state))
print("scores matrix: ", scores)
print("correspondig actions: ", actions)
print("best action: ", actions[scores.argmax()])

action = agent.get_action(state)
print("action: ", action)
state, reward, terminated, _ , _ = env.step(action)
print("state: ", state)
print("reward: ", reward)
print("terminated: ", terminated)

scores = agent.decode_sequential(*agent.encode(state))
print("scores: ", scores)

action = agent.get_action(state)
print("action: ", action)
state, reward, terminated, _ , _ = env.step(action)
print("state: ", state)
print("reward: ", reward)
print("terminated: ", terminated)

print(actor.sample_episode())

memory = ReplayMemory(100)
actor = MemoryWrapper(actor, memory)
actor.sample_episode()
print(memory.memory)

## Save as python script

In [32]:
import os

!jupyter nbconvert --to script "RL_env.ipynb"
filename = "RL_env.py"

# delete this cell from python file
lines = []
with open(filename, 'r') as fp:
    lines = fp.readlines()
with open(filename, 'w') as fp:
    for number, line in enumerate(lines):
        if number < len(lines)-17: 
            fp.write(line)

[NbConvertApp] Converting notebook RL_env.ipynb to script
[NbConvertApp] Writing 14066 bytes to RL_env.py
