In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import time
import os
import numpy as np 
from numpy import array
import threading
import webbrowser
import pandas as pd # useful for displaying tables
#import ipysheet as ips
from IPython.display import HTML, display
from matplotlib import pyplot as plt
import PIL

from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.observations import GlobalObsForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
# We also include a renderer because we want to visualize what is going on in the environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.utils.flask_util import simple_flask_server

In [None]:
nSize = 3

if nSize == 1: 
    width = 30  # With of map
    height = 30  # Height of map
    nr_trains = 5  # Number of trains that have an assigned task in the env
    cities_in_map = 2  # Number of cities where agents can start or end
elif nSize == 2:
    width = 8 * 7  # With of map
    height = 8 * 7  # Height of map
    nr_trains = 20  # Number of trains that have an assigned task in the env
    cities_in_map = 10  # Number of cities where agents can start or end
else:
    width = 16 * 7  # With of map
    height = 9 * 7  # Height of map
    nr_trains = 50  # Number of trains that have an assigned task in the env
    cities_in_map = 20  # Number of cities where agents can start or end
    
seed = 14  # Random seed
grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2  # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6  # Max number of parallel tracks within a city, representing a realistic trainstation


In [None]:
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
                                       seed=seed,
                                       grid_mode=grid_distribution_of_cities,
                                       max_rails_between_cities=max_rails_between_cities,
                                       max_rails_in_city=max_rail_in_cities,
                                       )

speed_ration_map = {1.: 0.25,  # Fast passenger train
                    1. / 2.: 0.25,  # Fast freight train
                    1. / 3.: 0.25,  # Slow commuter train
                    1. / 4.: 0.25}  # Slow freight train

schedule_generator = sparse_schedule_generator(speed_ration_map)

stochastic_data = MalfunctionParameters(malfunction_rate=10000,  # Rate of malfunction occurence
                                        min_duration=15,  # Minimal duration of malfunction
                                        max_duration=50  # Max duration of malfunction
                                        )

# Custom observation builder without predictor
observation_builder = GlobalObsForRailEnv()

In [None]:
env = RailEnv(width=width,
              height=height,
              rail_generator=rail_generator,
              schedule_generator=schedule_generator,
              number_of_agents=nr_trains,
              obs_builder_object=observation_builder,
              malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
              remove_agents_at_target=True)

env.reset();

In [None]:
env_renderer = RenderTool(env,  
                          gl="BROWSER",
                          agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                          show_debug=False,
                          host="127.0.0.1",
                          port=None
                         )

In [None]:
render_pil = RenderTool(env, gl="PILSVG",
                          agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                          show_debug=False,
                          screen_height=1000,  # Adjust these parameters to fit your resolution
                          screen_width=1300)  # Adjust these parameters to fit your resolution

In [None]:
agents_with_same_start = set()
agent_pairs = []
dict_pos_agents = {}

for iAgent, agent in enumerate(env.agents):
    rcPos = agent.initial_position
    if rcPos in dict_pos_agents:
        dict_pos_agents[rcPos].append(iAgent)
    else:
        dict_pos_agents[rcPos] = [iAgent]
pd.set_option("display.max_rows", 30)
df_init = pd.DataFrame(dict_pos_agents.items(), columns=["Initial_Posn", "Agent_indices"])
df_init["nAgents"] = df_init.Agent_indices.apply(len)
agents_with_same_start = sorted(np.sum(df_init.Agent_indices[df_init.nAgents>1].values))
df_init[df_init.nAgents > 1]

In [None]:
r_aspect = width / height
chart_width_inches = 15
plt.figure(figsize=(chart_width_inches, chart_width_inches / r_aspect))

rc2xy = array([[0,1],[-1,0]]).T # matrix to convert row, col -> x, y - height
for agent in env.agents:
    rc1 = array(agent.initial_position)
    rc2 = array(agent.target)
    rc12 = np.stack([rc1, rc2])
    xy12 = np.matmul(rc12, rc2xy)
    plt.plot(xy12[:,0], height + xy12[:,1])

for rcStart, liAgent in dict_pos_agents.items():
    xyStart = np.matmul(array(rcStart), rc2xy) + [0, height]
    plt.annotate(liAgent, xyStart)
    
#plt.xlim(0,srConf.width)
#plt.ylim(0,srConf.height)
plt.xticks(range(0, width, 5))
plt.yticks(range(0, height, 5))
plt.grid()
plt.title("Agent Indices at initial positions")
render_pil.render_env(show=False, show_observations=False)
plt.imshow(render_pil.get_image(), extent = [0, width, 0, height])

In [None]:
url = env_renderer.get_endpoint_URL()
HTML("<a href={}> Click here to open viewer at {}</a>".format(url, url))

In [None]:
action_dict = dict()
for agent_id in agents_with_same_start:
    action_dict[agent_id] = 1  # Try to move with the agents
env.step(action_dict)
pass

In [None]:
class RandomAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        
    def act(self, state):
        # random choice
        return np.random.choice([RailEnvActions.MOVE_FORWARD, 
                                 RailEnvActions.MOVE_RIGHT, 
                                 RailEnvActions.MOVE_LEFT,
                                 RailEnvActions.STOP_MOVING])
    def step(self, memories):
        return
    
    def save(self, filename):
        return
    
    def load(self, filename):
        return
    
controller = RandomAgent(218, env.action_space[0])

In [None]:
for a in range(env.get_num_agents()):
    action = controller.act(0)
    action_dict.update({a: action})

observations, rewards, dones, information = env.step(action_dict)

In [None]:
#url = env_renderer.get_endpoint_URL()
#HTML("<iframe src='{}/index.html' height=500 width=100% />".format(url))

In [None]:
nSteps = 5

score = 0
frame_step = 0

for step in range(nSteps):
    for a in range(env.get_num_agents()):
        action = controller.act(observations[a])
        action_dict.update({a: action})

    # Environment step which returns the observations for all agents, their corresponding
    # reward and whether their are done
    next_obs, all_rewards, done, _ = env.step(action_dict)
    
    print('{}/{} agents done.'.format(np.sum(list(done.values())), len(done)))
    
    env_renderer.render_env(show=False, 
                            show_observations=True, 
                            show_predictions=False)
    
    # env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step))
    frame_step += 1
    
    # Update replay buffer and train agent
    for a in range(env.get_num_agents()):
        score += all_rewards[a]
        
    observations = next_obs.copy()
    if done['__all__']:
        print("All done!")
        break
    print('Step {}\t Sum scores={}'.format(frame_step, score))
    
    time.sleep(0.1)


<app-root></app-root>
<script>
try{
    window.nodeRequire = require;
    delete window.require;
    delete window.exports;
    delete window.module;
} catch (e) {}
</script>

<script src="static/runtime.js" defer></script>
<script src="static/polyfills-es5.js" nomodule defer></script>
<script src="static/polyfills.js" defer></script>
<script src="static/styles.js" defer></script>
<script src="static/scripts.js" defer></script>
<script src="static/vendor.js" defer></script>
<script src="static/main.js" defer></script>
</body>