# GNN approach

## Required libraries

In [21]:
# Import relevant libraries

%reload_ext autoreload
%autoreload 2

import numpy as np
import os
import pandas as pd
import wandb
import math
from datetime import datetime
from tabulate import tabulate
from ast import literal_eval
import matplotlib
import matplotlib.pyplot as plt
import copy
from enum import IntEnum
import time
import random
import seaborn as sns
import collections
from collections import namedtuple, deque
from typing import List, Optional, Tuple, Union, Callable, Dict, Sequence, NamedTuple
from pathlib import Path
import itertools
from importlib_resources import path
import yaml
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.data import Data, Batch
from torch.nn.modules.container import ParameterList
import torch.optim as optim

import networkx as nx
import pickle



# Base flatland environment
from flatland.envs.line_generators import SparseLineGen
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file
from flatland.envs.malfunction_generators import (
    MalfunctionParameters,
    ParamMalfunctionGen,
)

from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.distance_map import DistanceMap
import flatland.envs.rail_env_shortest_paths as sp
from flatland.envs.rail_env import RailEnv, RailEnvActions

from flatland.envs.step_utils.states import TrainState
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.persistence import RailEnvPersister
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder

from flatland.core.grid.grid4_utils import get_new_position, direction_to_point, MOVEMENT_ARRAY
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.rail_env_grid import RailEnvTransitions

from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.utils.ordered_set import OrderedSet

import sys
sys.path.append('src/')

from src import test_utils, training, rewards
from src.observation_utils import normalize_observation
from src.models import *
from src.deep_model_policy import DeepPolicy, PolicyParameters
from training import train_agent

# Visualization
from flatland.utils.rendertools import RenderTool, AgentRenderVariant


In [2]:
## Utils

def get_linear(input_size, output_size, hidden_sizes, nonlinearity="tanh"):
    '''
    Returns a PyTorch Sequential object containing FC layers with
    non-linear activation functions, by following the given input/hidden/output sizes
    '''
    fc = []
    nl = nn.ReLU(inplace=True) if nonlinearity == "relu" else nn.Tanh()
    sizes = [input_size] + hidden_sizes + [output_size]
    for i in range(1, len(sizes)):
        fc.extend([nn.Linear(sizes[i - 1], sizes[i]), nl])
    return nn.Sequential(*fc)


def conv_bn_act(input_channels, output_channels, kernel_size=3,
                stride=1, padding=0, nonlinearity="relu"):
    '''
    Returns a block composed by a convolutional layer and a batch norm one,
    followed by a non-linearity (e.g. ReLU or Tanh)
    '''
    return [
        nn.Conv2d(
            input_channels, output_channels,
            kernel_size=kernel_size, stride=stride, padding=padding
        ),
        nn.BatchNorm2d(output_channels),
        nn.ReLU(inplace=True) if nonlinearity == "relu" else nn.Tanh()
    ]


def get_conv(input_channels, output_channels, hidden_channels,
             conv_params, pool_params, nonlinearity="relu"):
    '''
    Returns a PyTorch Sequential object containing `conv_bn_act` blocks
    interleaved with max pooling layers, following the given 
    input/hidden/output number of channels

    Note: the `conv_params` and `pool_params` arguments should be tuples
    containing (kernel_size, stride, padding) to use with the respective layer
    '''
    assert len(hidden_channels) >= 1

    convs = []
    channels = [input_channels] + hidden_channels + [output_channels]
    conv_kernel_size, conv_stride, conv_padding = conv_params
    pool_kernel_size, pool_stride, pool_padding = pool_params
    for i in range(1, len(channels)):
        block = conv_bn_act(
            channels[i - 1], channels[i],
            kernel_size=conv_kernel_size, stride=conv_stride,
            padding=conv_padding, nonlinearity=nonlinearity
        )
        # Add pooling once every two layers
        if i % 2 == 0:
            block += [
                nn.MaxPool2d(
                    kernel_size=pool_kernel_size,
                    stride=pool_stride,
                    padding=pool_padding
                )
            ]
        convs.extend(block)

    return nn.Sequential(*convs)


def conv_block_output_size(modules, input_width, input_height):
    '''
    Given a sequence of PyTorch modules (e.g. Python list, PyTorch Sequential/ModuleList)
    containing convolution related layers (currently only Conv2d and MaxPool2d are supported),
    returns the output size of the input tensor, after it passes through all the given layers
    '''
    output_width, output_height = input_width, input_height
    for module in modules:
        if type(module) in (nn.Conv2d, nn.MaxPool2d):
            if type(module) == nn.Conv2d:
                kernel_size, stride, padding, dilation = get_conv2d_params(
                    module
                )
            elif type(module) == nn.MaxPool2d:
                kernel_size, stride, padding, dilation = get_maxpool2d_params(
                    module
                )
            kernel_size_h, kernel_size_w = kernel_size
            stride_h, stride_w = stride
            padding_h, padding_w = padding
            dilation_h, dilation_w = dilation
            output_width = np.floor((
                output_width + 2 * padding_w -
                dilation_w * (kernel_size_w - 1) - 1
            ) / stride_w + 1)
            output_height = np.floor((
                output_height + 2 * padding_h -
                dilation_h * (kernel_size_h - 1) - 1
            ) / stride_h + 1)
    return int(output_width), int(output_height)


def get_conv2d_params(conv):
    '''
    Return kernel size, stride, padding and dilation for a Conv2d layer
    '''
    return (
        conv.kernel_size,
        conv.stride,
        conv.padding,
        conv.dilation
    )


def get_maxpool2d_params(pool):
    '''
    Return kernel size, stride, padding and dilation for a MaxPool2d layer
    '''
    return (
        (pool.kernel_size, pool.kernel_size),
        (pool.stride, pool.stride),
        (pool.padding, pool.padding),
        (pool.dilation, pool.dilation)
    )

## Model

In [3]:
## DQN

