# Augmented Random Search

In [None]:
import datetime,gym,os,pybullet_envs,time,os,psutil,cv2
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from collections import deque
from gym.spaces import Box, Discrete
np.set_printoptions(precision=2)
gym.logger.set_level(40) # gym logger
print ("Packaged loaded. TF version is [%s]."%(tf.__version__))

In [None]:
class MLP(tf.keras.Model): 
    def __init__(self,odim=24,adim=8,hdims=[256,256],actv='relu',out_actv='relu'):
        super(MLP, self).__init__()
        self.hdims = hdims
        self.layers_ = tf.keras.Sequential()
        ki = tf.keras.initializers.truncated_normal(stddev=0.1)
        self.layers_.add(tf.keras.layers.InputLayer(input_shape=(odim,)))
        for hdim in self.hdims:
            linear = tf.keras.layers.Dense(hdim, kernel_initializer=ki, activation=actv)
            self.layers_.add(linear)
        linear_out = tf.keras.layers.Dense(adim, kernel_initializer=ki, activation=out_actv)
        self.layers_.add(linear_out)
    @tf.function
    def call(self, obs):
        x = obs
        mu = self.layers_(x)
        return mu

def get_noises_from_weights(weights, nu=0.01):
    noises = []
    for weight in weights:
        noise = nu * np.random.randn(*weight.shape) # set shape
        noises.append(noise)
    return noises

print ("Done.")

In [None]:
class UpdateWorker(object):
    """
    Worker for update purposes
    """
    def __init__(self, args, seed=1):
        self.seed = seed
        self.env, _ = get_envs()
        odim, adim = self.env.observation_space.shape[0], self.env.action_space.shape[0]
        self.odim = odim
        self.adim = adim
        # Initialize model
        tf.random.set_seed(self.seed)
        np.random.seed(self.seed)
        # ARS Model
        self.model = MLP(odim, adim, hdims=args['hdims'], actv=args['actv'], out_actv=args['out_actv'])
        
    @tf.function
    def get_action(self, o):
        return self.model(o)[0]

    @tf.function
    def get_weights(self):
        weight_vals = self.model.trainable_weights
        return weight_vals

    @tf.function
    def set_weights(self, weight_vals):
        for old_weight, new_weight  in zip(self.model.trainable_weights, weight_vals):
            old_weight.assign(new_weight)

    def save_weight(self, log_path):
        self.model.save_weights(log_path + "/weights/weights")

    def load_weight(self, checkpoint):
        self.model.load_weights(checkpoint)

class RolloutWorker(object):
    """
    Worker for rollout
    """
    def __init__(self,args,worker_id=0,):
        self.worker_id = worker_id
        self.env, _ = get_envs()
        odim, adim = self.env.observation_space.shape[0], self.env.action_space.shape[0]
        self.odim = odim
        self.adim = adim
        # ARS Model
        self.model = MLP(odim, adim, hdims=args['hdims'], actv=args['actv'], out_actv=args['out_actv'])

    @tf.function
    def get_action(self, o):
        return self.model(o)[0]

    # @tf.function
    def set_weights(self, weight_vals, noise_vals, noise_sign=+1):
        for idx, weight in enumerate(self.model.trainable_weights):
            weight.assign(weight_vals[idx]+noise_sign*noise_vals[idx])

    def rollout(self,len_rollout=1000):
        """
        Rollout
        """
        # Loop
        self.o = self.env.reset() # reset always
        r_sum,step = 0,0
        for t in range(len_rollout):
            self.a = self.get_action(self.o.reshape(1, -1))
            self.o2,self.r,self.d,_ = self.env.step(self.a)
            # Save next state
            self.o = self.o2
            # Accumulate reward
            r_sum += self.r
            step += 1
            if self.d: break
        return r_sum,step


In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display, HTML

