# PPO with CustomADTEnvContinuous

In [None]:
import datetime,gym,os,pybullet_envs,time,os,psutil,ray
import numpy as np
import tensorflow as tf
from util import gpu_sess,suppress_tf_warning
from ppo import PPOBuffer,create_ppo_model,create_ppo_graph,update_ppo,\
    save_ppo_model,restore_ppo_model
np.set_printoptions(precision=2)
suppress_tf_warning() # suppress warning 
gym.logger.set_level(40) # gym logger 

from episci.environment_wrappers.tactical_action_adt_env_continuous import CustomADTEnvContinuous
from episci.agents.utils.constants import Agents, RewardType
print ("Packaged loaded. TF version is [%s]."%(tf.__version__))

### Environments

In [None]:
def get_env():
    from episci.environment_wrappers.tactical_action_adt_env_continuous import CustomADTEnvContinuous
    from episci.agents.utils.constants import Agents, RewardType
    
    red_distribution = {
        Agents.SPOT_4G: 0.15,
        Agents.SPOT_5G: 0.30,
        Agents.SPOT_RANDOM: 0.45,
        Agents.EXPERT_SYSTEM_TRIAL_2: 0.6,
        Agents.EXPERT_SYSTEM_TRIAL_3_SCRIMMAGE_4: 0.75,
        Agents.EXPERT_SYSTEM: 1.0
    }
    env_config = {
        "red_distribution": red_distribution,
        "reward_type": RewardType.SHAPED
    }
    return CustomADTEnvContinuous(env_config)

def get_eval_env():
    # from episci.environment_wrappers.tactical_action_adt_env_continuous import CustomADTEnvContinuous
    # from episci.agents.utils.constants import Agents, RewardType
    red_distribution = {
        Agents.SPOT_4G: 0.15,
        Agents.SPOT_5G: 0.30,
        Agents.SPOT_RANDOM: 0.45,
        Agents.EXPERT_SYSTEM_TRIAL_2: 0.6,
        Agents.EXPERT_SYSTEM_TRIAL_3_SCRIMMAGE_4: 0.75,
        Agents.EXPERT_SYSTEM: 1.0
    }
    env_config = {
        "red_distribution": red_distribution,
        "reward_type": RewardType.SHAPED
    }
    return CustomADTEnvContinuous(env_config)


In [None]:
# Model
hdims = [256,256]

# Graph
clip_ratio = 0.2
pi_lr = 1e-4 # 3e-4
vf_lr = 1e-3 # 1e-3
epsilon = 1e-5 # 1e-2

# Buffer
gamma = 0.99 # 0.99
lam = 0.95 # 0.95

# Update
train_pi_iters = 100
train_v_iters = 100
target_kl = 0.005 # 0.01
epochs = 1000

# Worker Config
n_cpu,n_workers = 16,15
total_steps,evaluate_every,print_every = 5000,20,5
ep_len_rollout = 5000 # 10000
batch_size = int(2**12)

# Evaluation
red_list = [Agents.SPOT_4G,Agents.SPOT_5G,Agents.SPOT_RANDOM,
            Agents.EXPERT_SYSTEM_TRIAL_2,Agents.EXPERT_SYSTEM_TRIAL_3_SCRIMMAGE_4,
            Agents.EXPERT_SYSTEM]
num_eval,max_ep_len_eval = len(red_list),15e3

In [None]:
class RolloutWorkerClass(object):
    """
    Worker without RAY (for update purposes)
    """
    def __init__(self,seed=1):
        self.seed = seed
        # Each worker should maintain its own environment
        import gym
        from util import suppress_tf_warning
        suppress_tf_warning() # suppress TF warnings
        gym.logger.set_level(40) # gym logger 
        self.env = get_eval_env()
        odim,adim = self.env.observation_space.shape[0],self.env.action_space.shape[0]
        self.odim = odim
        self.adim = adim
        _ = self.env.reset()
        
        # Initialize PPO
        self.model,self.sess = create_ppo_model(env=self.env,hdims=hdims,output_actv=tf.nn.tanh)
        self.graph = create_ppo_graph(self.model,
                                      clip_ratio=clip_ratio,pi_lr=pi_lr,vf_lr=vf_lr,epsilon=epsilon)
        # Initialize model 
        tf.set_random_seed(self.seed)
        np.random.seed(self.seed)
        self.sess.run(tf.global_variables_initializer())
        
        # Flag to initialize assign operations for 'set_weights()'
        self.FIRST_SET_FLAG = True
        
    def get_action(self,o,deterministic=False):
        act_op = self.model['mu'] if deterministic else self.model['pi']
        return self.sess.run(act_op, feed_dict={self.model['o_ph']:o.reshape(1,-1)})[0]
    
    def get_weights(self):
        """
        Get weights
        """
        weight_vals = self.sess.run(self.model['pi_vars']+self.model['v_vars'])
        return weight_vals
    
    def set_weights(self,weight_vals):
        """
        Set weights without memory leakage
        """
        if self.FIRST_SET_FLAG:
            self.FIRST_SET_FLAG = False
            self.assign_placeholders = []
            self.assign_ops = []
            for w_idx,weight_tf_var in enumerate(self.model['pi_vars']+self.model['v_vars']):
                a = weight_tf_var
                assign_placeholder = tf.placeholder(a.dtype, shape=a.get_shape())
                assign_op = a.assign(assign_placeholder)
                self.assign_placeholders.append(assign_placeholder)
                self.assign_ops.append(assign_op)
        for w_idx,weight_tf_var in enumerate(self.model['pi_vars']+self.model['v_vars']):
            self.sess.run(self.assign_ops[w_idx],
                          {self.assign_placeholders[w_idx]:weight_vals[w_idx]})    
    
