# SAC with ADT Continuous Env

In [1]:
import datetime,gym,time,os,psutil,ray
import numpy as np
import tensorflow as tf
from util import open_txt,write_txt
from sac import ReplayBuffer,create_sac_model,create_sac_graph,\
    save_sac_model_and_buffers,restore_sac_model_and_buffers
from episci.environment_wrappers.tactical_action_adt_env_continuous \
    import CustomADTEnvContinuous
from episci.agents.utils.constants import Agents,RewardType,StateInfo
print ("Packaged loaded. TF version is [%s]."%(tf.__version__))

Packaged loaded. TF version is [1.15.0].


### Hyperparameters

In [2]:
# Worker
exp_name = 'sac_adt_cont'
n_cpu = 31
n_workers = 30

# Environment
action_length = 5 # 50/5 = 10HZ
# Red agent distribution for training
red_list_train = [
    Agents.ZOMBIE,
    Agents.SPOT_RANDOM,
    Agents.BUD_FSM,
    Agents.EXPERT_SYSTEM
]
# Red agent distribution for evaluation
red_list_eval = [
    Agents.ZOMBIE,
    Agents.ROSIE, 
    Agents.BUD, 
    Agents.BUD_FSM, 
    Agents.EXPERT_SYSTEM
]*n_workers
red_list_eval = red_list_eval[:n_workers]
num_eval = len(red_list_eval) # evaluation

# Steps
total_steps,evaluate_every,print_every = 5000,5,5
ep_len_rollout = 3000 # 15,000/5
buffer_size = int(3000*len(red_list_train)) 

# Network configuration
hdims = [64,32,16]
actv = tf.nn.relu
batch_size,update_count = int(2**15),1000 # batchsize / number of updates
lr = 1e-5 # 1e-3
epsilon = 1e-5
# SAC
gamma = 0.99 # discount 0.99
alpha_q,alpha_pi = 0.5,0.5
polyak = 0.995 # 0.995
# Buffer
buffer_sz_long,buffer_sz_short = 1e6,1e4 # 1e5,1e5
# Temperature
temp_min,temp_max = 0.0,0.1

### Environment

In [3]:
def get_env(red_distribution=None):
    from episci.environment_wrappers.tactical_action_adt_env_continuous \
        import CustomADTEnvContinuous
    from episci.agents.utils.constants import Agents, RewardType
    env_config = {
        "red_distribution": red_distribution,
        "reward_type": RewardType.SHAPED
    }
    return CustomADTEnvContinuous(env_config,action_length=action_length)

### Logger

In [4]:
txt_path = '../report/log/%s/log_%s.txt'%(
    exp_name,
    datetime.datetime.now().strftime("%b-%d-%Y-%H:%M:%S"))
f = open_txt(txt_path)
print ("[%s] created."%(txt_path))
time.sleep(1)

[../report/log/sac_adt_cont] created.
[../report/log/sac_adt_cont/log_Aug-16-2020-03:15:49.txt] created.


### Rollout Workers

