# Map-related methods

## Import necessary packages

In [44]:
# Import relevant libraries

%reload_ext autoreload
%autoreload 2

import sys
sys.path.append('src/')

import numpy as np
import os
import pandas as pd
import math
from ast import literal_eval
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import time
import random
import seaborn as sns
import collections
from collections import deque
from typing import List, Optional, Tuple, Union, Callable, Dict, Sequence, NamedTuple
from pathlib import Path

from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import SparseLineGen
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.envs.rail_env import RailEnvActions
from training import train_agent

# Base flatland environment
from flatland.envs.line_generators import SparseLineGen
from flatland.envs.malfunction_generators import (
    MalfunctionParameters,
    ParamMalfunctionGen,
)
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import SparseRailGen
from flatland.envs.observations import GlobalObsForRailEnv

from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.distance_map import DistanceMap
import flatland.envs.rail_env_shortest_paths as sp

from src import test_utils, training, rewards
from src.observation_utils import normalize_observation
from src.models import *
from src.deep_model_policy import DeepPolicy, PolicyParameters

from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position, direction_to_point
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.step_utils.states import TrainState

from importlib_resources import path

# Visualization
from flatland.utils.rendertools import RenderTool


## Create the environment

In [23]:
# Create the environment

env = RailEnv(
    width=20,
    height=20,
    rail_generator=SparseRailGen(
        max_num_cities=2,  # Number of cities
        grid_mode=True,
        max_rails_between_cities=2,
        max_rail_pairs_in_city=1,
    ),
    line_generator=SparseLineGen(speed_ratio_map={1.: 1.}
        ),
    number_of_agents=2, 
    obs_builder_object=TreeObsForRailEnv(max_depth=3),
    malfunction_generator=ParamMalfunctionGen(
        MalfunctionParameters(
            malfunction_rate=0.,  # Rate of malfunction
            min_duration=3,  # Minimal duration
            max_duration=20,  # Max duration
        )
    ),
)

_,_ = env.reset()

## Classic SARSA

In [27]:
# Useful functions for displaying tables

# Display the Q-table as a set of heatmaps, one for each action
def qtable_display(
    q_array: np.ndarray, 
    title: Optional[str] = None, 
    figsize: Tuple[int, int] = (4, 4), 
    annot: bool = True, 
    fmt: str = "0.1f", 
    linewidths: float = .5, 
    square: bool = True, 
    cbar: bool = False, 
    cmap: str = "Reds"
) -> None:
    """
    Display a Q-table as a set of heatmaps, one for each action.

    For the frozen lake environment, there are 16 states and 4 actions thus this function will display 4 heatmaps, one for each action.
    Each heatmap will display the Q-values for each state when performing the action indexed by the heatmap.

    Parameters
    ----------
    q_array : np.ndarray
        The Q-table to display. Each row corresponds to a state, and each column corresponds to an action.
        In the frozen lake environment, there are 16 states and 4 actions thus the Q-table has a shape of (16, 4).
        For instance, q_array[0, 3] is the Q-value (estimation of the expected reward) for performing action 3 ("move up") in state 0 (the top left square).
    title : str, optional
        The title of the plot, by default None
    figsize : tuple, optional
        The size of the figure (in inches), by default (4, 4)
    annot : bool, optional
        If True, write the data value in each cell, by default True
    fmt : str, optional
        The string formatting code to use when adding annotations, by default "0.1f" that will display a single decimal
    linewidths : float, optional
        The width of the lines that will divide each cell, by default .5
    square : bool, optional
        Whether to set the Axes aspect to "equal" so each cell is square-shaped, by default True
    cbar : bool, optional
        Whether to draw a colorbar, by default False
    cmap : str, optional
        The mapping from data values to color space, by default "Reds"

    Returns
    -------
    None
    """
    # Get the number of actions from the shape of the Q-table
    num_actions = q_array.shape[1]

    # Adjust the figure size (in inches) based on the number of actions
    global_figsize = list(figsize)
    global_figsize[0] *= num_actions

    # Create a subplot for each action
    fig, ax_list = plt.subplots(ncols=num_actions, figsize=global_figsize)

    # For each action, display the Q-values for all states as a heatmap
    for action_index in range(num_actions):
        ax = ax_list[action_index]

        # Retrieve the Q-values for each state when performing the action indexed by "action_index".
        # This forms a 1D array, state_vec, where state_vec[i] = Q(i, action_index).
        state_vec = q_array[:,action_index]

        # Display the Q-values for each state when performing the action indexed by "action_index"
        # i.e. display Q(., action_index)
        states_display(
            state_vec,
            title=r"$Q(\cdot,a_{})$".format(action_index),
            #title=r"$Q(\cdot,a_{})$ {}".format(action_index, action_labels[action_index]),
            figsize=figsize, 
            annot=annot, 
            fmt=fmt, 
            linewidths=linewidths, 
            square=square, 
            cbar=cbar, 
            cmap=cmap, 
            ax=ax
        )

    # Set the title for the entire figure
    plt.suptitle(title)
    # Display the figure
    plt.show()

