# Map-related methods

## Import necessary packages

In [1]:
# Import relevant libraries

%reload_ext autoreload
%autoreload 2

import sys
sys.path.append('src/')

import numpy as np
import os
import pandas as pd
from ast import literal_eval
import matplotlib.pyplot as plt
import torch
import time
import random

# Base flatland environment
from flatland.envs.line_generators import SparseLineGen
from flatland.envs.malfunction_generators import (
    MalfunctionParameters,
    ParamMalfunctionGen,
)
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import SparseRailGen
from flatland.envs.observations import GlobalObsForRailEnv

from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.distance_map import DistanceMap
import flatland.envs.rail_env_shortest_paths as sp

from src import test_utils, training, rewards
from src.observation_utils import normalize_observation
from src.models import *
from src.deep_model_policy import DeepPolicy, PolicyParameters

# Visualization
from flatland.utils.rendertools import RenderTool


## Create the environment

In [2]:
# Create the environment

env = RailEnv(
    width=20,
    height=15,
    rail_generator=SparseRailGen(
        max_num_cities=2,  # Number of cities
        grid_mode=True,
        max_rails_between_cities=2,
        max_rail_pairs_in_city=1,
    ),
    line_generator=SparseLineGen(speed_ratio_map={1.: 1.}
        ),
    number_of_agents=2, 
    obs_builder_object=TreeObsForRailEnv(max_depth=3),
    malfunction_generator=ParamMalfunctionGen(
        MalfunctionParameters(
            malfunction_rate=0.,  # Rate of malfunction
            min_duration=3,  # Minimal duration
            max_duration=20,  # Max duration
        )
    ),
)

_,_ = env.reset()

## Classic SARSA

In [None]:
DISPLAY_EVERY_N_EPISODES = 50

# Initialize the history of the Q-table and learning rate
q_array_history = []
alpha_history = []

def sarsa(
    environment: gym.Env, 
    alpha: float = 0.1, 
    alpha_factor: float = 0.9995, 
    gamma: float = 0.99, 
    epsilon: float = 0.5, 
    num_episodes: int = 10000, 
    display: bool = False
) -> np.ndarray:
    """
    Perform SARSA learning on a given environment.

    Parameters
    ----------
    environment : gym.Env
        The environment to learn in.
    alpha : float, optional
        The learning rate, between 0 and 1. By default 0.1
    alpha_factor : float, optional
        The factor to decrease alpha by each episode, by default 0.9995
    gamma : float, optional
        The discount factor, between 0 and 1. By default 0.99
    epsilon : float, optional
        The probability of choosing a random action, by default 0.5
    num_episodes : int, optional
        The number of episodes to run, by default 10000
    display : bool, optional
        Whether to display the Q-table (every DISPLAY_EVERY_N_EPISODES episodes), by default False

    Returns
    -------
    np.ndarray
        The learned Q-table.
        Each row corresponds to a state, and each column corresponds to an action.
        In the frozen lake environment, there are 16 states and 4 actions thus the Q-table has a shape of (16, 4).
        For instance, q_array[0, 3] is the Q-value (estimation of the expected reward) for performing action 3 ("move up") in state 0 (the top left square).
    """
    # Get the number of states and actions in the environment
    num_states = 
    num_actions = 

    # Initialize the Q-table to zeros
    q_array = np.zeros([num_states, num_actions])

    # Loop over the episodes
    for episode_index in range(num_episodes):
        # Display the Q-table every DISPLAY_EVERY_N_EPISODES episodes if display is True
        if display and episode_index % DISPLAY_EVERY_N_EPISODES == 0:
            qtable_display(q_array, title="Q table")
        else:
            print('.', end="")

        # Save the current Q-table and learning rate
        q_array_history_ex4.append(q_array.copy())
        alpha_history_ex4.append(alpha)

        # Decrease the learning rate if alpha_factor is not None
        if alpha_factor is not None:
            alpha = alpha * alpha_factor

        # TODO...
        s = environment.reset()[0]
        done = False
        a = epsilon_greedy_policy(s, q_array, epsilon)
        while done is False:
            s_, r, done, _ , _= environment.step(a)
            a_ = epsilon_greedy_policy(s_, q_array, epsilon)
            q_array[s, a] += alpha * (r + gamma * q_array[s_, a_] - q_array[s, a])
            s = s_
            a = a_        

    # Return the learned Q-table
    return q_array

## Mixed strategy : bitmaps and DQN

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from flatland.envs.rail_env import RailEnv
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.line_generators import sparse_line_generator

# Assuming that these functions are already defined as per the images provided:
# - create_rail_env_bitmap()
# - create_heatmaps()
# - get_agent_action_from_shortest_path()

class DuelingDQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DuelingDQN, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        
        # Define network layers here based on the provided network architecture image
        self.conv = nn.Conv2d(in_channels=input_shape[0], out_channels=64, kernel_size=(1, input_shape[2]))
        self.fc_adv = nn.Linear(in_features=self.feature_size(), out_features=512)
        self.fc_val = nn.Linear(in_features=self.feature_size(), out_features=512)
        self.fc_adv2 = nn.Linear(in_features=512, out_features=num_actions)
        self.fc_val2 = nn.Linear(in_features=512, out_features=1)
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = x.view(x.size(0), -1)
        
        adv = F.relu(self.fc_adv(x))
        val = F.relu(self.fc_val(x))
        
        adv = self.fc_adv2(adv)
        val = self.fc_val2(val).expand(x.size(0), self.num_actions)
        
        q_values = val + adv - adv.mean(1).unsqueeze(1).expand(x.size(0), self.num_actions)
        return q_values

    def feature_size(self):
        return self.conv(torch.zeros(1, *self.input_shape)).view(1, -1).size(1)

