# Soft Actor Critic (SAC)

### Import Libraries

In [None]:
import random,datetime,gym,os,time,psutil,cv2,scipy.signal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from gym.spaces import Box, Discrete
from matplotlib import animation
from IPython.display import display, HTML
from collections import deque,namedtuple
%matplotlib inline
gym.logger.set_level(40)
print ("gym version:[%s]"%(gym.__version__))
print ("Pytorch:[%s]"%(torch.__version__))

### Util functions

In [None]:
def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

class SACBuffer:
    """
    A buffer for storing trajectories experienced by a SAC agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """
    def __init__(self, odim, adim, size=5000):
        self.obs1_buf = np.zeros(combined_shape(size, odim), dtype=np.float32)
        self.obs2_buf = np.zeros(combined_shape(size, odim), dtype=np.float32)
        self.acts_buf = np.zeros(combined_shape(size, adim), dtype=np.float32)
        self.rews_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size  # buffer has to have room so you can store
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(obs1=self.obs1_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     acts=self.acts_buf[idxs],
                     rews=self.rews_buf[idxs],
                     done=self.done_buf[idxs])
        return {k: v for k, v in batch.items()}

    def get(self):
        names = ['obs1_buf','obs2_buf','acts_buf','rews_buf','done_buf',
                 'ptr','size','max_size']
        vals =[self.obs1_buf,self.obs2_buf,self.acts_buf,self.rews_buf,self.done_buf,
               self.ptr,self.size,self.max_size]
        return names,vals

    def restore(self,a):
        self.obs1_buf = a[0]
        self.obs2_buf = a[1]
        self.acts_buf = a[2]
        self.rews_buf = a[3]
        self.done_buf = a[4]
        self.ptr = a[5]
        self.size = a[6]
        self.max_size = a[7]

def display_animation(anim):
    plt.close(anim._fig)
    return HTML(anim.to_jshtml())

def display_frames_as_gif(frames):
    patch = plt.imshow(frames[0])
    plt.axis('off')
    def animate(i):
        patch.set_data(frames[i])
    anim = animation.FuncAnimation(
        plt.gcf(),animate,frames=len(frames),interval=10)
    display(display_animation(anim))
    
print ("Done.")        

### Model

In [None]:
def mlp(odim=24, hdims=[256, 256]):
    layers = []
    layers.append(nn.Linear(odim,hdims[0]))
    layers.append(nn.ReLU())
    for idx, hdim in enumerate(hdims):
        if idx < len(hdims)-2:
            layers.append(nn.Linear(hdim,hdims[idx+1]))
            layers.append(nn.ReLU())
        elif idx == len(hdims)-1 :
            layers.append(nn.Linear(hdims[idx-1],hdim))

    return nn.Sequential(*layers)

def gaussian_loglik(x,mu,log_std):
    EPS = 1e-8
    pre_sum = -0.5*(( (x-mu)/(torch.exp(log_std)+EPS) )**2 + 2*log_std + np.log(2*np.pi))
    return torch.sum(pre_sum, axis=1)

class GaussianPolicy(nn.Module):
    def __init__(self,odim,adim,hdims=[256,256]):
        super(GaussianPolicy, self).__init__()
        self.odim    = odim
        self.adim    = adim
        self.hdims   = hdims
        # Define network
        self.net     = mlp(self.odim,self.hdims)
        self.mu      = nn.Linear(self.hdims[-1],self.adim)
        self.log_std = nn.Linear(self.hdims[-1],self.adim)

    def forward(self, o, get_logprob=True):
        net_ouput = self.net(o)
        mu = self.mu(net_ouput)
        log_std = self.log_std(net_ouput)

        LOG_STD_MIN, LOG_STD_MAX = -10.0, +2.0
        log_std = torch.clip(log_std, LOG_STD_MIN, LOG_STD_MAX) #log_std
        std = torch.exp(log_std) 
        dist = torch.distributions.Normal(mu, std)
        pi = dist.sample()   

        if get_logprob:
            # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
            # NOTE: The correction formula is a little bit magic. To get an understanding
            # of where it comes from, check out the original SAC paper (arXiv 1801.01290)
            # and look in appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi = gaussian_loglik(x=pi, mu=mu, log_std=log_std)
            logp_pi -= torch.sum(2*(np.log(2) - pi - F.softplus(-2*pi)), axis=1)
        else:
            logp_pi = None
        mu, pi = torch.tanh(mu), torch.tanh(pi) # squach action 

        return mu, pi, logp_pi

class QFunction(nn.Module):
    def __init__(self,odim,adim,hdims=[256,256]):
        super(QFunction, self).__init__()
        self.q = mlp(odim+adim, hdims=hdims)
        self.q2 = nn.Linear(hdims[-1],1)

    def forward(self, o, a):
        x = torch.concat([o, a], -1)
        q = self.q(x)
        q = self.q2(q)
        return torch.squeeze(q, axis=1)