In [28]:
def states_display(
    state_seq: Sequence[float], 
    title: Optional[str] = None, 
    figsize: Tuple[int, int] = (5, 5), 
    annot: bool = True, 
    fmt: str = "0.1f", 
    linewidths: float = .5, 
    square: bool = True, 
    cbar: bool = False, 
    cmap: str = "Reds", 
    ax: Optional[matplotlib.axes.Axes] = None
) -> Optional[matplotlib.axes.Axes]:
    """
    Display the expected values of all states as a heatmap.

    Parameters
    ----------
    state_seq : Sequence[float]
        The sequence of expected values to display. This can be a list, a 1D array, etc.
        Each element is the estimation of the expected value of the corresponding state.
        For example, state_seq[0] is the estimation of the expected value of the first state.
        There are 16 elements in this sequence for the frozenlake environment, i.e., one per state of the environment.
    title : str, optional
        The title of the plot, by default None
    figsize : tuple, optional
        The size of the figure (in inches), by default (5, 5)
    annot : bool, optional
        If True, write the data value in each cell, by default True
    fmt : str, optional
        The string formatting code to use when adding annotations, by default "0.1f"
    linewidths : float, optional
        The width of the lines that will divide each cell, by default .5
    square : bool, optional
        Whether to set the Axes aspect to "equal" so each cell is square-shaped, by default True
    cbar : bool, optional
        Whether to draw a colorbar, by default False
    cmap : str, optional
        The mapping from data values to color space, by default "Reds"
    ax : matplotlib.axes.Axes, optional
        The axes object to draw the heatmap on, by default None

    Returns
    -------
    matplotlib.axes.Axes, optional
        The axes object with the heatmap if one was provided, otherwise None.
    """
    # Calculate the size of the state array
    size = int(math.sqrt(len(state_seq)))
    # Convert the state sequence to a numpy array (if it isn't already one)
    state_array = np.array(state_seq)
    # Reshape the state array into a square matrix
    # (we assume here that the environment has a state space that can be visualized as a square grid)
    state_array = state_array.reshape(size, size)

    # If no axes object is provided, create a new figure and axes
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    # Create a heatmap of the state array on the axes
    ax = sns.heatmap(
        state_array, 
        annot=annot, 
        fmt=fmt, 
        linewidths=linewidths, 
        square=square, 
        cbar=cbar, 
        cmap=cmap,
        ax=ax
    )
    
    # If a title is provided, set the title of the plot
    if title is not None:
        ax.set_title(title)

    # If no axes object was provided, display the plot
    # Otherwise, return the axes object with the heatmap
    if ax is None:
        plt.show()
    else:
        return ax

In [29]:
# Define epsilon-greedy policy

def greedy_policy(state: int, q_array: np.ndarray) -> int:
    """
    Determine the action that maximizes the Q-value for a given state.

    Parameters
    ----------
    state : int
        The current state.
    q_array : np.ndarray
        The Q-table.

    Returns
    -------
    int
        The action that maximizes the Q-value for the given state.
    """
    action = np.argmax(q_array[state, :])

    return action


def epsilon_greedy_policy(state: int, q_array: np.ndarray, epsilon: float) -> int:
    """
    Determine the action to take based on an epsilon-greedy policy.

    Parameters
    ----------
    state : int
        The current state.
    q_array : np.ndarray
        The Q-table.
    epsilon : float
        The probability of choosing a random action.

    Returns
    -------
    int
        The action to take.
    """
    if np.random.rand() < epsilon:
        # With probability epsilon, choose a random action
        action = np.random.choice(len(q_array[state, :]))
    else:
        # With probability 1 - epsilon, choose the action with the highest Q-value
        action = np.argmax(q_array[state, :])
        
    return action



In [7]:
DISPLAY_EVERY_N_EPISODES = 50

# Initialize the history of the Q-table and learning rate
q_array_history = []
alpha_history = []