# Define the environment
env = RailEnv(
    width=20,
    height=15,
    rail_generator=SparseRailGen(
        max_num_cities=2,  # Number of cities
        grid_mode=True,
        max_rails_between_cities=2,
        max_rail_pairs_in_city=1,
    ),
    line_generator=SparseLineGen(speed_ratio_map={1.: 1.}
        ),
    number_of_agents=2, 
    obs_builder_object=TreeObsForRailEnv(max_depth=3),
    malfunction_generator=ParamMalfunctionGen(
        MalfunctionParameters(
            malfunction_rate=0.,  # Rate of malfunction
            min_duration=3,  # Minimal duration
            max_duration=20,  # Max duration
        )
    ),
)

# Reset environment to get the initial observation
obs, _ = env.reset()
input_shape = (3, 20, 15)  # Assuming the input shape of the network is (3, 20, 15)
num_actions = env.action_space[0]  # Assuming action space is defined and accessible

# Initialize network and optimizer
model = DuelingDQN(input_shape, num_actions)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_episodes = 1000
for episode in range(num_episodes):
    state = env.reset()
    done = False
    while not done:
        # Get action from model
        q_values = DuelingDQN.forward(state)
        action = q_values.max(1)[1].view(1, 1)

        # Step the environment
        next_obs, reward, done, _ = env.step(action.item())
        next_state = get_state(next_obs)  # Convert observations to model input state

        # Compute the transition loss
        loss = compute_loss(state, action, reward, next_state, done)

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Move to the next state
        state = next_state

TypeError: forward() missing 1 required positional argument: 'x'

## Q-Learning from bitmaps

**Structure:**

A "rail occupancy bitmap" shows on which rail and in which direction the agent is traveling at every timestep and is obtained :
1. A directed graph representation of the railway network is generated through BFS, each node is a switch and each edge is a rail between two switches.
2. The shortest path for each agent is computed
3. The path is transformed into a bitmap with the timesteps as columns and the rails as rows. The direction is 1 if the agent is traveling the edge from the source node to the destination node or -1 otherwise.

Heatmaps are used to provide information about how the traffic is distributed across the rails over time.
Each agent computes 2 heatmaps, one positive and one negative, both are generated summing the bitmaps of all the other agents.

The general architecture is a Dueling DQN, where the input is a Conv2D layer that processes a concatenation of the agent bitmap, the positive and the negative heatmaps. Then data goes through two separate streams, the value and the advantage to be recombined in the final output Q values.

In [3]:
import torch.nn as nn
import torch.nn.functional as F

'''
Dueling DQN model
'''
def dim_output(input_dim, filter_dim, stride_dim):
    return (input_dim - filter_dim) // stride_dim + 1

class Dueling_DQN(nn.Module):
    def __init__(self, width, height, action_space):
        super(Dueling_DQN, self).__init__()
        self.action_space = action_space
        # input shape (batch_size, in_channels = height/num_rails, width/prediction_depth + 1) 
        # self.conv1 = nn.Conv1d(in_channels=height, out_channels=64, kernel_size=1)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(1, width))

        # output shape (batch_size, out_channels, conv_width)
        # conv_width = dim_output(input_dim=width, filter_dim=20, stride_dim=1)

        # in_features = conv_width * out_channels (feature maps/number of kernels, arbitrary)
        # after last Conv1d
        self.fc1_adv = nn.Linear(in_features=64 * height, out_features=512) 
        self.fc1_val = nn.Linear(in_features=64 * height, out_features=512)
        self.fc2_adv = nn.Linear(in_features=512, out_features=action_space)
        self.fc2_val = nn.Linear(in_features=512, out_features=1)

        self.relu = nn.ReLU()

    def forward(self, x): # 
        # batch_size = x.size(0)
        x = self.relu(self.conv1(x))
        x = x.view(x.size(0), -1)

        adv = self.relu(self.fc1_adv(x))
        val = self.relu(self.fc1_val(x))

        adv = self.fc2_adv(adv)
        val = self.fc2_val(val).expand(x.size(0), self.action_space)

        x = val + adv - adv.mean(1).unsqueeze(1).expand(x.size(0), self.action_space)
        return x

Create bitmaps

In [4]:
# Preprocessing

class ObsPreprocessor:
    def __init__(self, max_rails, reorder_rails):
        self.max_rails = max_rails
        self.reorder_rails = reorder_rails

    def _fill_padding(self, obs, max_rails):
        """
        
        :param obs: Agent state 
        :param max_rails: Maximum number of rails in environment 
        :return: Observation padded with 0s along first axis (until max_rails)
        
        """
        prediction_depth = obs.shape[1]
        
        pad_agent_obs = np.zeros((max_rails, prediction_depth))
        pad_agent_obs[:obs.shape[0], :obs.shape[1]] = obs
        
        return pad_agent_obs

    # (agents x rails x depth)
    def _get_heatmap(self, handle, bitmaps, max_rails):
        temp_bitmaps = np.copy(bitmaps)
        temp_bitmaps[handle, :, :] = 0
        pos_dir = np.sum(np.where(temp_bitmaps > 0, temp_bitmaps, 0), axis=0)
        neg_dir = np.abs(np.sum(np.where(temp_bitmaps < 0, temp_bitmaps, 0), axis=0))
        
        return pos_dir, neg_dir

    def _swap_rails(self, bitmap, swap):
        bitmap[range(len(swap))] = bitmap[swap]
        bitmap[len(swap):, :] = 0
        return bitmap


    def _reorder_rails(self, bitmap, pos_map, neg_map):
        swap = np.array([], dtype=int)

        ts = 0 
        rail = np.argmax(np.absolute(bitmap[:, ts]))
        # If agent not departed
        if bitmap[rail, ts] == 0:
            ts = 1
            rail = np.argmax(np.absolute(bitmap[:, ts]))
        
        # While the bitmap is not empty
        while bitmap[rail, ts] != 0:
            swap = np.append(swap, rail)
            ts += np.argmax(bitmap[rail, ts:] == 0)
            rail = np.argmax(np.absolute(bitmap[:, ts]))

        if len(swap) > 0:
            bitmap = self._swap_rails(bitmap, swap)
            pos_map = self._swap_rails(pos_map, swap)
            neg_map = self._swap_rails(neg_map, swap)
        
        return bitmap, pos_map, neg_map

    def get_obs(self, handle, bitmap, maps):
        # Select subset of conflicting paths in bitmap
        pos_map, neg_map = self._get_heatmap(handle, maps, self.max_rails)

        if self.reorder_rails:
            bitmap, pos_map, neg_map = self._reorder_rails(bitmap, pos_map, neg_map)

        state = np.concatenate([
            self._fill_padding(bitmap, self.max_rails),
            self._fill_padding(pos_map, self.max_rails),
            self._fill_padding(neg_map, self.max_rails)
        ])
        
        return state # (prediction_depth + 1, max_cas * max_rails)