class ActorCritic(nn.Module):
    def __init__(self,odim,adim,hdims=[256,256],
                 alpha_pi=0.1,alpha_q=0.1,gamma=0.98,lr=3e-4):
        super(ActorCritic,self).__init__()
        self.odim       = odim
        self.adim       = adim
        self.hdims      = hdims
        
        self.alpha_pi   = alpha_pi
        self.alpha_q    = alpha_q
        self.gamma      = gamma
        self.lr         = lr
        
        # Define policy and value functions
        self.policy = GaussianPolicy(
            odim=self.odim,adim=self.adim,hdims=self.hdims)
        self.q1 = QFunction(
            odim=self.odim,adim=self.adim,hdims=self.hdims)
        self.q2 = QFunction(
            odim=self.odim,adim=self.adim,hdims=self.hdims)
        
        # Optimizers
        self.train_pi = optim.Adam(self.policy.parameters(),lr=self.lr)
        self.train_q1 = optim.Adam(self.q1.parameters(), lr=self.lr)
        self.train_q2 = optim.Adam(self.q2.parameters(), lr=self.lr)

    def forward(self, o, deterministic=False):
        mu, pi, _ = self.policy(o, False)
        if deterministic: return mu
        else: return pi

    def update_policy(self, data):
        o = data['obs1']
        _, pi, logp_pi = self.policy(o)
        q1_pi = self.q1(o, pi)
        q2_pi = self.q2(o, pi)
        min_q_pi = torch.minimum(q1_pi, q2_pi)
        pi_loss = torch.mean(self.alpha_pi*logp_pi - min_q_pi)

        self.train_pi.zero_grad()
        pi_loss.backward()
        self.train_pi.step()

        return pi_loss, logp_pi, min_q_pi

    def update_Q(self, target, data):
        o,a,r,o2,d = data['obs1'],data['acts'],data['rews'],data['obs2'],data['done']
        _, pi_next, logp_pi_next = self.policy(o2)
        q1_targ = target.q1(o2, pi_next)
        q2_targ = target.q2(o2, pi_next)
        min_q_targ = torch.minimum(q1_targ, q2_targ)
        q_backup =  r + self.gamma*(1-d)*(min_q_targ - self.alpha_q*logp_pi_next)
        q1 = self.q1(o, a)
        q2 = self.q2(o, a)
        q1_loss = 0.5*F.mse_loss(q1,q_backup.detach())
        q2_loss = 0.5*F.mse_loss(q2,q_backup.detach())
        value_loss = q1_loss + q2_loss

        self.train_q1.zero_grad()
        q1_loss.backward()
        self.train_q1.step()

        self.train_q2.zero_grad()
        q2_loss.backward()
        self.train_q2.step()

        return value_loss, q1, q2, logp_pi_next, q_backup, q1_targ, q2_targ

print ("Done.")

### SAC Agent

In [None]:
def get_envs():
    env = gym.make('Ant-v2',render_mode='rgb_array')
    eval_env = gym.make('Ant-v2',render_mode='rgb_array')
    _,_ = eval_env.reset()
    for _ in range(3): # dummy run for proper rendering
        a = eval_env.action_space.sample()
        o,r,d,_,_ = eval_env.step(a)
        time.sleep(0.01)
    return env,eval_env