def sarsa(
    environment: RailEnv, 
    alpha: float = 0.1, 
    alpha_factor: float = 0.9995, 
    gamma: float = 0.99, 
    epsilon: float = 0.5, 
    num_episodes: int = 1500, 
    display: bool = False
) -> np.ndarray:
    """
    Perform SARSA learning on a given environment.

    Parameters
    ----------
    environment : RailEnv
        The environment to learn in.
    alpha : float, optional
        The learning rate, between 0 and 1. By default 0.1
    alpha_factor : float, optional
        The factor to decrease alpha by each episode, by default 0.9995
    gamma : float, optional
        The discount factor, between 0 and 1. By default 0.99
    epsilon : float, optional
        The probability of choosing a random action, by default 0.5
    num_episodes : int, optional
        The number of episodes to run, by default 1500
    display : bool, optional
        Whether to display the Q-table (every DISPLAY_EVERY_N_EPISODES episodes), by default False

    Returns
    -------
    np.ndarray
        The learned Q-table.
    """
    
    # Get the number of states and actions in the environment
    num_states = env.height * env.width
    num_actions = 5    # MOVE_LEFT, MOVE_FORWARD, MOVE_RIGHT, STOP_MOVING, DO_NOTHING

    # Initialize the Q-table to zeros
    q_array = np.zeros([num_states, num_actions])

    # Loop over the episodes
    for episode_index in range(num_episodes):
        # Display the Q-table every DISPLAY_EVERY_N_EPISODES episodes if display is True
        if display and episode_index % DISPLAY_EVERY_N_EPISODES == 0:
            qtable_display(q_array, title="Q table")
        else:
            print('.', end="")

        # Save the current Q-table and learning rate
        q_array_history.append(q_array.copy())
        alpha_history.append(alpha)

        # Decrease the learning rate if alpha_factor is not None
        if alpha_factor is not None:
            alpha = alpha * alpha_factor

        s = environment.reset()[0]
        done = False 
        a = epsilon_greedy_policy(s, q_array, epsilon)
        while done is False:
            s_, r, done, _ , _= environment.step(a)
            a_ = epsilon_greedy_policy(s_, q_array, epsilon)
            q_array[s, a] += alpha * (r + gamma * q_array[s_, a_] - q_array[s, a])
            s = s_
            a = a_        

    # Return the learned Q-table
    return q_array

## Mixed strategy : bitmaps and DQN

## Q-Learning from bitmaps

**Structure:** (from Devid Farinelli and Giulia Cantini)

A "rail occupancy bitmap" indicated on which rail and in which direction the agent is traveling at every timestep:
1. A directed graph representation of the railway network is generated through BFS (breadth-firth search), each node is a switch and each edge is a rail between two switches.
2. The shortest path for each agent is computed
3. The path is transformed into a bitmap with the timesteps as columns and the rails as rows. The direction is 1 if the agent is traveling the edge from the source node to the destination node or -1 otherwise.

Heatmaps are used to provide information about how the traffic is distributed across the rails over time.
Each agent computes 2 heatmaps, one positive and one negative, both are generated summing the bitmaps of all the other agents.

The network architecture is a Dueling DQN, with a Conv2D layer as input processing the concatenated agent bitmap, positive and negative heatmaps. The data goes through two separate streams (value and advantage) to be recombined in the final output Q values.

In [3]:
# Dueling DQN model

def dim_output(input_dim, filter_dim, stride_dim):
    return (input_dim - filter_dim) // stride_dim + 1

class Dueling_DQN(nn.Module):
    def __init__(self, width, height, action_space):
        super(Dueling_DQN, self).__init__()
        
        self.action_space = action_space
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, width))
        self.fc1_adv = nn.Linear(in_features=64 * height, out_features=512) 
        self.fc1_val = nn.Linear(in_features=64 * height, out_features=512)
        self.fc2_adv = nn.Linear(in_features=512, out_features=action_space)
        self.fc2_val = nn.Linear(in_features=512, out_features=1)
        self.relu = nn.ReLU()

    def forward(self, x): # x is the input
        x = self.relu(self.conv1(x))
        x = x.view(x.size(0), -1)

        adv = self.relu(self.fc1_adv(x))
        val = self.relu(self.fc1_val(x))

        adv = self.fc2_adv(adv)
        val = self.fc2_val(val).expand(x.size(0), self.action_space)

        x = val + adv - adv.mean(1).unsqueeze(1).expand(x.size(0), self.action_space)
        return x

Create bitmaps

In [31]:
# Preprocessing the heatmaps and bitmap