In [5]:
class RolloutWorkerClass(object):
    """
    Worker without RAY (for update purposes)
    """
    def __init__(self,hdims=[256,256],actv=tf.nn.relu,
                 lr=1e-3,gamma=0.99,alpha_q=0.1,alpha_pi=0.1,polyak=0.995,epsilon=1e-2,
                 seed=1):
        self.seed = seed
        # Each worker should maintain its own environment
        import gym
        from util import suppress_tf_warning
        suppress_tf_warning() # suppress TF warnings
        gym.logger.set_level(40) 
        self.env = get_env()
        odim,adim = self.env.observation_space.shape[0],self.env.action_space.shape[0]
        self.odim = odim
        self.adim = adim
        _ = self.env.reset(red=Agents.SPOT_RANDOM)
        
        # Create SAC model and computational graph 
        self.model,self.sess = create_sac_model(
            odim=self.odim,adim=self.adim,hdims=hdims,actv=actv)
        self.step_ops,self.target_init = \
            create_sac_graph(self.model,lr=lr,gamma=gamma,alpha_q=alpha_q,alpha_pi=alpha_pi,
                             polyak=polyak,epsilon=epsilon)
        
        # Initialize model 
        tf.set_random_seed(self.seed)
        np.random.seed(self.seed)
        self.sess.run(tf.global_variables_initializer())
        self.sess.run(self.target_init)
        
        # Flag to initialize assign operations for 'set_weights()'
        self.FIRST_SET_FLAG = True
    
    def get_action(self,o,deterministic=False):
        act_op = self.model['mu'] if deterministic else self.model['pi']
        return self.sess.run(act_op, feed_dict={self.model['o_ph']:o.reshape(1,-1)})[0]
    
    def set_weights(self,weight_vals):
        """
        Set weights without memory leakage
        """
        if self.FIRST_SET_FLAG:
            self.FIRST_SET_FLAG = False
            self.assign_placeholders = []
            self.assign_ops = []
            for w_idx,weight_tf_var in enumerate(self.model['main_vars']):
                a = weight_tf_var
                assign_placeholder = tf.placeholder(a.dtype, shape=a.get_shape())
                assign_op = a.assign(assign_placeholder)
                self.assign_placeholders.append(assign_placeholder)
                self.assign_ops.append(assign_op)
        for w_idx,weight_tf_var in enumerate(self.model['main_vars']):
            # Memory-leakage-free assign (hopefully)
            self.sess.run(self.assign_ops[w_idx],
                          {self.assign_placeholders[w_idx]:weight_vals[w_idx]})

    def get_weights(self):
        """
        Get weights
        """
        weight_vals = self.sess.run(self.model['main_vars'])
        return weight_vals
    
