# TODO

* find the varid horizon

In [1]:
import gym, pickle, argparse, json, logging
from gym import ObservationWrapper
from copy import deepcopy
import ray

from gail import GAILTrainer
from ray import tune
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
from ray.rllib.agents import Trainer
from ray.rllib.evaluation import PolicyEvaluator, SampleBatch, MultiAgentBatch
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.offline.json_reader import JsonReader
from ray.tune.registry import register_env
from ray.tune.logger import pretty_print
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
from ray.rllib.evaluation.postprocessing import discount
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal

from flow.utils.registry import make_create_env
from flow.utils.rllib import FlowParamsEncoder, get_flow_params
logger = logging.getLogger(__name__)

In [2]:
num_cpus = 3
num_rollouts = 3
horizon = 750
gae_lambda = 0.97
step_size = 5e-4
num_iter = 10
benchmark_name = "multi_merge"
exp_name = "test_ir"

In [3]:
ray.init(num_cpus=num_cpus, logging_level=40, ignore_reinit_error=True)

{'node_ip_address': '169.237.32.118',
 'object_store_address': '/tmp/ray/session_2019-05-29_06-18-58_17631/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2019-05-29_06-18-58_17631/sockets/raylet',
 'redis_address': '169.237.32.118:36754',
 'webui_url': None}

In [4]:
config = deepcopy(DEFAULT_CONFIG)
config["num_workers"] = min(num_cpus, num_rollouts)
config["train_batch_size"] = horizon * num_rollouts
config["sample_batch_size"] = horizon / 2
config["use_gae"] = True
config["horizon"] = horizon
config["lambda"] = gae_lambda
config["lr"] = step_size
config["vf_clip_param"] = 1e6
config["num_sgd_iter"] = 10
config['clip_actions'] = False  # FIXME(ev) temporary ray bug
config["model"]["fcnet_hiddens"] = [128, 64, 32]
config["observation_filter"] = "NoFilter"
config["entropy_coeff"] = 0.0
config["expert_path"] = '/headless/rl_project/flow_codes/ModelBased/expert_sample'
config["discrim_hidden_size"] = 128

benchmark = __import__(
            "flow.benchmarks.%s" % benchmark_name, fromlist=["flow_params"])
flow_params = benchmark.gail_flow_params

# save the flow params for replay
flow_json = json.dumps(
    flow_params, cls=FlowParamsEncoder, sort_keys=True, indent=4)
config['env_config']['flow_params'] = flow_json

In [5]:
create_env, env_name = make_create_env(params=flow_params, version=0)
register_env(env_name, create_env)
env = create_env()

POLICY_ID = DEFAULT_POLICY_ID
default_policy = (PPOPolicyGraph, env.observation_space, env.action_space, {})
policy_graph = {POLICY_ID: default_policy}
config["multiagent"] = {
        'policy_graphs': policy_graph,
        'policy_mapping_fn': tune.function(lambda agent_id: POLICY_ID),
        'policies_to_train': [POLICY_ID]
    }

In [6]:
class Discriminator(nn.Module):
    def __init__(self, num_inputs, hidden_size):
        super(Discriminator, self).__init__()
        
        self.linear1   = nn.Linear(num_inputs, hidden_size)
        self.linear2   = nn.Linear(hidden_size, hidden_size)
        self.linear3   = nn.Linear(hidden_size, 1)
        self.linear3.weight.data.mul_(0.1)
        self.linear3.bias.data.mul_(0.0)
    
    def forward(self, x):
        x = F.tanh(self.linear1(x))
        x = F.tanh(self.linear2(x))
        prob = F.sigmoid(self.linear3(x))
        return prob

In [7]:
class CustomEnvPolicyEvaluator(PolicyEvaluator):
    def set_state_dict(self, state_dict):
        self.env.set_state_dict(state_dict)
        
    def init_discriminator(self, hidden_size):
        self.env.init_discriminator(hidden_size)

