In [12]:
import gym, pickle, argparse, json, logging
from copy import deepcopy
import ray
from gail import GailTrainer

from ray import tune
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
from ray.rllib.agents import Trainer
from ray.rllib.evaluation import PolicyEvaluator, SampleBatch, MultiAgentSampleBatchBuilder
from ray.rllib.offline.json_writer import JsonWriter
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.metrics import collect_metrics
from ray.tune.registry import register_env
from ray.rllib.utils.annotations import override

from flow.utils.registry import make_create_env
from flow.utils.rllib import FlowParamsEncoder, get_flow_params
logger = logging.getLogger(__name__)
ray.init(ignore_reinit_error=True)

2019-05-26 07:49:55,841	ERROR worker.py:1343 -- Calling ray.init() again after it has already been called.


In [2]:
num_cpus = 3
num_rollouts = 3
horizon = 750
gae_lambda = 0.97
step_size = 5e-4
num_iter = 10
benchmark_name = "multi_merge"
exp_name = "test_ir"

In [3]:
config = deepcopy(DEFAULT_CONFIG)
config["num_workers"] = min(num_cpus, num_rollouts)
config["train_batch_size"] = horizon * num_rollouts
config["sample_batch_size"] = horizon / 2
config["use_gae"] = True
config["horizon"] = horizon
config["lambda"] = gae_lambda
config["lr"] = step_size
config["vf_clip_param"] = 1e6
config["num_sgd_iter"] = 10
config['clip_actions'] = False  # FIXME(ev) temporary ray bug
config["model"]["fcnet_hiddens"] = [128, 64, 32]
config["observation_filter"] = "NoFilter"
config["entropy_coeff"] = 0.0

benchmark_name = "multi_merge"
benchmark = __import__(
            "flow.benchmarks.%s" % benchmark_name, fromlist=["flow_params"])
flow_params = benchmark.buffered_obs_flow_params
flow_params["env"].additional_params["buf_length"] = 1
create_env, env_name = make_create_env(params=flow_params, version=0)
env = create_env()
register_env(env_name, create_env)

default_policy = (PPOPolicyGraph, env.observation_space, env.action_space, {})
policy_graph = {DEFAULT_POLICY_ID: default_policy}
config["multiagent"] = {
        'policy_graphs': policy_graph,
        'policy_mapping_fn': tune.function(lambda agent_id: DEFAULT_POLICY_ID)
    }

In [4]:
agent = GailTrainer(config, env_name)

2019-05-26 07:47:23,872	INFO policy_evaluator.py:311 -- Creating policy evaluation worker 0 on CPU (please ignore any CUDA init errors)
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
2019-05-26 07:47:25,210	INFO policy_evaluator.py:728 -- Built policy map: {'default_policy': <ray.rllib.agents.ppo.ppo_policy_graph.PPOPolicyGraph object at 0x7fb6d2f82400>}
2019-05-26 07:47:25,211	INFO policy_evaluator.py:729 -- Built preprocessor map: {'default_policy': <ray.rllib.models.preprocessors.NoPreprocessor object at 0x7fb6d2f82048>}
2019-05-26 07:47:25,212	INFO policy_evaluator.py:343 -- Built filter map: {'default_policy': <ray.rllib.utils.filter.NoFilter object at 0x7fb6d2fb66a0>}


In [5]:
with open('1buf.wgt', 'rb') as f:
    weights = pickle.load(f)
    weights[DEFAULT_POLICY_ID] = weights.pop('default')

In [6]:
agent.set_weights(weights)

In [21]:
writer = JsonWriter("./expart_sample")

In [22]:
samples = agent.local_evaluator.sample()
samples.count

In [23]:
writer.write(sample_batch=samples)

2019-05-26 07:53:14,732	INFO json_writer.py:97 -- Writing to new output file <_io.TextIOWrapper name='/headless/rl_project/flow_codes/ModelBased/expart_sample/output-2019-05-26_07-53-14_worker-0_0.json' mode='w' encoding='UTF-8'>


In [24]:
reader = JsonReader("./expart_sample")
sample = reader.next()
sample.count

2019-05-26 07:53:18,468	INFO json_reader.py:65 -- Found 1 input files.


742

In [27]:
policy = agent.policies[DEFAULT_POLICY_ID]

In [28]:
policy.postprocess_trajectory(sample)

SampleBatch({'agent_index': array([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
       6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
       6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
       6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0