class DQN(nn.Module):

    def __init__(self, state_size, action_size, params, device="cpu"):
        super(DQN, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.params = params
        self.device = device
        self.fc = get_linear(
            state_size, action_size, self.params.hidden_sizes,
            nonlinearity=self.params.nonlinearity
        )

    def forward(self, states, mask=None):
        states = torch.flatten(states, start_dim=1)
        assert len(states.shape) == 2 and states.shape[1] == self.state_size
        if mask is None:
            mask = torch.ones(
                (states.shape[0],), dtype=torch.bool,
                device=self.device
            )
        mask = torch.flatten(mask)
        mask_q = torch.zeros(
            (states.shape[0], self.action_size), device=self.device)
        q_values = self.fc(states[mask, :])
        mask_q[mask, :] = q_values
        return mask_q


## Dueling DQN 

class DuelingDQN(DQN):
   

    def __init__(self, state_size, action_size, params, device="cpu"):
        super(DuelingDQN, self).__init__(
            state_size, action_size, params, device=device
        )
        self.aggregation = self.params.dueling.aggregation.get_true_key()
        self.fc_val = get_linear(
            state_size, 1, self.params.hidden_sizes,
            nonlinearity=self.params.nonlinearity
        )

    def forward(self, states, mask=None):
        states = torch.flatten(states, start_dim=1)
        assert len(states.shape) == 2 and states.shape[1] == self.state_size
        if mask is None:
            mask = torch.ones(
                (states.shape[0],), dtype=torch.bool,
                device=self.device
            )
        mask = torch.flatten(mask)
        mask_val = torch.zeros(
            (states.shape[0], self.action_size),
            device=self.device
        )
        val = self.fc(states[mask, :])
        mask_val[mask, :] = val
        mask_adv = super().forward(states, mask=mask)
        mask_agg = torch.zeros((states.shape[0], 1), device=self.device)
        agg = (
            mask_adv[mask, :].mean(dim=1, keepdim=True) if self.aggregation == "mean"
            else mask_adv[mask, :].max(dim=1, keepdim=True)
        )
        mask_agg[mask, :] = agg
        return mask_val + mask_adv - mask_agg


## Entire graph GNN

class EntireGNN(nn.Module):
    
    def __init__(self, state_size, depth, params, device="cpu"):
        super(EntireGNN, self).__init__()
        self.state_size = state_size
        self.depth = depth
        self.params = params
        self.device = device

        self.embedding_size = self.params.embedding_size
        self.hidden_size = self.params.hidden_size
        self.pos_size = self.params.pos_size
        self.dropout = self.params.dropout
        self.nonlinearity = self.params.nonlinearity.get_true_key()

        self.nl = (
            nn.ReLU(inplace=True) if self.nonlinearity == "relu" else nn.Tanh()
        )
        self.gnn_conv = nn.ModuleList()
        sizes = (
            [state_size] +
            [self.hidden_size] * (self.depth - 2) +
            [self.embedding_size]
        )
        for i in range(1, len(sizes)):
            self.gnn_conv.append(
                gnn.GCNConv(sizes[i - 1], sizes[i])
            )

    def forward(self, states, **kwargs):
        graphs = states.to_data_list()
        embs = torch.zeros(
            size=(
                len(graphs),
                self.pos_size * self.embedding_size
            ), dtype=torch.float,
            device=self.device
        )

        # For each graph in the batch
        for i, graph in enumerate(graphs):
            x, edge_index, edge_weight, pos = (
                graph.x, graph.edge_index, graph.edge_weight, graph.pos
            )

            # Perform a number of graph convolutions specified by
            # the given depth
            for d in range(self.depth):
                x = self.gnn_conv[d](x, edge_index, edge_weight=edge_weight)
                emb = x
                x = self.nl(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

            # Extract useful embeddings
            tmp_embs = torch.full(
                (self.pos_size, self.embedding_size),
                [-self.depth] * self.embedding_size,
                dtype=torch.float,
                device=self.device
            )
            for j, p in enumerate(pos):
                if p != -1:
                    tmp_embs[j] = emb[p.item()]
            embs[i] = torch.flatten(tmp_embs)

        return embs


## Multi agent GNN

class MultiGNN(nn.Module):

    def __init__(self, input_width, input_height, input_channels, params, device="cpu"):
        super(MultiGNN, self).__init__()
        self.input_width = input_width
        self.input_height = input_height
        self.input_channels = input_channels
        self.params = params
        self.device = device

        self.output_channels = self.params.cnn_encoder.output_channels
        self.hidden_channels = self.params.cnn_encoder.hidden_channels
        self.mlp_output_size = self.params.mlp_compression.output_size
        self.mlp_hidden_sizes = self.params.mlp_compression.hidden_sizes
        self.gnn_hidden_sizes = self.params.gnn_communication.hidden_sizes
        self.embedding_size = self.params.gnn_communication.embedding_size
        self.dropout = self.params.gnn_communication.dropout
        self.nonlinearity = self.params.nonlinearity.get_true_key()
        self.nl = (
            nn.ReLU(inplace=True) if self.nonlinearity == "relu" else nn.Tanh()
        )

        # Encoder
        conv_settings, pool_settings = self.params.cnn_encoder.conv, self.params.cnn_encoder.pool
        conv_params = conv_settings.kernel_size, conv_settings.stride, conv_settings.padding
        pool_params = pool_settings.kernel_size, pool_settings.stride, pool_settings.padding
        self.convs = get_conv(
            self.input_channels, self.output_channels, self.hidden_channels,
            conv_params, pool_params, nonlinearity=self.nonlinearity
        )

        # MLP
        output_width, output_height = conv_block_output_size(
            self.convs, self.input_width, self.input_height
        )
        assert output_width > 0 and output_height > 0
        self.mlp = get_linear(
            output_width * output_height * self.output_channels,
            self.mlp_output_size, self.mlp_hidden_sizes, nonlinearity=self.nonlinearity
        )

        # GNN
        self.gnn_conv = nn.ModuleList()
        sizes = (
            [self.mlp_output_size] +
            self.gnn_hidden_sizes +
            [self.embedding_size]
        )
        for i in range(1, len(sizes)):
            self.gnn_conv.append(
                gnn.GATConv(
                    sizes[i - 1], sizes[i], add_self_loops=False,
                    heads=2, concat=False
                )
            )

    def forward(self, states, **kwargs):
        
        encoded = self.convs(states.states) # Encode the FOV observation of each agent with the convolutional encoder
        flattened = torch.flatten(encoded, start_dim=1) # Use an MLP from the encoded values to have a consistent number of features
        features = self.mlp(flattened)
        embeddings = None
        for conv in self.gnn_conv:
            features = conv(features, states.edge_index)
            embeddings = features
            features = self.nl(features)
            features = F.dropout(
                features, p=self.dropout, training=self.training
            )

        return embeddings

## Predictors

In [4]:
Prediction = namedtuple('Prediction', ['lenght', 'path', 'edges', 'positions'])


def _empty_prediction():
    '''
    Return an empty Prediction namedtuple
    '''
    return Prediction(
        lenght=np.inf, path=[], edges=[], positions=[]
    )


class NullPredictor(PredictionBuilder):

    def __init__(self, max_depth=None):
        super().__init__(max_depth)

    def set_env(self, env):
        super().set_env(env)

    def get_many(self):
        '''
        Build the prediction for every agent
        '''
        return {agent.handle: None for agent in self.env.agents}

    def get(self, handle):
        '''
        Build the prediction for the given agent
        '''
        return None


class ShortestDeviationPathPredictor(PredictionBuilder):

    def __init__(self, max_depth, max_deviations):
        super().__init__(max_depth)
        self.max_deviations = max_deviations

    def set_env(self, env):
        super().set_env(env)

    def reset(self):
        '''
        Initialize shortest paths for each agent
        '''
        self._shortest_paths = dict()
        for agent in self.env.agents:
            self._shortest_paths[agent.handle] = self.env.railway_encoding.shortest_paths(
                agent.handle
            )

    def get_shortest_path(self, handle):
        '''
        Keep a list of shortest paths for the given agent.
        At each time step, update the already compute paths and delete the ones
        which cannot be followed anymore.
        The returned shortest paths have the agent's position as the first element.
        '''
        position = self.env.railway_encoding.get_agent_cell(handle)
        node, _ = self.env.railway_encoding.next_node(position)
        chosen_path = None
        paths_to_delete = []
        for i, shortest_path in enumerate(self._shortest_paths[handle]):
            lenght, path = shortest_path
            # Delete divergent path
            if node != path[0] and node != path[1]:
                paths_to_delete = [i] + paths_to_delete
                continue

            # Update agent position
            if path[0] != position:
                lenght -= 1
            path[0] = position

            # If the agent is on a packed graph node, drop it
            if path[0] == path[1]:
                path = path[1:]

            # Agent arrived to target
            if lenght == 0:
                chosen_path = lenght, path
                break

            # Select this path if no other path has been previously selected
            if chosen_path is None:
                chosen_path = lenght, path

            # Update shortest path
            self._shortest_paths[handle][i] = lenght, path

        # Delete divergent paths
        for i in paths_to_delete:
            del self._shortest_paths[handle][i]

        # Compute shortest paths, if no path is already available
        if chosen_path is None:
            self._shortest_paths[handle] = self.env.railway_encoding.shortest_paths(
                handle
            )
            if not self._shortest_paths[handle]:
                if position == node:
                    node = self.env.railway_encoding.get_successors(node)[0]
                return np.inf, [position, node]

            chosen_path = self._shortest_paths[handle][0]

        return chosen_path

    def get_deviation_paths(self, handle, lenght, path):
        '''
        Return one deviation path for at most `max_deviations` nodes in the given path
        and limit the computed path lenghts by `max_depth`
        '''
        start = 0
        depth = min(self.max_deviations, len(path) - 1)
        deviation_paths = []
        padding = self.max_deviations
        if lenght < np.inf:
            padding -= len(path)
            source, _ = self.env.railway_encoding.next_node(path[0])
            if source != path[0]:
                start = 1
                deviation_paths.append(_empty_prediction())
            for i in range(start, depth):
                paths = self.env.railway_encoding.deviation_paths(
                    handle, path[i], path[i + 1]
                )
                deviation_path = []
                deviation_lenght = 0
                if len(paths) > 0:
                    deviation_path = paths[0][1]
                    deviation_lenght = paths[0][0]
                    edges = self.env.railway_encoding.edges_from_path(
                        deviation_path[:self.max_depth]
                    )
                    pos = self.env.railway_encoding.positions_from_path(
                        deviation_path[:self.max_depth]
                    )
                    deviation_paths.append(
                        Prediction(
                            lenght=deviation_lenght,
                            path=deviation_path[:self.max_depth],
                            edges=edges,
                            positions=pos
                        )
                    )
                else:
                    deviation_paths.append(_empty_prediction())

        deviation_paths.extend(
            [_empty_prediction()] * (padding)
        )
        return deviation_paths

    def get_many(self):
        '''
        Build the prediction for every agent
        '''
        prediction_dict = {}
        for agent in self.env.agents:
            prediction_dict[agent.handle] = None
            if agent.malfunction_data["malfunction"] == 0:
                prediction_dict[agent.handle] = self.get(agent.handle)
        return prediction_dict

    def get(self, handle):
        '''
        Build the prediction for the given agent
        '''
        agent = self.env.agents[handle]
        if agent.status == TrainState.DONE:
            return None

        # Build predictions
        lenght, path = self.get_shortest_path(handle)
        edges = self.env.railway_encoding.edges_from_path(
            path[:self.max_depth]
        )
        pos = self.env.railway_encoding.positions_from_path(
            path[:self.max_depth]
        )
        shortest_path_prediction = Prediction(
            lenght=lenght, path=path[:self.max_depth], edges=edges, positions=pos
        )
        deviation_paths_prediction = self.get_deviation_paths(
            handle, lenght, path
        )

        # Update GUI
        visited = OrderedSet()
        visited.update(shortest_path_prediction.positions)
        self.env.dev_pred_dict[handle] = visited

        return (shortest_path_prediction, deviation_paths_prediction)

Utils

In [5]:
from timeit import default_timer

def get_index(arr, elem):
    '''
    Return the index of the first occurrence of `elem` in `arr`,
    if `elem` is present in `arr`, otherwise return None
    '''
    return arr.index(elem) if elem in arr else None


def is_close(a, b, rtol=1e-03):
    '''
    Return if a is relatively close to the value of b
    '''
    return abs(a - b) <= rtol


def reciprocal_sum(a, b):
    '''
    Compute the reciprocal sum of the given inputs
    '''
    return (1 / a) + (1 / b)


def min_max_scaling(values, lower, upper, under, over, known_min=None, known_max=None):
    '''
    Perform min-max scaling over the given array
    (`under` is substituted for -np.inf and `over` for np.inf)
    '''
    finite_values = values[np.isfinite(values)]
    min_value, max_value = known_min, known_max
    try:
        if min_value is None:
            min_value = finite_values.min()
        if max_value is None:
            max_value = finite_values.max()
        if min_value != max_value:
            values = lower + (
                ((values - min_value) * (upper - lower)) /
                (max_value - min_value)
            )
        elif min_value != 0:
            values = values / min_value
        else:
            values[:] = under
    except:
        pass
    values[values == -np.inf] = under
    values[values == np.inf] = over
    return values


def extract_fov(matrix, center_index, window_size, pad=0):
    '''
    Extract a patch of size window_size from the given matrix centered around
    the specified position and pad external values with the given fill value
    '''
    # Window is entirely contained in the given matrix
    m, n = matrix.shape
    offset = window_size // 2
    yl, yu = center_index[0] - offset, center_index[0] + offset
    xl, xu = center_index[1] - offset, center_index[1] + offset
    if xl >= 0 and xu < n and yl >= 0 and yu < m:
        return np.array(matrix[yl: yu + 1, xl:xu + 1], dtype=matrix.dtype)

    # Window has to be padded
    window = np.full((window_size, window_size), pad, dtype=matrix.dtype)
    c_yl, c_yu = np.clip(yl, 0, m), np.clip(yu, 0, m)
    c_xl, c_xu = np.clip(xl, 0, n), np.clip(xu, 0, n)
    sub = matrix[c_yl: c_yu + 1, c_xl:c_xu + 1]
    w_yl = 0 if yl >= 0 else abs(yl)
    w_yu = window_size if yu < m else window_size - (yu - m) - 1
    w_xl = 0 if xl >= 0 else abs(xl)
    w_xu = window_size if xu < n else window_size - (xu - n) - 1
    window[w_yl:w_yu, w_xl:w_xu] = sub
    return window


def fix_random(seed):
    '''
    Fix all the possible sources of randomness
    '''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def set_num_threads(num_threads):
    '''
    Set the maximum number of threads PyTorch can use
    '''
    torch.set_num_threads(num_threads)
    os.environ["OMP_NUM_THREADS"] = str(num_threads)
    os.environ["MKL_NUM_THREADS"] = str(num_threads)


class Timer():
    '''
    Utility to measure times
    '''

    def __init__(self):
        self.total_time = 0.0
        self.start_time = 0.0
        self.end_time = 0.0

    def start(self):
        self.start_time = default_timer()

    def end(self):
        self.total_time += default_timer() - self.start_time

    def get(self):
        return self.total_time

    def get_current(self):
        return default_timer() - self.start_time

    def reset(self):
        self.__init__()

    def __repr__(self):
        return self.get()


class Struct:
    '''
    Struct class, s.t. a nested dictionary is transformed
    into a nested object
    '''

    def __init__(self, **entries):
        for k, v in entries.items():
            if isinstance(v, dict):
                self.__dict__.update({k: Struct(**v)})
            else:
                self.__dict__.update({k: v})

    def get_true_key(self):
        '''
        Return the only key in the Struct s.t. its value is True
        '''
        true_types = [k for k, v in self.__dict__.items() if v == True]
        assert len(true_types) == 1
        return true_types[0]

    def __str__(self):
        return str(self.__dict__)

    def __repr__(self):
        return str(self.__dict__)

## Deadlock detector

In [6]:
class DeadlocksDetector:
    '''
    Class containing code to track deadlocks during an episode,
    based on https://github.com/AlessandroLombardi/FlatlandChallenge
    '''

    def __init__(self):
        self.deadlocks = dict()
        self.deadlock_turns = dict()

    def reset(self, num_agents):
        '''
        Reset deadlock counters
        '''
        self.deadlocks = {a: False for a in range(num_agents)}
        self.deadlock_turns = {a: None for a in range(num_agents)}

    def step(self, env):
        '''
        Check for new deadlocks, updates counter and returns it
        '''
        agents = []
        for a in range(env.get_num_agents()):
            if env.agents[a].status == TrainState.ACTIVE:
                agents.append(a)
                if not self.deadlocks[a]:
                    self.deadlocks[a] = self._check_deadlocks(
                        agents, self.deadlocks, env
                    )
                if not self.deadlocks[a]:
                    del agents[-1]
                elif self.deadlock_turns[a] is None:
                    self.deadlock_turns[a] = env._elapsed_steps - 1
            else:
                self.deadlocks[a] = False

        return self.deadlocks, self.deadlock_turns

    def _check_feasible_transitions(self, pos, env):
        '''
        Function used to collect chains of blocked agents
        '''
        transitions = env.rail.get_transitions(*pos)
        n_transitions = 0
        occupied = 0
        agent_in_path = None
        for direction, values in enumerate(MOVEMENT_ARRAY):
            if transitions[direction] == 1:
                n_transitions += 1
                new_position = get_new_position(pos, direction)
                for agent in range(env.get_num_agents()):
                    if env.agents[agent].position == new_position:
                        occupied += 1
                        agent_in_path = agent
        if n_transitions > occupied:
            return None
        return agent_in_path

    def _check_next_pos(self, agent, env):
        '''
        Check the next pos and the possible transitions of an agent to find deadlocks
        '''
        pos = (*env.agents[agent].position, env.agents[agent].direction)
        return self._check_feasible_transitions(pos, env)

    def _check_deadlocks(self, agents, deadlocks, env):
        '''
        Recursive procedure to find out whether agents in `agents` are in a deadlock
        '''
        other_agent = self._check_next_pos(agents[-1], env)

        # No agents in front
        if other_agent is None:
            return False

        # Deadlocked agent in front or loop chain found
        if deadlocks[other_agent] or other_agent in agents:
            return True

        # Investigate further
        agents.append(other_agent)
        deadlocks[other_agent] = self._check_deadlocks(agents, deadlocks, env)

        # If the agent `other_agent` is in deadlock
        # also the last one in `agents` is
        if deadlocks[other_agent]:
            return True

        # Back to previous recursive call
        del agents[-1]
        return False

RailEnvWrapper

## EnvWrapper, encoding of the map in a graph and Observators

In [66]:
def get_num_actions():
    '''
    Return the number of possible RailEnvActions
    '''
    return len([
        action_type for _, action_type in RailEnvActions.__members__.items()
    ])
    
class RailEnvWrapper(RailEnv):
    '''
    Railway environment wrapper, to handle custom logic
    '''

    def __init__(self, params, *args, normalize=True, **kwargs):
        super(RailEnvWrapper, self).__init__(*args, **kwargs)
        self.params = params
        self.railway_encoding = None
        self.normalize = normalize
        self.state_size = self._get_state_size()
        self.deadlocks_detector = DeadlocksDetector()
        self.partial_rewards = dict()
        self.arrived_turns = dict()
        self.stop_actions = dict()
        self.current_info = dict()
        self.num_actions = get_num_actions()

    def _get_state_size(self):
        '''
        Compute the state size based on observation type
        '''
        n_features_per_node = self.obs_builder.observation_dim
        n_nodes = 1
        if isinstance(self.obs_builder, TreeObsForRailEnv):
            n_nodes = sum(
                4 ** i for i in range(self.obs_builder.max_depth + 1)
            )
        elif isinstance(self.obs_builder, BinaryTreeObservator):
            n_nodes = sum(2 ** i for i in range(self.obs_builder.max_depth))
        return n_features_per_node * n_nodes

    def get_agents_same_start(self):
        '''
        Return a dictionary indexed by agents starting positions,
        and having a list of handles as values, s.t. agents with
        the same starting position are ordered by decreasing speed
        '''
        agents_with_same_start = dict()
        for handle_one, agent_one in enumerate(self.agents):
            for handle_two, agent_two in enumerate(self.agents):
                if handle_one != handle_two and agent_one.initial_position == agent_two.initial_position:
                    agents_with_same_start.setdefault(
                        agent_one.initial_position, set()
                    ).update({handle_one, handle_two})

        for position in agents_with_same_start:
            agents_with_same_start[position] = sorted(
                list(agents_with_same_start[position]), reverse=True,
                key=lambda x: self.agents[x].speed_data['speed']
            )
        return agents_with_same_start

    def check_if_all_blocked(self, deadlocks):
        '''
        Checks whether all the agents are blocked (full deadlock situation)
        '''
        remaining_agents = self.railway_encoding.remaining_agents_handles()
        num_deadlocks = sum(
            int(v) for k, v in deadlocks.items()
            if k in remaining_agents
        )
        return num_deadlocks == len(remaining_agents) and len(remaining_agents) > 0

    def save(self, path):
        '''
        Save the given RailEnv environment as pickle
        '''
        filename = os.path.join(
            path, f"{self.width}x{self.height}-{self.random_seed}.pkl"
        )
        RailEnvPersister.save(self, filename)

    def get_renderer(self):
        '''
        Return a renderer for the current environment
        '''
        return RenderTool(
            self,
            agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
            show_debug=True,
            screen_height=1080,
            screen_width=1920
        )

    def reset(self, regenerate_rail=True, regenerate_schedule=True,
              activate_agents=False, random_seed=None):
        '''
        Reset the environment
        '''
        # Get a random seed
        if random_seed:
            self._seed(random_seed)

        # Regenerate the rail, if necessary
        optionals = {}
        if regenerate_rail or self.rail is None:
            rail, optionals = self._generate_rail()
            self.rail = rail
            self.height, self.width = self.rail.grid.shape
            self.obs_builder.set_env(self)

        # Set the distance map
        if optionals and 'distance_map' in optionals:
            self.distance_map.set(optionals['distance_map'])

        # Reset agents positions
        self.agent_positions = np.full(
            (self.height, self.width), -1, dtype=int
        )
        self.reset_agents()
        for i, agent in enumerate(self.agents):
            if activate_agents:
                self.set_agent_active(agent)
            self._break_agent(agent)
            if agent.malfunction_data["malfunction"] > 0:
                agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
            self._fix_agent_after_malfunction(agent)

            # Reset partial rewards
            self.partial_rewards[i] = 0.0

        # Reset common variables
        self.num_resets += 1
        self._elapsed_steps = 0
        self.dones = dict.fromkeys(
            list(range(self.get_num_agents())) + ["__all__"], False
        )
        self.arrived_turns = [None] * self.get_num_agents()
        self.stop_actions = [0] * self.get_num_agents()

        # Build the cell orientation graph
        self.railway_encoding = CellOrientationGraph(
            grid=self.rail.grid, agents=self.agents
        )

        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()
        self.distance_map.reset(self.agents, self.rail)

        # Empty the episode store of agent positions
        self.cur_episode = []

        # Compute deadlocks
        self.deadlocks_detector.reset(self.get_num_agents())

        # Build the info dict
        self.current_info = {
            'action_required': {}, 'malfunction': {}, 'speed': {},
            'status': {}, 'deadlocks': {}, 'deadlock_turns': {}, 'finished': {},
            'first_time_deadlock': {}, 'first_time_finished': {}
        }
        for i, agent in enumerate(self.agents):
            self.current_info['action_required'][i] = self.action_required(
                agent
            )
            self.current_info['malfunction'][i] = agent.malfunction_data['malfunction']
            self.current_info['speed'][i] = agent.speed_data['speed']
            self.current_info['status'][i] = agent.status
            self.current_info["deadlocks"][i] = self.deadlocks_detector.deadlocks[i]
            self.current_info["deadlock_turns"][i] = self.deadlocks_detector.deadlock_turns[i]
            self.current_info["finished"][i] = self.dones[i] or self.deadlocks_detector.deadlocks[i]
            self.current_info["first_time_deadlock"][i] = (
                self.deadlocks_detector.deadlocks[i] and
                0 == self.deadlocks_detector.deadlock_turns[i]
            )
            self.current_info["first_time_finished"][i] = (
                self.dones[i] and
                0 == self.arrived_turns[i]
            )

        # Return the new observation vectors for each agent
        observation_dict = self._get_observations()
        return (self._normalize_obs(observation_dict), self.current_info)

    def _generate_rail(self):
        '''
        Regenerate the rail, if necessary
        '''
        if "__call__" in dir(self.rail_generator):
            return self.rail_generator(
                self.width, self.height, self.number_of_agents, self.num_resets, self.np_random
            )
        elif "generate" in dir(self.rail_generator):
            return self.rail_generator.generate(
                self.width, self.height, self.number_of_agents, self.num_resets, self.np_random
            )
        raise ValueError(
            "Could not invoke __call__ or generate on rail_generator"
        )

    def step(self, action_dict_):
        '''
        Perform a step in the environment
        '''
        current_step = self._elapsed_steps
        agents_in_decision_cells = self.agents_in_decision_cells()
        obs, rewards, dones, info = super().step(action_dict_)
        info["deadlocks"], info["deadlock_turns"] = self.deadlocks_detector.step(
            self
        )

        # Patch dones dict, update arrived agents turns and stop actions
        # and store other data in info dict
        finished, first_time_deadlock, first_time_finished = dict(), dict(), dict()
        remove_all = False
        for agent in range(self.get_num_agents()):
            # Update dones
            if dones[agent] and not self.railway_encoding.is_done(agent):
                dones[agent] = False
                remove_all = True

            # Update arrived agents and first time dones
            if dones[agent] and self.arrived_turns[agent] is None:
                self.arrived_turns[agent] = current_step
                first_time_finished[agent] = True
            else:
                first_time_finished[agent] = False

            # Store the number of consequent stop actions
            if action_dict_[agent] == RailEnvActions.STOP_MOVING.value:
                self.stop_actions[agent] += 1
            else:
                self.stop_actions[agent] = 0

            # Update done or deadlocked agents and first time deadlocked agents
            finished[agent] = dones[agent] or info['deadlocks'][agent]
            first_time_deadlock[agent] = (
                info['deadlocks'][agent] and
                current_step == info['deadlock_turns'][agent]
            )

        # Patch info dict
        info['finished'] = finished
        info['first_time_finished'] = first_time_finished
        info['first_time_deadlock'] = first_time_deadlock

        # If at least one agent is not at target, then
        # the __all__ flag of dones should be False
        if remove_all:
            dones["__all__"] = False

        # Compute custom rewards
        custom_rewards = self._reward_shaping(
            action_dict_, rewards, dones, info, agents_in_decision_cells
        )

        # Update last info dict
        self.current_info = info

        return (self._normalize_obs(obs), rewards, custom_rewards, dones, info)

    def _reward_shaping(self, action_dict_, rewards, dones, info, agents_in_decision_cells):
        '''
        Apply custom reward functions
        '''
        # The step for which we are evaluating the rewards
        # is the previous one
        step = self._elapsed_steps - 1

        custom_rewards = copy.deepcopy(rewards)
        for agent in range(self.get_num_agents()):
            # Return a positive reward equal to the maximum number of steps
            # minus the current number of steps if the agent has arrived
            if dones[agent] and self.arrived_turns[agent] == step:
                done_reward = self._max_episode_steps - step
                custom_rewards[agent] = done_reward
                self.partial_rewards[agent] = done_reward
            # Return minus the maximum number of steps if the agent is in deadlock
            elif info["deadlocks"][agent] and info["deadlock_turns"][agent] == step:
                deadlock_penalty = -self._max_episode_steps
                custom_rewards[agent] = deadlock_penalty
                self.partial_rewards[agent] = deadlock_penalty
            # Accumulate rewards for choices if the agent is not in a decision cell
            # and add other penalties, such as the stop moving one
            else:
                self.partial_rewards[agent] += custom_rewards[agent]
                # If an agent performed a STOP action, give the agent a reward
                # which is worse than the the reward it could have received
                # by choosing any other action
                if action_dict_[agent] == RailEnvActions.STOP_MOVING.value:
                    weight = self.railway_encoding.stop_moving_worst_alternative_weight(
                        agent
                    )
                    weight *= (1 / self.agents[agent].speed_data['speed'])
                    weight *= (
                        self.stop_actions[agent] *
                        self.params.env.rewards.stop_penalty
                    )
                    # Clip stop reward to be at maximum equal to the deadlock penalty
                    self.partial_rewards[agent] += np.clip(
                        -weight, -self._max_episode_steps, 0
                    )
                custom_rewards[agent] = 0.0

            # Reset the partial rewards counter when an agent is
            # in a decision cell or in the last step
            if agents_in_decision_cells[agent] or self._elapsed_steps == self._max_episode_steps:
                custom_rewards[agent] = self.partial_rewards[agent]
                self.partial_rewards[agent] = 0.0

            # Normalize rewards
            custom_rewards[agent] /= self._max_episode_steps

        return custom_rewards

    def agents_in_decision_cells(self):
        '''
        Return the agents that are on a decision cell
        '''
        return [
            self.railway_encoding.is_real_decision(h)
            for h in range(self.get_num_agents())
        ]

    def agents_adjacency_matrix(self, radius=None):
        '''
        Return the adjacency matrix containing pairwise distances between agents
        '''
        adj = np.zeros((self.get_num_agents(), self.get_num_agents()))
        for i in range(adj.shape[0]):
            for j in range(adj.shape[1]):
                if i != j:
                    distance = self.railway_encoding.get_agents_distance(i, j)
                    if (distance is not None and
                            (radius is None or (radius is not None and distance <= radius))):
                        adj[i, j] = distance
        return adj

    def pre_act(self):
        '''
        Return the list of legal actions and choices for each agent and a list
        representing which agent needs to make a choice
        '''
        legal_choices = np.full(
            (self.get_num_agents(), RailEnvChoices.choice_size()),
            RailEnvChoices.default_choices()
        )
        legal_actions = np.full(
            (self.get_num_agents(), self.num_actions), False
        )
        moving_agents = np.full((self.get_num_agents(),), False)

        # Compute which agents need to make a choice
        for agent in self.get_agent_handles():
            legal_actions[
                agent, self.railway_encoding.get_agent_actions(agent)
            ] = True
            if (self.current_info['action_required'][agent] and
                    self.railway_encoding.is_real_decision(agent)):
                legal_choices[agent] = self.railway_encoding.get_legal_choices(
                    agent, legal_actions[agent]
                )
                moving_agents[agent] = True

        return legal_actions, legal_choices, moving_agents

    def post_act(self, choices, is_best, legal_actions, moving_agents):
        '''
        Return the action dictionary (to be given to the environment step)
        and a dictionary of training metadatas
        '''
        action_dict, choice_dict = dict(), dict()
        choices_count = np.zeros((RailEnvChoices.choice_size(),))
        num_exploration_choices = np.zeros_like(choices_count)

        # Assign an action to each agent
        for agent in self.get_agent_handles():
            action = RailEnvActions.DO_NOTHING.value
            if moving_agents[agent]:
                action = self.railway_encoding.map_choice_to_action(
                    choices[agent], legal_actions[agent]
                )
                assert action != RailEnvActions.DO_NOTHING.value, (
                    choices[agent], legal_actions[agent]
                )
                choices_count[choices[agent]] += 1
                num_exploration_choices[choices[agent]] += int(
                    not(is_best[agent])
                )
                choice_dict[agent] = choices[agent]
            elif (not self.dones[agent] and
                  self.agents[agent].speed_data['position_fraction'] == 0):
                actions = np.flatnonzero(legal_actions[agent])
                assert actions.shape[0] > 0, actions
                action = actions[0]
                if actions.shape[0] > 1:
                    action = RailEnvActions.DO_NOTHING.value
            action_dict[agent] = action

        # Build the metadata dict
        metadata = {
            'choices_count': choices_count,
            'num_exploration_choices': num_exploration_choices,
            'choice_dict': choice_dict
        }

        return action_dict, metadata

    def pre_step(self, experience):
        '''
        To be called before the policy step function, 
        it returns a list of experiences to be passed to the policy step
        '''
        (
            prev_obs, prev_choices, custom_rewards,
            obs, legal_choices, update_values
        ) = experience
        finished = np.array(list(self.current_info['finished'].values()))
        experiences = []

        # Gather valuable experiences
        for agent in self.get_agent_handles():
            if (update_values[agent] or
                    self.current_info['first_time_finished'][agent] or
                    self.current_info['first_time_deadlock'][agent]):
                # Check for policy type
                if self.params.policy.type.decentralized_fov:
                    exp = (
                        prev_obs[agent],
                        list(prev_choices[agent].values()),
                        np.array(list(custom_rewards.values())),
                        obs[agent],
                        legal_choices,
                        finished,
                        update_values
                    )
                else:
                    exp = (
                        prev_obs[agent],
                        prev_choices[agent],
                        custom_rewards[agent],
                        obs[agent],
                        legal_choices[agent],
                        finished[agent],
                        update_values[agent]
                    )
                experiences.append(exp)

        return experiences

    def post_step(self, obs, choice_dict, next_obs, update_values, rewards, custom_rewards):
        '''
        To be called after the policy step function, 
        it returns a dictionary of training metadatas
        '''
        prev_obs, prev_choices = dict(), dict()
        score, custom_score = 0.0, 0.0

        for agent in self.get_agent_handles():
            # Update previous observations and choices
            if (update_values[agent] or
                    self.current_info['first_time_finished'][agent] or
                    self.current_info['first_time_deadlock'][agent]):
                prev_obs[agent] = copy_obs(obs[agent])
                if self.params.policy.type.decentralized_fov:
                    prev_choices[agent] = dict(choice_dict)
                else:
                    prev_choices[agent] = choice_dict[agent]

            # Update observation and score
            score += rewards[agent]
            custom_score += custom_rewards[agent]
            if next_obs[agent] is not None:
                if self.params.policy.type.decentralized_fov:
                    obs[agent] = next_obs[agent]
                else:
                    obs[agent] = copy_obs(next_obs[agent])

        # Build and return the metadata dict
        return {
            'obs': obs,
            'prev_obs': prev_obs,
            'prev_choices': prev_choices,
            'score': score,
            'custom_score': custom_score
        }

    def _normalize_obs(self, obs):
        '''
        Normalize observations
        '''
        if not self.normalize:
            return obs

        for handle in obs:
            if obs[handle] is not None:
                # Normalize tree observation
                if isinstance(self.obs_builder, TreeObsForRailEnv):
                    obs[handle] = normalize_tree_obs(
                        obs[handle], self.obs_builder.max_depth,
                        self.params.observator.tree.radius
                    )

        return obs

## Encoding of the railway map in a cell orientation graph
TRANS = [
    Grid4TransitionsEnum.NORTH,
    Grid4TransitionsEnum.EAST,
    Grid4TransitionsEnum.SOUTH,
    Grid4TransitionsEnum.WEST
]

def agent_action(original_dir, final_dir):
    '''
    Return the action performed by an agent, by analyzing
    the starting direction and the final direction of the movement
    '''
    value = (final_dir.value - original_dir.value) % 4
    if value in (1, -3):
        return RailEnvActions.MOVE_RIGHT
    elif value in (-1, 3):
        return RailEnvActions.MOVE_LEFT
    return RailEnvActions.MOVE_FORWARD

class CellOrientationGraph():

    _BITMAP_TO_TRANS = [(t1, t2) for t1 in TRANS for t2 in TRANS]

    def __init__(self, grid, agents):
        self.grid = grid
        self.agents = agents
        self.graph = None
        self._unpacked_graph = None
        self._dead_ends = set()
        self._straight_rails = set()

        # For each target position, store associated agents
        self._targets = dict()
        for agent in agents:
            self._targets.setdefault(agent.target, []).append(agent.handle)

        # Build the packed and unpacked graphs
        self._generate_graph()

        # Store the node to index and index to node mappings of
        # the packed graph
        self.node_to_index, self.index_to_node = self._build_vocab(
            unpacked=False
        )

    def _generate_graph(self):
        '''
        Generate both the unpacked and the packed graph and
        set default attributes to the nodes in the packed graph
        '''
        edges = self._generate_edges()
        self._unpacked_graph = nx.DiGraph()
        self._unpacked_graph.add_edges_from(edges)
        nx.freeze(self._unpacked_graph)
        self.graph = nx.DiGraph(self._unpacked_graph)
        self._pack_graph()
        self._set_nodes_attributes()

    def _generate_edges(self):
        '''
        Translate the environment grid to the unpacked cell orientation graph
        '''
        edges = []
        for i, row in enumerate(self.grid):
            for j, _ in enumerate(row):
                if self.grid[i][j] != 0:
                    trans_int = self.grid[i][j]
                    trans_bitmap = format(trans_int, 'b').rjust(16, '0')
                    num_ones = trans_bitmap.count('1')
                    if num_ones == 2:
                        self._straight_rails.add((i, j))
                    elif num_ones == 1:
                        self._dead_ends.add((i, j))
                    tmp_edges, tmp_actions = [], dict()
                    for k, bit in enumerate(trans_bitmap):
                        if bit == '1':
                            original_dir, final_dir = self._BITMAP_TO_TRANS[k]
                            new_position_x, new_position_y = get_new_position(
                                [i, j], final_dir.value
                            )
                            tmp_action = agent_action(
                                original_dir, final_dir
                            )
                            tmp_edges.append((
                                (i, j, original_dir.value),
                                (new_position_x, new_position_y, final_dir.value),
                                tmp_action
                            ))
                            tmp_actions.setdefault(
                                (i, j, original_dir.value),
                                np.full((get_num_actions(),), False)
                            )[tmp_action.value] = True

                    for tmp_edge in tmp_edges:
                        tmp_choice = self.map_action_to_choice(
                            tmp_edge[2], tmp_actions[tmp_edge[0]]
                        )
                        edge = (
                            tmp_edge[0],
                            tmp_edge[1],
                            {
                                'weight': 1,
                                'action': tmp_edge[2],
                                'choice': tmp_choice
                            }
                        )
                        edges.append(edge)
        return edges

    def _pack_graph(self):
        '''
        Generate a compact version of the cell orientation graph,
        by only keeping junctions, targets and dead ends
        '''
        to_remove = self._straight_rails.difference(
            set(self._targets.keys())
        )
        for cell in to_remove:
            self._remove_cell(cell)

    def _remove_node(self, node):
        '''
        Remove a node from the in-construction packed graph and
        add an edge between the neighboring nodes, while
        also propagating edges data
        '''
        sources = [
            (source, data)
            for source, _, data in self.graph.in_edges(node, data=True)
        ]
        targets = [
            (target, data)
            for _, target, data in self.graph.out_edges(node, data=True)
        ]
        new_edges = [
            (
                source[0], target[0],
                {
                    'weight': source[1]['weight'] + target[1]['weight'],
                    'action': source[1]['action'],
                    'choice': source[1]['choice']
                }
            )
            for source in sources for target in targets
        ]
        self.graph.add_edges_from(new_edges)
        self.graph.remove_node(node)

    def _remove_cell(self, position):
        '''
        Remove the given cell with every direction component,
        in order to build the packed graph
        '''
        nodes = self.get_nodes(position)
        for node in nodes:
            self._remove_node(node)

    def _set_nodes_attribute(self, name, positions=None, value=None, default=None):
        '''
        Set the attribute "name" to the nodes given in the set "positions",
        to be "value" (could be a single value or a dictionary indexed by "positions").
        If the "value" argument is a dictionary, you can give a default value to be set
        to the nodes which are not present in the set "positions"
        '''
        if default is not None:
            nx.set_node_attributes(self.graph, default, name)
        attributes = {}
        if positions is not None and value is not None:
            for pos in positions:
                nodes = [pos]
                if len(pos) == 2:
                    nodes = self.get_nodes(pos)
                for node in nodes:
                    val = value
                    if isinstance(value, dict):
                        val = value[pos]
                    attributes[node] = {name:  val}
            nx.set_node_attributes(self.graph, attributes)

    def _set_nodes_attributes(self):
        '''
        Set default attributes for each and every node in the packed graph
        '''
        self._set_nodes_attribute(
            'is_dead_end', positions=self._dead_ends, value=True, default=False
        )
        self._set_nodes_attribute(
            'is_target', positions=set(self._targets.keys()), value=True, default=False
        )
        fork_positions, join_positions = self._compute_decision_types()
        self._set_nodes_attribute(
            'is_fork', positions=fork_positions, value=True, default=False
        )
        self._set_nodes_attribute(
            'is_join', positions=join_positions, value=True, default=False
        )

    def _compute_decision_types(self):
        '''
        Set decision types (at fork and/or at join) for each node in the packed graph
        '''
        fork_positions, join_positions = set(), set()
        for node in self.graph.nodes:
            if not self.graph.nodes[node]['is_dead_end']:
                other_nodes = set(self.get_nodes(node)) - {node}
                # If diamond crossing and/or fork set join for other nodes
                num_successors = len(self.get_successors(node))
                if len(other_nodes) == 3 or num_successors > 1:
                    for other_node in other_nodes:
                        join_positions.add(other_node)
                # Set fork for current node
                if num_successors > 1:
                    fork_positions.add(node)
        return fork_positions, join_positions

    def _build_vocab(self, unpacked=False):
        '''
        Build a vocabulary, mapping nodes to indexes and vice-versa
        '''
        graph = self.graph if not unpacked else self._unpacked_graph
        nodes = sorted(list(graph.nodes()))
        node_to_index = {node: i for i, node in enumerate(nodes)}
        index_to_node = {i: node for i, node in enumerate(nodes)}
        return node_to_index, index_to_node

    def is_straight_rail(self, cell):
        '''
        Check if the given cell is a straight rail
        '''
        if len(cell) > 2:
            cell = cell[:-1]
        return cell in self._straight_rails

    def get_nodes(self, position, unpacked=False):
        '''
        Given a position (row, column), return a list
        of nodes present in the packed or unpacked graph of the type
        [(row, column, NORTH), ..., (row, column, WEST)]
        '''
        nodes = []
        for direction in TRANS:
            node = (position[0], position[1], direction.value)
            node_in_packed = not unpacked and self.graph.has_node(node)
            node_in_unpacked = unpacked and self._unpacked_graph.has_node(node)
            if node_in_packed or node_in_unpacked:
                nodes.append(node)
        return nodes

    def is_node(self, node, unpacked=False):
        '''
        Return true if the given node is present in the packed or
        unpacked graph
        '''
        graph = self._unpacked_graph if unpacked else self.graph
        return node in graph.nodes

    def get_edge_data(self, u, v, t, unpacked=False):
        '''
        Return the feature `t` in edge `(u, v)`
        '''
        graph = self.graph if not unpacked else self._unpacked_graph
        assert (u, v) in graph.edges
        edge_data = graph.get_edge_data(u, v)
        assert t in edge_data
        return edge_data[t]

    def get_predecessors(self, node, unpacked=False):
        '''
        Return the predecessors of the given node in the packed or
        unpacked graph
        '''
        graph = self._unpacked_graph if unpacked else self.graph
        if node not in graph.nodes:
            return []
        return list(graph.predecessors(node))

    def get_successors(self, node, unpacked=False):
        '''
        Return the successors of the given node in the packed or
        unpacked graph
        '''
        graph = self._unpacked_graph if unpacked else self.graph
        if node not in graph.nodes:
            return []
        return list(graph.successors(node))

    def next_node(self, cell):
        '''
        Return the closest node in the packed graph
        w.r.t. the given cell in the unpacked graph,
        in the same direction
        '''
        if cell in self.graph.nodes:
            return cell, 0
        weight = 0
        successors = self._unpacked_graph.successors(cell)
        while True:
            try:
                cell = next(successors)
                weight += 1
                if cell in self.graph.nodes:
                    return cell, weight
                successors = self._unpacked_graph.successors(cell)
            except StopIteration:
                break
        return None

    def previous_node(self, cell):
        '''
        Return the closest node in the packed graph
        w.r.t. the given cell in the unpacked graph,
        in the opposite direction
        '''
        if cell in self.graph.nodes:
            return cell, 0
        weight = 0
        next_node, _ = self.next_node(cell)
        predecessors = self._unpacked_graph.predecessors(cell)
        while True:
            try:
                cell = next(predecessors)
                weight += 1
                edge = (cell, next_node)
                if edge in self.graph.edges:
                    return cell, weight
                predecessors = itertools.chain(
                    predecessors, self._unpacked_graph.predecessors(cell)
                )
            except StopIteration:
                break
        return None

    def get_agent_cell(self, handle):
        '''
        Return the unpacked graph node in which the agent
        identified by the given handle is
        '''
        position = None
        agent = self.agents[handle]
        if agent.status == TrainState.READY_TO_DEPART:
            position = (
                agent.initial_position[0],
                agent.initial_position[1],
                agent.initial_direction
            )
        elif agent.status == TrainState.MOVING:
            position = (
                agent.position[0],
                agent.position[1],
                agent.direction
            )
        elif agent.status == TrainState.DONE:
            position = (
                agent.target[0],
                agent.target[1],
                agent.direction
            )
        return position

    def stop_moving_worst_alternative_weight(self, handle):
        '''
        Return the weight associated with the worst move alternative
        to a stop choice, starting from the position of the agent
        '''
        position = self.get_agent_cell(handle)
        node, weight = self.next_node(position)
        nodes = []
        if self.is_join(node):
            nodes = [(node, weight)]
        else:
            successors = self.get_successors(node, unpacked=True)
            for succ in successors:
                succ_weight = self.get_edge_data(
                    node, succ, 'weight', unpacked=True
                )
                assert succ_weight == 1
                if self.is_join(succ):
                    nodes.append((succ, succ_weight + weight))

        max_weight = 0
        for start_node, start_weight in nodes:
            successors = self.get_successors(start_node, unpacked=False)
            for succ in successors:
                succ_weight = self.get_edge_data(
                    start_node, succ, 'weight', unpacked=False
                )
                if succ_weight > max_weight:
                    max_weight = succ_weight + start_weight
                    max_succ = succ

        return max_weight

    def is_done(self, handle):
        '''
        Returns True if an agent arrived at its target
        '''
        return self.agents[handle].status in (
            TrainState.DONE
        )

    def map_choice_to_action(self, choice, actions):
        '''
        Map the given RailEnvChoices choice to a RailEnvActions action
        '''
        # If CHOICE_LEFT, then priorities are MOVE_LEFT, MOVE_FORWARD, MOVE_RIGHT
        if choice == RailEnvChoices.CHOICE_LEFT.value:
            if actions[RailEnvActions.MOVE_LEFT.value]:
                return RailEnvActions.MOVE_LEFT
            elif actions[RailEnvActions.MOVE_FORWARD.value]:
                return RailEnvActions.MOVE_FORWARD
            elif actions[RailEnvActions.MOVE_RIGHT.value]:
                return RailEnvActions.MOVE_RIGHT
        # If CHOICE_RIGHT, then priorities are MOVE_RIGHT, MOVE_FORWARD
        elif choice == RailEnvChoices.CHOICE_RIGHT.value:
            if actions[RailEnvActions.MOVE_RIGHT.value]:
                return RailEnvActions.MOVE_RIGHT
            elif actions[RailEnvActions.MOVE_FORWARD.value]:
                return RailEnvActions.MOVE_FORWARD
        # If STOP, then the priority is STOP_MOVING
        elif choice == RailEnvChoices.STOP.value:
            return RailEnvActions.STOP_MOVING
        # Otherwise, last resort is DO_NOTHING
        return RailEnvActions.DO_NOTHING

    def map_action_to_choice(self, action, actions):
        '''
        Map the given RailEnvActions action to a RailEnvChoices choice
        '''
        if action == RailEnvActions.MOVE_LEFT and actions[RailEnvActions.MOVE_LEFT.value]:
            return RailEnvChoices.CHOICE_LEFT
        if action == RailEnvActions.MOVE_RIGHT and actions[RailEnvActions.MOVE_RIGHT.value]:
            if np.count_nonzero(actions) > 1:
                return RailEnvChoices.CHOICE_RIGHT
            elif np.count_nonzero(actions) == 1:
                return RailEnvChoices.CHOICE_LEFT
        if action == RailEnvActions.MOVE_FORWARD and actions[RailEnvActions.MOVE_FORWARD.value]:
            if actions[RailEnvActions.MOVE_LEFT.value]:
                return RailEnvChoices.CHOICE_RIGHT
            if actions[RailEnvActions.MOVE_RIGHT.value]:
                return RailEnvChoices.CHOICE_LEFT
            return RailEnvChoices.CHOICE_LEFT
        return RailEnvChoices.STOP

    def get_possible_choices(self, position, actions):
        '''
        Map the given RailEnvActions actions to a list of RailEnvChoices
        '''
        # If only one agent, stop moving is not legal
        possible_moves = np.full(
            (RailEnvChoices.choice_size(),), False)
        possible_moves[RailEnvChoices.STOP.value] = (
            self.is_before_join(position) and not self.only_one_agent())

        if actions[RailEnvActions.MOVE_FORWARD.value]:
            # If RailEnvActions.MOVE_LEFT or RailEnvActions.MOVE_RIGHT in legal actions
            if np.count_nonzero(actions) > 1:
                possible_moves[RailEnvChoices.CHOICE_RIGHT.value] = True
            possible_moves[RailEnvChoices.CHOICE_LEFT.value] = True
        if actions[RailEnvActions.MOVE_LEFT.value]:
            possible_moves[RailEnvChoices.CHOICE_LEFT.value] = True
        if actions[RailEnvActions.MOVE_RIGHT.value]:
            # If only RailEnvActions.MOVE_RIGHT in legal actions
            if np.count_nonzero(actions) == 1:
                possible_moves[RailEnvChoices.CHOICE_LEFT.value] = True
            else:
                possible_moves[RailEnvChoices.CHOICE_RIGHT.value] = True
        return possible_moves

    def get_legal_choices(self, handle, actions):
        '''
        Map the given RailEnvActions actions to a list of RailEnvChoices,
        by considering the position of the agent
        '''
        # If the agent is arrived, only stop moving is possible
        # (necessary because of flatland bug)
        if self.is_done(handle):
            return RailEnvChoices.default_choices()

        return self.get_possible_choices(self.get_agent_cell(handle), actions)

    def is_fork(self, position):
        '''
        Return True iff the given position is a fork
        '''
        if position in self.graph.nodes:
            return self.graph.nodes[position]['is_fork']
        return False

    def is_join(self, position):
        '''
        Return True iff the given position is a join
        '''
        if position in self.graph.nodes and self.graph.nodes[position]['is_join']:
            return True
        return False

    def is_before_join(self, position):
        '''
        Return True iff the given position is before a join cell
        '''
        successors = self.get_successors(position, unpacked=True)
        for succ in successors:
            if self.is_join(succ):
                return True
        return False

    def is_at_fork(self, handle):
        '''
        Returns True iff the agent is at a fork
        '''
        return self.is_fork(self.get_agent_cell(handle))

    def is_at_before_join(self, handle):
        '''
        Returns True iff the agent is before a join
        '''
        return self.is_before_join(self.get_agent_cell(handle))

    def remaining_agents_handles(self):
        '''
        Return the number of remaining agents in the rail,
        considering the ones that already reached their target
        '''
        return {
            agent for agent in range(len(self.agents))
            if not self.is_done(agent)
        }

    def remaining_agents(self):
        '''
        Return the number of remaining agents in the rail,
        considering the ones that already reached their target
        '''
        return len(self.remaining_agents_handles())

    def only_one_agent(self):
        '''
        Returns True iff only one agent remains in the railway
        '''
        return self.remaining_agents() < 2

    def is_real_decision(self, handle):
        '''
        Returns True iff the agent has to make a decision
        '''
        return self.is_at_fork(handle) or (
            self.is_at_before_join(handle) and not self.only_one_agent()
        )

    def get_actions(self, position):
        '''
        Return all the possible active actions that can be performed from a given position
        '''
        successors = self.get_successors(position, unpacked=True)
        actions = []
        for succ in successors:
            actions.append(
                self._unpacked_graph.get_edge_data(
                    position, succ)['action'].value
            )
        return actions

    def get_agent_actions(self, handle):
        '''
        Return all the possible active actions that an agent can perform
        '''
        return self.get_actions(self.get_agent_cell(handle))

    def action_from_positions(self, source, dest, unpacked=True):
        '''
        Return the action that an agent has to make to transition
        from the `source` node to the `dest` node
        '''
        graph = self._unpacked_graph if unpacked else self.graph
        if (source, dest) in graph.edges:
            return graph.get_edge_data(source, dest)['action']
        return None

    def position_by_action(self, position, action):
        '''
        Return the next node if the given action will be performed in the given position
        '''
        successors = self.get_successors(position, unpacked=True)
        for succ in successors:
            if self._unpacked_graph.get_edge_data(position, succ)['action'] == action:
                return succ
        return None

    def agent_position_by_action(self, handle, action):
        '''
        Return the next node that the agent will occupy if it performs the given action
        '''
        self.position_by_action(self.get_agent_cell(handle), action)

    def shortest_paths(self, handle):
        '''
        Compute the shortest paths from the current position and direction,
        to the target of the agent identified by the given handle,
        considering every possibile target arrival direction.
        The shortest paths are then ordered by increasing lenght
        '''
        
        agent = self.agents[handle]
        position = self.get_agent_cell(handle)
        source, weight = self.next_node(position)
        targets = self.get_nodes(agent.target)
        paths = []
        for target in targets:
            try:
                lenght, path = nx.bidirectional_dijkstra(
                    self.graph, source, target
                )
                if position != path[0]:
                    path = [position] + path
                    lenght += weight
                paths.append((lenght, path))
            except nx.NetworkXNoPath:
                continue
        if not paths:
            return []
        return sorted(paths, key=lambda x: x[0])

    def deviation_paths(self, handle, source, node_to_avoid):
        '''
        Return alternative paths from `source` to the agent's target,
        without considering the actual shortest path
        '''
        agent = self.agents[handle]
        targets = self.get_nodes(agent.target)
        paths = []
        for succ in self.graph.successors(source):
            if succ != node_to_avoid:
                edge = self.graph.edges[(source, succ)]
                weight = edge['weight']
                for target in targets:
                    try:
                        lenght, path = nx.bidirectional_dijkstra(
                            self.graph, succ, target
                        )
                        path = [source] + path
                        lenght += weight
                        paths.append((lenght, path))
                    except nx.NetworkXNoPath:
                        continue
        if len(paths) == 0:
            return []
        return sorted(paths, key=lambda x: x[0])

    def meaningful_subgraph(self, handle):
        '''
        Return the subgraph which could be visited by the agent
        identified by the given handle
        '''
        nodes = {}
        source, _ = self.next_node(self.get_agent_cell(handle))
        for path in nx.all_simple_paths(self.graph, source, self.agents[handle].target):
            nodes.update(path)
        return nx.subgraph(self.graph, nodes)

    def get_agents_distance(self, handle_one, handle_two):
        '''
        Return the minimum distance between the given agents
        '''
        pos_one = self.get_agent_cell(handle_one)
        pos_two = self.get_agent_cell(handle_two)
        if pos_one is None or pos_two is None:
            return None
        node_one, weight_one = self.next_node(pos_one)
        node_two, weight_two = self.next_node(pos_two)
        try:
            distance = nx.dijkstra_path_length(
                self.graph, node_one, node_two
            )
            return distance + weight_one + weight_two
        except nx.NetworkXNoPath:
            return None

    def get_distance(self, source, dest):
        '''
        Return the minimum distance between the source
        and destination nodes
        '''
        if (source not in self._unpacked_graph.nodes or
                dest not in self._unpacked_graph.nodes):
            return np.inf
        return nx.dijkstra_path_length(
            self._unpacked_graph, source, dest
        )

    def get_adjacency_matrix(self, unpacked=False):
        '''
        Return the adjacency matrix of the specified graph,
        as a SciPy sparse COO matrix
        '''
        graph = self.graph if not unpacked else self._unpacked_graph
        return graph.to_scipy_sparse_matrix(
            dtype=np.dtype('long'), weight='weight', format='coo'
        )

    def get_graph_edges(self, unpacked=False, data=False):
        '''
        Return edges and associated features of the specified graph
        '''
        graph = self.graph if not unpacked else self._unpacked_graph
        return graph.edges(data=data)

    def get_graph_nodes(self, unpacked=False, data=False):
        '''
        Return nodes and associated features of the specified graph
        '''
        graph = self.graph if not unpacked else self._unpacked_graph
        return graph.nodes(data=data)

    def edges_from_path(self, path):
        '''
        Given a path in the packed graph as a sequence of nodes,
        return the corresponding sequence of edges
        '''
        edges = []
        starting_index = 0
        if path[0] not in self.graph.nodes:
            fake_weight, mini_path = nx.bidirectional_dijkstra(
                self._unpacked_graph, path[0], path[1]
            )
            edges.append((
                path[0], path[1],
                {
                    'weight': fake_weight,
                    'action': self._unpacked_graph.get_edge_data(mini_path[0], mini_path[1])['action'],
                    'choice': RailEnvChoices.CHOICE_LEFT
                }
            ))
            starting_index = 1
        for i in range(starting_index, len(path) - 1):
            if path[i] != path[i + 1]:
                edge = (path[i], path[i + 1])
                edge_attributes = self.graph.get_edge_data(*edge)
                edges.append((*edge, edge_attributes))
        return edges

    def positions_from_path(self, path, max_lenght=None):
        '''
        Given a path in the packed graph, return the corresponding
        path in the unpacked graph, without the direction component
        '''
        positions = [path[0]]
        for i in range(0, len(path) - 1):
            _, mini_path = nx.bidirectional_dijkstra(
                self._unpacked_graph, path[i], path[i + 1]
            )
            positions.extend(mini_path[1:])
            if max_lenght is not None and len(positions) >= max_lenght:
                return positions[:max_lenght]
        return positions

    def different_direction_nodes(self, node):
        '''
        Given a node, described by row, column and direction,
        return every other node in the packed graph with
        a different direction component
        '''
        nodes = []
        row, col, direction = node
        for new_direction in range(len(TRANS)):
            new_node = (row, col, new_direction)
            if new_node != node and new_node in self.graph:
                nodes.append(new_node)
        return nodes

    def no_successors_nodes(self, unpacked=False):
        '''
        Return a list of nodes that have no successors in the graph
        '''
        graph = self._unpacked_graph if unpacked else self.graph
        no_succ = []
        for node in graph.nodes:
            succ = self.get_successors(node, unpacked=unpacked)
            if len(succ) == 0:
                no_succ.append(node)
        return no_succ

    def draw_graph(self):
        '''
        Show the packed graph, with labels on nodes
        '''
        nx.draw(self.graph, with_labels=True)
        plt.show()

    def draw_unpacked_graph(self):
        '''
        Show the unpacked graph, with labels on nodes
        '''
        nx.draw(self._unpacked_graph, with_labels=True)
        plt.show()

    def draw_path(self, path):
        '''
        Show a path in the packed graph, where edges belonging
        to the path are colored in red
        '''
        if path[0] not in self.graph.nodes:
            path = path[1:]
        pos = nx.spring_layout(self.graph)
        nx.draw(self.graph, pos)
        path_edges = list(zip(path, path[1:]))
        nx.draw_networkx_nodes(self.graph, pos, nodelist=path, node_color='r')
        nx.draw_networkx_edges(
            self.graph, pos, edgelist=path_edges, edge_color='r', width=5
        )
        plt.axis('equal')
        plt.show()
        
## Binary tree observation ##

BT_LOWER, BT_UPPER = -1, 1
BT_UNDER, BT_OVER = -2, 2


def dumb_normalize_binary_tree_obs(observation):
    '''
    Substitute infinite values with a lower bound (e.g. -1),
    but avoid scaling observations
    '''
    normalized_observation = observation.copy()
    normalized_observation[normalized_observation == -np.inf] = BT_LOWER
    normalized_observation[normalized_observation == np.inf] = BT_LOWER
    return normalized_observation


def normalize_binary_tree_obs(observation, remaining_agents, max_malfunction, fixed_radius):
    '''
    Normalize the given observations by performing min-max scaling
    over individual features
    '''
    normalized_observation = observation.copy()
    num_agents = normalized_observation[:, :, 0:4]
    agent_distances = normalized_observation[:, :, 4:6]
    malfunctions = normalized_observation[:, :, 6:8]
    target_distances = normalized_observation[:, :, 8]
    turns_to_node = normalized_observation[:, :, 9]
    c_nodes = normalized_observation[:, :, 10]
    deadlocks = normalized_observation[:, :, 11]
    deadlock_distances = normalized_observation[:, :, 12]
    are_forks = normalized_observation[:, :, 13]
    stop_actions = normalized_observation[:, :, 14]

    # Normalize number of agents in path
    num_agents = min_max_scaling(
        num_agents, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=0, known_max=remaining_agents
    )

    # Normalize malfunctions
    malfunctions = min_max_scaling(
        malfunctions, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=0, known_max=max_malfunction
    )

    # Normalize common nodes
    c_nodes = min_max_scaling(
        c_nodes, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=0, known_max=remaining_agents
    )

    # Normalize deadlocks
    deadlocks = min_max_scaling(
        deadlocks, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=0, known_max=remaining_agents
    )

    # Normalize distances
    agent_distances = min_max_scaling(
        agent_distances, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=-fixed_radius, known_max=fixed_radius
    )
    target_distances = min_max_scaling(
        target_distances, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=0, known_max=fixed_radius
    )
    turns_to_node = min_max_scaling(
        turns_to_node, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=0, known_max=fixed_radius
    )
    deadlock_distances = min_max_scaling(
        deadlock_distances, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=-fixed_radius, known_max=fixed_radius
    )

    # Normalize stop actions
    stop_actions = min_max_scaling(
        stop_actions, BT_LOWER, BT_UPPER, BT_LOWER, BT_UPPER,
        known_min=0, known_max=fixed_radius
    )

    # Build the normalized observation
    normalized_observation[:, :, 0:4] = num_agents
    normalized_observation[:, :, 4:6] = agent_distances
    normalized_observation[:, :, 6:8] = malfunctions
    normalized_observation[:, :, 8] = target_distances
    normalized_observation[:, :, 9] = turns_to_node
    normalized_observation[:, :, 10] = c_nodes
    normalized_observation[:, :, 11] = deadlocks
    normalized_observation[:, :, 12] = deadlock_distances
    normalized_observation[:, :, 13] = are_forks
    normalized_observation[:, :, 14] = stop_actions

    # Sanity check
    normalized_observation[normalized_observation == -np.inf] = BT_LOWER
    normalized_observation[normalized_observation == np.inf] = BT_UPPER
    normalized_observation = np.clip(
        normalized_observation, BT_LOWER, BT_UPPER
    )

    # Check if the output is in range [BT_LOWER, BT_UPPER]
    assert np.logical_and(
        normalized_observation >= BT_LOWER,
        normalized_observation <= BT_UPPER
    ).all(), (observation, normalized_observation)

    return normalized_observation


## Tree observation 

T_CLIP_MIN, T_CLIP_MAX = -1, 1


def max_lt(seq, val):
    '''
    Return greatest item in seq for which item < val applies.
    None is returned if seq was empty or all items in seq were >= val.
    '''
    max = 0
    idx = len(seq) - 1
    while idx >= 0:
        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
            max = seq[idx]
        idx -= 1
    return max


def min_gt(seq, val):
    '''
    Return smallest item in seq for which item > val applies.
    None is returned if seq was empty or all items in seq were >= val.
    '''
    min = np.inf
    idx = len(seq) - 1
    while idx >= 0:
        if seq[idx] >= val and seq[idx] < min:
            min = seq[idx]
        idx -= 1
    return min


def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
    '''
    This function returns the difference between min and max value of an observation
    '''
    if fixed_radius > 0:
        max_obs = fixed_radius
    else:
        max_obs = max(1, max_lt(obs, 1000)) + 1

    min_obs = 0
    if normalize_to_range:
        min_obs = min_gt(obs, 0)
    if min_obs > max_obs:
        min_obs = max_obs
    if max_obs == min_obs:
        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
    norm = np.abs(max_obs - min_obs)
    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)


def _split_node_into_feature_groups(node):
    '''
    This function separates features of the given node into logical groups
    '''
    # Data features
    data = np.zeros(6)
    data[0] = node.dist_own_target_encountered
    data[1] = node.dist_other_target_encountered
    data[2] = node.dist_other_agent_encountered
    data[3] = node.dist_potential_conflict
    data[4] = node.dist_unusable_switch
    data[5] = node.dist_to_next_branch

    # Distance features
    distance = np.zeros(1)
    distance[0] = node.dist_min_to_target

    # Agent data features
    agent_data = np.zeros(4)
    agent_data[0] = node.num_agents_same_direction
    agent_data[1] = node.num_agents_opposite_direction
    agent_data[2] = node.num_agents_malfunctioning
    agent_data[3] = node.speed_min_fractional

    return data, distance, agent_data


def _split_subtree_into_feature_groups(node, current_tree_depth, max_tree_depth):
    '''
    This function recursively extracts information starting from the given node
    '''
    if node == -np.inf:
        remaining_depth = max_tree_depth - current_tree_depth
        num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
        return (
            [-np.inf] * num_remaining_nodes * 6,
            [-np.inf] * num_remaining_nodes,
            [-np.inf] * num_remaining_nodes * 4
        )

    data, distance, agent_data = _split_node_into_feature_groups(node)
    if not node.childs:
        return data, distance, agent_data

    for direction in TreeObsForRailEnv.tree_explored_actions_char:
        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(
            node.childs[direction], current_tree_depth + 1, max_tree_depth
        )
        data = np.concatenate((data, sub_data))
        distance = np.concatenate((distance, sub_distance))
        agent_data = np.concatenate((agent_data, sub_agent_data))

    return data, distance, agent_data


def split_tree_into_feature_groups(tree, max_tree_depth):
    '''
    This function splits the tree into three difference arrays of values
    '''
    data, distance, agent_data = _split_node_into_feature_groups(tree)

    for direction in TreeObsForRailEnv.tree_explored_actions_char:
        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(
            tree.childs[direction], 1, max_tree_depth
        )
        data = np.concatenate((data, sub_data))
        distance = np.concatenate((distance, sub_distance))
        agent_data = np.concatenate((agent_data, sub_agent_data))

    return data, distance, agent_data


def normalize_tree_obs(observation, tree_depth, radius):
    '''
    This function normalizes the observation used by the RL algorithm
    '''
    data, distance, agent_data = split_tree_into_feature_groups(
        observation, tree_depth
    )

    data = norm_obs_clip(
        data, clip_min=T_CLIP_MIN, clip_max=T_CLIP_MAX,
        fixed_radius=radius
    )
    distance = norm_obs_clip(
        distance, clip_min=T_CLIP_MIN, clip_max=T_CLIP_MAX,
        normalize_to_range=True
    )
    agent_data = np.clip(agent_data, T_CLIP_MIN, T_CLIP_MAX)
    normalized_obs = np.concatenate(
        (np.concatenate((data, distance)), agent_data)
    )
    return normalized_obs

## Custom binary tree observation

'''
Observation:
    - Structure:
        * Tensor of shape (1 + max_deviations, max_depth, features), where max_depth
          is the maximum number of nodes in the packed graph to consider and
          features is the total amount of features for each node
        * The feature matrix contains the features of the nodes in the shortest path
          as the first row and the features of the nodes in the deviation paths
          (which are exactly max_depth - 1) as the following rows
        * The feature matrix is then shaped like a binary tree, where branches identify
          choices, i.e. CHOICE_LEFT or CHOICE_RIGHT
    - Features:
        1. Number of agents (going in my direction) identified in the subpath
           from the root up to each node in the path
        2. Number of agents (going in a direction different from mine) identified
           in the subpath from the root up to each node in the path
        3. Number of malfunctioning agents (going in my direction) identified in the subpath
           from the root up to each node in the path
        4. Number of malfunctioning agents (going in a direction different from mine) identified
           in the subpath from the root up to each node in the path
        5. Minimum distances from an agent to other agents (going in my direction)
           in each edge of the path
        6. Minimum distances from an agent to other agents (going in a direction
           different than mine) in each edge of the path
        7. Maximum number of malfunctioning turns of other agents (going in my direction),
           in each edge of the path
        8. Maximum number of malfunctioning turns of other agents (going in a direction
           different from mine), in each edge of the path
        9. Distances from the target, from each node in the path
        10. Path weights in turns to reach the given node from the root one
        11. Number of agents using the node to reach their target in the shortest path
        12. Number of agents in deadlock in the previous path, assuming that all the
            other agents follow their shortest path
        13. How many turns before a possible deadlock
        14. If the node is a fork or not
        15. How many turns I've been repeatedly selecting the stop action
'''

# SpeedData:
# - `times` represents the total number of turns required for an agent to complete a cell
# - `remaining` represents the remaining number of steps required for an agent to complete the current cell
SpeedData = namedtuple('SpeedData', ['times', 'remaining'])

# Node:
# - `position` represents the position of the node in the railway
# - `features` represents the features associated to a node
# - `left` represents its left child
# - `right` represents its right child
Node = namedtuple('Node', ['position', 'features', 'left', 'right'])


class BinaryTreeObservator(ObservationBuilder):

    def __init__(self, max_depth, predictor):
        super().__init__()
        self.max_depth = max_depth
        self.predictor = predictor
        self.observations = dict()
        self.observation_dim = 15

    def _init_agents(self):
        '''
        Store agent-related info:
        - `speed_data`: a SpeedData object for each agent
        - `agent_handles`: set of agent handles
        - `other_agents`: list of other agent's handles for each agent
        - `last_nodes`: list of last visited nodes for each agent
          (along with corresponding weights)
        '''
        self.agent_handles = set(self.env.get_agent_handles())
        self.other_agents = dict()
        self.speed_data = dict()
        self.last_nodes = []
        for handle, agent in enumerate(self.env.agents):
            times_per_cell = int(np.reciprocal(agent.speed_data["speed"]))
            self.speed_data[handle] = SpeedData(
                times=times_per_cell, remaining=0
            )
            self.other_agents[handle] = self.agent_handles - {handle}
            agent_position = self.env.railway_encoding.get_agent_cell(handle)
            prev_node, prev_weight = self.env.railway_encoding.previous_node(
                agent_position
            )
            self.last_nodes.append(
                (prev_node, prev_weight * times_per_cell)
            )

    def reset(self):
        self._init_agents()
        if self.predictor is not None:
            self.predictor.reset()

    def set_env(self, env):
        super().set_env(env)
        if self.predictor:
            self.predictor.set_env(self.env)

    def _update_shortest(self, handle, prediction):
        '''
        Store shortest paths, shortest positions and shortest cumulative weights
        for the current observation of the given agent
        '''
        # Update speed data
        remaining_turns_in_cell = 0
        if self.env.agents[handle].speed_data["speed"] < 1.0:
            remaining_turns_in_cell = int(
                (1 - np.clip(self.env.agents[handle].speed_data["position_fraction"], 0.0, 1.0)) /
                self.env.agents[handle].speed_data["speed"]
            )
        self.speed_data[handle] = SpeedData(
            times=self.speed_data[handle].times,
            remaining=remaining_turns_in_cell
        )

        # Update shortest paths
        shortest_path = np.array(prediction.path, np.dtype('int, int, int'))
        self._shortest_paths[handle, :shortest_path.shape[0]] = shortest_path

        # Update shortest positions
        shortest_positions = np.array(
            [node[:-1] for node in prediction.path], np.dtype('int, int')
        )
        self._shortest_positions[handle, :shortest_positions.shape[0]] = (
            shortest_positions
        )

        # Update shortest cumulative weights
        self._shortest_cum_weights[handle] = self.compute_cumulative_weights(
            handle, prediction.lenght, prediction.edges, remaining_turns_in_cell
        )

        # Update last visited node and last positions
        prev_node, prev_weight = self.env.railway_encoding.previous_node(
            prediction.path[0]
        )
        self.last_nodes[handle] = (
            prev_node, prev_weight*self.speed_data[handle].times)

    def get_many(self, handles=None):
        self.predictions = self.predictor.get_many()
        self._shortest_paths = np.full(
            (len(self.agent_handles), self.max_depth),
            -1, np.dtype('int, int, int')
        )
        self._shortest_positions = np.full(
            (len(self.agent_handles), self.max_depth), -1, np.dtype('int, int')
        )
        self._shortest_cum_weights = np.zeros(
            (len(self.agent_handles), self.max_depth)
        )
        for handle, prediction in self.predictions.items():
            # Check if agent is not at target
            if self.predictions[handle] is not None:
                shortest_prediction = prediction[0]
                self._update_shortest(handle, shortest_prediction)

        return super().get_many(handles)

    def get(self, handle=0):
        dim = sum(2 ** i for i in range(self.max_depth)) * self.observation_dim
        self.observations[handle] = np.full(dim, BT_UNDER)
        features = np.full(
            (
                1 + self.predictor.max_deviations,
                self.max_depth, self.observation_dim
            ), -np.inf
        )

        # Compute features if necessary
        if (self.predictions[handle] is not None and (
                self.env.railway_encoding.is_real_decision(handle) or
                self.env.agents[handle].status == TrainState.READY_TO_DEPART)):
            shortest_path_prediction, deviation_paths_prediction = self.predictions[handle]
            packed_positions, packed_weights = self._get_shortest_packed_positions()
            shortest_feats = self._fill_path_values(
                handle, shortest_path_prediction, packed_positions, packed_weights
            )
            prev_num_agents = shortest_feats[:, :4]
            features[0, :, :] = shortest_feats

            # Compute deviation paths features
            for i, deviation_prediction in enumerate(deviation_paths_prediction):
                prev_deadlocks = 0
                prev_num_agents_values = None
                if i >= 1:
                    prev_deadlocks = shortest_feats[i - 1, 11]
                    prev_num_agents_values = prev_num_agents[i - 1, :]
                dev_feats = self._fill_path_values(
                    handle, deviation_prediction, packed_positions, packed_weights,
                    turns_to_deviation=self._shortest_cum_weights[handle, i],
                    prev_deadlocks=prev_deadlocks, prev_num_agents=prev_num_agents_values,
                    deviation=True
                )
                features[i + 1, :, :] = dev_feats

            # Normalize features
            features = normalize_binary_tree_obs(
                features,
                self.env.railway_encoding.remaining_agents(),
                self.env.malfunction_generator.get_process_data().max_duration,
                self.env.params.observator.binary_tree.radius
            )

            # Build the binary tree
            binary_tree = self.get_agent_binary_tree(
                handle, self.predictions[handle], features
            )

            # Linearize the binary tree and store it as an observation
            self.observations[handle] = self.concat_nodes(binary_tree)

        return self.observations[handle]

    def _fill_path_values(self, handle, prediction, packed_positions, packed_weights,
                          turns_to_deviation=0, prev_deadlocks=0, prev_num_agents=None, deviation=False):
        '''
        Compute observations for the given prediction and return
        a suitable feature matrix
        '''
        # Adjust weights and positions based on which kind of path
        # we are analyzing (shortest or deviation)
        path_weights = self._shortest_cum_weights[handle]
        path = prediction.path
        positions = [node[:-1] for node in path]
        if deviation == False:
            positions = packed_positions[handle].tolist()[:len(path)]
            positions_weights = packed_weights[handle]
        else:
            path_weights = np.array(
                self.compute_cumulative_weights(
                    handle, prediction.lenght, prediction.edges, turns_to_deviation
                )
            )
            positions_weights = path_weights

        # Compute features
        num_agents, agent_distances, malfunctions = self.agents_in_path(
            handle, path, path_weights, prev_num_agents=prev_num_agents
        )
        target_distances = self.distance_from_target(
            handle, prediction.lenght, path, path_weights, turns_to_deviation
        )
        c_nodes = self.common_nodes(handle, positions)
        deadlocks, deadlock_distances = self.find_deadlocks(
            handle, positions, positions_weights, packed_positions, packed_weights,
            prev_deadlocks=prev_deadlocks
        )
        are_forks = self.compute_is_fork(path)
        stop_actions = np.full(are_forks.shape, self.env.stop_actions[handle])

        # Build the feature matrix
        feature_matrix = np.vstack([
            num_agents, agent_distances, malfunctions,
            target_distances, path_weights, c_nodes, deadlocks, deadlock_distances,
            are_forks, stop_actions
        ]).T

        return feature_matrix

    def get_binary_tree(self, position, depth, prediction, features, choices=[]):
        '''
        Recursive function that build a binary tree starting from the given position
        and adds the correct feature set to each node
        '''
        if depth == 0:
            return None
        children = {"left": None, "right": None}
        if self.env.railway_encoding.is_node(position, unpacked=False):
            successors = self.env.railway_encoding.get_successors(
                position, unpacked=False
            )
            for succ in successors:
                choice = self.env.railway_encoding.get_edge_data(
                    position, succ, 'choice', unpacked=False
                )
                if choice == RailEnvChoices.CHOICE_LEFT:
                    children["left"] = succ
                elif choice == RailEnvChoices.CHOICE_RIGHT:
                    children["right"] = succ
        return Node(
            position=position,
            features=self.get_node_features(
                prediction, features, choices, depth
            ),
            left=self.get_binary_tree(
                children["left"], depth - 1, prediction, features,
                choices=choices + [RailEnvChoices.CHOICE_LEFT]
            ),
            right=self.get_binary_tree(
                children["right"], depth - 1, prediction, features,
                choices=choices + [RailEnvChoices.CHOICE_RIGHT]
            )
        )

    def get_node_features(self, prediction, features, choices, depth):
        '''
        Logically traverse the binary tree based on the given sequence of choices
        and extract the features at that level in the feature matrix
        '''
        sp_prediction, dp_predictions = prediction
        pos = self.max_depth - depth

        # Root node
        if pos == 0:
            return features[0, 0, :]

        # Non-root node
        sp_edges = [edge[2]['choice'] for edge in sp_prediction.edges[:pos]]
        dp_edges = [[c for c in sp_edges[:i]] for i in range(len(sp_edges))]
        for i, dp_prediction in enumerate(dp_predictions[:pos]):
            for c in dp_prediction.edges[:pos - i]:
                dp_edges[i].append(c[2]['choice'])

        # Node on shortest path
        if choices == sp_edges:
            return features[0, pos, :]

        # Node on deviation path
        for i, dp in enumerate(dp_edges):
            if choices == dp:
                return features[i + 1, pos - i, :]

        # Fallback to default filling values
        return np.full(self.observation_dim, BT_LOWER)

    def get_agent_binary_tree(self, handle, prediction, features):
        '''
        Build the observation binary tree for the given agent and fill
        it with the given features
        '''
        position = self.env.railway_encoding.get_agent_cell(handle)
        node, _ = self.env.railway_encoding.next_node(position)
        return self.get_binary_tree(node, self.max_depth, prediction, features)

    def concat_nodes(self, node):
        '''
        Linearize the given binary tree features in a single array
        '''
        if node is None:
            return []
        return np.concatenate((
            node.features,
            self.concat_nodes(node.left),
            self.concat_nodes(node.right)
        ))

    def compute_is_fork(self, path):
        '''
        Given a path, returns for each node if it is a fork or not
        '''
        are_forks = np.full((self.max_depth,), -np.inf)
        for ind, node in enumerate(path):
            are_forks[ind] = self.env.railway_encoding.is_fork(node)
        return are_forks

    def compute_cumulative_weights(self, handle, lenght, edges, initial_distance):
        '''
        Given a list of edges, compute the cumulative sum of weights,
        representing the number of turns the given agent must perform
        to reach each node in the path
        '''
        np_weights = np.zeros((self.max_depth,))
        if lenght == np.inf:
            np_weights = np.full((self.max_depth,), np.inf)
        weights = [initial_distance] + [
            e[2]['weight'] * self.speed_data[handle].times for e in edges
        ]
        np_weights[:len(weights)] = np.cumsum(weights)
        return np_weights

    def agents_in_path(self, handle, path, cum_weights, prev_num_agents=None):
        '''
        Return three arrays:
        - Number of agents identified in the subpath from the root up to
          each node in the path (in both directions and both malfunctioning or not)
        - Minimum distances from an agent to other agent's
          in each edge of the path (in both directions)
        - Maximum turns that an agent has to wait because it is malfunctioning,
          in each edge of the path (in both directions)

        The directions are considered as:
        - Same direction, if two agents "follow" each other
        - Other direction, otherwise
        '''
        num_agents = np.zeros((self.max_depth, 4))
        if prev_num_agents is not None:
            num_agents[:] = np.array(prev_num_agents)
        distances = np.full((self.max_depth, 2), np.inf)
        malfunctions = np.zeros((self.max_depth, 2))

        # For each agent different than myself
        for agent in self.other_agents[handle]:
            position = self.env.railway_encoding.get_agent_cell(agent)
            # Check if agent is not DONE_REMOVED (position would be None)
            if position is not None:
                # Take the other agent's next node in the packed graph
                node, next_node_distance = self.env.railway_encoding.next_node(
                    position
                )
                # Take every possible direction for the given node in the packed graph
                nodes = self.env.railway_encoding.get_nodes((node[0], node[1]))
                # Check the next nodes of the next node in order to see
                # the other agent's entry direction
                next_nodes = self.env.railway_encoding.get_successors(node)

                # Check if one of the next nodes of the other agent are in my path
                for other_node in nodes:
                    index = get_index(path, other_node)
                    if index is not None:
                        # Initialize distances
                        distance = cum_weights[index]
                        if cum_weights[index] < self.speed_data[handle].times:
                            distance = (
                                self.speed_data[handle].remaining -
                                self.speed_data[handle].times
                            )
                        turns_to_reach_other_agent = abs(
                            (next_node_distance - (self.speed_data[agent].remaining / self.speed_data[agent].times)) *
                            self.speed_data[handle].times
                        )

                        # Check if same direction or other direction
                        different_node = other_node != node
                        more_than_one_choice = len(next_nodes) > 1
                        last_node_in_path = len(path) <= index + 1
                        different_one_choice = (
                            not last_node_in_path and
                            len(next_nodes) > 0 and
                            next_nodes[0] != path[index + 1]
                        )
                        if (different_node and (more_than_one_choice or last_node_in_path or different_one_choice)):
                            direction = 1
                        else:
                            turns_to_reach_other_agent = -turns_to_reach_other_agent
                            direction = 0

                        # Update number of agents
                        num_agents[index:len(path), direction] += 1

                        # Update distances s.t. we always keep the greatest one (if distance is negative),
                        # otherwise the minimum one (if distance is positive)
                        distance += turns_to_reach_other_agent
                        if ((distances[index, direction] == np.inf) or
                            (distance >= 0 and distances[index, direction] > distance) or
                                (distance <= 0 and distances[index, direction] < distance) or
                                (distance >= 0 and distances[index, direction] < 0)):
                            distances[index, direction] = distance

                        # Update malfunctions
                        malfunction = self.env.agents[agent].malfunction_data['malfunction']
                        if malfunction > 0:
                            num_agents[index:len(path), direction + 2] += 1
                        if malfunctions[index, direction] < malfunction:
                            malfunctions[index, direction] = malfunction
                        break

        return np.transpose(num_agents), np.transpose(distances), np.transpose(malfunctions)

    def distance_from_target(self, handle, lenght, path, cum_weights, turns_to_deviation=0):
        '''
        For a shortest path:
        - `lenght` should be the actual length of the shortest path
        - `cum_weights` should be the cumulative number of turns to reach each node
        - `turns_to_deviation` should be zero

        For a deviation path:
        - `lenght` should be the actual length of the deviation path
        - `cum_weights` should be the cumulative number of turns to reach each node
          (starting from the agent's position instead of the root of the deviation path)
        - `turns_to_deviation` should be the number of turns required to reach the root
           of the deviation path

        Returns the actual distance from each node of the path to its target
        '''
        # If the agent cannot arrive to the target
        if lenght == np.inf:
            return np.full((self.max_depth,), np.inf)

        # Initialize each node with the distance from the agent to the target
        distances = np.zeros((self.max_depth,))
        max_distance = (
            (lenght * self.speed_data[handle].times)
            + turns_to_deviation
        )
        distances[:len(path)] = np.full((len(path),), max_distance)

        # Compute actual distances for each node
        distances -= cum_weights
        return distances

    def common_nodes(self, handle, positions):
        '''
        Given an agent's positions and the shortest positions for every other agent,
        compute the number of agents intersecting at each node
        '''
        c_nodes = np.zeros((self.max_depth,))
        if len(positions) > 0:
            nd_positions = np.array(positions, np.dtype('int, int'))
            computed = np.zeros((len(positions),))
            for row in self.other_agents[handle]:
                computed += np.count_nonzero(
                    np.isin(
                        nd_positions, self._shortest_positions[row, :]
                    ).reshape(1, len(nd_positions)),
                    axis=0
                )
            c_nodes[:computed.shape[0]] = computed
        return c_nodes

    def _get_shortest_packed_positions(self):
        '''
        For each agent's shortest path, substitute the first node for
        its previous node in the packed graph, if it doesn't
        already match with the agent's position

        Return the modified path (without the direction component),
        along with the associated cumulative weights (which are re-computed
        starting from the original cumulative weights)
        '''
        prev_weights = []
        prev_nodes = [node[0] for node in self.last_nodes]
        for agent, path in enumerate(self._shortest_paths):
            # If the agent's position is not on the packed graph
            if tuple(path[0]) != prev_nodes[agent]:
                prev_weights.append(
                    - (self.last_nodes[agent][1] +
                       self.speed_data[agent].times -
                       self.speed_data[agent].remaining)
                )
            # If the agent's position is already in the packed path,
            # do not change the cumulative weights of the first node
            else:
                prev_weights.append(self._shortest_cum_weights[agent, 0])

        # Remove the first column of the original shortest positions
        # and replace it with the previous node
        packed_positions = np.hstack([
            np.array(
                [node[:-1] for node in prev_nodes], np.dtype('int, int')
            ).reshape(self._shortest_positions.shape[0], 1),
            self._shortest_positions[:, 1:]
        ])
        # Update the corresponding cumulative weights
        packed_weights = np.hstack([
            np.array(prev_weights).reshape(
                self._shortest_cum_weights.shape[0], 1
            ),
            self._shortest_cum_weights[:, 1:]
        ])
        return packed_positions, packed_weights

    def find_deadlocks(self, handle, positions, cum_weights, packed_positions, packed_weights, prev_deadlocks=0):
        '''
        For a shortest path and a deviation path:
        - `positions` should be the packed positions
        - `cum_weights` should be the packed cumulative weights
        - `packed_positions` should be the list of packed shortest positions for each agent
        - `packed_weights` should be the list of packed cumulative weights for each agent

        Returns two lists:
        - `deadlocks`: the number of possible deadlocks for each node in `path`
        - `crash_turns`: the number of turns to the first deadlock for each node in `path`
        '''
        deadlocks = np.full((self.max_depth,), prev_deadlocks)
        crash_turns = np.full((self.max_depth,), np.inf)

        # For each agent different than myself
        for agent in self.other_agents[handle]:
            deadlock_found = False
            agent_path = packed_positions[agent].tolist()
            # For each node in the other agent's path
            for i in range(len(agent_path) - 1):
                # Avoid non-informative pair of nodes
                if tuple(agent_path[i]) != (-1, -1) and tuple(agent_path[i + 1]) != (-1, -1):
                    # For each node in my path
                    for j in range(len(positions) - 1):
                        source, dest = positions[j], positions[j + 1]
                        from_dest_to_source = (
                            source == agent_path[i + 1] and
                            dest == agent_path[i]
                        )
                        intersecting_turns = (
                            not cum_weights[j] > packed_weights[agent, i + 1] and
                            not cum_weights[j + 1] < packed_weights[agent, i]
                        )
                        deadlock_found = from_dest_to_source and intersecting_turns
                        if deadlock_found:
                            space = (
                                cum_weights[j + 1] - cum_weights[j]
                            ) / self.speed_data[handle].times

                            # Both agents in same edge: reduce space by how much they
                            # already have traversed
                            if cum_weights[j] < 0 and packed_weights[agent, i] < 0:
                                space += (
                                    cum_weights[j] /
                                    self.speed_data[handle].times
                                )
                                space += (
                                    packed_weights[agent, i] /
                                    self.speed_data[agent].times
                                )
                            # My entry turn is greater than the other agent's entry turn:
                            # reduce space by how the other agent's has already traversed,
                            # by the time my agent enters the edge
                            elif cum_weights[j] > packed_weights[agent, i]:
                                space -= abs(
                                    cum_weights[j] -
                                    abs(packed_weights[agent, i])
                                ) / self.speed_data[agent].times
                            # The opposite of the previous case
                            elif packed_weights[agent, i] > cum_weights[j]:
                                space += abs(
                                    packed_weights[agent, i] -
                                    abs(cum_weights[j])
                                ) / self.speed_data[agent].times

                            # Compute the distance in turns from my agent to
                            # the possible identified deadlock
                            crash_turn = np.ceil(
                                np.clip(cum_weights[j], 0, None) +
                                space / reciprocal_sum(
                                    self.speed_data[agent].times,
                                    self.speed_data[handle].times
                                )
                            )
                            # Store only the minimum distance
                            if crash_turns[j] > crash_turn:
                                crash_turns[j] = crash_turn

                            # Update number of deadlocks
                            deadlocks[j:len(positions)] += 1

                            # If one deadlock is found, do not check any other
                            # between the same pair of agents
                            break

                    # If one deadlock is found, do not check any other
                    # between the same pair of agents
                    if deadlock_found:
                        break

        return deadlocks, crash_turns
    
    
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.data import Data


class FOVObservator(ObservationBuilder):
    '''
    An Observator that return a local observation of each agent in the form of
    a tensor of size max_depth centered around the agent position

    Features:
        0. Cell type of the rail in the agent's FOV
        1. Cell orientation of the rail in the agent's FOV
        2. Distances in shortest path in the agent's FOV
        3. Other agents positions in the agent's FOV (direction of each agent)
        4. Agents targets in the agent's FOV (1 agent target, 0 other agent, -1 otherwise)
        5. Agents malfunctioning turns
        6. Agents fractional speeds
        7+. Distances in deviation paths in the agent's FOV
    '''

    def __init__(self, max_depth, predictor):
        super().__init__()
        # Always keep an odd number of "squares", so that the agent
        # is centered w.r.t. its FOV
        assert max_depth % 2 != 0, 'FOV window must be an odd number'
        self.max_depth = max_depth
        self.predictor = predictor
        self.observations = dict()
        self.observation_dim = 7 + self.predictor.max_deviations
        self.possible_transitions_dict = self.compute_all_possible_transitions()
        self.agent_positions = None
        self.agent_malfunctions = None
        self.agent_speeds = None
        self.agent_targets = None
        self.data_object = None

    def reset(self):
        if self.predictor is not None:
            self.predictor.reset()
        rail_obs_16_channels = np.zeros((self.env.height, self.env.width, 16))
        for i in range(rail_obs_16_channels.shape[0]):
            for j in range(rail_obs_16_channels.shape[1]):
                bitlist = [
                    int(digit) for digit in
                    bin(self.env.rail.get_full_transitions(i, j))[2:]
                ]
                bitlist = [0] * (16 - len(bitlist)) + bitlist
                rail_obs_16_channels[i, j] = np.array(bitlist)
        self.rail_obs = self.convert_transitions_map(rail_obs_16_channels)
        self.agent_targets = np.full(
            (self.env.get_num_agents(), self.env.rail.height, self.env.rail.width), -1
        )
        for handle, agent in enumerate(self.env.agents):
            target = agent.target
            if target is not None:
                self.agent_targets[handle, target[0], target[1]] = 1
                other_agents = set(self.env.get_agent_handles()) - {handle}
                for other in other_agents:
                    if self.agent_targets[other, target[0], target[1]] == -1:
                        self.agent_targets[other, target[0], target[1]] = 0

    def set_env(self, env):
        super().set_env(env)
        if self.predictor:
            self.predictor.set_env(self.env)

    def convert_transitions_map(self, obs_transitions_map):
        '''
        Given an np.array of shape (env_height, env_width_, 16),
        convert it to (env_height,env_width, 2) where the first channel
        encodes cell types (empty cell 0 is encoded as -1,
        while cell types 1 to 10 are encoded as-is)
        and the second channel orientations (0, 90, 180, 270 as 0, 1, 2, 3)
        '''
        new_transitions_map = np.full(
            (obs_transitions_map.shape[0], obs_transitions_map.shape[1], 2), -1
        )

        for i in range(obs_transitions_map.shape[0]):
            for j in range(obs_transitions_map.shape[1]):
                transition_bitmap = obs_transitions_map[i, j]
                int_transition_bitmap = int(
                    transition_bitmap.dot(
                        2 ** np.arange(transition_bitmap.size)[::-1]
                    )
                )
                if int_transition_bitmap != 0:
                    new_transitions_map[i, j] = (
                        self.possible_transitions_dict[int_transition_bitmap]
                    )

        return new_transitions_map

    def compute_all_possible_transitions(self):
        '''
        Given transitions list considering cell types,
        outputs all possible transitions bitmap,
        considering cell rotations too
        '''
        # Bitmaps are read in decimal numbers
        transitions = RailEnvTransitions()
        transitions_with_rotation_dict = {}
        rotation_degrees = [0, 90, 180, 270]

        for index, transition in enumerate(transitions.transition_list):
            for rot_type, rot in enumerate(rotation_degrees):
                rot_transition = transitions.rotate_transition(transition, rot)
                if rot_transition not in transitions_with_rotation_dict:
                    transitions_with_rotation_dict[rot_transition] = (
                        np.array([index, rot_type])
                    )
        return transitions_with_rotation_dict

    def extract_path_fov(self, path, lenght, pad=0):
        '''
        Given a path returns the matrix fov marking the occupied positions assuming
        the first position as the center one of the fov
        '''
        path_fov = np.full((self.max_depth, self.max_depth), pad)
        distance = lenght
        if distance < np.inf:
            y, x = self.max_depth // 2, self.max_depth // 2
            prev_pos = path[0]
            for pos in path[1:]:
                if y >= 0 and y < self.max_depth and x >= 0 and x < self.max_depth:
                    path_fov[y, x] = distance
                    if pos[0] != prev_pos[0] or pos[1] != prev_pos[1]:
                        distance -= 1
                y += pos[0] - prev_pos[0]
                x += pos[1] - prev_pos[1]
                prev_pos = pos

            # Add last element
            if y >= 0 and y < self.max_depth and x >= 0 and x < self.max_depth:
                path_fov[y, x] = distance

        return path_fov

    def get_many(self, handles=None):
        self.agent_positions = np.full(
            (self.env.rail.height, self.env.rail.width), -1
        )
        self.agent_malfunctions = np.full(
            (self.env.rail.height, self.env.rail.width), -1
        )
        self.agent_speeds = np.full(
            (self.env.rail.height, self.env.rail.width), -1
        )
        if self.predictor is not None:
            self.predictions = self.predictor.get_many()
        for handle, agent in enumerate(self.env.agents):
            agent_position = self.env.railway_encoding.get_agent_cell(
                handle
            )
            if agent_position is not None:
                self.agent_positions[
                    agent_position[0], agent_position[1]
                ] = agent_position[2]
                self.agent_malfunctions[
                    agent_position[0], agent_position[1]
                ] = agent.malfunction_data['malfunction']
                self.agent_speeds[
                    agent_position[0], agent_position[1]
                ] = agent.speed_data['speed']

        # Compute adjacency matrix and store it in a
        # PyTorch Geometric Data object
        adjacency = self.env.agents_adjacency_matrix(
            radius=self.max_depth
        )
        edge_index = torch.from_numpy(
            np.argwhere(adjacency != 0)
        ).long().t().contiguous()
        edge_weight = torch.from_numpy(
            adjacency[np.nonzero(adjacency)]
        ).float()
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value=1, num_nodes=self.env.get_num_agents()
        )

        # Add features to PyTorch Geometric Data object
        states = super().get_many(handles)
        self.data_object = Data(
            edge_index=edge_index,
            edge_weight=edge_weight,
            num_nodes=self.env.get_num_agents(),
            states=torch.tensor(
                list(states.values()), dtype=torch.float
            )
        )
        self.observations = {
            handle: self.data_object for handle in self.env.get_agent_handles()
        }
        return self.observations

    def get(self, handle=0):
        self.observations[handle] = np.full(
            (self.observation_dim, self.max_depth, self.max_depth), -1
        )
        if (self.predictions[handle] is not None):
            agent_position = self.env.railway_encoding.get_agent_cell(
                handle
            )
            if agent_position is not None:
                shortest_pred, deviations_pred = self.predictions[handle]
                num_devs = self.predictor.max_deviations

                # Cell type of the rail in the agent's FOV
                cell_type = extract_fov(
                    self.rail_obs[:, :, 0], agent_position, self.max_depth, -1
                )

                # Cell orientation of the rail in the agent's FOV
                cell_orientation = extract_fov(
                    self.rail_obs[:, :, 1], agent_position, self.max_depth, -1
                )

                # Distances in shortest path in the agent's FOV
                path_fov = self.extract_path_fov(
                    shortest_pred.positions, shortest_pred.lenght, pad=-1
                )

                # Distances in deviation paths in the agent's FOV
                dev_paths_fov = np.full(
                    (num_devs, self.max_depth, self.max_depth), -1
                )
                for i, dev_pred in enumerate(deviations_pred):
                    dev_path_fov = np.full(
                        (self.max_depth, self.max_depth), -1
                    )
                    if dev_pred.lenght < np.inf:
                        source_index = shortest_pred.positions.index(
                            dev_pred.positions[0]
                        )
                        if source_index != -1:
                            dev_pos = (
                                shortest_pred.positions[:source_index] +
                                dev_pred.positions
                            )
                            dev_path_fov = self.extract_path_fov(
                                dev_pos, len(dev_pos), pad=-1
                            )
                    dev_paths_fov[i] = dev_path_fov

                # Other agents positions in the agent's FOV (direction of each agent)
                agents_fov = extract_fov(
                    self.agent_positions, agent_position, self.max_depth, -1
                )

                # Agents targets in the agent's FOV (1 agent target, 0 other agent, -1 otherwise)
                targets_fov = extract_fov(
                    self.agent_targets[handle, :, :],
                    agent_position, self.max_depth, -1
                )

                # Malfunctioning turns
                malf_fov = extract_fov(
                    self.agent_malfunctions, agent_position, self.max_depth, -1
                )

                # Speed information
                speed_fov = extract_fov(
                    self.agent_speeds, agent_position, self.max_depth, -1
                )

                # Update observations
                self.observations[handle][0] = cell_type
                self.observations[handle][1] = cell_orientation
                self.observations[handle][2] = path_fov
                self.observations[handle][3] = agents_fov
                self.observations[handle][4] = targets_fov
                self.observations[handle][5] = malf_fov
                self.observations[handle][6] = speed_fov
                self.observations[handle][7: 7 + num_devs + 1] = dev_paths_fov

        return self.observations[handle]
    
    
