In [1]:
% matplotlib inline
import tensorflow as tf
from ray.rllib.utils.sampler import SyncSampler
import ray
ray.init(num_workers=32)

Waiting for redis server at 127.0.0.1:19985 to respond...
Waiting for redis server at 127.0.0.1:19181 to respond...
Starting local scheduler with the following resources: {'CPU': 64, 'GPU': 0}.

View the web UI at http://localhost:8889/notebooks/ray_ui54232.ipynb?token=7de1a8f5e4d9154301d266efa13d81b1c2f88b368e1256f9



{'local_scheduler_socket_names': ['/tmp/scheduler90208507'],
 'node_ip_address': '127.0.0.1',
 'object_store_addresses': [ObjectStoreAddress(name='/tmp/plasma_store67042599', manager_name='/tmp/plasma_manager98253140', manager_port=48214)],
 'redis_address': '127.0.0.1:19985',
 'webui_url': 'http://localhost:8889/notebooks/ray_ui54232.ipynb?token=7de1a8f5e4d9154301d266efa13d81b1c2f88b368e1256f9'}

In [2]:
from ray.rllib.envs import create_and_wrap
from ray.rllib.optimizers import Evaluator
from ray.rllib.a3c.common import get_policy_cls
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.process_rollout import process_rollout

class A3CEvaluator(Evaluator):
    """Actor object to start running simulation on workers.

    The gradient computation is also executed from this object.

    Attributes:
        policy: Copy of graph used for policy. Used by sampler and gradients.
        rew_filter: Reward filter used in rollout post-processing.
        sampler: Component for interacting with environment and generating
            rollouts.
        logdir: Directory for logging.
    """
    def __init__(self, env_creator, config, logdir):
        import gym
        self.env = env = create_and_wrap(env_creator, config["model"])
        policy_cls = get_policy_cls(config)
        # TODO(rliaw): should change this to be just env.observation_space
        self.policy = policy_cls(env.observation_space.shape, env.action_space)
        obs_filter = get_filter(
            config["observation_filter"], env.observation_space.shape)
        self.rew_filter = get_filter(config["reward_filter"], ())
        self.sampler = SyncSampler(env, self.policy, obs_filter,
                                    config["batch_size"])
        self.logdir = logdir

    def sample(self):
        """
        Returns:
            trajectory (PartialRollout): Experience Samples from evaluator"""
        rollout = self.sampler.get_data()
        traj = process_rollout(rollout, self.rew_filter, gamma=0.99)
        for k in traj:
            traj[k] = traj[k].copy()
        return traj
    
    def set_weights(self, weights):
        self.policy.set_weights(weights)

In [3]:
import heapq
from ray.rllib.optimizers import SampleBatch

class Store(object):
    def __init__(self, size):
        self.replay = []
        self.map = {}
        self.dataset = None
        self.size = size
        self.dirty = True
    
    def add(self, batch):
        self.map[hash(batch)] = batch
        key = (sum(batch["rewards"]), len(batch["rewards"]), hash(batch))
        if np.random.rand() < 0.005:
            drop = np.random.choice(len(self.replay))
            _, _, h = self.replay.pop(drop)
            del self.map[h]
        if len(self.replay) < self.size:
            heapq.heappush(self.replay, key)
        else:
            ret = heapq.heappushpop(self.replay, key)
            del self.map[ret[-1]]
        self.dirty = True
        
    def sample(self, batch_size):
        if self.dirty or self.dataset is None:
            print("updating dataset")
            self.dataset = SampleBatch.concat_samples(self.map.values())
            self.dirty = False        
        permutation = np.random.choice(
            np.r_[:self.dataset["actions"].shape[0]], replace=False, size=batch_size)
        return self.dataset["observations"][permutation], self.dataset["actions"][permutation]
    

In [4]:
from  ray.rllib.a3c import DEFAULT_CONFIG

config = DEFAULT_CONFIG.copy()
import gym
import numpy as np

def setup():
    config["batch_size"] = 5000
    config["use_lstm"] = False
    config["model"]["dim"] = 80
    evaluator = ray.remote(A3CEvaluator)
    env_creator = lambda: gym.make("PongDeterministic-v4")
    evaluators = [evaluator.remote(env_creator, config, "/tmp/ray/results") for i in range(10)]
    local_model = A3CEvaluator(env_creator, config, "/tmp/ray/results").policy

    with local_model.g.as_default():
        labels = tf.placeholder("int32", [None])
        opt = tf.train.GradientDescentOptimizer(0.01)
        loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=local_model.logits, labels=labels))
        optimize = opt.minimize(loss)
        local_model.sess.run(tf.global_variables_initializer())
    return local_model, loss, optimize, labels, evaluators

local_model, loss, optimize, labels, evaluators = setup()

[2017-12-18 10:03:09,580] Making new env: PongDeterministic-v4


Observation shape is (210, 160, 3)
Assuming Atari pixel env, using AtariPixelPreprocessor.
Setting up loss