class ObsPreprocessor:
    def __init__(self, max_rails, reorder_rails):
        self.max_rails = max_rails
        self.reorder_rails = reorder_rails

    def _fill_padding(self, obs, max_rails):
        """
        
        :param obs: Agent state 
        :param max_rails: Maximum number of rails in environment 
        :return: Observation padded with 0s along first axis (until max_rails)
        
        """
        prediction_depth = obs.shape[1]
        
        pad_agent_obs = np.zeros((max_rails, prediction_depth))
        pad_agent_obs[:obs.shape[0], :obs.shape[1]] = obs
        
        return pad_agent_obs

    def _get_heatmap(self, handle, bitmaps, max_rails):
        temp_bitmaps = np.copy(bitmaps)
        temp_bitmaps[handle, :, :] = 0
        pos_dir = np.sum(np.where(temp_bitmaps > 0, temp_bitmaps, 0), axis=0)
        neg_dir = np.abs(np.sum(np.where(temp_bitmaps < 0, temp_bitmaps, 0), axis=0))
    
        return pos_dir, neg_dir

    def _swap_rails(self, bitmap, swap):
        bitmap[range(len(swap))] = bitmap[swap]
        bitmap[len(swap):, :] = 0
        return bitmap

    def _reorder_rails(self, bitmap, pos_map, neg_map): 
        swap = np.array([], dtype=int) 

        ts = 0 
        rail = np.argmax(np.absolute(bitmap[:, ts]))
        # If agent not departed
        if bitmap[rail, ts] == 0: 
            ts = 1
            rail = np.argmax(np.absolute(bitmap[:, ts]))
        
        # While the bitmap is not empty
        while bitmap[rail, ts] != 0: 
            swap = np.append(swap, rail)
            ts += np.argmax(bitmap[rail, ts:] == 0) 
            rail = np.argmax(np.absolute(bitmap[:, ts]))

        if len(swap) > 0: 
            bitmap = self._swap_rails(bitmap, swap)
            pos_map = self._swap_rails(pos_map, swap)
            neg_map = self._swap_rails(neg_map, swap)
        
        return bitmap, pos_map, neg_map

    def get_obs(self, handle, bitmap, maps):
        pos_map, neg_map = self._get_heatmap(handle, maps, self.max_rails)

        if self.reorder_rails:
            bitmap, pos_map, neg_map = self._reorder_rails(bitmap, pos_map, neg_map)

        state = np.concatenate([
            self._fill_padding(bitmap, self.max_rails),
            self._fill_padding(pos_map, self.max_rails),
            self._fill_padding(neg_map, self.max_rails)
        ])
        
        return state

In [46]:
Waypoint = NamedTuple(
    'Waypoint', [('position', Tuple[int, int]), ('direction', int)])

CardinalNode = \
	NamedTuple('CardinalNode', [('id_node', int), ('cardinal_point', int)])

Now, we implement a class that returns the rails occupancy as a bitmap with rails on y-axis and timesteps on x-axis. Rails are edges and the 1/-1 in the bitmap indicate the direction of the agent on the rail.

In [6]:
"""
        --- timesteps --->
rail 0: 1 1 1       -1-1
rail 1:      1 1
rail 2:         -1-1
.
.
rail n:
"""

'\n        --- timesteps --->\nrail 0: 1 1 1       -1-1\nrail 1:      1 1\nrail 2:         -1-1\n.\n.\nrail n:\n'