## Graph Observator 
class GraphObservator(ObservationBuilder):

    def __init__(self, max_depth, predictor):
        super().__init__()
        self.max_depth = max_depth
        self.predictor = predictor
        self.observations = dict()
        self.observation_dim = 2
        self.data_object = None

    def reset(self):
        self.data_object = self._init_graph()
        if self.predictor is not None:
            self.predictor.reset()

    def set_env(self, env):
        super().set_env(env)
        if self.predictor:
            self.predictor.set_env(self.env)

    def get_many(self, handles=None):
        self.predictions = self.predictor.get_many()
        return super().get_many(handles)

    def get(self, handle=0):
        # Compute node features
        nodes = self.env.railway_encoding.get_graph_nodes(
            unpacked=False, data=True
        )
        x = [None] * len(nodes)
        agents_positions = {
            self.env.railway_encoding.get_agent_cell(h)
            for h in range(len(self.env.agents)) if h != handle
        }
        for n, _ in nodes:
            target_distance = self.env.distance_map.get()[
                handle, n[0], n[1], n[2]
            ]
            is_occupied = n in agents_positions
            x[self.env.railway_encoding.node_to_index[n]] = [
                # d["is_dead_end"], d["is_fork"], d["is_join"], d["is_target"],
                target_distance, is_occupied
            ]
        x = torch.tensor(x, dtype=torch.float)

        # Store a list of important positions, so that the DQN is called with
        # the GNN embeddings of these nodes
        agent_position = self.env.railway_encoding.get_agent_cell(handle)
        agent_pos_index = -1
        successors = []
        if agent_position is not None:
            agent_in_packed = self.env.railway_encoding.is_node(
                agent_position, unpacked=False
            )
            if agent_in_packed:
                successors = self.env.railway_encoding.get_successors(
                    agent_position, unpacked=False
                )
            else:
                actual_agent_position = tuple(agent_position)
                agent_position, _ = self.env.railway_encoding.previous_node(
                    actual_agent_position
                )
                successor, _ = self.env.railway_encoding.next_node(
                    actual_agent_position
                )
                successors = [successor]
            agent_pos_index = self.env.railway_encoding.node_to_index[agent_position]

        successors_indexes = {"left": -1, "right": -1}
        for succ in successors:
            succ_index = self.env.railway_encoding.node_to_index[succ]
            succ_choice = self.env.railway_encoding.get_edge_data(
                agent_position, succ, 'choice', unpacked=False
            )
            if succ_choice == RailEnvChoices.CHOICE_LEFT:
                successors_indexes["left"] = succ_index
            elif succ_choice == RailEnvChoices.CHOICE_RIGHT:
                successors_indexes["right"] = succ_index
        pos = torch.tensor([
            successors_indexes["left"],
            successors_indexes["right"],
            agent_pos_index
        ], dtype=torch.long)

        # Create a PyTorch Geometric Data object
        self.observations[handle] = Data(
            edge_index=self.data_object.edge_index,
            edge_weight=self.data_object.edge_weight,
            pos=pos, x=x
        )
        return self.observations[handle]

    def _init_graph(self):
        '''
        Initialize the graph structure, which is the same
        for all agents in an episode
        '''
        # Compute edges and edges attributes
        edges = self.env.railway_encoding.get_graph_edges(
            unpacked=False, data=True
        )
        edge_index, edge_weight = [], []
        for u, v, d in edges:
            edge_index.append([
                self.env.railway_encoding.node_to_index[u],
                self.env.railway_encoding.node_to_index[v]
            ])
            edge_weight.append(d['weight'])
        edge_index = torch.tensor(
            edge_index, dtype=torch.long
        ).t().contiguous()
        edge_weight = torch.tensor(edge_weight, dtype=torch.float)
        return Data(
            edge_index=edge_index, edge_weight=edge_weight
        )
        