@ray.remote
class RayRolloutWorkerClass(object):
    """
    Rollout Worker with RAY
    """
    def __init__(self,worker_id=0,hdims=[256,256],actv=tf.nn.relu,
                 ep_len_rollout=1000,buffer_size=1000):
        # Parse
        self.worker_id = worker_id
        self.ep_len_rollout = ep_len_rollout
        self.buffer_size = buffer_size
        # Each worker should maintain its own environment
        import gym
        from util import suppress_tf_warning
        suppress_tf_warning() # suppress TF warnings
        gym.logger.set_level(40) 
        self.env = get_env()
        odim,adim = self.env.observation_space.shape[0],self.env.action_space.shape[0]
        self.odim = odim
        self.adim = adim
        _ = self.env.reset(red=Agents.SPOT_RANDOM)
        
        # Replay buffers to pass
        self.o_buffer = np.zeros((self.buffer_size,self.odim))
        self.a_buffer = np.zeros((self.buffer_size,self.adim))
        self.r_buffer = np.zeros((self.buffer_size))
        self.o2_buffer = np.zeros((self.buffer_size,self.odim))
        self.d_buffer = np.zeros((self.buffer_size))
        
        # Create SAC model
        self.model,self.sess = create_sac_model(
            odim=self.odim,adim=self.adim,hdims=hdims,actv=actv)
        self.sess.run(tf.global_variables_initializer())
        print ("Ray Worker [%d] Ready."%(self.worker_id))
        
        # Flag to initialize assign operations for 'set_weights()'
        self.FIRST_SET_FLAG = True
        
        # Flag to initialize rollout
        self.FIRST_ROLLOUT_FLAG = True
        
    def get_action(self,o,deterministic=False,temperature=1.0):
        """
        Get action (if temperature is 0, it becomes deterministic)
        """
        a_mu = self.sess.run(self.model['mu'],
                             feed_dict={self.model['o_ph']:o.reshape(1,-1)})[0]
        a_pi = self.sess.run(self.model['pi'],
                             feed_dict={self.model['o_ph']:o.reshape(1,-1)})[0]
        if deterministic:
            a = a_mu
        else:
            a = temperature*a_pi + (1-temperature)*a_mu
        return a
    
    def set_weights(self,weight_vals):
        """
        Set weights without memory leakage
        """
        if self.FIRST_SET_FLAG:
            self.FIRST_SET_FLAG = False
            self.assign_placeholders = []
            self.assign_ops = []
            for w_idx,weight_tf_var in enumerate(self.model['main_vars']):
                a = weight_tf_var
                assign_placeholder = tf.placeholder(a.dtype, shape=a.get_shape())
                assign_op = a.assign(assign_placeholder)
                self.assign_placeholders.append(assign_placeholder)
                self.assign_ops.append(assign_op)
        for w_idx,weight_tf_var in enumerate(self.model['main_vars']):
            # Memory-leakage-free assign (hopefully)
            self.sess.run(self.assign_ops[w_idx],
                          {self.assign_placeholders[w_idx]:weight_vals[w_idx]})
            
    def rollout(self,temperature=1.0,
                red_list=[Agents.SPOT_RANDOM,Agents.EXPERT_SYSTEM]):
        """
        Rollout
        """
        if self.FIRST_ROLLOUT_FLAG:
            self.FIRST_ROLLOUT_FLAG = False
            self.o = self.env.reset(red=Agents.SPOT_RANDOM) # reset environment
        # Loop
        r_sum,cnt = 0,0
        for r_idx,red in enumerate(red_list): # for each red policy
            self.o = self.env.reset(red=red) # reset environment
            for t in range(self.ep_len_rollout):
                self.a = self.get_action(self.o,deterministic=False,temperature=temperature)
                self.o2,self.r,self.d,_ = self.env.step(self.a)
                r_sum += self.r
                # Append
                self.o_buffer[cnt,:] = self.o
                self.a_buffer[cnt,:] = self.a
                self.r_buffer[cnt] = self.r
                self.o2_buffer[cnt,:] = self.o2
                self.d_buffer[cnt] = self.d
                cnt += 1
                # Save next state 
                self.o = self.o2
                if self.d: 
                    self.o = self.env.reset(red=Agents.SPOT_RANDOM) # reset when done 
        o_buffer = self.o_buffer[:cnt,:]
        a_buffer = self.a_buffer[:cnt,:]
        r_buffer = self.r_buffer[:cnt]
        o2_buffer = self.o2_buffer[:cnt,:]
        d_buffer = self.d_buffer[:cnt]
        r_avg = r_sum / cnt
        return o_buffer,a_buffer,r_buffer,o2_buffer,d_buffer,r_avg
    
    def evaluate(self,red=None):
        """
        Evaluate
        """
        o,d,ep_ret,ep_len = self.env.reset(red=red),False,0,0
        while not(d or (ep_len == self.ep_len_rollout)):
            a = self.get_action(o,deterministic=True)
            o,r,d,_ = self.env.step(a)
            ep_ret += r # compute return 
            ep_len += 1
        blue_health,red_health = self.env.blue_health,self.env.red_health
        eval_res = [ep_ret,ep_len,blue_health,red_health] # evaluation result 
        return eval_res
    

### Initialize Workers

In [6]:
ray.init(num_cpus=n_cpu)
tf.reset_default_graph()
R = RolloutWorkerClass(hdims=hdims,actv=actv,
                       lr=lr,gamma=gamma,alpha_q=alpha_q,alpha_pi=alpha_pi,
                       polyak=polyak,epsilon=epsilon,
                       seed=0)
workers = [RayRolloutWorkerClass.remote(
    worker_id=i,hdims=hdims,actv=actv,
    ep_len_rollout=ep_len_rollout,buffer_size=buffer_size) 
           for i in range(n_workers)]
print ("RAY initialized with [%d] cpus and [%d] workers."%
       (n_cpu,n_workers))