class Agent(object):
    def __init__(self, args, seed=1):
        # Config
        print(args)
        self.n_workers = args['n_workers']
        self.total_steps = args['total_steps']
        self.evaluate_every = args['evaluate_every']
        self.print_every = args['print_every']
        self.num_eval = args['num_eval']
        self.alpha = args['alpha']
        self.nu = args['nu']

        self.seed = seed
        # Environment
        self.env, self.eval_env = get_envs()
        odim, adim = self.env.observation_space.shape[0],self.env.action_space.shape[0]
        self.odim = odim
        self.adim = adim

        self.R = UpdateWorker(args, seed=0)
        self.workers = [RolloutWorker(
            args, worker_id=i,
        ) for i in range(self.n_workers)]

        
    def train(self,len_rollout=1000):
        start_time = time.time()
        n_env_step = 0

        for t in range(int(self.total_steps)):
            # Distribute worker weights
            weights = self.R.get_weights()
            noises_list = []
            for _ in range(self.n_workers):
                noises_list.append(get_noises_from_weights(weights, nu=self.nu))

            # Positive rollouts (noise_sign=+1)
            set_weights_list = [worker.set_weights(weights, noises, noise_sign=1)
                                for worker, noises in zip(self.workers, noises_list)]
            ops = [worker.rollout(len_rollout=len_rollout) for worker in self.workers]
            res_pos = ops
            rollout_pos_vals, r_idx = np.zeros(self.n_workers), 0
            for rew, eplen in res_pos:
                rollout_pos_vals[r_idx] = rew
                r_idx = r_idx + 1
                n_env_step += eplen

            # Negative rollouts (noise_sign=-1)
            set_weights_list = [worker.set_weights(weights, noises, noise_sign=-1)
                                for worker, noises in zip(self.workers, noises_list)]
            ops = [worker.rollout(len_rollout=len_rollout) for worker in self.workers]
            res_neg = ops
            rollout_neg_vals,r_idx = np.zeros(self.n_workers),0
            for rew,eplen in res_neg:
                rollout_neg_vals[r_idx] = rew
                r_idx = r_idx + 1
                n_env_step += eplen

            b = self.n_workers // 5

            # Scale reward
            rollout_pos_vals, rollout_neg_vals = rollout_pos_vals / 100, rollout_neg_vals / 100

            # Reward
            rollout_concat_vals = np.concatenate((rollout_pos_vals, rollout_neg_vals))
            rollout_delta_vals = rollout_pos_vals - rollout_neg_vals  # pos-neg
            rollout_max_vals = np.maximum(rollout_pos_vals, rollout_neg_vals)
            rollout_max_val = np.max(rollout_max_vals)  # single maximum
            rollout_delta_max_val = np.max(np.abs(rollout_delta_vals))

            # Re-initialize
            rollout_pos_vals, rollout_neg_vals = np.array([]), np.array([])

            # Sort
            sort_idx = np.argsort(-rollout_max_vals)

            # Update
            sigma_R = np.std(rollout_concat_vals)
            weights_updated = []
            for w_idx, weight in enumerate(weights):  # for each weight
                delta_weight_sum = np.zeros_like(weight)
                for k in range(b):
                    idx_k = sort_idx[k]  # sorted index
                    rollout_delta_k = rollout_delta_vals[k]
                    noises_k = noises_list[k]
                    noise_k = (1 / self.nu) * noises_k[w_idx]  # noise for current weight
                    delta_weight_sum += rollout_delta_k * noise_k
                delta_weight = (self.alpha / (b * sigma_R)) * delta_weight_sum
                weight = weight + delta_weight
                weights_updated.append(weight)

            # Set weight
            self.R.set_weights(weights_updated)

            # Print
            if (t == 0) or (((t + 1) % self.print_every) == 0):
                print("[%d/%d] rollout_max_val:[%.2f] rollout_delta_max_val:[%.2f] sigma_R:[%.2f] " %
                      (t, self.total_steps, rollout_max_val, rollout_delta_max_val, sigma_R))

            # Evaluate
            if (t == 0) or (((t + 1) % self.evaluate_every) == 0) or (t == (self.total_steps - 1)):
                ram_percent = psutil.virtual_memory().percent  # memory usage
                print("[Evaluate] step:[%d/%d][%.1f%%] #step:[%.1e] time:[%s] ram:[%.1f%%]." %
                      (t + 1, self.total_steps, t / self.total_steps * 100,
                       n_env_step,
                       time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time)),
                       ram_percent)
                      )
                for eval_idx in range(self.num_eval):
                    o, d, ep_ret, ep_len = self.eval_env.reset(), False, 0, 0
                    frames = []
                    while not (d):
                        a = self.R.get_action(o.reshape(1, -1))
                        o, r, d, _ = self.eval_env.step(a)
                        frame = self.eval_env.render(mode='rgb_array')
                        texted_frame = cv2.putText(
                            img=np.copy(frame),
                            text='tick:[%d]'%(ep_len),
                            org=(80,30),fontFace=2,fontScale=0.8,color=(0,0,255),thickness=1)
                        if (ep_len%5) == 0:
                            frames.append(texted_frame)
                        ep_ret += r  # compute return
                        ep_len += 1
                    display_frames_as_gif(frames)
                    print(" [Evaluate] [%d/%d] ep_ret:[%.4f] ep_len:[%d]"
                          % (eval_idx, self.num_eval, ep_ret, ep_len))

        print("Done.")
        self.eval_env.close()

def get_envs():
    env_name = 'AntBulletEnv-v0'
    env,eval_env = gym.make(env_name),gym.make(env_name)
    _ = eval_env.reset()
    for _ in range(3): # dummy run for proper rendering
        a = eval_env.action_space.sample()
        o,r,d,_ = eval_env.step(a)
        time.sleep(0.01)
    return env,eval_env

def display_animation(anim):
    plt.close(anim._fig)
    return HTML(anim.to_jshtml())

def display_frames_as_gif(frames):
    patch = plt.imshow(frames[0])
    plt.axis('off')
    def animate(i):
        patch.set_data(frames[i])
    anim = animation.FuncAnimation(
        plt.gcf(),animate,frames=len(frames),interval=10)
    display(display_animation(anim))

### Train

In [None]:
args = {}
args['total_steps'] = 1000
args['hdims'] = [256, 256]
args['num_eval'] = 1
args['print_every'] = 10
args['actv'] = 'tanh'
args['out_actv'] = 'tanh'
args['evaluate_every'] = 50
args['alpha'] = 0.01
args['nu'] = 0.06
args['n_workers'] = 50

a = Agent(args)
print("Start training")
a.train(len_rollout=500)