In [34]:
class RailObsForRailEnv(ObservationBuilder):

	def __init__(self, predictor):
		"""
		predictor: class that predicts the path.
		"""
		super(RailObsForRailEnv, self).__init__()
  
		self.predictor = predictor
		self.num_agents = None
		self.num_rails = None # computed in reset()
		self.max_time_steps = self.predictor.max_depth

		
		self.cell_to_id_node = {} # Map cell position : id_node
		self.id_node_to_cell = {} # Map id_node to cell position
		self.info = {} # Map id_edge : tuple (CardinalNode1, CardinalNode2, edge_length)
		self.id_edge_to_cells = {} # Map id_edge : list of tuples (cell pos, crossing dir) in rail
		self.nodes = set() # Set of node ids
		self.edges = set() # Set of edge ids

		self.recompute_bitmap = True

	def set_env(self, env: Environment):
		super().set_env(env)
		if self.predictor:
			self.predictor.set_env(self.env)

	def reset(self):
		self.cell_to_id_node = {}
		self.id_node_to_cell = {}
		self.info = {}
		self.id_edge_to_cells = {}
		self.nodes = set()
		self.edges = set()
		self._map_to_graph()
		self.recompute_bitmap = True
		self.num_agents = len(self.env.agents)

		# Calculate agents timesteps per cell
		self.tpc = dict()
		for a in range(self.num_agents):
			agent_speed = self.env.agents[a].speed_data['speed']
			self.tpc[a] = int(np.reciprocal(agent_speed))
		
	def get_many(self, handles: Optional[List[int]] = None):
		maps = None
		# Compute bitmaps from shortest paths
		if self.recompute_bitmap:
			self.recompute_bitmap = False

			prediction_dict = self.predictor.get()
			self.paths = self.predictor.shortest_paths
			cells_sequence = self.predictor.compute_cells_sequence(prediction_dict)

			maps = np.zeros((self.num_agents, self.num_rails, self.max_time_steps + 1), dtype=int)
			for a in range(self.num_agents):
				maps[a, :, :] = self._bitmap_from_cells_seq(a, cells_sequence[a])

			maps = np.roll(maps, 1)
			maps[:, :, 0] = 0

		return maps

	def get_altmaps(self, handle):
		agent = self.env.agents[handle]
		altpaths, cells_seqs = self.predictor.get_altpaths(handle, self.cell_to_id_node)
		maps = []
		for i in range(len(cells_seqs)):
			bitmap = self._bitmap_from_cells_seq(handle, cells_seqs[i])

			# If agent not departed, add 0 at the beginning
			if agent.status == TrainState.READY_TO_DEPART:
				bitmap[:, -1] = 0
				bitmap = np.roll(bitmap, 1)

			maps.append(bitmap)

		return maps, altpaths

	def get_agent_action(self, handle):
		agent = self.env.agents[handle]
		action = RailEnvActions.DO_NOTHING
		
		if agent.status == TrainState.READY_TO_DEPART:
			action = RailEnvActions.MOVE_FORWARD

		elif agent.status == TrainState.MOVING:
			if self.paths[handle] is None or len(self.paths[handle]) == 0:  # Railway disrupted
				print('[WARN] AGENT {} RAIL DISRUPTED'.format(handle))
				action = RailEnvActions.STOP_MOVING
			else:
				# Get action
				step = self.paths[handle][0]
				next_action_element = step.next_action_element.action  # Get next_action_element

				if next_action_element == 1:
					action = RailEnvActions.MOVE_LEFT
				elif next_action_element == 2:
					action = RailEnvActions.MOVE_FORWARD
				elif next_action_element == 3:
					action = RailEnvActions.MOVE_RIGHT
				
				self.paths[handle] = self.paths[handle][1:]

		return action

	def is_before_switch(self, a):
		agent = self.env.agents[a]
		before_switch = False

		if agent.state == TrainState.MOVING :
			if len(self.paths[a]) > 0:
				curr_pos = agent.position
				next_pos = self.paths[a][0].next_action_element.next_position
				curr_rail, _ = self._get_edge_from_cell(curr_pos)
				next_rail, _ = self._get_edge_from_cell(next_pos)
				before_switch = curr_rail != -1 and next_rail == -1
			else:
				print('[WARN] agent\'s {} path run out'.format(a))
				before_switch = True

		return before_switch

	def _get_rail_dir(self, a, maps, ts=0):
		rail = np.argmax(np.absolute(maps[a, :, ts]))
		direction = maps[a, rail, ts]
		return rail, direction

	def _delay(self, a, maps, rail, direction, delay):
		tpc = self.tpc[a]
		old_rail, old_dir = self._get_rail_dir(a, maps)

		maps[a] = np.roll(maps[a], delay)
		maps[a, :, 0:delay+tpc] = 0				 # Reset the first bits
		maps[a, old_rail, 0:tpc] = old_dir       # Fill the first with the current rail info
		maps[a, rail, tpc:tpc+delay] = direction # Add delay to the next rail
		
		return maps

	def _is_cell_occupied(self, a, cell):
		occupied = False

		for other in range(self.env.get_num_agents()):
			if other != a and self.env.agents[other].position == cell:
				occupied = True
				break
		
		return occupied

	def update_bitmaps(self, a, maps, is_before_switch=False):
		# Calculate exit time when switching rail
		if is_before_switch:
			tpc = self.tpc[a]
   
			next_rail, next_dir = self._get_rail_dir(a, maps, ts=tpc)
			_, last_exit = self._last_train_on_rail(a, next_rail, maps)
   
			if last_exit > 0: # Check if rail is already occupied
				# tpc: skips the first bits that are curr_rail
				curr_exit = np.argmax(maps[a, next_rail, tpc:] == 0)
				# Also consider the last cell of curr_rail
				curr_exit += tpc
			
				if curr_exit <= last_exit:
					delay = last_exit + tpc - curr_exit
					maps = self._delay(a, maps, next_rail, next_dir, delay)

		maps[a, :, 0] = 0
		maps[a] = np.roll(maps[a], -1)
		return maps
	
	def set_agent_path(self, a, path):
		self.paths[a] = path

	def _last_train_on_rail(self, a, rail, maps):
		"""
		Find train preceding agent 'handle' on rail.
		:param maps: 
		:param rail: 
		:param handle: 
		:return: 
		"""
		last, last_exit = 0, 0 # Final train, its expected exit time

		for other in range(self.env.get_num_agents()):
			if other == a or self.env.agents[other].status == TrainState.READY_TO_DEPART:
				continue

			tpc = self.tpc[other]

			
			if maps[other, rail, 0] != 0:  # If agent is already on this rail
				other_exit = np.argmax(maps[other, rail, :] == 0)

				if other_exit > last_exit:
					last, last_exit = other, other_exit

			# We use tpc-1, to skip the first bits of trains that have decided to enter rail, but are still crossing the cell before
		
			elif maps[other, rail, tpc - 1] != 0:
				other_rail, _ = self._get_rail_dir(other, maps)
				other_exit = 0

				# Consider the time to cross the current cell
				if other_rail != rail:
					other_exit = np.argmax(maps[other, other_rail, :] == 0)

				# Add the estimated exit time
				other_exit += np.argmax(maps[other, rail, other_exit:] == 0)

				if other_exit > last_exit:
					last, last_exit = other, other_exit

		return last, last_exit

	def _get_trains_on_rails(self, maps, rail, handle):
		trains = []
		for a in range(self.env.get_num_agents()):
			if not (maps[a, rail, 0] == 0 or a == handle):
				expected_exit_time = np.argmax(maps[a, rail, :] == 0) 
				trains.append((a, expected_exit_time))
		trains.sort()
		
		return trains

	def _get_edge_from_cell(self, cell):
		"""
		:param cell: Cell for which we want to find the associated rail id.
		:return: A tuple (id rail, dist) where dist is the distance as offset from the beginning of the rail.
		"""
		for edge in self.id_edge_to_cells.keys():
			cells = [cell[0] for cell in self.id_edge_to_cells[edge]] 
			if cell in cells:
				return edge, cells.index(cell)

		return -1, -1

	def _bitmap_from_cells_seq(self, handle, path) -> np.ndarray:
		"""
		Compute bitmap for agent handle, given a selected path.
		:param handle: 
		:return: 
		"""
		bitmap = np.zeros((self.num_rails, self.max_time_steps + 1), dtype=int)  # Max steps in the future + current ts
		agent = self.env.agents[handle]
		# Truncate path in the future, after reaching target
		target_index = [i for i, pos in enumerate(path) if pos[0] == agent.target[0] and pos[1] == agent.target[1]]
		if len(target_index) != 0:
			target_index = target_index[0]
			path = path[:target_index + 1]

		# Add 0 at first ts - for 'not departed yet'
		rail, _ = self._get_edge_from_cell(path[0])

		# Agent's cardinal node, where it entered the last edge
		agent_entry_node = None
		# Calculate initial edge entry point
		i = 0
		rail, _ = self._get_edge_from_cell(path[i])
		if rail != -1: # If it's on an edge
			initial_rail = rail
			# Search first switch
			while rail != -1:
				i += 1
				rail, _ = self._get_edge_from_cell(path[i])

			src, dst, _ = self.info[initial_rail]
			node_id = self.cell_to_id_node[path[i]]
			# Reversed because we want the switch's cp
			entry_cp = self._reverse_dir(direction_to_point(path[i-1], path[i]))
			# If we reach the dst node
			if (node_id, entry_cp) == dst:
				# We entered from the src node (cross_dir = 1)
				agent_entry_node = src
			# Otherwise the opposite
			elif (node_id, entry_cp) == src: 
				agent_entry_node = dst
		else:
			#Handle the case you call this while on a switch before a rail
			node_id = self.cell_to_id_node[path[i]]
			# Calculate exit direction (that's the entry cp for the next edge)
			cp = direction_to_point(path[0], path[1]) # it's ok
			# Not reversed because it's already relative to a switch
			agent_entry_node = CardinalNode(node_id, cp)

		holes = 0
		# Fill rail occupancy according to predicted position at ts
		for ts in range(0, len(path)):
			cell = path[ts]
			# Find rail associated to cell
			rail, _ = self._get_edge_from_cell(cell)
			# Find crossing direction
			if rail == -1: # Agent is on a switch
				holes += 1
				# Skip duplicated cells (for agents with fractional speed)
				if ts+1 < len(path) and cell != path[ts+1]:
					node_id = self.cell_to_id_node[cell]
					# Calculate exit direction (that's the entry cp for the next edge)
					cp = direction_to_point(cell, path[ts+1])
					# Not reversed because it's already relative to a switch
					agent_entry_node = CardinalNode(node_id, cp)
			else: # Agent is on a rail
				crossing_dir = None
				src, dst, _ = self.info[rail]
				if agent_entry_node == dst:
					crossing_dir = 1
				elif agent_entry_node == src: 
					crossing_dir = -1

				assert crossing_dir != None

				bitmap[rail, ts] = crossing_dir

				if holes > 0:
					bitmap[rail, ts-holes:ts] = crossing_dir
					holes = 0

		assert(holes == 0, "All the cells of the bitmap should be filled")

		temp = np.any(bitmap[:, 1:(len(path)-1)] != 0, axis=0)
		assert(np.all(temp), "Thee agent's bitmap shouldn't have holes ")
		return bitmap

	def _map_to_graph(self):
		"""
		Build the representation of the map as a graph.
		:return: 
		"""
		id_node_counter = 0
		connections = {}
		# targets = [agent.target for agent in self.env.agents]

		# Identify cells hat are nodes (switches or diamond crossings)
		for i in range(self.env.height):
			for j in range(self.env.width):

				is_switch = False
				is_crossing = False
				# is_target = False
				connections_matrix = np.zeros((4, 4))  # Matrix NESW x NESW

				# Check if diamond crossing
				transitions_bit = bin(self.env.rail.get_full_transitions(i, j))
				if int(transitions_bit, 2) == int('1000010000100001', 2):
					is_crossing = True
					connections_matrix[0, 2] = connections_matrix[2, 0] = 1
					connections_matrix[1, 3] = connections_matrix[3, 1] = 1

				else:
					# Check if target
					# if (i, j) in targets:
					#	is_target = True
					# Check if switch
					for direction in (0, 1, 2, 3):  # 0:N, 1:E, 2:S, 3:W
						possible_transitions = self.env.rail.get_transitions(i, j, direction)
						for t in range(4):  # Check groups of bits
							if possible_transitions[t]:
								inv_direction = (direction + 2) % 4
								connections_matrix[inv_direction, t] = connections_matrix[t, inv_direction] = 1
						num_transitions = np.count_nonzero(possible_transitions)
						if num_transitions > 1:
							is_switch = True

				if is_switch or is_crossing: #or is_target:
					# Add node - keep info on cell position
					# Update only for nodes that are switches
					connections.update({id_node_counter: connections_matrix})
					self.id_node_to_cell.update({id_node_counter: (i, j)})
					self.cell_to_id_node.update({(i, j): id_node_counter})
					id_node_counter += 1

		# Enumerate edges from these nodes
		id_edge_counter = 0
		# Start from connections of one node and follow path until next switch is found
		nodes = connections.keys()  # ids
		visited = set()  # Keeps set of CardinalNodes that were already visited
		for n in nodes:
			for cp in range(4):  # Check edges from the 4 cardinal points
				if np.count_nonzero(connections[n][cp, :]) > 0:
					visited.add(CardinalNode(n, cp))  # Add to visited
					cells_sequence = []
					node_found = False
					edge_length = 0
					# Keep going until another node is found
					direction = cp
					pos = self.id_node_to_cell[n]
					while not node_found:
						neighbour_pos = get_new_position(pos, direction)
						cells_sequence.append((neighbour_pos, direction))
						if neighbour_pos in self.cell_to_id_node:  # If neighbour is a node
							# node_found = True
							# Build edge, mark visited
							id_node1 = n
							cp1 = cp
							id_node2 = self.cell_to_id_node[neighbour_pos]
							cp2 = self._reverse_dir(direction)
							if CardinalNode(id_node2, cp2) not in visited:
								self.info.update({id_edge_counter:
									                  (CardinalNode(id_node1, cp1),
									                   CardinalNode(id_node2, cp2),
									                   edge_length)})
								cells_sequence.pop()  # Don't include this node in the edge
								self.id_edge_to_cells.update({id_edge_counter: cells_sequence})
								id_edge_counter += 1
								visited.add(CardinalNode(id_node2, cp2))
							break
						edge_length += 1  # Not considering switches in the count
						# Update pos and dir
						pos = neighbour_pos
						exit_dir = self._reverse_dir(direction)
						possible_transitions = np.array(self.env.rail.get_transitions(pos[0], pos[1], direction))
						possible_transitions[exit_dir] = 0  # Don't consider direction from which I entered
						# t = 2
						t = np.argmax(possible_transitions)  # There's only one possible transition except the one that I took to get in
						temp_pos = get_new_position(pos, t)
						if 0 <= temp_pos[0] < self.env.height and 0 <= temp_pos[1] < self.env.width:  # Patch - check if this cell is a rail
							# Entrance dir is always opposite to exit dir
							direction = t
						else:
							break

		self.nodes = nodes # Set of nodes
		self.edges = self.info.keys() # Set of edges
		self.num_rails = len(self.edges)

	@staticmethod
	def _reverse_dir(direction):
		"""
		Invert direction (int) of one agent.
		:param direction: 
		:return: 
		"""
		return int((direction + 2) % 4)
		pass

  assert(holes == 0, "All the cells of the bitmap should be filled")
  assert(np.all(temp), "Thee agent's bitmap shouldn't have holes ")