In [8]:
class GAILTrainer(Trainer):
    _allow_unknown_configs = True
    _name = "GAIL"
    _default_config = DEFAULT_CONFIG
    _policy_graph = PPOPolicyGraph
    
    @override(Trainer)
    def _init(self, config, env_name):
        self.train_batch_size = self.config["train_batch_size"]
        self.num_sgd_iter = self.config["num_sgd_iter"]
        
        # load expert trajectory
        self.expert_reader = JsonReader(self.config["expert_path"])
        self.expert_samples = self.expert_reader.next()
               
        # set evaluators
        self.local_evaluator = self.make_local_evaluator(
             env_name, self._policy_graph, self.config)        
        self.remote_evaluators = self.make_remote_evaluators(
            env_name, self._policy_graph, self.config["num_workers"])
       
        # discriminator
        num_inputs = self.local_evaluator.env.observation_space.shape[0]
        num_outputs = self.local_evaluator.env.action_space.shape[0]
        self.discrim_criterion = nn.BCELoss()
        self.discriminator = Discriminator(num_inputs+num_outputs,
                                           config["discrim_hidden_size"])
        self.optimizer_discrim = optim.Adam(self.discriminator.parameters(),
                                            lr=config["lr"])
 
        # share discriminators
        self.local_evaluator.init_discriminator(config["discrim_hidden_size"])
        for e in self.remote_evaluators:
            e.init_discriminator.remote(config["discrim_hidden_size"])
        self.set_state_dict()
            
    def set_state_dict(self):
        state_dict =  self.discriminator.state_dict()
        self.local_evaluator.set_state_dict(state_dict)
        for e in self.remote_evaluators:
            e.set_state_dict.remote(state_dict)  
        
    def get_state_action_from_samples(self, samples):
        state_action = np.hstack((samples["obs"], samples["actions"]))
        state_action = torch.FloatTensor(state_action)
        return state_action

    def sample(self, sample_size):
        # set local weights to remote
        weights = ray.put(self.local_evaluator.get_weights())
        for e in self.remote_evaluators:
            e.set_weights.remote(weights)
            
        samples = []
        while sum(s.count for s in samples) < sample_size:
            samples.extend(
                ray.get([
                    e.sample.remote() for e in self.remote_evaluators
                ]))
        samples = SampleBatch.concat_samples(samples)
        return samples
    
    def train_policy_by_samples(self, samples):
        # train policy by given samples
        for i in range(self.num_sgd_iter):
            fetches = self.local_evaluator.learn_on_batch(samples)
            
        def update(pi, pi_id):
            if pi_id in fetches:
                pi.update_kl(fetches[pi_id]['learner_stats']["kl"])
            else:
                logger.debug(
                    "No data for {}, not updating kl".format(pi_id))
        self.local_evaluator.foreach_trainable_policy(update)       
   
    def train_discriminator_by_state_action(self, state_action, expert_state_action):
        fake = self.discriminator(state_action)
        real = self.discriminator(expert_state_action)       
        self.optimizer_discrim.zero_grad()
        # if perfect, fake == 1, real == 0
        discrim_loss = self.discrim_criterion(fake, torch.ones((state_action.shape[0], 1)).cpu())
        discrim_loss += self.discrim_criterion(real, 
                       torch.zeros((expert_state_action.size(0), 1)).cpu())        
        discrim_loss.backward()
        self.optimizer_discrim.step()
        
        return discrim_loss
        
    @override(Trainer)    
    def _train(self):
        samples = self.sample(self.train_batch_size)
        samples.shuffle()
        self.expert_samples = self.expert_reader.next()
        self.expert_samples.shuffle()
        state_action = self.get_state_action_from_samples(samples)
        expert_state_action = self.get_state_action_from_samples(self.expert_samples)
        
        self.train_policy_by_samples(samples)
        discrim_loss = self.train_discriminator_by_state_action(state_action, expert_state_action)
        
        res = collect_metrics(self.local_evaluator, self.remote_evaluators)
        res["custom_metrics"]["discrim_loss"] =  discrim_loss.data.item()
        pretty_print(res)
        return res

    @override(Trainer)
    def __getstate__(self):
        state = super().__getstate__()
        state["discrim_state_dict"] = self.discriminator.state_dict()
        return state
    
    @override(Trainer)
    def __setstate__(self, state):
        super().__setstate__(state)
        self.discriminator.load_state_dict(state["discrim_state_dict"])

    def make_local_evaluator(self,
                             env_creator,
                             policy_graph,
                             extra_config=None):
        """Convenience method to return configured local evaluator."""

        return self._make_evaluator(
            CustomEnvPolicyEvaluator,
            env_creator,
            policy_graph,
            0,
            merge_dicts(
                # important: allow local tf to use more CPUs for optimization
                merge_dicts(
                    self.config, {
                        "tf_session_args": self.
                        config["local_evaluator_tf_session_args"]
                    }),
                extra_config or {}))        
    
    def make_remote_evaluators(self, env_creator, policy_graph, count):
        """Convenience method to return a number of remote evaluators."""

        remote_args = {
            "num_cpus": self.config["num_cpus_per_worker"],
            "num_gpus": self.config["num_gpus_per_worker"],
            "resources": self.config["custom_resources_per_worker"],
        }

        cls = CustomEnvPolicyEvaluator.as_remote(**remote_args).remote

        return [
            self._make_evaluator(cls, env_creator, policy_graph, i + 1,
                                 self.config) for i in range(count)
        ]
       