2020-08-16 03:15:50,111	INFO resource_spec.py:212 -- Starting Ray with 159.77 GiB memory available for workers and up to 72.47 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).
2020-08-16 03:15:50,661	INFO services.py:1165 -- View the Ray dashboard at [1m[32mlocalhost:8265[39m[22m


RAY initialized with [31] cpus and [30] workers.


### Replay Buffers

In [7]:
replay_buffer_long = ReplayBuffer(odim=R.odim,adim=R.adim,size=int(buffer_sz_long))
replay_buffer_short = ReplayBuffer(odim=R.odim,adim=R.adim,size=int(buffer_sz_short))

### Loop

In [8]:
npz_path = ''
if npz_path:
    restore_sac_model_and_buffers(npz_path=npz_path,R=R,
                                  replay_buffer_long=replay_buffer_long,
                                  replay_buffer_short=replay_buffer_short,
                                  VERBOSE=False,IGNORE_BUFFERS=True)

In [None]:
npz_path_list,ep_ret_avg_list = [],[]
start_time = time.time()
n_env_step = 0 # number of environment steps
for t in range(int(total_steps)):
    esec = time.time()-start_time
    
    # 1. Synchronize worker weights
    weights = R.get_weights()
    set_weights_list = [worker.set_weights.remote(weights) for worker in workers] 
    
    # 2. Make rollout and accumulate to Buffers
    t_start = time.time()
    ops = [worker.rollout.remote(
        temperature=temp_min+(temp_max-temp_min)*np.random.rand(),
        red_list=red_list_train # <= with the list of pre-defined red agent policies
    )
           for worker in workers]
    rollout_vals = ray.get(ops)
    r_sum = 0
    for rollout_val in rollout_vals:
        o_buffer,a_buffer,r_buffer,o2_buffer,d_buffer,r_rollout_avg = rollout_val
        r_sum += r_rollout_avg
        for i in range(buffer_size):
            o,a,r,o2,d = o_buffer[i,:],a_buffer[i,:],r_buffer[i],o2_buffer[i,:],d_buffer[i]
            replay_buffer_long.store(o, a, r, o2, d) 
            replay_buffer_short.store(o, a, r, o2, d) 
            n_env_step += 1
    r_avg = r_sum / len(rollout_vals)
    sec_rollout = time.time() - t_start
    
    # 3. Update the SAC model
    t_start = time.time()
    avg_qs = np.zeros(int(update_count))
    for c_idx in range(int(update_count)):
        batch_long = replay_buffer_long.sample_batch(batch_size//2) 
        batch_short = replay_buffer_short.sample_batch(batch_size//2) 
        feed_dict = {R.model['o_ph']: np.concatenate((batch_long['obs1'],batch_short['obs1'])),
                     R.model['o2_ph']: np.concatenate((batch_long['obs2'],batch_short['obs2'])),
                     R.model['a_ph']: np.concatenate((batch_long['acts'],batch_short['acts'])),
                     R.model['r_ph']: np.concatenate((batch_long['rews'],batch_short['rews'])),
                     R.model['d_ph']: np.concatenate((batch_long['done'],batch_short['done']))
                    }
        outs = R.sess.run(R.step_ops, feed_dict) # update 
        q1_vals,q2_vals = outs[3],outs[4]
        avg_q = 0.5*np.mean(q1_vals)+0.5*np.mean(q2_vals)
        avg_qs[c_idx] = avg_q
    sec_update = time.time() - t_start
    
    # 4. Synchronize worker weights (after update)
    weights = R.get_weights()
    set_weights_list = [worker.set_weights.remote(weights) for worker in workers] 
    
    # Print
    if (t == 0) or (((t+1)%print_every) == 0): 
        print ("[%d/%d] n_env_step:[%.1e] rollout:[%.1f]s update:[%.1f]s r_avg:[%.4f] avg_q:[%.3f]."%
               (t+1,total_steps,n_env_step,sec_rollout,sec_update,r_avg,np.mean(avg_qs)))
    
    # 5. Evaluate
    if (t == 0) or (((t+1)%evaluate_every) == 0): 
        ram_percent = psutil.virtual_memory().percent # memory usage
        print ("[Eval. start] step:[%d/%d][%.1f%%] #step:[%.1e] time:[%s] ram:[%.1f%%]."%
               (t+1,total_steps,t/total_steps*100,
                n_env_step,
                time.strftime("day:[%d] %H:%M:%S", time.gmtime(time.time()-start_time)),
                ram_percent)
              )
        ops = []
        for i_idx in range(num_eval):
            worker,red = workers[i_idx],red_list_eval[i_idx]
            ops.append(worker.evaluate.remote(red=red))
        eval_vals = ray.get(ops)
        ep_ret_sum = 0
        for i_idx in range(num_eval):
            red,eval_val = red_list_eval[i_idx],eval_vals[i_idx]
            ep_ret,ep_len,blue_health,red_health = eval_val[0],eval_val[1],eval_val[2],eval_val[3]
            ep_ret_sum += ep_ret
            print (" [%d/%d] [%s] ep_ret:[%.4f] ep_len:[%d]. blue health:[%.2f] red health:[%.2f]"
                %(i_idx,len(eval_vals),red,ep_ret,ep_len,blue_health,red_health))
        ep_ret_avg = ep_ret_sum / num_eval
        print ("[Eval. done] time:[%s] ep_ret_avg:[%.3f].\n"%
               (time.strftime("day:[%d] %H:%M:%S", time.gmtime(time.time()-start_time)),ep_ret_avg))
        write_txt(f,"%.2f, r_train:%.4f, ret_eval:%.4f"%(time.time()-start_time,r_avg,ep_ret_avg),
                  ADD_NEWLINE=True,DO_PRINT=False)
        # Save current SAC model and replay buffers 
        npz_path = '../report/net/%s/model_and_buffers_%d.npz'%(exp_name,t+1)
        save_sac_model_and_buffers(npz_path,R,replay_buffer_long,replay_buffer_short,
                                   VERBOSE=False,IGNORE_BUFFERS=True)
        
    # 6. If something went bad, restore
    npz_path_list.append(npz_path)
    ep_ret_avg_list.append(ep_ret_avg)
    ep_ret_avg_array = np.asanyarray(ep_ret_avg_list)
    ep_ret_avg_max = np.max(ep_ret_avg_array)
    if (ep_ret_avg < 0.5*ep_ret_avg_max) and (len(npz_path_list) >= 2):
        npz_path = npz_path_list[-2] 
        restore_sac_model_and_buffers(npz_path=npz_path,R=R,
                                      replay_buffer_long=replay_buffer_long,
                                      replay_buffer_short=replay_buffer_short,
                                      VERBOSE=False,IGNORE_BUFFERS=True)

print ("Done.")

[2m[36m(pid=67452)[0m 
[2m[36m(pid=67452)[0m 
[2m[36m(pid=67452)[0m      JSBSim Flight Dynamics Model v1.1.0.dev1 Jul 11 2020 05:35:14
[2m[36m(pid=67423)[0m 
[2m[36m(pid=67423)[0m 
[2m[36m(pid=67423)[0m      JSBSim Flight Dynamics Model v1.1.0.dev1 Jul 11 2020 05:35:14
[2m[36m(pid=67426)[0m 
[2m[36m(pid=67426)[0m 
[2m[36m(pid=67426)[0m      JSBSim Flight Dynamics Model v1.1.0.dev1 Jul 11 2020 05:35:14
[2m[36m(pid=67445)[0m 
[2m[36m(pid=67445)[0m 
[2m[36m(pid=67445)[0m      JSBSim Flight Dynamics Model v1.1.0.dev1 Jul 11 2020 05:35:14
[2m[36m(pid=67429)[0m 
[2m[36m(pid=67429)[0m 
[2m[36m(pid=67429)[0m      JSBSim Flight Dynamics Model v1.1.0.dev1 Jul 11 2020 05:35:14
[2m[36m(pid=67422)[0m 
[2m[36m(pid=67422)[0m 
[2m[36m(pid=67422)[0m      JSBSim Flight Dynamics Model v1.1.0.dev1 Jul 11 2020 05:35:14
[2m[36m(pid=67444)[0m 
[2m[36m(pid=67444)[0m 
[2m[36m(pid=67444)[0m      JSBSim Flight Dynamics Model v1.1.0.dev1 Jul 11 2020 0