## UTILS ##

OBSERVATORS = {
    "tree": TreeObsForRailEnv,
    "binary_tree": BinaryTreeObservator,
    "graph": GraphObservator,
    "decentralized_fov": FOVObservator
}
PREDICTORS = {
    "tree": ShortestPathPredictorForRailEnv,
    "binary_tree": ShortestDeviationPathPredictor,
    "graph": NullPredictor,
    "decentralized_fov": ShortestDeviationPathPredictor
}


class RailEnvChoices(IntEnum):

    CHOICE_LEFT = 0
    CHOICE_RIGHT = 1
    STOP = 2

    @staticmethod
    def value_of(value):
        '''
        Return an instance of RailEnvChoices from the given choice type int
        '''
        for _, choice_type in RailEnvChoices.__members__.items():
            if choice_type.value == value.capitalize():
                return choice_type
        return None

    @staticmethod
    def values():
        '''
        Return a list of every possible RailEnvChoices
        '''
        return [
            choice_type
            for _, choice_type in RailEnvChoices.__members__.items()
        ]

    @staticmethod
    def choice_size():
        '''
        Return the number of values that can be assigned
        to a RailEnvChoices instance
        '''
        return len(RailEnvChoices.values())

    @staticmethod
    def default_choices():
        '''
        Return a mask of choices, s.t. the only choice that
        can always be applied is STOP
        '''
        return [False, False, True]