In [5]:
import collections
from typing import Optional, List, Dict, Tuple, NamedTuple

Waypoint = NamedTuple(
    'Waypoint', [('position', Tuple[int, int]), ('direction', int)])

CardinalNode = \
	NamedTuple('CardinalNode', [('id_node', int), ('cardinal_point', int)])

Now, we implement a class that returns the rails occupancy as a bitmap with rails on y-axis and timesteps on x-axis. Rails are edges and the 1/-1 in the bitmap indicate the direction of the agent on the rail.

In [6]:
"""
        --- timesteps --->
rail 0: 1 1 1       -1-1
rail 1:      1 1
rail 2:         -1-1
.
.
rail n:
"""

'\n        --- timesteps --->\nrail 0: 1 1 1       -1-1\nrail 1:      1 1\nrail 2:         -1-1\n.\n.\nrail n:\n'

In [7]:
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position, direction_to_point
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.step_utils.states import TrainState

class RailObsForRailEnv(ObservationBuilder):

	def __init__(self, predictor):
		"""
		predictor: class that predicts the path.
		"""
		super(RailObsForRailEnv, self).__init__()
		self.predictor = predictor
		
		self.num_agents = None
		self.num_rails = None # Depends on the map, must be computed in reset()
		self.max_time_steps = self.predictor.max_depth

		# Not all of them are necessary
		self.cell_to_id_node = {} # Map cell position : id_node
		self.id_node_to_cell = {} # Map id_node to cell position
		self.info = {} # Map id_edge : tuple (CardinalNode1, CardinalNode2, edge_length)
		self.id_edge_to_cells = {} # Map id_edge : list of tuples (cell pos, crossing dir) in rail (nodes are not counted)
		self.nodes = set() # Set of node ids
		self.edges = set() # Set of edge ids

		self.recompute_bitmap = True

	def set_env(self, env: Environment):
		super().set_env(env)
		if self.predictor:
			# Use set_env available in PredictionBuilder (parent class)
			self.predictor.set_env(self.env)

	def reset(self):
		self.cell_to_id_node = {}
		self.id_node_to_cell = {}
		self.info = {}
		self.id_edge_to_cells = {}
		self.nodes = set()
		self.edges = set()
		self._map_to_graph()
		self.recompute_bitmap = True

		self.num_agents = len(self.env.agents)

		# Calculate agents timesteps per cell
		self.tpc = dict()
		for a in range(self.num_agents):
			agent_speed = self.env.agents[a].speed_data['speed']
			self.tpc[a] = int(np.reciprocal(agent_speed))
		
	def get_many(self, handles: Optional[List[int]] = None):
		maps = None
		# Compute bitmaps from shortest paths
		if self.recompute_bitmap:
			self.recompute_bitmap = False

			prediction_dict = self.predictor.get()
			self.paths = self.predictor.shortest_paths
			cells_sequence = self.predictor.compute_cells_sequence(prediction_dict)

			maps = np.zeros((self.num_agents, self.num_rails, self.max_time_steps + 1), dtype=int)
			for a in range(self.num_agents):
				maps[a, :, :] = self._bitmap_from_cells_seq(a, cells_sequence[a])

			maps = np.roll(maps, 1)
			maps[:, :, 0] = 0

		return maps

	def get_altmaps(self, handle):
		agent = self.env.agents[handle]
		altpaths, cells_seqs = self.predictor.get_altpaths(handle, self.cell_to_id_node)
		maps = []
		for i in range(len(cells_seqs)):
			bitmap = self._bitmap_from_cells_seq(handle, cells_seqs[i])

			# If agent not departed, add 0 at the beginning
			if agent.status == TrainState.READY_TO_DEPART:
				bitmap[:, -1] = 0
				bitmap = np.roll(bitmap, 1)

			maps.append(bitmap)

		return maps, altpaths

	def get_agent_action(self, handle):
		agent = self.env.agents[handle]
		action = RailEnvActions.DO_NOTHING
		
		if agent.status == TrainState.READY_TO_DEPART:
			action = RailEnvActions.MOVE_FORWARD

		elif agent.status == TrainState.MOVING:
			# This can return None when rails are disconnected or there was an error in the DistanceMap
			if self.paths[handle] is None or len(self.paths[handle]) == 0:  # Railway disrupted
				print('[WARN] AGENT {} RAIL DISRUPTED'.format(handle))
				action = RailEnvActions.STOP_MOVING
			else:
				# Get action
				step = self.paths[handle][0]
				next_action_element = step.next_action_element.action  # Get next_action_element

				# Just to use the correct form/name
				if next_action_element == 1:
					action = RailEnvActions.MOVE_LEFT
				elif next_action_element == 2:
					action = RailEnvActions.MOVE_FORWARD
				elif next_action_element == 3:
					action = RailEnvActions.MOVE_RIGHT
				
				self.paths[handle] = self.paths[handle][1:]

		return action

	def is_before_switch(self, a):
		agent = self.env.agents[a]
		before_switch = False

		if agent.state == TrainState.MOVING :
			if len(self.paths[a]) > 0:
				curr_pos = agent.position
				next_pos = self.paths[a][0].next_action_element.next_position
				curr_rail, _ = self._get_edge_from_cell(curr_pos)
				next_rail, _ = self._get_edge_from_cell(next_pos)
				before_switch = curr_rail != -1 and next_rail == -1
			else:
				# This shouldn't happen, but it may happen
				print('[WARN] agent\'s {} path run out'.format(a))
				# Force path recalc
				before_switch = True

		return before_switch

	def _get_rail_dir(self, a, maps, ts=0):
		rail = np.argmax(np.absolute(maps[a, :, ts]))
		direction = maps[a, rail, ts]
		return rail, direction

	# This should only be used by a train to delay itself
	def _delay(self, a, maps, rail, direction, delay):
		tpc = self.tpc[a]

		old_rail, old_dir = self._get_rail_dir(a, maps)

		maps[a] = np.roll(maps[a], delay)
		# Reset the first bits
		maps[a, :, 0:delay+tpc] = 0
		# Fill the first with the current rail info
		maps[a, old_rail, 0:tpc] = old_dir
		# Add delay to the next rail
		maps[a, rail, tpc:tpc+delay] = direction
		
		return maps

	def _is_cell_occupied(self, a, cell):
		occupied = False

		for other in range(self.env.get_num_agents()):
			if other != a and self.env.agents[other].position == cell:
				occupied = True
				break
		
		return occupied

	def _check_headon_crash(self, a, rail, direction, maps):
		crash = False

		# Check if rail is already occupied to compute new exit time
		last, last_exit = self._last_train_on_rail(a, rail, maps)

		if last_exit > 0:
			# last_exit-1 instead of 0, because in 0 it may be crossing the last
			# cell before the switch
			last_dir = maps[last, rail, last_exit - 1]
			crash = last_dir != direction

		return crash

	def check_crash(self, a, maps, is_before_switch=False):
		crash = False
		agent = self.env.agents[a]

		if agent.status == TrainState.READY_TO_DEPART:
			# init_pos not occupied
			next_pos = agent.initial_position
			crash = self._is_cell_occupied(a, next_pos)

			if not crash:
				# We should skip the first bit that is 0
				rail, direction = self._get_rail_dir(a, maps, ts=1)
				crash = self._check_headon_crash(a, rail, direction, maps)

		elif is_before_switch:
			tpc = self.tpc[a]
			next_rail, next_dir = self._get_rail_dir(a, maps, ts=tpc)
			crash = self._check_headon_crash(a, next_rail, next_dir, maps)

		else: # action_required
			if len(self.paths[a]) > 0:
				next_pos = self.paths[a][0].next_action_element.next_position
				crash = self._is_cell_occupied(a, next_pos)

		return crash

	def update_bitmaps(self, a, maps, is_before_switch=False):
		# Calculate exit time when switching rail
		if is_before_switch:
			tpc = self.tpc[a]
			next_rail, next_dir = self._get_rail_dir(a, maps, ts=tpc)

			# Check if rail is already occupied to compute new exit time
			_, last_exit = self._last_train_on_rail(a, next_rail, maps)
			if last_exit > 0:
				# tpc: skips the first bits that are curr_rail
				curr_exit = np.argmax(maps[a, next_rail, tpc:] == 0)
				# Also consider the last cell of curr_rail
				curr_exit += tpc
				# TODO? something changes if the last id is > or <  ?
				if curr_exit <= last_exit:
					delay = last_exit + tpc - curr_exit
					maps = self._delay(a, maps, next_rail, next_dir, delay)

		maps[a, :, 0] = 0
		maps[a] = np.roll(maps[a], -1)
		return maps
	
	def set_agent_path(self, a, path):
		self.paths[a] = path

	def _last_train_on_rail(self, a, rail, maps):
		"""
		Find train preceding agent 'handle' on rail.
		:param maps: 
		:param rail: 
		:param handle: 
		:return: 
		"""
		last, last_exit = 0, 0 # Final train, its expected exit time

		for other in range(self.env.get_num_agents()):
			if other == a or self.env.agents[other].status == TrainState.READY_TO_DEPART:
				continue

			tpc = self.tpc[other]

			# If agent is already on this rail
			if maps[other, rail, 0] != 0:
				# Add the estimated exit time
				other_exit = np.argmax(maps[other, rail, :] == 0)

				if other_exit > last_exit:
					last, last_exit = other, other_exit

			# We use tpc-1, to skip the first bits of trains that have decided
			# to enter rail, but are still crossing the cell before
			# If an agent has not yet decided in tpc-1 it will be in the old rail 
			elif maps[other, rail, tpc - 1] != 0:
				other_rail, _ = self._get_rail_dir(other, maps)
				other_exit = 0

				# Consider the time to cross the current cell
				if other_rail != rail:
					other_exit = np.argmax(maps[other, other_rail, :] == 0)

				# TODO! CHECK
				# Add the estimated exit time
				other_exit += np.argmax(maps[other, rail, other_exit:] == 0)

				if other_exit > last_exit:
					last, last_exit = other, other_exit

		return last, last_exit

	def _get_trains_on_rails(self, maps, rail, handle):
		trains = []
		for a in range(self.env.get_num_agents()):
			if not (maps[a, rail, 0] == 0 or a == handle):
				expected_exit_time = np.argmax(maps[a, rail, :] == 0) # Takes index/ts of last bit in a row
				trains.append((a, expected_exit_time))
		trains.sort()
		
		return trains

	def _get_edge_from_cell(self, cell):
		"""
		:param cell: Cell for which we want to find the associated rail id.
		:return: A tuple (id rail, dist) where dist is the distance as offset from the beginning of the rail.
		"""
		for edge in self.id_edge_to_cells.keys():
			cells = [cell[0] for cell in self.id_edge_to_cells[edge]] 
			if cell in cells:
				return edge, cells.index(cell)

		return -1, -1  # Node

	def _bitmap_from_cells_seq(self, handle, path) -> np.ndarray:
		"""
		Compute bitmap for agent handle, given a selected path.
		:param handle: 
		:return: 
		"""
		bitmap = np.zeros((self.num_rails, self.max_time_steps + 1), dtype=int)  # Max steps in the future + current ts
		agent = self.env.agents[handle]
		# Truncate path in the future, after reaching target
		target_index = [i for i, pos in enumerate(path) if pos[0] == agent.target[0] and pos[1] == agent.target[1]]
		if len(target_index) != 0:
			target_index = target_index[0]
			path = path[:target_index + 1]

		# Add 0 at first ts - for 'not departed yet'
		rail, _ = self._get_edge_from_cell(path[0])

		# Agent's cardinal node, where it entered the last edge
		agent_entry_node = None
		# Calculate initial edge entry point
		i = 0
		rail, _ = self._get_edge_from_cell(path[i])
		if rail != -1: # If it's on an edge
			initial_rail = rail
			# Search first switch
			while rail != -1:
				i += 1
				rail, _ = self._get_edge_from_cell(path[i])

			src, dst, _ = self.info[initial_rail]
			node_id = self.cell_to_id_node[path[i]]
			# Reversed because we want the switch's cp
			entry_cp = self._reverse_dir(direction_to_point(path[i-1], path[i]))
			# If we reach the dst node
			if (node_id, entry_cp) == dst:
				# We entered from the src node (cross_dir = 1)
				agent_entry_node = src
			# Otherwise the opposite
			elif (node_id, entry_cp) == src: 
				agent_entry_node = dst
		else:
			#Handle the case you call this while on a switch before a rail
			node_id = self.cell_to_id_node[path[i]]
			# Calculate exit direction (that's the entry cp for the next edge)
			cp = direction_to_point(path[0], path[1]) # it's ok
			# Not reversed because it's already relative to a switch
			agent_entry_node = CardinalNode(node_id, cp)

		holes = 0
		# Fill rail occupancy according to predicted position at ts
		for ts in range(0, len(path)):
			cell = path[ts]
			# Find rail associated to cell
			rail, _ = self._get_edge_from_cell(cell)
			# Find crossing direction
			if rail == -1: # Agent is on a switch
				holes += 1
				# Skip duplicated cells (for agents with fractional speed)
				if ts+1 < len(path) and cell != path[ts+1]:
					node_id = self.cell_to_id_node[cell]
					# Calculate exit direction (that's the entry cp for the next edge)
					cp = direction_to_point(cell, path[ts+1])
					# Not reversed because it's already relative to a switch
					agent_entry_node = CardinalNode(node_id, cp)
			else: # Agent is on a rail
				crossing_dir = None
				src, dst, _ = self.info[rail]
				if agent_entry_node == dst:
					crossing_dir = 1
				elif agent_entry_node == src: 
					crossing_dir = -1

				assert crossing_dir != None

				bitmap[rail, ts] = crossing_dir

				if holes > 0:
					bitmap[rail, ts-holes:ts] = crossing_dir
					holes = 0

		assert(holes == 0, "All the cells of the bitmap should be filled")

		temp = np.any(bitmap[:, 1:(len(path)-1)] != 0, axis=0)
		assert(np.all(temp), "Thee agent's bitmap shouldn't have holes ")
		return bitmap

	# Slightly modified wrt to the other
	def _map_to_graph(self):
		"""
		Build the representation of the map as a graph.
		:return: 
		"""
		id_node_counter = 0
		connections = {}
		# targets = [agent.target for agent in self.env.agents]

		# Identify cells hat are nodes (switches or diamond crossings)
		for i in range(self.env.height):
			for j in range(self.env.width):

				is_switch = False
				is_crossing = False
				# is_target = False
				connections_matrix = np.zeros((4, 4))  # Matrix NESW x NESW

				# Check if diamond crossing
				transitions_bit = bin(self.env.rail.get_full_transitions(i, j))
				if int(transitions_bit, 2) == int('1000010000100001', 2):
					is_crossing = True
					connections_matrix[0, 2] = connections_matrix[2, 0] = 1
					connections_matrix[1, 3] = connections_matrix[3, 1] = 1

				else:
					# Check if target
					# if (i, j) in targets:
					#	is_target = True
					# Check if switch
					for direction in (0, 1, 2, 3):  # 0:N, 1:E, 2:S, 3:W
						possible_transitions = self.env.rail.get_transitions(i, j, direction)
						for t in range(4):  # Check groups of bits
							if possible_transitions[t]:
								inv_direction = (direction + 2) % 4
								connections_matrix[inv_direction, t] = connections_matrix[t, inv_direction] = 1
						num_transitions = np.count_nonzero(possible_transitions)
						if num_transitions > 1:
							is_switch = True

				if is_switch or is_crossing: #or is_target:
					# Add node - keep info on cell position
					# Update only for nodes that are switches
					connections.update({id_node_counter: connections_matrix})
					self.id_node_to_cell.update({id_node_counter: (i, j)})
					self.cell_to_id_node.update({(i, j): id_node_counter})
					id_node_counter += 1

		# Enumerate edges from these nodes
		id_edge_counter = 0
		# Start from connections of one node and follow path until next switch is found
		nodes = connections.keys()  # ids
		visited = set()  # Keeps set of CardinalNodes that were already visited
		for n in nodes:
			for cp in range(4):  # Check edges from the 4 cardinal points
				if np.count_nonzero(connections[n][cp, :]) > 0:
					visited.add(CardinalNode(n, cp))  # Add to visited
					cells_sequence = []
					node_found = False
					edge_length = 0
					# Keep going until another node is found
					direction = cp
					pos = self.id_node_to_cell[n]
					while not node_found:
						neighbour_pos = get_new_position(pos, direction)
						cells_sequence.append((neighbour_pos, direction))
						if neighbour_pos in self.cell_to_id_node:  # If neighbour is a node
							# node_found = True
							# Build edge, mark visited
							id_node1 = n
							cp1 = cp
							id_node2 = self.cell_to_id_node[neighbour_pos]
							cp2 = self._reverse_dir(direction)
							if CardinalNode(id_node2, cp2) not in visited:
								self.info.update({id_edge_counter:
									                  (CardinalNode(id_node1, cp1),
									                   CardinalNode(id_node2, cp2),
									                   edge_length)})
								cells_sequence.pop()  # Don't include this node in the edge
								self.id_edge_to_cells.update({id_edge_counter: cells_sequence})
								id_edge_counter += 1
								visited.add(CardinalNode(id_node2, cp2))
							break
						edge_length += 1  # Not considering switches in the count
						# Update pos and dir
						pos = neighbour_pos
						exit_dir = self._reverse_dir(direction)
						possible_transitions = np.array(self.env.rail.get_transitions(pos[0], pos[1], direction))
						possible_transitions[exit_dir] = 0  # Don't consider direction from which I entered
						# t = 2
						t = np.argmax(possible_transitions)  # There's only one possible transition except the one that I took to get in
						temp_pos = get_new_position(pos, t)
						if 0 <= temp_pos[0] < self.env.height and 0 <= temp_pos[1] < self.env.width:  # Patch - check if this cell is a rail
							# Entrance dir is always opposite to exit dir
							direction = t
						else:
							break

		self.nodes = nodes # Set of nodes
		self.edges = self.info.keys() # Set of edges
		self.num_rails = len(self.edges)

	@staticmethod
	def _reverse_dir(direction):
		"""
		Invert direction (int) of one agent.
		:param direction: 
		:return: 
		"""
		return int((direction + 2) % 4)
		pass

  assert(holes == 0, "All the cells of the bitmap should be filled")
  assert(np.all(temp), "Thee agent's bitmap shouldn't have holes ")