Shortest path prediction builder (Romain's one)

In [41]:
env.reset()

def get_shortest_paths(env, vis=False):
    distance_map = DistanceMap(env.agents, env.width, env.height)
    distance_map.reset(env.agents, env.rail)
    distance_map.get()

    # Visualize the distance map
    if vis:
        sp.visualize_distance_map(distance_map, 0)
        sp.visualize_distance_map(distance_map, 1)

    shortest_paths = sp.get_shortest_paths(distance_map)
    for handle in shortest_paths.keys():
        if len(shortest_paths) <= 1:
            shortest_paths[handle] = 2
        elif env.agents[handle].position is None:
            shortest_paths[handle] = 2 # Forward = start moving in the map
        else:
            next_cell = shortest_paths[handle][1] # Next cell to visit
            shortest_paths[handle] = sp.get_action_for_move(env.agents[handle].position, 
                                                        env.agents[handle].direction,
                                                        next_cell.position,
                                                        next_cell.direction,
                                                        env.rail)
    return shortest_paths



In [42]:
class SingleAgentShortest(TreeObsForRailEnv):
    '''Implements shortest path observation for the agents.'''
    def __init__(self):
        super().__init__(max_depth=0)

    def reset(self):
        super().reset()

    def get(self, handle):
        return get_shortest_paths(self.env)[handle]

Final process

In [43]:
# Create the environment

env = RailEnv(
    width=20,
    height=15,
    rail_generator=SparseRailGen(
        max_num_cities=2, 
        grid_mode=True,
        max_rails_between_cities=2,
        max_rail_pairs_in_city=1,
    ),
    line_generator=SparseLineGen(speed_ratio_map={1.: 1.}
        ),
    number_of_agents=2, 
    obs_builder_object=TreeObsForRailEnv(max_depth=3),
    malfunction_generator=ParamMalfunctionGen(
        MalfunctionParameters(
            malfunction_rate=0.,  # Rate of malfunction
            min_duration=3,  # Minimal duration
            max_duration=20,  # Max duration
        )
    ),
)

_,_ = env.reset()

In [45]:
prediction_depth = 40
observation_builder = GlobalObsForRailEnv(bfs_depth=4, predictor=sp(max_depth=prediction_depth))

state_size = prediction_depth + 5
network_action_size = 2 # we limit ourselves to 2 actions : stop or move forward
controller = Dueling_DQN(20, 15, network_action_size)
railenv_action_dict = dict()
    
evaluation_number = 0
while True:
    
    evaluation_number += 1
    obs, info = env.reset()
    if not obs:
        break
    
    print("Test Number : {}".format(evaluation_number))

    number_of_agents = len(env.agents)
    steps = 0
    for a in range(number_of_agents):
        action = 2
        railenv_action_dict.update({a:action})
    obs, all_rewards, done, info = env.step(railenv_action_dict)
    
    while done['__all__'] == False:
        for a in range(number_of_agents):
            if info['action_required'][a]:
                network_action = controller.act(obs[a])
                railenv_action = RailObsForRailEnv.get_agent_action(a)
            else:
                railenv_action = 0
            railenv_action_dict.update({a: railenv_action})
            
        obs, all_rewards, done, info = env.step(railenv_action_dict)
        steps += 1
        
        if done['__all__']:
            print("Reward : ", sum(list(all_rewards.values())))
            break

TypeError: 'module' object is not callable

In [None]:
# Train the policy
obs_params = {
    "observation_tree_depth": 3,
    "observation_radius": 10,
}
train_params = {
    "eps_start": 1.0,
    "eps_end": 0.01,
    "eps_decay": 0.99,
    "n_episodes": 40,
    "checkpoint_interval": 50,
    "n_eval_episodes": 10,
    "restore_replay_buffer": False,
    "save_replay_buffer": False,
    "render": False,
    "buffer_size": int(1e5),
    "LSTM" : True
}


model = Dueling_DQN(env.width, env.height, 5)
policy = 

train_agent(model, policy, train_params, obs_params)

In [None]:
# Test the policy on the best seed
env_renderer = test_utils.render_one_test(env, policy, obs_params, seed=5, real_time_render=False, force_gif=True)
env_renderer.make_gif('test')