def get_num_actions():
    '''
    Return the number of possible RailEnvActions
    '''
    return len([
        action_type for _, action_type in RailEnvActions.__members__.items()
    ])


def create_rail_env(args, load_env=""):
    '''
    Build a RailEnv object with the specified parameters,
    as described in the .yml file
    '''
    # Check if an environment file is provided
    if load_env:
        rail_generator = rail_from_file(load_env)
    else:
        rail_generator = sparse_rail_generator(
            max_num_cities=args.env.max_cities,
            grid_mode=args.env.grid,
            max_rails_between_cities=args.env.max_rails_between_cities,
            seed=args.env.seed
        )

    # Build predictor and observator
    obs_type = args.policy.type.get_true_key()
    if PREDICTORS[obs_type] is ShortestDeviationPathPredictor:
        predictor = PREDICTORS[obs_type](
            max_depth=args.observator.max_depth,
            max_deviations=args.predictor.max_depth
        )
    else:
        predictor = PREDICTORS[obs_type](max_depth=args.predictor.max_depth, max_deviations = args.predictor.max_deviations)
    observator = OBSERVATORS[obs_type](args.observator.max_depth, predictor)

    # Initialize malfunctions
    malfunctions = None
    if args.env.malfunctions.enabled:
        malfunctions = ParamMalfunctionGen(
            MalfunctionParameters(
                malfunction_rate=args.env.malfunctions.rate,
                min_duration=args.env.malfunctions.min_duration,
                max_duration=args.env.malfunctions.max_duration
            )
        )

    # Initialize agents speeds
    speed_map = None
    if args.env.variable_speed:
        speed_map = {
            1.: 0.25,
            1. / 2.: 0.25,
            1. / 3.: 0.25,
            1. / 4.: 0.25
        }
    schedule_generator = sparse_line_generator(
        speed_map, seed=args.env.seed
    )

    # Build the environment
    return RailEnvWrapper(
        params=args,
        width=args.env.width,
        height=args.env.height,
        rail_generator=rail_generator,
        number_of_agents=args.env.num_trains,
        obs_builder_object=observator,
        malfunction_generator=malfunctions,
        remove_agents_at_target=True,
        random_seed=args.env.seed
    )


