# Use Ray with TensorFlow

In [1]:
import datetime,gym,os,pybullet_envs,time,ray
import numpy as np
np.set_printoptions(precision=2)
import tensorflow as tf
from util import suppress_tf_warning
suppress_tf_warning()
print ("Packaged loaded. TF version is [%s]."%(tf.__version__))

Packaged loaded. TF version is [1.15.0].


### Initialize Ray

In [2]:
n_cpus = 5
ray.init(num_cpus=n_cpus)
print ("RAY initialized with [%d] cpus."%(n_cpus))

2020-08-19 11:01:31,346	INFO resource_spec.py:231 -- Starting Ray with 37.01 GiB memory available for workers and up to 18.52 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2020-08-19 11:01:31,803	INFO services.py:1193 -- View the Ray dashboard at [1m[32mlocalhost:8265[39m[22m


RAY initialized with [5] cpus.


### TF Model Creator. Note that ```import tensorflor as tf``` should be inside the function.

In [3]:
def create_model(x_dim,y_dim):
    """
    Create TF model 
    """
    import tensorflow as tf
    from util import suppress_tf_warning
    suppress_tf_warning()
    
    # Build a simple two-layer model
    ph_x = tf.placeholder(tf.float32,shape=[None,x_dim])
    with tf.variable_scope('main'):
        net = tf.layers.dense(inputs=ph_x,units=32,activation=tf.nn.relu)
        y = tf.layers.dense(inputs=net,units=y_dim,activation=None)
    def get_vars(scope):
        return [x for x in tf.compat.v1.global_variables() if scope in x.name]
    g_vars = get_vars('main')
    
    # Have own session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    
    # Initialize weights
    sess.run(tf.global_variables_initializer())
    return ph_x,y,g_vars,sess
print ("Ready.")

Ready.


### Rollout Worker with a TF Model inside

In [4]:
@ray.remote
class RolloutWorkerClass(object):
    def __init__(self,worker_id=0,x_dim=5,y_dim=2):
        self.worker_id = worker_id
        # Make TF Model
        self.ph_x,self.y,self.g_vars,self.sess = create_model(x_dim=x_dim,y_dim=y_dim)
    def get_weights(self):
        """
        Get weights of 'g_vars'
        """
        return self.sess.run(self.g_vars)
    def set_weights(self,weight_list):
        """
        Set weights of 'g_vars'
        """
        for g_idx,g_var in enumerate(self.g_vars):
            self.sess.run(tf.assign(g_var,weight_list[g_idx]))
    def rollout(self,x):
        return self.sess.run(self.y,feed_dict={self.ph_x:x})

print("Ready.")

Ready.


### Initialize Workers

In [5]:
x_dim,y_dim = 64,4
n_workers = 5
workers = [RolloutWorkerClass.remote(worker_id=i,x_dim=x_dim,y_dim=y_dim) for i in range(n_workers)]
print ("[%d] workers initialized."%(n_workers))

[5] workers initialized.


### Initialize a Central Worker

In [6]:
_,_,g_vars,sess = create_model(x_dim=x_dim,y_dim=y_dim)
weights = sess.run(g_vars)

### Rollouts and Check the results

In [7]:
x_rand = np.random.rand(1,x_dim)

In [8]:
rollout_list = [worker.rollout.remote(x=x_rand) for worker in workers] # non-block
rollout_res_list = ray.get(rollout_list)

## All Rollout results are DIFFERENT as weights are all different!

In [9]:
for r_idx,rollout_res in enumerate(rollout_res_list):
    print ("Rollout result of [%d] worker is:\n %s"%(r_idx,rollout_res))

Rollout result of [0] worker is:
 [[-0.82 -0.66 -0.73  0.71]]
Rollout result of [1] worker is:
 [[-0.16  0.12 -0.36 -0.72]]
Rollout result of [2] worker is:
 [[-0.74  0.35 -0.26 -0.12]]
Rollout result of [3] worker is:
 [[ 0.35  0.21 -0.16  0.47]]
Rollout result of [4] worker is:
 [[ 0.26  0.35 -0.63  0.39]]


## Assign the same weights to all workers

In [10]:
set_weights_list = [worker.set_weights.remote(weights) for worker in workers] # non-block
get_weights_list = [worker.get_weights.remote() for worker in workers] # non-block
weights_list = ray.get(get_weights_list)

### Rollouts and Check the results

In [11]:
rollout_list = [worker.rollout.remote(x=x_rand) for worker in workers] # non-block
rollout_res_list = ray.get(rollout_list)
for r_idx,rollout_res in enumerate(rollout_res_list):
    print ("Rollout result of [%d] worker is:\n %s"%(r_idx,rollout_res))

Rollout result of [0] worker is:
 [[0.32 0.17 0.55 0.39]]
Rollout result of [1] worker is:
 [[0.32 0.17 0.55 0.39]]
Rollout result of [2] worker is:
 [[0.32 0.17 0.55 0.39]]
Rollout result of [3] worker is:
 [[0.32 0.17 0.55 0.39]]
Rollout result of [4] worker is:
 [[0.32 0.17 0.55 0.39]]


### Shutdown Ray

In [12]:
ray.shutdown()
print ("RAY shutdown.")

RAY shutdown.