In [6]:
agent = GAILTrainer(config, env_name)

2019-05-29 06:19:07,834	INFO json_reader.py:65 -- Found 1 input files.
2019-05-29 06:19:08,956	INFO policy_evaluator.py:311 -- Creating policy evaluation worker 0 on CPU (please ignore any CUDA init errors)
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
2019-05-29 06:19:10,411	INFO policy_evaluator.py:728 -- Built policy map: {'default_policy': <ray.rllib.agents.ppo.ppo_policy_graph.PPOPolicyGraph object at 0x7f0ee4305f98>}
2019-05-29 06:19:10,413	INFO policy_evaluator.py:729 -- Built preprocessor map: {'default_policy': <ray.rllib.models.preprocessors.NoPreprocessor object at 0x7f0ee4305be0>}
2019-05-29 06:19:10,414	INFO policy_evaluator.py:343 -- Built filter map: {'default_policy': <ray.rllib.utils.filter.NoFilter object at 0x7f0ee43846a0>}


# Debug

In [7]:
agent.train()

[2m[36m(pid=17666)[0m Loading configuration... done.
[2m[36m(pid=17666)[0m Success.
[2m[36m(pid=17666)[0m Loading configuration... done.
[2m[36m(pid=17669)[0m Loading configuration... done.
[2m[36m(pid=17669)[0m Success.
[2m[36m(pid=17669)[0m Loading configuration... done.
[2m[36m(pid=17668)[0m Loading configuration... done.
[2m[36m(pid=17668)[0m Success.
[2m[36m(pid=17668)[0m Loading configuration... done.
[2m[36m(pid=17666)[0m 2019-05-29 06:19:22,148	INFO policy_evaluator.py:311 -- Creating policy evaluation worker 1 on CPU (please ignore any CUDA init errors)
[2m[36m(pid=17666)[0m 2019-05-29 06:19:22.150866: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA
[2m[36m(pid=17669)[0m 2019-05-29 06:19:22,203	INFO policy_evaluator.py:311 -- Creating policy evaluation worker 2 on CPU (please ignore any CUDA init errors)
[2m[36m(pid=17669)[0m 2019-05-

[2m[36m(pid=17669)[0m Loading configuration... done.
[2m[36m(pid=17669)[0m Success.
[2m[36m(pid=17669)[0m Loading configuration... done.
[2m[36m(pid=17666)[0m 2019-05-29 06:19:51,135	INFO policy_evaluator.py:474 -- Completed sample batch:
[2m[36m(pid=17666)[0m 
[2m[36m(pid=17666)[0m { 'data': { 'action_prob': np.ndarray((1000,), dtype=float32, min=0.004, max=0.399, mean=0.282),
[2m[36m(pid=17666)[0m             'actions': np.ndarray((1000, 1), dtype=float32, min=-2.873, max=3.049, mean=0.016),
[2m[36m(pid=17666)[0m             'advantages': np.ndarray((1000,), dtype=float32, min=0.684, max=17.627, mean=14.484),
[2m[36m(pid=17666)[0m             'agent_index': np.ndarray((1000,), dtype=int64, min=0.0, max=4.0, mean=2.377),
[2m[36m(pid=17666)[0m             'behaviour_logits': np.ndarray((1000, 2), dtype=float32, min=-0.006, max=0.007, mean=0.002),
[2m[36m(pid=17666)[0m             'dones': np.ndarray((1000,), dtype=bool, min=0.0, max=1.0, mean=0.006),
[

2019-05-29 06:19:52,047	INFO policy_evaluator.py:564 -- Training on concatenated sample batches:

{ 'data': { 'action_prob': np.ndarray((3058,), dtype=float32, min=0.0, max=0.399, mean=0.283),
            'actions': np.ndarray((3058, 1), dtype=float32, min=-3.603, max=4.024, mean=0.011),
            'advantages': np.ndarray((3058,), dtype=float32, min=0.684, max=17.627, mean=14.409),
            'agent_index': np.ndarray((3058,), dtype=int64, min=0.0, max=5.0, mean=2.465),
            'behaviour_logits': np.ndarray((3058, 2), dtype=float32, min=-0.007, max=0.008, mean=0.002),
            'dones': np.ndarray((3058,), dtype=bool, min=0.0, max=1.0, mean=0.006),
            'eps_id': np.ndarray((3058,), dtype=int64, min=64280840.0, max=1742552928.0, mean=1279435003.714),
            'infos': np.ndarray((3058,), dtype=object, head={'cost1': 0.0396621877457352, 'cost2': 0.0, 'mean_vel': 1.039616562488652, 'outflow': 468.0}),
            'new_obs': np.ndarray((3058, 12), dtype=float32, min=-0

{'config': {'batch_mode': 'truncate_episodes',
  'callbacks': {'on_episode_end': None,
   'on_episode_start': None,
   'on_episode_step': None,
   'on_postprocess_traj': None,
   'on_sample_end': None,
   'on_train_result': None},
  'clip_actions': False,
  'clip_param': 0.3,
  'clip_rewards': None,
  'collect_metrics_timeout': 180,
  'compress_observations': False,
  'custom_resources_per_worker': {},
  'discrim_hidden_size': 128,
  'entropy_coeff': 0.0,
  'env': 'MultiWaveAttenuationMergePOEnvGAIL-v0',
  'expert_path': '/headless/rl_project/flow_codes/ModelBased/expert_sample',
  'gamma': 0.99,
  'grad_clip': None,
  'horizon': 750,
  'ignore_worker_failures': False,
  'input': 'sampler',
  'input_evaluation': ['is', 'wis'],
  'kl_coeff': 0.2,
  'kl_target': 0.01,
  'lambda': 0.97,
  'local_evaluator_tf_session_args': {'inter_op_parallelism_threads': 8,
   'intra_op_parallelism_threads': 8},
  'log_level': 'INFO',
  'lr': 0.0005,
  'lr_schedule': None,
  'metrics_smoothing_episodes':

In [7]:
samples = agent.sample(agent.train_batch_size)

[2m[36m(pid=7456)[0m Loading configuration... done.
[2m[36m(pid=7456)[0m Success.
[2m[36m(pid=7456)[0m Loading configuration... done.
[2m[36m(pid=7454)[0m Loading configuration... done.
[2m[36m(pid=7454)[0m Success.
[2m[36m(pid=7454)[0m Loading configuration... done.
[2m[36m(pid=7457)[0m Loading configuration... done.
[2m[36m(pid=7457)[0m Success.
[2m[36m(pid=7457)[0m Loading configuration... done.
[2m[36m(pid=7456)[0m 2019-05-27 21:18:22,497	INFO policy_evaluator.py:311 -- Creating policy evaluation worker 1 on CPU (please ignore any CUDA init errors)
[2m[36m(pid=7456)[0m 2019-05-27 21:18:22.499366: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA
[2m[36m(pid=7454)[0m 2019-05-27 21:18:22,702	INFO policy_evaluator.py:311 -- Creating policy evaluation worker 3 on CPU (please ignore any CUDA init errors)
[2m[36m(pid=7454)[0m 2019-05-27 21:18:22.7

[2m[36m(pid=7454)[0m Loading configuration... done.
[2m[36m(pid=7454)[0m Success.
[2m[36m(pid=7454)[0m Loading configuration... done.
[2m[36m(pid=7457)[0m Loading configuration... done.
[2m[36m(pid=7457)[0m Success.
[2m[36m(pid=7457)[0m Loading configuration... done.
[2m[36m(pid=7456)[0m 2019-05-27 21:18:31,288	INFO policy_evaluator.py:474 -- Completed sample batch:
[2m[36m(pid=7456)[0m 
[2m[36m(pid=7456)[0m { 'count': 375,
[2m[36m(pid=7456)[0m   'policy_batches': { 'rl': { 'data': { 'action_prob': np.ndarray((1200,), dtype=float32, min=0.006, max=0.401, mean=0.28),
[2m[36m(pid=7456)[0m                                         'actions': np.ndarray((1200, 1), dtype=float32, min=-2.804, max=2.912, mean=-0.02),
[2m[36m(pid=7456)[0m                                         'advantages': np.ndarray((1200,), dtype=float32, min=0.839, max=94.008, mean=39.505),
[2m[36m(pid=7456)[0m                                         'agent_index': np.ndarray((1200,), 

# Train

In [14]:
agent.train()

[2m[36m(pid=9192)[0m Loading configuration... done.
[2m[36m(pid=9192)[0m Success.
[2m[36m(pid=9192)[0m Loading configuration... done.
[2m[36m(pid=9189)[0m Loading configuration... done.
[2m[36m(pid=9189)[0m Success.
[2m[36m(pid=9189)[0m Loading configuration... done.
[2m[36m(pid=9191)[0m Loading configuration... done.
[2m[36m(pid=9191)[0m Success.
[2m[36m(pid=9191)[0m Loading configuration... done.
[2m[36m(pid=9189)[0m Loading configuration... done.
[2m[36m(pid=9189)[0m Success.
[2m[36m(pid=9189)[0m Loading configuration... done.
[2m[36m(pid=9191)[0m Loading configuration... done.
[2m[36m(pid=9191)[0m Success.
[2m[36m(pid=9191)[0m Loading configuration... done.
[2m[36m(pid=9192)[0m Loading configuration... done.
[2m[36m(pid=9192)[0m Success.
[2m[36m(pid=9192)[0m Loading configuration... done.


2019-05-28 20:25:11,840	INFO policy_evaluator.py:564 -- Training on concatenated sample batches:

{ 'count': 2250,
  'policy_batches': { 'rl': { 'data': { 'action_prob': np.ndarray((6064,), dtype=float32, min=0.001, max=0.4, mean=0.284),
                                        'actions': np.ndarray((6064, 1), dtype=float32, min=-3.446, max=3.51, mean=0.005),
                                        'advantages': np.ndarray((6064,), dtype=float32, min=-23.186, max=-0.288, mean=-15.745),
                                        'agent_index': np.ndarray((6064,), dtype=int64, min=0.0, max=5.0, mean=2.5),
                                        'behaviour_logits': np.ndarray((6064, 2), dtype=float32, min=-0.003, max=0.004, mean=-0.0),
                                        'dones': np.ndarray((6064,), dtype=bool, min=0.0, max=1.0, mean=0.005),
                                        'eps_id': np.ndarray((6064,), dtype=int64, min=34288929.0, max=1873683181.0, mean=757703806.976),
           

{'config': {'batch_mode': 'truncate_episodes',
  'callbacks': {'on_episode_end': None,
   'on_episode_start': None,
   'on_episode_step': None,
   'on_postprocess_traj': None,
   'on_sample_end': None,
   'on_train_result': None},
  'clip_actions': False,
  'clip_param': 0.3,
  'clip_rewards': None,
  'collect_metrics_timeout': 180,
  'compress_observations': False,
  'custom_resources_per_worker': {},
  'entropy_coeff': 0.0,
  'env': 'MultiWaveAttenuationMergePOEnv-v0',
  'expert_path': '/headless/rl_project/flow_codes/ModelBased/expert_sample',
  'gamma': 0.99,
  'grad_clip': None,
  'horizon': 750,
  'ignore_worker_failures': False,
  'input': 'sampler',
  'input_evaluation': ['is', 'wis'],
  'kl_coeff': 0.2,
  'kl_target': 0.01,
  'lambda': 0.97,
  'local_evaluator_tf_session_args': {'inter_op_parallelism_threads': 8,
   'intra_op_parallelism_threads': 8},
  'log_level': 'INFO',
  'lr': 0.0005,
  'lr_schedule': None,
  'metrics_smoothing_episodes': 100,
  'model': {'conv_activation