def create_save_env(path, width, height, num_trains, max_cities,
                    max_rails_between_cities, max_rails_in_cities, grid=False, seed=0):
    '''
    Create a RailEnv environment with the given settings and save it as pickle
    '''
    rail_generator = sparse_rail_generator(
        max_num_cities=max_cities,
        seed=seed,
        grid_mode=grid,
        max_rails_between_cities=max_rails_between_cities,
        max_rails_in_city=max_rails_in_cities,
    )
    env = RailEnv(
        width=width,
        height=height,
        rail_generator=rail_generator,
        number_of_agents=num_trains
    )
    env.save(path)


def get_seed(env, seed=None):
    '''
    Exploit the RailEnv to get a random seed
    '''
    seed = env._seed(seed)
    return seed[0]


def copy_obs(obs):
    '''
    Return a deep copy of the given observation
    '''
    if hasattr(obs, "copy"):
        return obs.copy()
    return copy.deepcopy(obs)


def agent_action(original_dir, final_dir):
    '''
    Return the action performed by an agent, by analyzing
    the starting direction and the final direction of the movement
    '''
    value = (final_dir.value - original_dir.value) % 4
    if value in (1, -3):
        return RailEnvActions.MOVE_RIGHT
    elif value in (-1, 3):
        return RailEnvActions.MOVE_LEFT
    return RailEnvActions.MOVE_FORWARD

## Policy

Policy utils

In [40]:
class MaskedMSELoss(nn.Module):
    '''
    MSE loss with masked inputs/targets
    '''

    def __init__(self, reduction='mean'):
        super(MaskedMSELoss, self).__init__()
        self.reduction = reduction

    def forward(self, input, target, mask=None):
        if mask is None:
            return F.mse_loss(input, target, reduction=self.reduction)

        flattened_mask = torch.flatten(mask).float()
        diff = ((
            torch.flatten(input) - torch.flatten(target)
        ) ** 2.0) * flattened_mask
        mask_sum = (
            torch.sum(flattened_mask)
            if self.reduction == 'mean'
            else 1.0
        )
        return torch.sum(diff) / mask_sum


class MaskedHuberLoss(nn.Module):
    '''
    Huber loss with masked inputs/targets
    '''

    def __init__(self, reduction='mean', beta=1.0):
        super(MaskedHuberLoss, self).__init__()
        self.reduction = reduction
        self.beta = float(beta)

    def forward(self, input, target, mask=None):
        if mask is None:
            return F.smooth_l1_loss(
                input, target, reduction=self.reduction, beta=self.beta
            )

        flattened_mask = torch.flatten(mask).float()
        errors = torch.abs(torch.flatten(input) -
                           torch.flatten(target)) * flattened_mask
        diff = torch.zeros_like(errors)
        less = errors < self.beta
        diff[less] = (
            (0.5 * (errors[less] ** 2) / self.beta) * flattened_mask[less]
        )
        diff[~less] = (errors[~less] - 0.5 * self.beta) * flattened_mask[~less]
        mask_sum = (
            torch.sum(flattened_mask)
            if self.reduction == 'mean'
            else 1.0
        )
        return torch.sum(diff) / mask_sum


class Sequential(nn.Sequential):
    '''
    Extension of the PyTorch Sequential module,
    to handle a variable number of arguments
    '''

    def forward(self, input, **kwargs):
        for module in self:
            input = module(input, **kwargs)
        return input


def masked_softmax(vec, mask, dim=1, temperature=1):
    '''
    Softmax only on valid outputs
    '''
    assert vec.shape == mask.shape
    assert np.all(mask.astype(bool).any(axis=dim)), mask

    exps = vec.copy()
    exps = np.exp(vec / temperature)
    exps[~mask.astype(bool)] = 0
    return exps / exps.sum(axis=dim, keepdims=True)


def masked_max(vec, mask, dim=1):
    '''
    Max only on valid outputs
    '''
    assert vec.shape == mask.shape
    assert np.all(mask.astype(bool).any(axis=dim)), mask

    res = vec.copy()
    res[~mask.astype(bool)] = np.nan
    return np.nanmax(res, axis=dim, keepdims=True)


def masked_argmax(vec, mask, dim=1):
    '''
    Argmax only on valid outputs
    '''
    assert vec.shape == mask.shape
    assert np.all(mask.astype(bool).any(axis=dim)), mask

    res = vec.copy()
    res[~mask.astype(bool)] = np.nan
    argmax_arr = np.nanargmax(res, axis=dim)

    # Argmax has no keepdims argument
    if dim > 0:
        new_shape = list(res.shape)
        new_shape[dim] = 1
        argmax_arr = argmax_arr.reshape(tuple(new_shape))

    return argmax_arr

Action selectors

In [41]:
## Action selection

class ParameterDecay:

    def __init__(self, parameter_start, parameter_end,
                 parameter_decay=None, total_episodes=None, decaying_episodes=None):
        parameter_decay_choice = parameter_decay is not None
        episodes_decay_choice = total_episodes is not None and decaying_episodes is not None
        assert parameter_decay_choice or episodes_decay_choice

        self.parameter_start = parameter_start
        self.parameter_end = parameter_end
        self.parameter_decay = parameter_decay

    def decay(self, parameter):
        raise NotImplementedError()


class NullParameterDecay(ParameterDecay):

    def __init__(self, parameter_start, *args):
        super(NullParameterDecay, self).__init__(
            parameter_start, parameter_start, parameter_decay=0
        )

    def decay(self, parameter):
        return parameter


class LinearParameterDecay(ParameterDecay):

    def __init__(self, parameter_start, parameter_end,
                 parameter_decay=None, total_episodes=None, decaying_episodes=None):
        super(LinearParameterDecay, self).__init__(
            parameter_start, parameter_end,
            parameter_decay=parameter_decay,
            total_episodes=total_episodes,
            decaying_episodes=decaying_episodes
        )
        if self.parameter_decay is None:
            self.parameter_decay = (
                (self.parameter_start - self.parameter_end) /
                (total_episodes * decaying_episodes)
            )

    def decay(self, parameter):
        return max(
            self.parameter_end, parameter - self.parameter_decay
        )


class ExponentialParameterDecay(ParameterDecay):

    def __init__(self, parameter_start, parameter_end,
                 parameter_decay=None, total_episodes=None, decaying_episodes=None):
        super(ExponentialParameterDecay, self).__init__(
            parameter_start, parameter_end,
            parameter_decay=parameter_decay,
            total_episodes=total_episodes,
            decaying_episodes=decaying_episodes
        )
        if self.parameter_decay is None:
            self.parameter_decay = (
                (self.parameter_end / self.parameter_start) ^
                (1 / (total_episodes * decaying_episodes))
            )

    def decay(self, parameter):
        return max(
            self.parameter_end, parameter * self.parameter_decay
        )


PARAMETER_DECAYS = {
    "none": NullParameterDecay,
    "linear": LinearParameterDecay,
    "exponential": ExponentialParameterDecay
}


class ActionSelector:

    def __init__(self, decay_schedule):
        assert isinstance(decay_schedule, ParameterDecay)
        self.decay_schedule = decay_schedule

    def select(self, actions, legal_actions=None, training=False):
        raise NotImplementedError()

    def select_many(self, actions, moving_agents, legal_actions, training=False):
        assert len(moving_agents.shape) == 1
        assert len(legal_actions.shape) == 2
        assert len(actions.shape) == 2
        assert moving_agents.shape[0] == legal_actions.shape[0] == actions.shape[0]
        assert legal_actions.shape[1] == actions.shape[1]
        num_agents = moving_agents.shape[0]
        choices = np.full((num_agents,), -1)
        is_best = np.full((num_agents,), False)
        for handle in range(num_agents):
            if moving_agents[handle]:
                choices[handle], is_best[handle] = self.select(
                    actions[handle], legal_actions[handle], training=training
                )
        return choices, is_best

    def decay(self):
        return None

    def reset(self):
        return None

    def get_parameter(self):
        return None