Shortest path prediction builder (Romain's one)

In [8]:
env.reset()

def get_shortest_paths(env, vis=False):
    distance_map = DistanceMap(env.agents, env.width, env.height)
    distance_map.reset(env.agents, env.rail)
    distance_map.get()

    # Visualize the distance map
    if vis:
        sp.visualize_distance_map(distance_map, 0)
        sp.visualize_distance_map(distance_map, 1)

    shortest_paths = sp.get_shortest_paths(distance_map)
    for handle in shortest_paths.keys():
        if len(shortest_paths) <= 1:
            shortest_paths[handle] = 2
        elif env.agents[handle].position is None:
            shortest_paths[handle] = 2 # Forward = start moving in the map
        else:
            next_cell = shortest_paths[handle][1] # Next cell to visit
            shortest_paths[handle] = sp.get_action_for_move(env.agents[handle].position, 
                                                        env.agents[handle].direction,
                                                        next_cell.position,
                                                        next_cell.direction,
                                                        env.rail)
    return shortest_paths



In [23]:
class SingleAgentShortest(TreeObsForRailEnv):
    '''Implements shortest path observation for the agents.'''
    def __init__(self):
        super().__init__(max_depth=0)

    def reset(self):
        super().reset()

    def get(self, handle):
        return get_shortest_paths(self.env)[handle]

Final process

In [14]:
# Create the environment

env = RailEnv(
    width=20,
    height=15,
    rail_generator=SparseRailGen(
        max_num_cities=2,  # Number of cities
        grid_mode=True,
        max_rails_between_cities=2,
        max_rail_pairs_in_city=1,
    ),
    line_generator=SparseLineGen(speed_ratio_map={1.: 1.}
        ),
    number_of_agents=2, 
    obs_builder_object=TreeObsForRailEnv(max_depth=3),
    malfunction_generator=ParamMalfunctionGen(
        MalfunctionParameters(
            malfunction_rate=0.,  # Rate of malfunction
            min_duration=3,  # Minimal duration
            max_duration=20,  # Max duration
        )
    ),
)

_,_ = env.reset()

In [34]:
from importlib_resources import path
from flatland.envs.observation_utils import GraphObsForRailEnv

prediction_depth = 40
observation_builder = GraphObsForRailEnv(bfs_depth=4, predictor=ShortestPathPredictorForRailEnv(max_depth=prediction_depth))

state_size = prediction_depth + 5
network_action_size = 2
controller = Agent('fc', state_size, network_action_size)
railenv_action_dict = dict()

    
evaluation_number = 0
while True:
    
    evaluation_number += 1
    time_start = time.time()
    
    obs, info = env.reset()
    if not obs:
        break
    
    print("Test Number : {}".format(evaluation_number))

    local_env = env
    number_of_agents = len(env.agents)
    
    time_taken_by_controller = []
    time_taken_per_step = []
    steps = 0
    # First random action
    for a in range(number_of_agents):
        action = 2
        railenv_action_dict.update({a:action})
    obs, all_rewards, done, info = env.step(railenv_action_dict)

    while True:
        # Evaluation of a single episode
    
        time_start = time.time()
        # Pick actions
        for a in range(number_of_agents):
            if info['action_required'][a]:
                network_action = controller.act(obs[a])
                railenv_action = observation_builder.choose_railenv_action(a, network_action)
            else:
                railenv_action = 0
            railenv_action_dict.update({a: railenv_action})
                
        time_taken = time.time() - time_start
        time_taken_by_controller.append(time_taken)

        time_start = time.time()
        # Perform env step
        obs, all_rewards, done, info = env.step(railenv_action_dict)
        steps += 1
        time_taken = time.time() - time_start
        time_taken_per_step.append(time_taken)
        
        if done['__all__']:
            print("Reward : ", sum(list(all_rewards.values())))

            break

    np_time_taken_by_controller = np.array(time_taken_by_controller)
    np_time_taken_per_step = np.array(time_taken_per_step)
    print("=" * 100)
    print("=" * 100)
    print("Evaluation Number : ", evaluation_number)
    print("Current Env Path : ", env.path)
    print("Env Creation Time : ", env_creation_time)
    print("Number of Steps : ", steps)
    print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
          np_time_taken_by_controller.std())
    print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
    print("=" * 100)

