In [1]:
%load_ext autoreload
%autoreload 2
from evolution_algos import cem_uncorrelated, saes_1_1, ObjectiveFunction
from evolution_policy import NeuralNetworkPolicy, LogisticRegression

import numpy as np
from test_utils import RenderWrapper
from flatland.envs.line_generators import SparseLineGen
from flatland.envs.malfunction_generators import (
    MalfunctionParameters,
    ParamMalfunctionGen,
)
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import SparseRailGen
# from flatland_cutils import TreeObsForRailEnv as TreeCutils
from flatland.envs.observations import GlobalObsForRailEnv

from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.distance_map import DistanceMap
import flatland.envs.rail_env_shortest_paths as sp

from flatland.utils.rendertools import RenderTool

from observation_utils import normalize_observation

In [2]:
# Create the environment
observation_tree_depth = 1
observation_radius = 10
observation_max_path_depth = 20

env = RailEnv(
    width=20,
    height=15,
    rail_generator=SparseRailGen(
        seed=0,  # Random seed
        max_num_cities=2,  # Number of cities
        grid_mode=True,
        max_rails_between_cities=2,
        max_rail_pairs_in_city=1,
    ),
    line_generator=SparseLineGen(speed_ratio_map={1.: 1.}
        ),
    number_of_agents=2, # Only one agent
    obs_builder_object=TreeObsForRailEnv(max_depth=observation_tree_depth),
    malfunction_generator=ParamMalfunctionGen(
        MalfunctionParameters(
            malfunction_rate=0.,  # Rate of malfunction
            min_duration=3,  # Minimal duration
            max_duration=20,  # Max duration
        )
    ),
)

In [3]:
# Initializing policy
n_features_per_node = env.obs_builder.observation_dim
n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
state_size = n_features_per_node * n_nodes
policy = LogisticRegression(state_size, 5)
print(policy.num_params)

275


In [4]:
theta = np.random.rand(policy.num_params)

obs, info = env.reset()
agent_obs = [None] * env.get_num_agents()
for agent in env.get_agent_handles():
    if obs[agent]:
        agent_obs[agent] = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
actions = {}
for agent in env.get_agent_handles():
    actions[agent] = policy.act(agent_obs[agent], theta)

obs, all_rewards, done, _ = env.step(actions)
print(all_rewards)
    

{0: 0, 1: 0}


In [5]:
# Cross entropy method

# Create the objective function
objective_function = ObjectiveFunction(env, policy,observation_tree_depth=observation_tree_depth, observation_radius=observation_radius, num_episodes=1 , max_time_steps=200)

init_mean_array = np.random.random(size = policy.num_params)
init_var_array = np.ones(shape=policy.num_params) * 100.
hist_dict = {}
cem = cem_uncorrelated(
    objective_function,
    mean_array = init_mean_array,
    var_array = init_var_array,
    max_iterations=50,
    sample_size=50,
    elite_frac=0.2,
    print_every=1,
    success_score=10,
    hist_dict=hist_dict
)

iteration :  0
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: -26, 1: -20}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: 0, 1: 0}
{0: -1

In [6]:
cem

array([  4.39600732,  21.10187922,   2.04887412,  -3.45067182,
        -8.28518118,   9.02779407,  -7.64676398,  -0.78644161,
         8.69079509,   0.21578269, -17.56267894,   5.12839776,
       -15.77023344,  -0.12320932,  15.69880216,   5.58898924,
         5.75041878,   5.97482581,  -0.33422858,   1.01102485,
        12.49609174,  -5.35203671,  -1.34532606,  25.9918216 ,
         6.65269118,   3.89340805, -17.03109631,   4.26935104,
         1.28320249,   7.56967396, -11.30704727,   9.48365719,
       -15.75313884,  16.68185735,  13.04679112,  16.55406254,
        14.83050024,   0.4231417 ,   2.25701552,   4.21741206,
         8.62005738,  13.31863856,  14.08438442,  -4.60719054,
       -12.85384766,   2.90036137,   1.94149114,  -4.55970064,
        12.87607525,   7.62139628,  13.45531539,  10.31991436,
       -11.45594114,   4.33199176,  11.44891091,  19.02909024,
         6.99648275, -21.06169184,  -4.9097203 ,  -7.177343  ,
         7.92068634,   8.63366639,  -6.73515345, -20.61