class EpsilonGreedyActionSelector(ActionSelector):

    def __init__(self, decay_schedule):
        super(EpsilonGreedyActionSelector, self).__init__(decay_schedule)
        self.epsilon = decay_schedule.parameter_start

    def select(self, actions, legal_actions=None, training=False):
        if legal_actions is None:
            legal_actions = np.ones_like(actions, dtype=bool)
        max_action = masked_argmax(actions, legal_actions, dim=0)
        if not training or random.random() > self.epsilon:
            return max_action, True
        random_action = np.random.choice(
            np.arange(actions.size)[legal_actions]
        )
        return (random_action, max_action == random_action)

    def decay(self):
        self.epsilon = self.decay_schedule.decay(self.epsilon)

    def reset(self):
        self.epsilon = self.decay_schedule.parameter_start

    def get_parameter(self):
        return self.epsilon


class RandomActionSelector(EpsilonGreedyActionSelector):

    def __init__(self, *args):
        super(RandomActionSelector, self).__init__(
            NullParameterDecay(parameter_start=1)
        )


class GreedyActionSelector(EpsilonGreedyActionSelector):

    def __init__(self, *args):
        super(GreedyActionSelector, self).__init__(
            NullParameterDecay(parameter_start=0)
        )


class BoltzmannActionSelector(ActionSelector):

    def __init__(self, decay_schedule):
        super(BoltzmannActionSelector, self).__init__(decay_schedule)
        self.temperature = decay_schedule.parameter_start

    def select(self, actions, legal_actions=None, training=False):
        if legal_actions is None:
            legal_actions = np.ones_like(actions, dtype=bool)
        max_action = masked_argmax(actions, legal_actions, dim=0)
        if not training:
            return max_action, True
        dist = masked_softmax(
            actions, legal_actions, dim=0, temperature=self.temperature
        )
        random_action = np.random.choice(
            np.arange(actions.size)[legal_actions], p=dist[legal_actions]
        )
        is_equal = max_action == random_action
        return random_action, is_equal

    def decay(self):
        self.temperature = self.decay_schedule.decay(self.temperature)

    def reset(self):
        self.temperature = self.decay_schedule.parameter_start

    def get_parameter(self):
        return self.temperature


class CategoricalActionSelector(BoltzmannActionSelector):

    def __init__(self, *args):
        super(CategoricalActionSelector, self).__init__(
            NullParameterDecay(parameter_start=1)
        )


ACTION_SELECTORS = {
    "eps_greedy":  EpsilonGreedyActionSelector,
    "random": RandomActionSelector,
    "greedy": GreedyActionSelector,
    "boltzmann": BoltzmannActionSelector,
    "categorical": CategoricalActionSelector
}

Replay buffers

In [42]:
Experience = namedtuple(
    "Experience", field_names=[
        "state", "choice", "reward", "next_state",
        "next_legal_choices", "finished", "moving"
    ]
)


class ReplayBuffer:
    '''
    Fixed-size buffer to store experience tuples
    '''

    def __init__(self, choice_size, batch_size, buffer_size, device):
        '''
        Initialize a ReplayBuffer object
        '''
        self.choice_size = choice_size
        self.batch_size = batch_size
        self.memory = deque(maxlen=buffer_size)
        self.device = device

    def add(self, experience):
        '''
        Add a new experience to memory
        '''
        self.memory.append(Experience(*experience))

    def sample(self):
        '''
        Randomly sample a batch of experiences from memory.
        Each returned tensor has shape (batch_size, *)
        '''
        states, choices, rewards, next_states, next_legal_choices, finished, moving = zip(
            *random.sample(self.memory, k=self.batch_size)
        )

        # Check for PyTorch Geometric
        if isinstance(states[0], np.ndarray):
            states = torch.tensor(
                states, dtype=torch.float32, device=self.device
            )
            next_states = torch.tensor(
                next_states, dtype=torch.float32, device=self.device
            )
        elif isinstance(states[0], Data):
            states = Batch.from_data_list(states).to(self.device)
            next_states = Batch.from_data_list(next_states).to(self.device)

        choices = torch.tensor(
            choices, dtype=torch.int64, device=self.device
        )
        rewards = torch.tensor(
            rewards, dtype=torch.float32, device=self.device
        )
        next_legal_choices = torch.tensor(
            next_legal_choices, dtype=torch.bool, device=self.device
        )
        finished = torch.tensor(
            finished, dtype=torch.uint8, device=self.device
        )
        moving = torch.tensor(
            moving, dtype=torch.bool, device=self.device
        )

        return states, choices, rewards, next_states, next_legal_choices, finished, moving

    def can_sample(self):
        '''
        Check if there are enough samples in the replay buffer
        '''
        return len(self.memory) >= self.batch_size

    def save(self, filename):
        '''
        Save the current replay buffer to a pickle file
        '''
        with open(filename, 'wb') as f:
            pickle.dump(list(self.memory), f)

    def load(self, filename):
        '''
        Load the current replay buffer from the given pickle file
        '''
        with open(filename, 'rb') as f:
            self.memory = pickle.load(f)

    def __len__(self):
        '''
        Return the current size of internal memory
        '''
        return len(self.memory)

Policies

In [43]:
LOSSES = {
    "huber": MaskedHuberLoss(),
    "mse": MaskedMSELoss(),
}

class Policy:
    '''
    Policy abstract class
    '''

    def __init__(self, params=None, state_size=None, choice_size=None, choice_selector=None, training=False):
        self.params = params
        self.state_size = state_size
        self.choice_size = choice_size
        self.choice_selector = choice_selector
        self.training = training

    def act(self, state, legal_choices=None, training=False):
        raise NotImplementedError()

    def step(self, experience):
        raise NotImplementedError()

    def save(self, filename):
        raise NotImplementedError()

    def load(self, filename):
        raise NotImplementedError()


class RandomPolicy(Policy):
    '''
    Policy which chooses random moves
    '''

    def __init__(self, params=None, state_size=None, choice_selector=None, training=False):
        super(RandomPolicy, self).__init__(
            params, state_size, choice_size=RailEnvChoices.choice_size(),
            choice_selector=RandomActionSelector(), training=training
        )

    def act(self, states, legal_choices, moving_agents, training=False):
        choice_values = np.zeros((moving_agents.shape[0], self.choice_size))
        return self.choice_selector.select_many(
            choice_values, moving_agents, np.array(legal_choices),
            training=(training and self.training)
        )

    def step(self, experience):
        return None

    def save(self, filename):
        return None

    def load(self, filename):
        return None


class DQNPolicy(Policy):
    '''
    DQN policy
    '''

    def __init__(self, params, state_size, choice_selector, training=False):
        '''
        Initialize DQNPolicy object
        '''
        super(DQNPolicy, self).__init__(
            params, state_size, choice_size=RailEnvChoices.choice_size(),
            choice_selector=choice_selector, training=training
        )
        assert isinstance(
            choice_selector, ActionSelector
        ), "The choice selection object must be an instance of ActionSelector"

        # Parameters
        self.device = torch.device("cpu")
        if self.params.generic.use_gpu and torch.cuda.is_available():
            self.device = torch.device("cuda:0")
            print("🐇 Using GPU")

        # Q-Network
        net = DuelingDQN if self.params.model.dqn.dueling.enabled else DQN
        self.qnetwork_local = net(
            self.state_size, RailEnvChoices.choice_size(),
            self.params.model.dqn, device=self.device
        ).to(self.device)

        # Training parameters
        if self.training:
            self.time_step = 0
            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
            self.optimizer = optim.Adam(
                self.qnetwork_local.parameters(), lr=self.params.learning.learning_rate
            )
            self.criterion = LOSSES[self.params.learning.loss.get_true_key()]
            self.loss = torch.tensor(0.0)
            self.memory = ReplayBuffer(
                RailEnvChoices.choice_size(), self.params.replay_buffer.batch_size,
                self.params.replay_buffer.size, self.device
            )

    def enable_wandb(self):
        '''
        Log gradients and parameters to wandb
        '''
        wandb.watch(
            self.qnetwork_local, self.criterion,
            log="all", log_freq=self.params.generic.wandb_gradients.checkpoint
        )

    def act(self, states, legal_choices, moving_agents, training=False):
        '''
        Perform action selection based on the Q-values returned by the network
        '''
        choice_values = np.zeros((moving_agents.shape[0], self.choice_size),)
        # If no agent is active, then skip computations
        if moving_agents.any():

            # Add 1 dimension to state to simulate a mini-batch of size 1
            if self.params.policy.type.decentralized_fov:
                states = Batch.from_data_list([states[0]]).to(self.device)
            elif self.params.policy.type.graph:
                states = Batch.from_data_list(states).to(self.device)
            else:
                states = torch.tensor(
                    states, dtype=torch.float, device=self.device
                )

            # Convert moving agents to tensor
            t_moving_agents = torch.from_numpy(
                moving_agents
            ).bool().to(self.device)

            # Call the network
            self.qnetwork_local.eval()
            with torch.no_grad():
                choice_values = self.qnetwork_local(
                    states, mask=t_moving_agents
                ).detach().cpu().numpy()

        # Select a legal choice based on the action selector
        return self.choice_selector.select_many(
            choice_values, moving_agents, np.array(legal_choices),
            training=(training and self.training)
        )

    def step(self, experiences):
        '''
        Add an experience to memory and eventually perform a training step
        '''
        assert self.training, "Policy has been initialized for evaluation only"
        for experience in experiences:
            # Save experience in replay memory
            self.memory.add(experience)

            # Learn every `checkpoint` time steps
            # (if enough samples are available in memory, get random subset and learn)
            self.time_step = (
                self.time_step + 1
            ) % self.params.replay_buffer.checkpoint
            if self.time_step == 0 and self.memory.can_sample():
                self.qnetwork_local.train()
                self._learn()

    def _learn(self):
        '''
        Perform a learning step
        '''
        # Sample a batch of experiences
        experiences = self.memory.sample()
        states, choices, rewards, next_states, next_legal_choices, finished, moving = experiences

        # Get expected Q-values from local model
        q_expected = self.qnetwork_local(states, mask=moving).gather(
            1, choices.flatten().unsqueeze(1)
        ).squeeze(1)

        # Get expected Q-values from target model
        q_targets_next = torch.from_numpy(
            self._get_q_targets_next(
                next_states, next_legal_choices.cpu().numpy(), moving
            )
        ).squeeze(1).to(self.device)

        # Compute Q-targets for current states
        q_targets = (
            torch.flatten(rewards) + (
                self.params.learning.discount *
                q_targets_next * (1 - torch.flatten(finished))
            )
        )

        # Compute and minimize the loss
        self.loss = self.criterion(q_expected, q_targets, mask=moving)
        self.optimizer.zero_grad()
        self.loss.backward()
        if self.params.learning.gradient.clip_norm:
            nn.utils.clip_grad.clip_grad_norm_(
                self.qnetwork_local.parameters(), self.params.learning.gradient.max_norm
            )
        elif self.params.learning.gradient.clamp_values:
            nn.utils.clip_grad.clip_grad_value_(
                self.qnetwork_local.parameters(), self.params.learning.gradient.value_limit
            )
        self.optimizer.step()

        '''
        print(
            q_expected.shape, 
            q_targets_next.shape, 
            q_targets.shape, 
            rewards.shape, 
            finished.shape, 
            self.loss
        )
        '''

        # Update target network
        self._soft_update(self.qnetwork_local, self.qnetwork_target)

    def _get_q_targets_next(self, next_states, next_legal_choices, moving):
        '''
        Get expected Q-values from target network
        '''

        def _double_dqn():
            q_targets_next = self.qnetwork_target(
                next_states, mask=moving
            ).detach().cpu().numpy()
            q_locals_next = self.qnetwork_local(
                next_states, mask=moving
            ).detach().cpu().numpy()

            # Softmax Bellman
            if self.params.learning.softmax_bellman.enabled:
                return np.sum(
                    q_targets_next * masked_softmax(
                        q_locals_next,
                        next_legal_choices.reshape(q_locals_next.shape),
                        temperature=self.params.learning.softmax_bellman.temperature
                    ), axis=1, keepdims=True
                )

            # Standard Bellman
            best_choices = masked_argmax(
                q_locals_next,
                next_legal_choices.reshape(q_locals_next.shape)
            )
            return np.take_along_axis(q_targets_next, best_choices, axis=1)

        def _dqn():
            q_targets_next = self.qnetwork_target(
                next_states
            ).detach().cpu().numpy()

            # Standard or softmax Bellman
            return (
                masked_max(
                    q_targets_next,
                    next_legal_choices.reshape(q_targets_next.shape)
                )
                if not self.params.learning.softmax_bellman.enabled
                else np.sum(
                    q_targets_next * masked_softmax(
                        q_targets_next,
                        next_legal_choices.reshape(q_targets_next.shape),
                        temperature=self.params.learning.softmax_bellman.temperature
                    ), axis=1, keepdims=True
                )
            )

        return _double_dqn() if self.params.model.dqn.double else _dqn()

    def _soft_update(self, local_model, target_model):
        '''
        Soft update model parameters: θ_target = τ * θ_local + (1 - τ) * θ_target
        '''
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(
                self.params.learning.tau * local_param.data +
                (1.0 - self.params.learning.tau) * target_param.data
            )

    def save(self, filename):
        '''
        Save both local and targets networks parameters
        '''
        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
        torch.save(self.qnetwork_target.state_dict(), filename + ".target")

    def load(self, filename):
        '''
        Load only the local network if evaluating,
        otherwise load both local and target networks
        '''
        if os.path.exists(filename + ".local"):
            self.qnetwork_local.load_state_dict(
                torch.load(filename + ".local", map_location=self.device)
            )
            if self.training and os.path.exists(filename + ".target"):
                self.qnetwork_target.load_state_dict(
                    torch.load(filename + ".target", map_location=self.device)
                )
        else:
            print("Model not found. Please, check the given path.")

    def save_replay_buffer(self, filename):
        '''
        Save the current replay buffer
        '''
        self.memory.save(filename)

    def load_replay_buffer(self, filename):
        '''
        Load a stored representation of the replay buffer
        '''
        self.memory.load(filename)


class DQNGNNPolicy(DQNPolicy):
    '''
    DQN + GNN policy
    '''

    def __init__(self, params, state_size, choice_selector, training=False):
        '''
        Initialize DQNGNNPolicy object
        '''
        super(DQNGNNPolicy, self).__init__(
            params, (
                params.model.entire_gnn.embedding_size *
                params.model.entire_gnn.pos_size
            ),
            choice_selector, training=training
        )

        self.qnetwork_local = Sequential(
            EntireGNN(
                state_size, self.params.observator.max_depth,
                self.params.model.entire_gnn, device=self.device
            ).to(self.device),
            self.qnetwork_local
        )

        if training:
            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)


class DecentralizedFOVDQNPolicy(DQNPolicy):
    '''
    Decentralized FOV CNN + GNN + DQN policy
    '''

    def __init__(self, params, state_size, choice_selector, training=False):
        '''
        Initialize MultiAgentDQNGNNPolicy object
        '''
        super(DecentralizedFOVDQNPolicy, self).__init__(
            params, params.model.multi_gnn.gnn_communication.embedding_size,
            choice_selector, training=training
        )

        self.qnetwork_local = Sequential(
            MultiGNN(
                self.params.observator.max_depth,
                self.params.observator.max_depth,
                state_size, self.params.model.multi_gnn,
                device=self.device
            ).to(self.device),
            self.qnetwork_local
        )

        if training:
            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)


POLICIES = {
    "tree": DQNPolicy,
    "binary_tree": DQNPolicy,
    "graph": DQNGNNPolicy,
    "decentralized_fov": DecentralizedFOVDQNPolicy,
    "random": RandomPolicy
}

## Training

In [68]:

from torch.utils.tensorboard import SummaryWriter


def tensorboard_log(writer, name, x, y, plot=['min', 'max', 'mean', 'std', 'hist']):
    '''
    Log the given x/y values to tensorboard
    '''
    if not isinstance(x, np.ndarray) and not isinstance(x, list):
        writer.add_scalar(name, x, y)
    else:
        if ((isinstance(x, list) and len(x) == 0) or
                (isinstance(x, np.ndarray) and x.size == 0)):
            return
        if 'min' in plot:
            writer.add_scalar(f"{name}_min", np.min(x), y)
        if 'max' in plot:
            writer.add_scalar(f"{name}_max", np.max(x), y)
        if 'mean' in plot:
            writer.add_scalar(f"{name}_mean", np.mean(x), y)
        if 'std' in plot:
            writer.add_scalar(f"{name}_std", np.std(x), y)
        if 'hist' in plot:
            writer.add_histogram(name, np.array(x), y)


def format_choices_probabilities(choices_probabilities):
    '''
    Helper function to pretty print choices probabilities
    '''
    choices_probabilities = np.round(choices_probabilities, 3)
    choices = ["←", "→", "◼"]

    buffer = ""
    for choice, choice_prob in zip(choices, choices_probabilities):
        buffer += choice + " " + "{:^4.2%}".format(choice_prob) + " "

    return buffer