@ray.remote
class RayRolloutWorkerClass(object):
    """
    Rollout Worker with RAY
    """
    def __init__(self,worker_id=0,ep_len_rollout=1000):
        # Parse
        self.worker_id = worker_id
        self.ep_len_rollout = ep_len_rollout
        # Each worker should maintain its own environment
        import gym
        from util import suppress_tf_warning
        suppress_tf_warning() # suppress TF warnings
        gym.logger.set_level(40) # gym logger 
        self.env = get_env()
        odim,adim = self.env.observation_space.shape[0],self.env.action_space.shape[0]
        self.odim = odim
        self.adim = adim
        _ = self.env.reset()
        
        # Replay buffers to pass
        self.o_buffer = np.zeros((self.ep_len_rollout,self.odim))
        self.a_buffer = np.zeros((self.ep_len_rollout,self.adim))
        self.r_buffer = np.zeros((self.ep_len_rollout))
        self.v_t_buffer = np.zeros((self.ep_len_rollout))
        self.logp_t_buffer = np.zeros((self.ep_len_rollout))
        # Create PPO model
        self.model,self.sess = create_ppo_model(env=self.env,hdims=hdims,output_actv=tf.nn.tanh)
        # Initialize model 
        self.sess.run(tf.global_variables_initializer())
        # Buffer
        self.buf = PPOBuffer(odim=self.odim,adim=self.adim,
                             size=ep_len_rollout,gamma=gamma,lam=lam)
        
        # Flag to initialize assign operations for 'set_weights()'
        self.FIRST_SET_FLAG = True
        
        # Flag to initialize rollout
        self.FIRST_ROLLOUT_FLAG = True
        
    def get_action(self,o,deterministic=False):
        act_op = self.model['mu'] if deterministic else self.model['pi']
        return self.sess.run(act_op, feed_dict={self.model['o_ph']:o.reshape(1,-1)})[0]
    
    def set_weights(self,weight_vals):
        """
        Set weights without memory leakage
        """
        if self.FIRST_SET_FLAG:
            self.FIRST_SET_FLAG = False
            self.assign_placeholders = []
            self.assign_ops = []
            for w_idx,weight_tf_var in enumerate(self.model['pi_vars']+self.model['v_vars']):
                a = weight_tf_var
                assign_placeholder = tf.placeholder(a.dtype, shape=a.get_shape())
                assign_op = a.assign(assign_placeholder)
                self.assign_placeholders.append(assign_placeholder)
                self.assign_ops.append(assign_op)
        for w_idx,weight_tf_var in enumerate(self.model['pi_vars']+self.model['v_vars']):
            self.sess.run(self.assign_ops[w_idx],
                          {self.assign_placeholders[w_idx]:weight_vals[w_idx]})    
        
    def rollout(self):
        """
        Rollout
        """
        if self.FIRST_ROLLOUT_FLAG:
            self.FIRST_ROLLOUT_FLAG = False
            self.o = self.env.reset() # reset environment
        # Loop
        for t in range(ep_len_rollout):
            a,v_t,logp_t = self.sess.run(
                self.model['get_action_ops'],feed_dict={self.model['o_ph']:self.o.reshape(1,-1)})
            o2, r, d, _ = self.env.step(a[0])
            # save and log
            self.buf.store(self.o,a,r,v_t,logp_t)
            # Update obs (critical!)
            self.o = o2
            if d:
                self.buf.finish_path(last_val=0.0)
                self.o = self.env.reset() # reset when done 
        
        last_val = self.sess.run(self.model['v'],
                                 feed_dict={self.model['o_ph']:self.o.reshape(1,-1)})
        self.buf.finish_path(last_val)
        return self.buf.get() # obs_buf, act_buf, adv_buf, ret_buf, logp_buf
    
    def evaluate(self,red=None):
        """
        Evaluate
        """
        o,d,ep_ret,ep_len = self.env.reset(red=red),False,0,0
        while not(d or (ep_len == max_ep_len_eval)):
            a = self.get_action(o,deterministic=True)
            o,r,d,_ = self.env.step(a)
            ep_ret += r # compute return 
            ep_len += 1
        blue_health,red_health = self.env.blue_health,self.env.red_health
        eval_res = [ep_ret,ep_len,blue_health,red_health] # evaluation result 
        return eval_res

