In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np

In [5]:

class ReplayBuffer():
    def __init__(self):
        self.s=[]
        self.a=[]
        self.r=[]
        self.s_=[]
        self.d=[]
    def add(self,state,action,reward,next_state,done):
        self.s.append(state)
        self.a.append(action)
        self.r.append(reward)
        self.s_.append(next_state)
        self.d.append(done)
        
    def sample(self,batch_size):
        index=np.random.choice(len(self.s),batch_size,replace=True)
        states=[self.s[i] for i in index]
        actions=[self.a[i] for i in index]
        rewards=[self.r[i] for i in index]
        next_states=[self.s_[i] for i in index]
        dones=[self.d[i] for i in index]
        
        return states,actions,rewards,next_states,dones

In [6]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim,max_action):
        super(Actor, self).__init__()
        self.max_action=max_action
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        return torch.tanh(self.fc3(x))*self.max_action

In [7]:

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [8]:

class SAC(object):
    def __init__(self, state_dim, action_dim, max_action, gamma=0.99,alpha=0.2):
        self.actor = Actor(state_dim, action_dim,max_action)
        self.critic = Critic(state_dim, action_dim)

        self.target_entropy = -action_dim
        self.log_alpha = torch.zeros(1, requires_grad=True)
        self.alpha = self.log_alpha.exp()

        self.optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.optimizer_critic = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        
        self.gamma=gamma
        self.alpha=alpha

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        action = self.actor(state).detach().numpy()[0]
        return action
    

    def update(self, replay_buffer, batch_size=100):
        # Sample a batch of transitions from replay buffer
        state, action, reward, next_state, done = replay_buffer.sample(batch_size)
        
        state=torch.FloatTensor(state)
        action=torch.FloatTensor(action)
        reward=torch.FloatTensor(reward)
        next_state=torch.FloatTensor(next_state)
        done=torch.FloatTensor(done)

        # Compute the target Q-value
        with torch.no_grad():
            next_action = self.actor(next_state)
            next_action=torch.FloatTensor(next_action)
            target_Q = self.critic(next_state, next_action)
            target_Q = reward + (1 - done) * self.gamma * (target_Q + self.alpha * next_action.pow(2).sum(dim=1))

        # Get current Q-value estimates
        current_Q = self.critic(state, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q, target_Q) + F.mse_loss(current_Q, target_Q)+1e-7

        # Compute actor loss
        actor_loss = -self.critic(state, self.actor(state)).mean() + self.alpha * (self.actor(state).pow(2).sum(dim=1) + self.target_entropy).mean()

        # Optimize the actor and critic
        self.optimizer_actor.zero_grad()
        actor_loss.backward()
        self.optimizer_actor.step()

        self.optimizer_critic.zero_grad()
        critic_loss.backward()
        self.optimizer_critic.step()

        # Update alpha
        '''alpha_loss = -(self.log_alpha * (actor_loss + self.target_entropy).detach()).mean()
        self.optimizer_alpha.zero_grad()
        alpha_loss.backward()
        self.optimizer_alpha.step()'''

        self.alpha = self.log_alpha.exp()

In [18]:
# Hyperparameters
gamma = 0.99
tau = 0.005

In [9]:
env=gym.make('Pendulum-v0')

In [None]:
done=False
for _ in range(100):
    for _ in range(200):
        s=env.reset()
        a=env.action_space.sample()
        s,r,done,_=env.step(a)
        env.render()
        if done:
            break

env.close()

In [11]:
state_dim=env.observation_space.shape[0]
action_dim=env.action_space.shape[0]

In [12]:
# Initialize SAC agent
sac = SAC(state_dim, action_dim,2)

In [13]:
rb=ReplayBuffer()

In [14]:
rewards=[]
for i in range(1000):
    s=env.reset()
    done=False
    er=0
    while not done:
        a=sac.select_action(s)
        s_,r,done,_=env.step(a)
        rb.add(s,a,r,s_,done)
        er+=r
        s=s_
    rewards.append(er)
    if i%500==0:
        print('{}/200',i/500)
reward_before=np.mean(rewards)

{}/200 0.0
{}/200 1.0


In [17]:
s,a,r,s_,d=rb.sample(2)

In [20]:
s[0].reshape(-1,1).shape

(3, 1)

In [None]:
state,action,reward,next_state,done=rb.sample(6)
next_state=torch.FloatTensor(next_state)
next_action=sac.actor(next_state)
next_action=torch.FloatTensor(next_action)
print(next_state,next_action)
x=torch.cat([next_state,next_action],dim=-1)
x

In [46]:
for _ in range(100):
    sac.update(rb,10000)

  critic_loss = F.mse_loss(current_Q, target_Q) + F.mse_loss(current_Q, target_Q)+1e-7


In [48]:
rewards=[]
for _ in range(1000):
    s=env.reset()
    done=False
    er=0
    while not done:
        a=sac.select_action(s)
        s_,r,done,_=env.step(a)
        rb.add(s,a,r,s_,done)
        er+=r
        s=s_
    rewards.append(er)
reward_after=np.mean(rewards)

In [49]:
reward_after

-1555.4407565465579