In [53]:
import torch
import torch.nn as nn 
import numpy as np
import gym
import random

from torch.distributions import Normal
from copy import deepcopy

In [None]:
class SAC(nn.Module):
    def __init__(self, state_dim, action_dim, batch_size=64, gamma=0.99, alpha=0.1, tau=1e-2, pi_lr=1e-3, q_lr=1e-3):
        super().__init__()

        self.pi_model = nn.Sequential(nn.Linear(state_dim, 128), nn.ReLU(),
                                      nn.Linear(128, 128), nn.ReLU(),
                                      nn.Linear(128, 2 * action_dim), nn.Tanh())
        
        self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, 128), nn.ReLU(),
                                      nn.Linear(128, 128), nn.ReLU(),
                                      nn.Linear(128, 1))
        
        self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, 128), nn.ReLU(),
                                      nn.Linear(128, 128), nn.ReLU(),
                                      nn.Linear(128, 1))
        
        self.memory = []
        self.batch_size = batch_size

        self.pi_optimizer = torch.optim.Adam(self.pi_model.parameters(), pi_lr)
        self.q1_optimizer = torch.optim.Adam(self.q1_model.parameters(), q_lr)
        self.q2_optimizer = torch.optim.Adam(self.q2_model.parameters(), q_lr)

        self.q1_target_model = deepcopy(self.q1_model)
        self.q2_target_model = deepcopy(self.q2_model)
        
        self.gamma = gamma
        self.alpha = alpha
        self.tau = tau

        
    def get_action(self, state):
        states = torch.FloatTensor(state).unsqueeze(0)
        action, _ = self.predict_actions(states)
        return action.squeeze(1).detach().numpy()
    
    def fit(self, state, action, reward, done, next_state):
        self.memory.append([state, action, reward, done, next_state])

        if len(self.memory) > self.batch_size:
            batch = random.sample(self.memory, self.batch_size)
            states, actions, rewards, dones, next_states = map(torch.FloatTensor, zip(*batch))
            rewards, dones = rewards.unsqueeze(1), dones.unsqueeze(1)

            next_actions, next_log_probs = self.predict_actions(next_states)
            next_states_actions = torch.concatenate((next_states, next_actions), dim = 1)
            q1_values = self.q1_target_model(next_states_actions)
            q2_values = self.q2_target_model(next_states_actions)
            next_min_values = torch.min(q1_values, q2_values)
            targets = rewards * self.gamma * (1- dones) * (next_min_values - self.alpha * next_log_probs)
            states_and_actions = torch.concatenate((states, actions), dim = 1)
            q1_loss = torch.mean((self.q1_model(states_and_actions)- targets.detach())**2)
            q2_loss = torch.mean((self.q2_model(states_and_actions)- targets.detach())**2)
            self.update_model(q1_loss, self.q1_optimizer, self.q1_model, self.q1_target_model)
            self.update_model(q2_loss, self.q2_optimizer, self.q2_model, self.q2_target_model)

            pred



    def update_model(self, loss, optimizer, model=None, target_model=None):
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if model != None and target_model != None:
            for param, terget_param in zip(model.parameters(), target_model.parameters()):
                new_terget_param = (1 - self.tau) * terget_param + self.tau * param
                terget_param.data.copy_(new_terget_param)


    def predict_actions(self, states):
        mean, log_std = self.pi_model(states).T
        mean, log_std = mean.unsqueeze(1), log_std.unsqueeze(1)
        dist = Normal(mean, torch.exp(log_std))
        action = dist.rsample()
        log_probs = dist.log_prob(action)
        return action, log_probs

In [58]:
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0] 
action_dim = env.action_space.shape[0]

agent = SAC(state_dim, action_dim)

total_reward = 0

state = env.reset()
for t in range(200):
    action = agent.get_action(state)
    next_state, reward, done, _ = env.step(action)
    
    agent.fit(state, action, reward, done, next_state)

    total_reward += reward
    state = next_state

total_reward

  logger.warn(


np.float64(-1045.3491777449074)