### Initialize Env

In [None]:
eval_env = get_eval_env()
adim,odim = eval_env.action_space.shape[0],eval_env.observation_space.shape[0]
print ("Environment Ready. odim:[%d] adim:[%d]."%(odim,adim))

### Initialize Workers

In [None]:
ray.init(num_cpus=n_cpu,
         memory = 5*1024*1024*1024,
         object_store_memory = 10*1024*1024*1024,
         driver_object_store_memory = 1*1024*1024*1024)
tf.reset_default_graph()
R = RolloutWorkerClass(seed=0)
workers = [RayRolloutWorkerClass.remote(worker_id=i,ep_len_rollout=ep_len_rollout) 
           for i in range(n_workers)]
print ("RAY initialized with [%d] cpus and [%d] workers."%
       (n_cpu,n_workers))

In [None]:
time.sleep(1)

### Loop

In [None]:
start_time = time.time()
n_env_step = 0 # number of environment steps
for t in range(int(total_steps)):
    esec = time.time()-start_time
    
    # Synchronize worker weights
    weights = R.get_weights()
    set_weights_list = [worker.set_weights.remote(weights) for worker in workers] 
    
    # Make rollout and accumulate to Buffers
    t_start = time.time()
    ops = [worker.rollout.remote() for worker in workers]
    rollout_vals = ray.get(ops)
    sec_rollout = time.time() - t_start
    
    # Get stats before update
    t_start = time.time() # tic
    """ 
    # Old update routine with batch learning 
    feeds_list = []
    for rollout_val in rollout_vals:
        n_env_step += ep_len_rollout
        feeds = {k:v for k,v in zip(R.model['all_phs'],rollout_val)}
        feeds_list.append(feeds)
        pi_l_old, v_l_old, ent = R.sess.run(
            [R.graph['pi_loss'],R.graph['v_loss'],R.graph['approx_ent']],feed_dict=feeds)
    # Update the central agent 
    for _ in range(train_pi_iters):
        for r_idx,rollout_val in enumerate(rollout_vals):
            feeds = feeds_list[r_idx]
            _, kl = R.sess.run([R.graph['train_pi'],R.graph['approx_kl']],feed_dict=feeds)
            if kl > 1.5 * target_kl:
                break
    for _ in range(train_v_iters):
        for r_idx,rollout_val in enumerate(rollout_vals):
            feeds = feeds_list[r_idx]
            R.sess.run(R.graph['train_v'],feed_dict=feeds)
    # Get stats after update
    for r_idx,rollout_val in enumerate(rollout_vals):
        feeds = feeds_list[r_idx]
        pi_l_new,v_l_new,kl,cf = R.sess.run(
            [R.graph['pi_loss'],R.graph['v_loss'],R.graph['approx_kl'],R.graph['clipfrac']],
            feed_dict=feeds)
    """ 
    # Mini-batch type of update
    for r_idx,rval in enumerate(rollout_vals):
        obs_buf,act_buf,adv_buf,ret_buf,logp_buf = \
            rval[0],rval[1],rval[2],rval[3],rval[4]
        if r_idx == 0:
            obs_bufs,act_bufs,adv_bufs,ret_bufs,logp_bufs = \
                obs_buf,act_buf,adv_buf,ret_buf,logp_buf
        else:
            obs_bufs = np.concatenate((obs_bufs,obs_buf),axis=0)
            act_bufs = np.concatenate((act_bufs,act_buf),axis=0)
            adv_bufs = np.concatenate((adv_bufs,adv_buf),axis=0)
            ret_bufs = np.concatenate((ret_bufs,ret_buf),axis=0)
            logp_bufs = np.concatenate((logp_bufs,logp_buf),axis=0)
    n_val_total = obs_bufs.shape[0]
    for pi_iter in range(train_pi_iters):
        rand_idx = np.random.permutation(n_val_total)[:batch_size]
        buf_batches = [obs_bufs[rand_idx],act_bufs[rand_idx],adv_bufs[rand_idx],
                       ret_bufs[rand_idx],logp_bufs[rand_idx]]
        feeds = {k:v for k,v in zip(R.model['all_phs'],buf_batches)}
        _,kl,pi_loss,ent = R.sess.run([R.graph['train_pi'],R.graph['approx_kl'],
                               R.graph['pi_loss'],R.graph['approx_ent']],
                           feed_dict=feeds)        
        if kl > 1.5 * target_kl:
            # print ("  pi_iter:[%d] kl(%.3f) is higher than 1.5x(%.3f)"%(pi_iter,kl,target_kl))
            break
    for _ in range(train_v_iters):
        rand_idx = np.random.permutation(n_val_total)[:batch_size]
        buf_batches = [obs_bufs[rand_idx],act_bufs[rand_idx],adv_bufs[rand_idx],
                       ret_bufs[rand_idx],logp_bufs[rand_idx]]
        feeds = {k:v for k,v in zip(R.model['all_phs'],buf_batches)}
        R.sess.run(R.graph['train_v'],feed_dict=feeds)
    sec_update = time.time() - t_start # toc
    
    # Print
    if (t == 0) or (((t+1)%print_every) == 0): 
        print ("[%d/%d] rollout:[%.1f]s pi_iter:[%d/%d] update:[%.1f]s kl:[%.4f] target_kl:[%.4f]."%
               (t+1,total_steps,sec_rollout,pi_iter,train_pi_iters,sec_update,kl,target_kl))
        print ("   pi_loss:[%.4f], entropy:[%.4f]"%
               (pi_loss,ent))
        
    # Evaluate
    if (t == 0) or (((t+1)%evaluate_every) == 0): 
        ram_percent = psutil.virtual_memory().percent # memory usage
        print ("[Eval. start] step:[%d/%d][%.1f%%] #step:[%.1e] time:[%s] ram:[%.1f%%]."%
               (t+1,total_steps,t/total_steps*100,
                n_env_step,
                time.strftime("day:[%d] %H:%M:%S", time.gmtime(time.time()-start_time)),
                ram_percent)
              )
        
        LOCAL_EVAL = 0
        if LOCAL_EVAL:
            ep_ret_sum = 0
            for eval_idx in range(num_eval): 
                red = red_list[eval_idx]
                o,d,ep_ret,ep_len = R.env.reset(red=red),False,0,0
                while not(d or (ep_len == max_ep_len_eval)):
                    a = R.get_action(o,deterministic=True)
                    o,r,d,_ = R.env.step(a)
                    ep_ret += r # compute return
                    ep_len += 1
                ep_ret_sum += ep_ret
                blue_health,red_health = R.env.blue_health,R.env.red_health
                print (" [%d/%d] [%s] ep_ret:[%.4f] ep_len:[%d]. blue health:[%.2f] red health:[%.2f]"
                    %(eval_idx,num_eval,red,ep_ret,ep_len, blue_health,red_health))
            ep_ret_avg = ep_ret_sum / num_eval
            print ("[Eval. done] time:[%s] ep_ret_avg:[%.3f]."%
                   (time.strftime("day:[%d] %H:%M:%S", time.gmtime(time.time()-start_time)),
                    ep_ret_avg)
                  )
        else: # parallel evaluation with Ray
            ops = []
            for i_idx in range(num_eval):
                worker,red = workers[i_idx],red_list[i_idx]
                ops.append(worker.evaluate.remote(red=red))
            eval_vals = ray.get(ops)
            ep_ret_sum = 0
            for i_idx in range(num_eval):
                red,eval_val = red_list[i_idx],eval_vals[i_idx]
                ep_ret,ep_len,blue_health,red_health = eval_val[0],eval_val[1],eval_val[2],eval_val[3]
                ep_ret_sum += ep_ret
                print (" [%d/%d] [%s] ep_ret:[%.4f] ep_len:[%d]. blue health:[%.2f] red health:[%.2f]"
                    %(i_idx,len(eval_vals),red,ep_ret,ep_len,blue_health,red_health))
            ep_ret_avg = ep_ret_sum / num_eval
            print ("[Eval. done] time:[%s] ep_ret_avg:[%.3f]."%
                   (time.strftime("day:[%d] %H:%M:%S", time.gmtime(time.time()-start_time)),
                    ep_ret_avg)
                  )
        
        # Save 
        npz_path = '../data/net/ppo_adt_continuous/model.npz'
        save_ppo_model(npz_path,R,VERBOSE=False)

print ("Done.")

### Close Ray

In [None]:
eval_env.close()

In [None]:
ray.shutdown()