In [5]:
# def onehot(a):
#     b = np.zeros((a.size, a.max()+1))
#     b[np.arange(a.size), a] = 1
#     return b

def sgd(model, obs, actions):
    results = local_model.sess.run([optimize, loss], feed_dict={
        local_model.x: obs,
        labels: actions})
    print(np.mean(results[1]))

def update_store(storage, evaluators, num_rollouts):
    i = 0
    futures = {e.sample.remote(): e for e in evaluators}
    all_rets = []
    while i < num_rollouts:
        [fut], _ = ray.wait(list(futures))
        e = futures.pop(fut)
        futures[e.sample.remote()] = e
        traj = ray.get(fut)
        all_rets.append(sum(traj["rewards"]))
        storage.add(SampleBatch(**traj))
        i += 1
    return np.mean(all_rets)
    
def train_setup(model, evaluators):
    weights = model.get_weights()
    remote_weights = ray.put(weights)
    [e.set_weights.remote(remote_weights) for e in evaluators]

In [6]:
store = Store(20)
update_store(store, evaluators, 100)

store.replay, store.map.keys()

([(-20.0, 992, -9223363289569234309),
  (-20.0, 1080, 8747285541601),
  (-20.0, 1042, -9223363289569234386),
  (-20.0, 1090, -9223363289569234260),
  (-19.0, 1014, 8747285541384),
  (-19.0, 997, 8747285541433),
  (-20.0, 1079, -9223363289570314408),
  (-19.0, 984, 8747285541510),
  (-19.0, 921, -9223363289569247309),
  (-18.0, 1056, -9223363289569234337),
  (-19.0, 1062, 8747285541412),
  (-19.0, 1089, -9223363289569247274),
  (-17.0, 1191, -9223363289569234344),
  (-19.0, 966, -9223363289569234407),
  (-18.0, 1104, -9223363289570314471),
  (-18.0, 1185, 8747285541489),
  (-18.0, 1190, -9223363289569234393),
  (-18.0, 1111, -9223363289569234302),
  (-19.0, 997, 8747285541538),
  (-18.0, 1212, -9223363289569234323)],
 dict_keys([8747285541433, -9223363289569234323, -9223363289569234309, 8747285541510, -9223363289569234260, 8747285541538, -9223363289569234337, -9223363289570314408, 8747285541601, 8747285541489, -9223363289569234344, -9223363289569234393, -9223363289569234302, 87472855413

In [None]:
store.size = 40
for k in range(30):
    for i in range(20):
        obs, actions = store.sample(256)
        sgd(local_model, obs, actions)
    #     print([(k, np.linalg.norm(w)) for k, w in local_model.get_weights().items()])

    train_setup(local_model, evaluators)
    v = update_store(store, evaluators, 300)
    print("Average rollout:", v)

    pprint.pprint((store.replay, store.map.keys()))

updating dataset
1.79273
1.79393
1.79451
1.79193
1.79128
1.78981
1.79281
1.79185
1.79374
1.79108
1.79325
1.79326
1.79191
1.79214
1.79055
1.79188
1.79137
1.79184
1.7911
1.79173
Average rollout: -20.37
([(-18.0, 1089, 8747285541731),
  (-18.0, 1105, -9223363289576704466),
  (-18.0, 1099, -9223363289569234137),
  (-18.0, 1147, -9223363289576704438),
  (-18.0, 1105, -9223363289576704354),
  (-18.0, 1121, 8747284461383),
  (-17.0, 1210, 8747284461397),
  (-17.0, 1142, 8747278071479),
  (-18.0, 1179, -9223363289569234067),
  (-17.0, 1157, -9223363289569234253),
  (-17.0, 1270, -9223363289569234365),
  (-18.0, 1217, 8747284461334),
  (-18.0, 1193, 8747284461425),
  (-17.0, 1222, -9223363289569234109),
  (-17.0, 1214, 8747285541717),
  (-17.0, 1270, 8747285541531),
  (-17.0, 1252, -9223363289570314457),
  (-17.0, 1213, -9223363289570314408),
  (-18.0, 1342, -9223363289576704298),
  (-17.0, 1196, 8747278071339),
  (-17.0, 1200, -9223363289569247281),
  (-17.0, 1245, -9223363289569234407),
  (-1

1.78946
1.79218
1.79181
1.79011
1.78912
1.79085
1.79213
1.78731
1.79219
1.79219
1.79032
1.79068
1.7926
1.79246
1.78993
1.79285
1.79031
1.79089
1.78994
1.79093
Average rollout: -20.4066666667
([(-18.0, 1217, 8747284461334),
  (-18.0, 1218, -9223363289569234344),
  (-18.0, 1362, -9223363289569234137),
  (-18.0, 1246, -9223363289576704438),
  (-18.0, 1362, -9223363289569234302),
  (-17.0, 1213, -9223363289570314408),
  (-17.0, 1157, -9223363289569234253),
  (-18.0, 1272, 8747278071374),
  (-18.0, 1313, -9223363289569234372),
  (-17.0, 1142, 8747278071479),
  (-17.0, 1214, 8747285541717),
  (-17.0, 1252, -9223363289570314457),
  (-18.0, 1175, -9223363289576704333),
  (-17.0, 1196, 8747278071339),
  (-17.0, 1200, -9223363289569247281),
  (-17.0, 1191, -9223363289465453308),
  (-17.0, 1270, -9223363289569234365),
  (-18.0, 1342, -9223363289576704298),
  (-18.0, 1363, 8747278071458),
  (-17.0, 1150, -9223363289569234067),
  (-17.0, 1246, 8747285541461),
  (-17.0, 1249, -9223363289569234351),


1.79141
1.78982
1.79199
1.79301
1.79073
1.78803
1.79135
1.79079
1.78772
1.79216
1.7909
1.7905
1.79052
1.79665
1.79156
1.78976
1.7902
1.78994
1.79014
1.79125
Average rollout: -20.3733333333
([(-18.0, 1271, -9223363289569247274),
  (-18.0, 1272, 8747278071374),
  (-18.0, 1342, -9223363289576704298),
  (-18.0, 1363, 8747278071458),
  (-18.0, 1302, 8747285541524),
  (-17.0, 1157, -9223363289569234253),
  (-17.0, 1191, -9223363289465453308),
  (-17.0, 1245, -9223363289569234407),
  (-17.0, 1142, 8747278071479),
  (-18.0, 1362, -9223363289569234137),
  (-17.0, 1252, -9223363289570314457),
  (-18.0, 1175, -9223363289576704333),
  (-17.0, 1196, 8747278071339),
  (-17.0, 1200, -9223363289569247281),
  (-17.0, 1270, -9223363289569234365),
  (-17.0, 1274, -9223363289576704480),
  (-17.0, 1246, -9223363289576704319),
  (-18.0, 1313, -9223363289569234372),
  (-17.0, 1181, -9223363289569234344),
  (-17.0, 1213, -9223363289570314408),
  (-16.0, 1407, 8747278071437),
  (-17.0, 1432, 8747285541447),
  

1.79289
1.7884
1.79061
1.79012
1.78959
1.79263
1.78948
1.79155
1.79166
1.78826
1.7905
1.78887
1.79197
1.79412
1.79133
1.79019
1.78956
1.79059
1.79291
1.79205
Average rollout: -20.36
([(-17.0, 1118, 8747284461327),
  (-17.0, 1142, 8747278071479),
  (-17.0, 1136, 8747285541717),
  (-17.0, 1120, -9223363289569234309),
  (-17.0, 1157, -9223363289569234253),
  (-17.0, 1191, -9223363289465453308),
  (-17.0, 1245, -9223363289569234407),
  (-18.0, 1362, -9223363289569234302),
  (-17.0, 1181, -9223363289569234344),
  (-17.0, 1174, 8747284461369),
  (-17.0, 1164, -9223363289576704361),
  (-17.0, 1196, 8747278071339),
  (-17.0, 1200, -9223363289569247281),
  (-17.0, 1270, -9223363289569234365),
  (-18.0, 1313, -9223363289569234372),
  (-18.0, 1472, 8747275648737),
  (-17.0, 1213, -9223363289570314408),
  (-17.0, 1252, -9223363289570314457),
  (-17.0, 1186, 8747275648751),
  (-17.0, 1185, -9223363289569247239),
  (-16.0, 1453, -9223363289569234267),
  (-17.0, 1466, 8747278071528),
  (-17.0, 1538, 

In [10]:
# for i in range(20):
obs, actions = store.sample(1024)
cost = print(np.mean(local_model.sess.run(loss, feed_dict={local_model.x: obs, labels: onehot(actions)})))


updating dataset
1.79157


In [22]:
samples = local_model.sess.run(local_model.logits,  feed_dict={local_model.x: obs})

In [23]:
samples

array([[ 0.01429953,  0.02751439, -0.01381383, -0.00753811, -0.0008716 ,
         0.01298719],
       [ 0.01569938,  0.02824248, -0.01411416, -0.00628738, -0.00153024,
         0.01059628],
       [ 0.01528305,  0.02772135, -0.01380849, -0.00699508, -0.00136004,
         0.01342287],
       ..., 
       [ 0.01580028,  0.02654697, -0.01469329, -0.00802835, -0.00080956,
         0.01403822],
       [ 0.01590056,  0.02639264, -0.01451672, -0.00747626, -0.00021473,
         0.01345342],
       [ 0.01626278,  0.02708155, -0.01331143, -0.00774199, -0.00036311,
         0.01301794]], dtype=float32)

In [26]:
np.random.choice(5)

0

In [11]:
import pprint
pprint.pprint({3: 5})

{3: 5}