class Agent(object):
    def __init__(self,hdims=[256,256],alpha_pi=0.1,alpha_q=0.1,gamma=0.98,polyak=0.995,
                 lr=3e-4,seed=1,
                 buffer_size_short=1e5,buffer_size_long=1e6):
        """
        Initialize SAC agent
        """
        self.hdims              = hdims
        self.alpha_pi           = alpha_pi
        self.alpha_q            = alpha_q
        self.gamma              = gamma
        self.polyak             = polyak
        
        self.lr                 = lr
        self.seed               = seed
        
        self.buffer_size_short  = buffer_size_short
        self.buffer_size_long   = buffer_size_long
        
        # Environment
        self.env, self.eval_env = get_envs()
        odim, adim    = self.env.observation_space.shape[0],self.env.action_space.shape[0]
        self.odim     = odim
        self.adim     = adim

        # Actor-critic model
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)
        self.model = ActorCritic(self.odim,self.adim,hdims=self.hdims,
                                 alpha_pi=self.alpha_pi,alpha_q=self.alpha_q,gamma=self.gamma,
                                 lr=self.lr)
        self.target = ActorCritic(self.odim,self.adim,hdims=self.hdims,
                                  alpha_pi=self.alpha_pi,alpha_q=self.alpha_q,gamma=self.gamma,
                                  lr=self.lr)
        self.target.load_state_dict(self.model.state_dict())
        # Buffers
        self.replay_buffer_long = SACBuffer(odim=self.odim,adim=self.adim,
                                            size=int(self.buffer_size_long))
        self.replay_buffer_short = SACBuffer(odim=self.odim,adim=self.adim,
                                             size=int(self.buffer_size_short))

    def get_action(self, o, deterministic=False):
        return self.model(torch.FloatTensor(o.reshape(1,-1)),deterministic)

    def get_weights(self):
        weight_vals = self.model.state_dict()
        return weight_vals

    def set_weights(self, weight_vals):
        return self.model.load_state_dict(weight_vals)
        
    def update_sac(self, replay_buffer):
        pi_loss, logp_pi, min_q_pi = self.model.update_policy(replay_buffer)
        value_loss, q1, q2, logp_pi_next, q_backup, q1_targ, q2_targ = \
            self.model.update_Q(self.target, replay_buffer)

        # Polyak averaging of value networks
        for v_main, v_targ in zip(self.model.q1.parameters(),
                                  self.target.q1.parameters()):
            v_targ.data.copy_(v_main.data * (1 - self.polyak) + v_targ.data * self.polyak)
        for v_main, v_targ in zip(self.model.q2.parameters(),
                                  self.target.q2.parameters()):
            v_targ.data.copy_(v_main.data * (1 - self.polyak) + v_targ.data * self.polyak)

        return logp_pi, min_q_pi, logp_pi_next, q_backup, q1_targ, q2_targ

    def train(self,total_steps=1e6,start_steps=1e4,evaluate_every=1e4,plot_every=1e4,
              batch_size=128,update_count=2,max_ep_len_eval=1000,load_dir=None):
        """
        Train SAC
        """
        start_time = time.time()
        [v_targ.data.copy_(v_main.data) for v_main, v_targ in zip(
            self.model.parameters(), self.target.parameters()
        )]
        o,_ = self.env.reset()
        r,d,ep_ret,ep_len,n_env_step = 0,False,0,0,0
        for step in range(int(total_steps)):
            # Step
            if step > start_steps:
                a = self.get_action(o, deterministic=False)
                a = a.numpy()[0]
            else:
                a = self.env.action_space.sample()
            o2,r,d,_,_ = self.env.step(a)
            if r < 0.0: r = 0.0
            r = r + 0.01
            ep_len += 1
            ep_ret += r

            # Append
            self.replay_buffer_long.store(o, a, r, o2, d)
            self.replay_buffer_short.store(o, a, r, o2, d)
            n_env_step += 1
            o = o2

            # Reset when done
            if d:
                o,_ = self.env.reset()
                ep_ret, ep_len = 0, 0

            # Update
            if step >= start_steps:
                for _ in range(update_count):
                    batch = self.replay_buffer_long.sample_batch(batch_size//2)
                    batch_short = self.replay_buffer_short.sample_batch(batch_size//2)
                    batch = {k: torch.FloatTensor(v) for k, v in batch.items()}
                    batch_short = {k: torch.FloatTensor(v) for k, v in batch_short.items()}
                    replay_buffer = dict(obs1=torch.concat([batch['obs1'], batch_short['obs1']], 0),
                                            obs2=torch.concat([batch['obs2'], batch_short['obs2']], 0),
                                            acts=torch.concat([batch['acts'], batch_short['acts']], 0),
                                            rews=torch.concat([batch['rews'], batch_short['rews']], 0),
                                            done=torch.concat([batch['done'], batch_short['done']], 0))
                    logp_pi, min_q_pi, logp_pi_next, q_backup, q1_targ, q2_targ = self.update_sac(replay_buffer)

            # Evaluate
            if (step==0) or (((step+1)%evaluate_every)==0) or (((step+1)%plot_every)==0):
                ram_percent = psutil.virtual_memory().percent  # memory usage
                print("[Eval. start] step:[%d/%d][%.1f%%] #step:[%.1e] time:[%s] ram:[%.1f%%]." %
                      (step + 1, total_steps, step / total_steps * 100,
                       n_env_step,
                       time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time)),
                       ram_percent)
                      )
                o,_ = self.eval_env.reset()
                d, ep_ret, ep_len = False, 0, 0
                _ = self.eval_env.render()
                frames = []
                while not (d or (ep_len == max_ep_len_eval)):
                    a = self.get_action(o, deterministic=True)
                    o,r,d,_,_ = self.eval_env.step(a.detach().numpy()[0])
                    if r < 0.0: r = 0.0
                    r = r + 0.01
                    frame = self.eval_env.render()
                    texted_frame = cv2.putText(
                        img=np.copy(frame),
                        text='tick:[%d]'%(ep_len),
                        org=(80,30),fontFace=2,fontScale=0.8,color=(0,0,255),thickness=1)
                    if (ep_len%5) == 0:
                        frames.append(texted_frame)
                    ep_ret += r  # compute return
                    ep_len += 1
                if (step==0) or (((step+1)%plot_every)==0):
                    display_frames_as_gif(frames)
                print("[Eval. done] ep_ret:[%.4f] ep_len:[%d]"% (ep_ret, ep_len))
    
print ("Done.")

### Train an Ant agent with SAC

In [None]:
A = Agent(hdims=[256,256],alpha_pi=0.1,alpha_q=0.1,gamma=0.98,polyak=0.995,
          lr=1e-4,seed=1,buffer_size_short=5e3,buffer_size_long=1e5)
A.train(total_steps=2e5,start_steps=1e4,evaluate_every=1e4,plot_every=5e4,
        batch_size=128,update_count=2,max_ep_len_eval=500)