print("Evaluation of all environments complete...")

print(remote_client.submit())

ModuleNotFoundError: No module named 'flatland.envs.observation_utils'

In [33]:
# -*- coding: utf-8 -*-
import sys
from pathlib import Path

from collections import deque

from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import SparseLineGen
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.envs.rail_env import RailEnvActions



def main(args):
	
	# Show options and values
	print(' ' * 26 + 'Options')
	for k, v in vars(args).items():
		print(' ' * 26 + k + ': ' + str(v))
	# Where to save models
	results_dir = os.path.join('results', args.model_id)
	if not os.path.exists(results_dir):
		os.makedirs(results_dir)
	
	rail_generator = sparse_rail_generator(max_num_cities=args.max_num_cities,
	                                       seed=args.seed,
	                                       grid_mode=args.grid_mode,
	                                       max_rails_between_cities=args.max_rails_between_cities,
	                                       max_rails_in_city=args.max_rails_in_city,
	                                       )

	# Maps speeds to % of appearance in the env
	speed_ration_map = {1.: 1}  # Fast passenger train
	
	if args.multi_speed:
		speed_ration_map = {1.: 0.25,  # Fast passenger train
							1. / 2.: 0.25,  # Fast freight train
							1. / 3.: 0.25,  # Slow commuter train
							1. / 4.: 0.25}  # Slow freight train

	schedule_generator = SparseLineGen(speed_ration_map)
	
	prediction_builder = ShortestPathPredictorForRailEnv(max_depth=args.prediction_depth)
	obs_builder = RailObsForRailEnv(predictor=prediction_builder)

	env = RailEnv(width=args.width,
	              height=args.height,
	              rail_generator=rail_generator,
	              random_seed=0,
	              schedule_generator=schedule_generator,
	              number_of_agents=args.num_agents,
	              obs_builder_object=obs_builder,
	              malfunction_generator_and_process_data=malfunction_from_params(
		              parameters={
			              'malfunction_rate': args.malfunction_rate,
			              'min_duration': args.min_duration,
			              'max_duration': args.max_duration
		              })
	              )

	if args.render:
		env_renderer = RenderTool(
			env,
			agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
			show_debug=True,
			screen_height=800,
			screen_width=800)

	if args.plot:
		writer = SummaryWriter(log_dir='runs/' + args.model_id)

	max_rails = 100 # TODO Must be a parameter of the env (estimated)
	# max_steps = env.compute_max_episode_steps(env.width, env.height)
	max_steps = 200

	preprocessor = ObsPreprocessor(max_rails, args.reorder_rails)

	dqn = DQNAgent(args, bitmap_height=max_rails * 3, action_space=2)

	if args.load_path:
		file = os.path.isfile(args.load_path)
		if file:
			dqn.qnetwork_local.load_state_dict(torch.load(args.load_path))
			print('WEIGHTS LOADED from: ', args.load_path)

	eps = args.start_eps
	railenv_action_dict = {}
	network_action_dict = {}
	# Metrics
	done_window = deque(maxlen=args.window_size) # Env dones over last window_size episodes
	done_agents_window = deque(maxlen=args.window_size) # Fraction of done agents over last ...
	reward_window = deque(maxlen=args.window_size) # Cumulative rewards over last window_size episodes
	norm_reward_window = deque(maxlen=args.window_size) # Normalized cum. rewards over last window_size episodes
	# Track means over windows of window_size episodes
	mean_dones = [] 
	mean_agent_dones = []
	mean_rewards = []
	mean_norm_rewards = []
	# Episode rewards/dones/norm rewards since beginning of training TODO
	#env_dones = []
	
	crash = [False] * args.num_agents
	update_values = [False] * args.num_agents
	buffer_obs = [[]] * args.num_agents

	############ Main loop
	for ep in range(args.num_episodes):
		cumulative_reward = 0
		env_done = 0
		altmaps = [None] * args.num_agents
		altpaths = [[]] * args.num_agents
		buffer_rew = [0] * args.num_agents
		buffer_done = [False] * args.num_agents
		curr_obs = [None] * args.num_agents

		maps, info = env.reset()
		if args.print:
			debug.print_bitmaps(maps)

		if args.render:
			env_renderer.reset()

		for step in range(max_steps - 1):
			# Save a copy of maps at the beginning
			buffer_maps = maps.copy()
			# rem first bit is 0 for agent not departed
			for a in range(env.get_num_agents()):
				agent = env.agents[a]
				crash[a] = False
				update_values[a] = False
				network_action = None
				action = None

				# If agent is arrived
				if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
					# TODO if agent !removed you should leave a bit in the bitmap
					# TODO? set bitmap only the first time
					maps[a, :, :] = 0
					network_action = 0
					action = RailEnvActions.DO_NOTHING

				# If agent is not departed
				elif agent.status == RailAgentStatus.READY_TO_DEPART:
					update_values[a] = True
					obs = preprocessor.get_obs(a, maps[a], buffer_maps)
					curr_obs[a] = obs.copy()

					# Network chooses action
					q_values = dqn.act(obs).cpu().data.numpy()
					if np.random.random() > eps:
						network_action = np.argmax(q_values)
					else:
						network_action = np.random.choice([0, 1])

					if network_action == 0:
						action = RailEnvActions.DO_NOTHING
					else: # Go
						crash[a] = obs_builder.check_crash(a, maps)
						
						if crash[a]:
							network_action = 0
							action = RailEnvActions.STOP_MOVING
						else:
							maps = obs_builder.update_bitmaps(a, maps)
							action = obs_builder.get_agent_action(a)
		
				# If the agent is entering a switch
				elif obs_builder.is_before_switch(a) and info['action_required'][a]:
					# If the altpaths cache is empty or already contains
					# the altpaths from the current agent's position
					if len(altpaths[a]) == 0 or agent.position != altpaths[a][0][0].position:
						altmaps[a], altpaths[a] = obs_builder.get_altmaps(a)

					if len(altmaps[a]) > 0:
						update_values[a] = True
						altobs = [None] * len(altmaps[a])
						q_values = np.array([])
						for i in range(len(altmaps[a])):
							altobs[i] = preprocessor.get_obs(a, altmaps[a][i], buffer_maps)
							q_values = np.concatenate([q_values, dqn.act(altobs[i]).cpu().data.numpy()])

						# Epsilon-greedy action selection
						if np.random.random() > eps:
							argmax = np.argmax(q_values)
							network_action = argmax % 2
							best_i = argmax // 2
						else:
							network_action = np.random.choice([0, 1])
							best_i = np.random.choice(np.arange(len(altmaps[a])))
						
						# Use new bitmaps and paths
						maps[a, :, :] = altmaps[a][best_i]
						obs_builder.set_agent_path(a, altpaths[a][best_i])
						curr_obs[a] = altobs[best_i].copy()

					else:
						print('[ERROR] NO ALTHPATHS EP: {} STEP: {} AGENT: {}'.format(ep, step, a))
						network_action = 0

					if network_action == 0:
						action = RailEnvActions.STOP_MOVING
					else:
						crash[a] = obs_builder.check_crash(a, maps, is_before_switch=True)
						
						if crash[a]:
							network_action = 0
							action = RailEnvActions.STOP_MOVING
						else:
							action = obs_builder.get_agent_action(a)
							maps = obs_builder.update_bitmaps(a, maps, is_before_switch=True)
		
				# If the agent is following a rail
				elif info['action_required'][a]:
					crash[a] = obs_builder.check_crash(a, maps)

					if crash[a]:
						network_action = 0
						action = RailEnvActions.STOP_MOVING
					else:
						network_action = 1
						action = obs_builder.get_agent_action(a)
						maps = obs_builder.update_bitmaps(a, maps)

				else: # not action_required
					network_action = 1
					action = RailEnvActions.DO_NOTHING
					maps = obs_builder.update_bitmaps(a, maps)
				
				network_action_dict.update({a: network_action})
				railenv_action_dict.update({a: action})

			# Obs is computed from bitmaps while state is computed from env step
			_, reward, done, info = env.step(railenv_action_dict)
			
			# Update replay buffer and train agent
			if args.train:
				for a in range(env.get_num_agents()):
					if args.crash_penalty and crash[a]:
						# Store bad experience
						dqn.step(curr_obs[a], 1, -100, curr_obs[a], True)

					if not args.switch2switch:
						if update_values[a] and not buffer_done[a]:
							next_obs = preprocessor.get_obs(a, maps[a], maps)
							dqn.step(curr_obs[a], network_action_dict[a], reward[a], next_obs, done[a])

					else:
						if update_values[a] and not buffer_done[a]:
							# If I had an obs from a previous switch
							if len(buffer_obs[a]) != 0:
								dqn.step(buffer_obs[a], 1, buffer_rew[a], curr_obs[a], done[a])
								buffer_obs[a] = []
								buffer_rew[a] = 0

							if network_action_dict[a] == 0:
								dqn.step(curr_obs[a], 1, reward[a], curr_obs[a], False)
							elif network_action_dict[a] == 1:
								# I store the obs and update at the next switch
								buffer_obs[a] = curr_obs[a].copy()

						# Cache reward only if we have an obs from a prev switch
						if len(buffer_obs[a]) != 0:
							buffer_rew[a] += reward[a]

					# Now update the done cache to avoid adding experience many times
					buffer_done[a] = done[a]

			for a in range(env.get_num_agents()):	
				cumulative_reward += reward[a] 
			 
			if done['__all__']:
				env_done = 1
				break

		################### End of the episode
		eps = max(args.end_eps, args.eps_decay * eps)  # Decrease epsilon
		# Metrics
		done_window.append(env_done) # Save done in this episode
		
		num_agents_done = 0  # Num of agents that reached their target in the last episode
		for a in range(env.get_num_agents()): 
			if done[a]:
				num_agents_done += 1
		done_agents_window.append(num_agents_done / env.get_num_agents())
		reward_window.append(cumulative_reward)  # Save cumulative reward in this episode
		normalized_reward = cumulative_reward / (env.compute_max_episode_steps(env.width, env.height) + env.get_num_agents())
		norm_reward_window.append(normalized_reward)
		
		mean_dones.append((np.mean(done_window)))
		mean_agent_dones.append((np.mean(done_agents_window)))
		mean_rewards.append(np.mean(reward_window))
		mean_norm_rewards.append(np.mean(norm_reward_window))

		# Print training results info
		print(
			'\r{} Agents on ({},{}). Episode: {}\t Mean done agents: {:.2f}\t Mean reward: {:.2f}\t Mean normalized reward: {:.2f}\t Done agents in last episode: {:.2f}%\t Epsilon: {:.2f}'.format(
				env.get_num_agents(), args.width, args.height,
				ep,
				mean_agent_dones[-1],  # Fraction of done agents
				mean_rewards[-1],
				mean_norm_rewards[-1],
				(num_agents_done / args.num_agents),
				eps), end=" ")

In [None]:
main()