def train_agents(args, writer):
    '''
    Train and evaluate agents on the specified environments
    '''
    # Initialize threads and seeds
    set_num_threads(args.generic.num_threads)
    if args.generic.fix_random:
        fix_random(args.generic.random_seed)

    # Setup the environments
    train_env = create_rail_env(
        args, load_env=args.training.train_env.load
    )
    eval_env = create_rail_env(
        args, load_env=args.training.eval_env.load
    )

    # Define "static" random seeds for evaluation purposes
    eval_seeds = [args.env.seed] * args.training.eval_env.episodes
    if args.training.eval_env.all_random:
        eval_seeds = [
            get_seed(eval_env)
            for e in range(args.training.eval_env.episodes)
        ]

    # Pick action selector and parameter decay
    pd_type = args.parameter_decay.type.get_true_key()
    parameter_decay = PARAMETER_DECAYS[pd_type](
        parameter_start=args.parameter_decay.start,
        parameter_end=args.parameter_decay.end,
        total_episodes=args.training.train_env.episodes,
        decaying_episodes=args.parameter_decay.decaying_episodes
    )
    as_type = args.action_selector.type.get_true_key()
    action_selector = ACTION_SELECTORS[as_type](parameter_decay)

    # Initialize the agents policy
    policy_type = args.policy.type.get_true_key()
    policy = POLICIES[policy_type](
        args, train_env.state_size, action_selector, training=True
    )
    if args.generic.enable_wandb and args.generic.wandb_gradients.enabled:
        policy.enable_wandb()

    # Handle replay buffer
    if args.replay_buffer.load:
        try:
            policy.load_replay_buffer(args.replay_buffer.load)
        except RuntimeError as e:
            print(
                "\n🛑 Could't load replay buffer, were the experiences generated using the same depth?"
            )
            print(e)
            exit(1)
    print("\n💾 Replay buffer status: {}/{} experiences".format(
        len(policy.memory), args.replay_buffer.size
    ))

    # Set the unique ID for this training
    now = datetime.now()
    training_id = now.strftime('%Y%m%d-%H%M%S')
    if args.training.renderer.training and args.training.renderer.save_frames:
        frames_dir = f"tmp/frames/{training_id}"
        os.makedirs(frames_dir, exist_ok=True)

    # Print initial training info
    training_timer = Timer()
    training_timer.start()
    print("\n🚉 Starting training \t Training {} trains on {}x{} grid for {} episodes \tEvaluating on {} episodes every {} episodes".format(
        args.env.num_trains,
        args.env.width, args.env.height,
        args.training.train_env.episodes,
        args.training.eval_env.episodes,
        args.training.checkpoint
    ))
    print(f"\n🧠 Model with training id {training_id}\n")

    # Do the specified number of episodes
    scores, custom_scores, completions, steps, deadlocks = [], [], [], [], []
    choices_taken = np.zeros((args.training.train_env.episodes + 1,))
    for episode in range(args.training.train_env.episodes + 1):

        # Initialize timers
        step_timer = Timer()
        reset_timer = Timer()
        learn_timer = Timer()
        inference_timer = Timer()

        # Reset environment and renderer
        reset_timer.start()
        if not args.training.train_env.all_random:
            obs, info = train_env.reset(random_seed=args.env.seed)
        else:
            obs, info = train_env.reset(
                regenerate_rail=True, regenerate_schedule=True
            )
        reset_timer.end()
        if args.training.renderer.training and episode % args.training.renderer.train_checkpoint == 0:
            env_renderer = train_env.get_renderer()

        # Compute agents with same source
        agents_with_same_start = train_env.get_agents_same_start()

        # Create data structures for training info
        score, custom_score, final_step = 0.0, 0.0, 0
        choices_count = np.zeros((RailEnvChoices.choice_size(),))
        num_exploration_choices = np.zeros_like(choices_count)
        legal_choices = np.full(
            (args.env.num_trains, RailEnvChoices.choice_size()),
            RailEnvChoices.default_choices()
        )
        legal_actions = np.full(
            (args.env.num_trains, get_num_actions()), False
        )
        moving_agents = np.full((args.env.num_trains,), False)
        action_dict, choice_dict = dict(), dict()
        prev_obs, prev_choices = dict(), dict()

        # Initialize data structures
        for handle in range(args.env.num_trains):
            legal_actions[handle] = train_env.railway_encoding.get_agent_actions(handle)
            legal_choices[handle] = train_env.railway_encoding.get_legal_choices(handle, legal_actions[handle])
            choice_dict.update({handle: RailEnvChoices.CHOICE_LEFT.value})
            
            if obs[handle] is not None:
                prev_obs[handle] = copy_obs(obs[handle])

        # Update initial previous choices based on the policy type
        for handle in range(args.env.num_trains):
            if args.policy.type.decentralized_fov:
                prev_choices[handle] = dict(choice_dict)
            else:
                prev_choices[handle] = choice_dict[handle]

        # Do an episode
        for step in range(train_env._max_episode_steps):

            # Prioritize entry of faster agents in the environment
            for position in agents_with_same_start:
                if len(agents_with_same_start[position]) > 0:
                    del agents_with_same_start[position][0]
                    for agent in agents_with_same_start[position]:
                        info['action_required'][agent] = False

            # Policy act
            inference_timer.start()
            legal_actions, legal_choices, moving_agents = train_env.pre_act()
            choices, is_best = policy.act(
                list(obs.values()), legal_choices,
                moving_agents, training=True
            )

            # Update training info after policy act
            action_dict, metadata = train_env.post_act(
                choices, is_best, legal_actions, moving_agents
            )
            current_choices_count = metadata['choices_count']
            choices_count += current_choices_count
            choices_taken[episode] += np.sum(current_choices_count)
            num_exploration_choices += metadata['num_exploration_choices']
            choice_dict.update(metadata['choice_dict'])
            inference_timer.end()

            # Environment step
            step_timer.start()
            next_obs, rewards, custom_rewards, done, info = train_env.step(
                action_dict
            )
            step_timer.end()

            # Render an episode at some interval
            if args.training.renderer.training and episode % args.training.renderer.train_checkpoint == 0:
                env_renderer.render_env(
                    show=True, show_observations=False, show_predictions=True, show_rowcols=True
                )
                # Save renderer frame
                if args.training.renderer.save_frames:
                    env_renderer.gl.save_image(
                        "{:s}/{:04d}.png".format(frames_dir, step)
                    )

            # Policy step
            learn_timer.start()
            experience = (
                prev_obs,
                prev_choices,
                custom_rewards,
                obs,
                legal_choices,
                moving_agents
            )
            experiences = train_env.pre_step(experience)
            policy.step(experiences)
            learn_timer.end()

            # Update training info after policy step
            metadata = train_env.post_step(
                obs, choice_dict, next_obs,
                moving_agents, rewards, custom_rewards
            )
            obs.update(metadata['obs'])
            prev_obs.update(metadata['prev_obs'])
            prev_choices.update(metadata['prev_choices'])
            score += metadata['score']
            custom_score += metadata['custom_score']

            # Break if every agent arrived
            final_step = step
            if done['__all__'] or train_env.check_if_all_blocked(info["deadlocks"]):
                break

        # Close window
        if args.training.renderer.training and episode % args.training.renderer.train_checkpoint == 0:
            env_renderer.close_window()

        # Parameter decay
        policy.choice_selector.decay()

        # Save final scores
        scores.append(
            score / (
                train_env._max_episode_steps *
                train_env.get_num_agents()
            )
        )
        custom_scores.append(custom_score / train_env.get_num_agents())
        completions.append(
            sum(done[idx] for idx in train_env.get_agent_handles()) /
            train_env.get_num_agents()
        )
        steps.append(final_step)
        deadlocks.append(
            sum(int(v) for v in info["deadlocks"].values()) /
            train_env.get_num_agents()
        )
        choices_probs = choices_count / np.sum(choices_count)

        # Save model and replay buffer at checkpoint
        if episode % args.training.checkpoint == 0:
            policy.save(f'./checkpoints/{training_id}-{episode}')

            # Save replay buffer
            if args.replay_buffer.save:
                policy.save_replay_buffer(
                    f'./replay_buffers/{training_id}-{episode}.pkl'
                )

        # Print episode info
        print(
            '\r🚂 Episode {:4n}'
            '\t 🏆 Score: {:<+5.4f}'
            ' Avg: {:>+5.4f}'
            '\t 🏅 Custom score: {:<+5.4f}'
            ' Avg: {:>+5.4f}'
            '\t 💯 Done: {:<7.2%}'
            ' Avg: {:>7.2%}'
            '\t 💀 Deadlocks: {:<7.2%}'
            ' Avg: {:>7.2%}'
            '\t 🦶 Steps: {:4n}/{:4n}'
            '\t 🎲 Exploration prob: {:4.3f} '
            '\t 🤔 Choices: {:4n}'
            '\t 🤠 Exploration: {:3n}'
            '\t 🔀 Choices probs: {:^}'.format(
                episode,
                scores[-1],
                np.mean(scores),
                custom_scores[-1],
                np.mean(custom_scores),
                completions[-1],
                np.mean(completions),
                deadlocks[-1],
                np.mean(deadlocks),
                steps[-1],
                train_env._max_episode_steps,
                policy.choice_selector.get_parameter(),
                choices_taken[episode],
                np.sum(num_exploration_choices),
                format_choices_probabilities(choices_probs)
            ), end="\n"
        )

        # Evaluate policy and log results at some interval
        # (always evaluate the final episode)
        if (args.training.eval_env.episodes > 0 and
            ((episode > 0 and episode % args.training.checkpoint == 0) or
             (episode == args.training.train_env.episodes))):
            eval_policy(args, writer, eval_env, policy, eval_seeds, episode)

    # Print final training info
    print("\n\r🏁 Training ended \tTrained {} trains on {}x{} grid for {} episodes \t Evaluated on {} episodes every {} episodes".format(
        args.env.num_trains,
        args.env.width, args.env.height,
        args.training.train_env.episodes,
        args.training.eval_env.episodes,
        args.training.checkpoint
    ))
    print(
        f"\n💾 Replay buffer status: {len(policy.memory)}/{args.replay_buffer.size} experiences"
    )

    # Save trained models
    print(f"\n🧠 Saving model with training id {training_id}")
    policy.save(f'./checkpoints/{training_id}-latest')
    if args.generic.enable_wandb:
        wandb.save(f'./checkpoints/{training_id}-latest.local')
    if args.replay_buffer.save:
        policy.save_replay_buffer(
            f'./replay_buffers/{training_id}-latest.pkl'
        )


def eval_policy(args, writer, env, policy, eval_seeds, train_episode):
    '''
    Perform a validation round with the given policy
    in the specified environment
    '''
    choices_taken = np.zeros((len(eval_seeds),))
    scores, custom_scores, completions, steps, deadlocks = [], [], [], [], []

    # Do the specified number of episodes
    print('\nStarting validation:')
    for episode, seed in enumerate(eval_seeds):
        score, custom_score, final_step = 0.0, 0.0, 0

        # Reset environment and renderer
        if not args.training.eval_env.all_random:
            obs, info = env.reset(random_seed=seed)
        else:
            obs, info = env.reset(
                regenerate_rail=True, regenerate_schedule=True,
            )
        if args.training.renderer.evaluation and episode % args.training.renderer.eval_checkpoint == 0:
            env_renderer = env.get_renderer()

        # Compute agents with same source
        agents_with_same_start = env.get_agents_same_start()

        # Do an episode
        for step in range(env._max_episode_steps):

            # Prioritize enter of faster agent in the environment
            for position in agents_with_same_start:
                if len(agents_with_same_start[position]) > 0:
                    del agents_with_same_start[position][0]
                    for agent in agents_with_same_start[position]:
                        info['action_required'][agent] = False

            # Policy act
            legal_actions, legal_choices, moving_agents = env.pre_act()
            choices, is_best = policy.act(
                list(obs.values()), legal_choices,
                moving_agents, training=False
            )
            action_dict, metadata = env.post_act(
                choices, is_best, legal_actions, moving_agents
            )
            current_choices_count = metadata['choices_count']
            choices_taken[episode] += np.sum(current_choices_count)

            # Environment step
            obs, rewards, custom_rewards, done, info = env.step(
                action_dict
            )

            # Render an episode at some interval
            if args.training.renderer.evaluation and episode % args.training.renderer.eval_checkpoint == 0:
                env_renderer.render_env(
                    show=True, show_observations=False, show_predictions=True, show_rowcols=True
                )

            # Update agents scores
            for agent in env.get_agent_handles():
                score += rewards[agent]
                custom_score += custom_rewards[agent]

            # Break if every agent arrived
            final_step = step
            if done['__all__'] or env.check_if_all_blocked(info["deadlocks"]):
                break

        # Close window
        if args.training.renderer.evaluation and episode % args.training.renderer.eval_checkpoint == 0:
            env_renderer.close_window()

        # Save final scores
        scores.append(
            score / (
                env._max_episode_steps *
                env.get_num_agents()
            )
        )
        custom_scores.append(custom_score / env.get_num_agents())
        completions.append(
            sum(done[idx] for idx in env.get_agent_handles()) /
            env.get_num_agents()
        )
        steps.append(final_step)
        deadlocks.append(
            sum(int(v) for v in info["deadlocks"].values()) /
            env.get_num_agents()
        )

        # Print evaluation results on one episode
        print(
            '\r🚂 Validation {:3n}'
            '\t 🏆 Score: {:+5.4f}'
            '\t 🏅 Custom score: {:+5.4f}'
            '\t 💯 Done: {:7.2%}'
            '\t 💀 Deadlocks: {:7.2%}'
            '\t 🦶 Steps: {:4n}/{:4n}'
            '\t 🤔 Choices: {:4n}'.format(
                episode,
                scores[-1],
                custom_scores[-1],
                completions[-1],
                deadlocks[-1],
                steps[-1],
                env._max_episode_steps,
                choices_taken[episode]
            ), end="\n"
        )

    # Print validation results
    print(
        '\r✅ Validation ended'
        '\t 🏆 Avg score: {:+5.2f}'
        '\t 🏅 Avg custom score: {:+5.2f}'
        '\t 💯 Avg done: {:7.2%}'
        '\t 💀 Avg deadlocks: {:7.2%}'
        '\t 🦶 Avg steps: {:5.2f}'
        '\t 🤔 Avg choices: {:5.2f}'.format(
            np.mean(scores),
            np.mean(custom_scores),
            np.mean(completions),
            np.mean(deadlocks),
            np.mean(steps),
            np.mean(choices_taken)
        ), end="\n\n"
    )


def main():
    '''
    Train environment with custom observation and prediction
    '''
    with open('lelia/parameters.yml', 'r') as conf:
        args = yaml.load(conf, Loader=yaml.FullLoader)
    writer = SummaryWriter()
    args = Struct(**args)
    train_agents(args, writer)
    writer.close()


if __name__ == "__main__":
    main()


💾 Replay buffer status: 0/100000 experiences

🚉 Starting training 	 Training 7 trains on 48x27 grid for 7500 episodes 	Evaluating on 20 episodes every 500 episodes

🧠 Model with training id 20240310-222743



IndexError: list index out of range

## Test

In [28]:
def print_agents_info(env, info, actions):
    '''
    Print information for each agent in a specific step
    '''
    _status_table = []
    for handle, agent in enumerate(env.agents):
        _status_table.append([
            handle,
            agent.status,
            agent.speed_data["speed"],
            agent.speed_data['position_fraction'],
            (
                agent.initial_position[0],
                agent.initial_position[1],
                agent.direction
            ) if agent.status == TrainState.READY_TO_DEPART else
            (
                agent.position[0],
                agent.position[1],
                agent.direction
            ) if agent.status != TrainState.DONE else (
                'DONE'
            ),
            agent.target,
            actions[handle],
            agent.malfunction_data['malfunction'],
            info["deadlocks"][handle]
        ])
    print(tabulate(
        _status_table,
        [
            "Handle", "Status", "Speed", "Position fraction",
            "Position", "Target", "Action Taken", "Malfunction", "Deadlock"
        ],
        colalign=["center"] * 9
    ))


def test_agents(args):
    '''
    Test agents on the specified environment
    '''
    choices_taken = np.zeros((args.testing.episodes,))
    scores, custom_scores, completions, steps, deadlocks = [], [], [], [], []

    # Initialize threads and seeds
    set_num_threads(args.generic.num_threads)
    if args.generic.fix_random:
        fix_random(args.generic.random_seed)

    # Create railway environment
    env = create_rail_env(args, load_env=args.testing.load)

    # Load the model if provided
    if args.testing.model:
        parameter_decay = PARAMETER_DECAYS["none"](
            parameter_start=args.parameter_decay.start
        )
        action_selector = ACTION_SELECTORS["greedy"](parameter_decay)
        policy_type = args.policy.type.get_true_key()
        policy = POLICIES[policy_type](
            args, env.state_size, action_selector, training=False
        )
        policy.load(args.testing.model)
    else:
        policy = POLICIES["random"]()

    print("\n🚉 Starting testing \t Testing {} trains on {}x{} grid for {} episodes".format(
        args.env.num_trains,
        args.env.width, args.env.height,
        args.testing.episodes,
    ))

    # Perform the given number of episodes
    for episode in range(args.testing.episodes):
        score, custom_score, final_step = 0.0, 0.0, 0

        # Generate a new railway and renderer
        obs, info = env.reset(
            regenerate_rail=True, regenerate_schedule=True
        )
        if args.testing.renderer.enabled:
            env_renderer = env.get_renderer()

        # Print agents tasks
        if args.testing.verbose:
            _tasks_table = []
            for handle, agent in enumerate(env.agents):
                _tasks_table.append([
                    handle,
                    agent.status,
                    agent.speed_data["speed"],
                    (
                        agent.initial_position[0],
                        agent.initial_position[1],
                        agent.direction
                    ),
                    agent.target
                ])
            print(f"Episode {episode}")
            print(tabulate(
                _tasks_table,
                ["Handle", "Status", "Speed", "Source", "Target"],
                colalign=["center"] * 5
            ))
            print()

        # Create frames directory
        now = datetime.now()
        test_id = now.strftime('%Y%m%d-%H%M%S')
        if args.testing.renderer.enabled and args.testing.renderer.save_frames:
            frames_dir = f"tmp/frames/{test_id}"
            os.makedirs(frames_dir, exist_ok=True)

        # Compute agents with same source
        agents_with_same_start = env.get_agents_same_start()

        for step in range(env._max_episode_steps):

            # Prioritize entry of faster agent in the environment
            for position in agents_with_same_start:
                if len(agents_with_same_start[position]) > 0:
                    del agents_with_same_start[position][0]
                    for agent in agents_with_same_start[position]:
                        info['action_required'][agent] = False

            # Policy act
            legal_actions, legal_choices, moving_agents = env.pre_act()
            choices, is_best = policy.act(
                list(obs.values()), legal_choices,
                moving_agents, training=False
            )
            action_dict, metadata = env.post_act(
                choices, is_best, legal_actions, moving_agents
            )
            current_choices_count = metadata['choices_count']
            choices_taken[episode] += np.sum(current_choices_count)

            # Environment step
            obs, rewards, custom_rewards, done, info = env.step(
                action_dict
            )

            # Render an episode at some interval
            if args.testing.renderer.enabled:
                env_renderer.render_env(
                    show=True, show_observations=False, show_predictions=True, show_rowcols=True
                )

                # Wait to observe the current frame
                if args.testing.renderer.sleep > 0:
                    time.sleep(args.testing.renderer.sleep)

                # Save renderer frame
                if args.testing.renderer.save_frames:
                    env_renderer.gl.save_image(
                        "{:s}/{:04d}.png".format(frames_dir, step)
                    )

            # Update agents score
            for handle in range(env.get_num_agents()):
                score += rewards[handle]
                custom_score += custom_rewards[handle]

            # Compute statistics
            if args.testing.verbose:
                normalized_score = (
                    score / (env._max_episode_steps * env.get_num_agents())
                )
                normalized_custom_score = custom_score / env.get_num_agents()
                print(
                    f"Score: {round(normalized_score, 4)} /"
                    f"Custom score: {round(normalized_custom_score, 4)}"
                )
                print_agents_info(env, info, action_dict)
                print()

            # Check if every agent is arrived
            final_step = step
            if done['__all__'] or env.check_if_all_blocked(info["deadlocks"]):
                break

        # Close window
        if args.testing.renderer.enabled:
            env_renderer.close_window()

        # Save final scores
        scores.append(
            score / (
                env._max_episode_steps *
                env.get_num_agents()
            )
        )
        custom_scores.append(custom_score / env.get_num_agents())
        completions.append(
            sum(done[idx] for idx in env.get_agent_handles()) /
            env.get_num_agents()
        )
        steps.append(final_step)
        deadlocks.append(
            sum(int(v) for v in info["deadlocks"].values()) /
            env.get_num_agents()
        )

        # Print episode info
        print(
            '\r🚂 Test {:4n}'
            '\t 🏆 Score: {:<+5.4f}'
            ' Avg: {:>+5.4f}'
            '\t 🏅 Custom score: {:<+5.4f}'
            ' Avg: {:>+5.4f}'
            '\t 💯 Done: {:<7.2%}'
            ' Avg: {:>7.2%}'
            '\t 💀 Deadlocks: {:<7.2%}'
            ' Avg: {:>7.2%}'
            '\t 🦶 Steps: {:4n}/{:4n}'
            '\t 🤔 Choices: {:4n}'.format(
                episode,
                scores[-1],
                np.mean(scores),
                custom_scores[-1],
                np.mean(custom_scores),
                completions[-1],
                np.mean(completions),
                deadlocks[-1],
                np.mean(deadlocks),
                steps[-1],
                env._max_episode_steps,
                choices_taken[episode]
            ), end="\n"
        )

    # Print final testing info
    print("\n\r🏁 Testing ended \tTested {} trains on {}x{} grid for {} episodes".format(
        args.env.num_trains, args.env.width, args.env.height, args.testing.episodes
    ))

    # Print final testing results
    print(
        '\r✅ Testing ended'
        '\t 🏆 Avg score: {:+7.4f}'
        '\t 🏅 Avg custom score: {:+7.4f}'
        '\t 💯 Avg done: {:7.4f}'
        '\t 💀 Avg deadlocks: {:7.4f}'
        '\t 🦶 Avg steps: {:5.2f}'
        '\t 🤔 Avg choices: {:5.2f}'.format(
            np.mean(scores),
            np.mean(custom_scores),
            np.mean(completions),
            np.mean(deadlocks),
            np.mean(steps),
            np.mean(choices_taken)
        ), end="\n\n"
    )


def main():
    '''
    Test environment with the given model
    '''
    with open('lelia/parameters.yml', 'r') as conf:
        args = yaml.load(conf, Loader=yaml.FullLoader)
    args = Struct(**args)
    test_agents(args)


if __name__ == "__main__":
    main()


🚉 Starting testing 	 Testing 7 trains on 48x27 grid for 500 episodes


AttributeError: 'RailEnvWrapper' object has no attribute 'schedule_generator'

In [None]:
## Other training based on the methods used in the project

train_agent(model, policy, )