1. 导入库

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from typing import Tuple
from collections import namedtuple
from torch.distributions import Normal

Buffer_Capacity=8192
min_Val=torch.tensor(1e-7).float()
gradient_steps=1
batch_size=128

2.创建Actor&Critic网络

In [2]:
class Actor(nn.Module):
    def __init__(self, state_dim,hidden_dim,max_action,min_log_std=-20,max_log_std=2):
        super(Actor,self).__init__()
        self.fc1=nn.Linear(state_dim,hidden_dim)
        self.fc2=nn.Linear(hidden_dim,hidden_dim)
        self.mu_head=nn.Linear(hidden_dim,1)
        self.log_std_head=nn.Linear(hidden_dim,1)
        self.max_action=max_action
        
        self.min_log_std=min_log_std
        self.max_log_std=max_log_std
    
    def forward(self,x):
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        mu=self.mu_head(x)
        log_std_head=F.relu(self.log_std_head(x))
        log_std_head=torch.clamp(log_std_head,self.min_log_std,self.max_log_std)
        return mu,log_std_head

In [3]:
class Q(nn.Module):
    def __init__(self,state_dim,action_dim,hidden_dim):
        super(Q,self).__init__()
        self.q=nn.Sequential(
            nn.Linear(state_dim+action_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,1)
        )
        
    def forward(self,state,action):
        x=torch.cat([state,action],dim=-1)
        return self.q(x)

In [4]:
class Critic(nn.Module):
    def __init__(self,state_dim,hidden_dim):
        super(Critic,self).__init__()
        self.critic=nn.Sequential(
            nn.Linear(state_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,1)
        )
        
    def forward(self,state):
        return self.critic(state)

3.Actor&Critic网络测试

In [5]:
state=torch.randn(5)
print('state:',state)

actor=Actor(5,32,1)
critic=Critic(5,32)
q=Q(5,2,32)

action=actor.forward(state)
value=critic.forward(state)
#qvalue=q.forward(state,action)

print('action:',action)
print('state value:',value)
#print('qvalue:',qvalue)

state: tensor([-1.0421,  0.5407, -0.0612,  1.6671, -0.7643])
action: (tensor([-0.1809], grad_fn=<ViewBackward0>), tensor([0.], grad_fn=<ClampBackward1>))
state value: tensor([-0.1126], grad_fn=<ViewBackward0>)


4. 构建SAC类

In [55]:
Transition = namedtuple('Transition', ['s', 'a', 'r', 's_', 'd'])
class SAC:
    def __init__(self,state_dim,action_dim,hidden_dim,max_action,device,actor_lr=1e-3,critic_lr=1e-3,gamma=0.99,tau=0.005,alpha=0.2):
        
        self.device=device
        
        self.policy_net=Actor(state_dim,action_dim,hidden_dim,max_action).to(self.device)
        self.value_net=Critic(state_dim,hidden_dim).to(self.device)
        self.Q_net=Q(state_dim,action_dim,hidden_dim).to(self.device)
        self.target_value_net=Critic(state_dim,hidden_dim).to(self.device)
        
        self.replay_buffer=[Transition]*Buffer_Capacity
        self.policy_optimizer=torch.optim.Adam(self.policy_net.parameters(),lr=actor_lr)
        self.value_optimizer=torch.optim.Adam(self.value_net.parameters(),lr=critic_lr)
        self.Q_optimizer=torch.optim.Adam(self.Q_net.parameters(),lr=critic_lr)
        self.num_transition=0
        self.num_training=1
        
        self.value_criterion=nn.MSELoss()
        self.Q_criterion=nn.MSELoss()
        
        for target_param,param in zip(self.target_value_net.parameters(),self.value_net.parameters()):
            target_param.data.copy_(param.data)
            
        self.gamma=gamma
        self.tau=tau
        self.alpha=alpha
            
    def select_action(self,state):
        state=torch.FloatTensor(state).to(self.device)
        mu,log_sigma=self.policy_net(state)
        sigma=torch.exp(log_sigma)
        dist=Normal(mu,sigma)
        z=dist.sample()
        action=torch.tanh(z).detach().cpu().numpy()
        return action
    
    def store(self,s,a,r,s_,d):
        index=self.num_transition%Buffer_Capacity
        transition=Transition(s,a,r,s_,d)
        self.replay_buffer[index]=transition
        self.num_transition+=1
        
    def get_action_log_prob(self,state):
        batch_mu,batch_log_sigma=self.policy_net(state)
        batch_sigma=torch.exp(batch_log_sigma)
        dist=Normal(batch_mu,batch_sigma)
        z=dist.sample()
        action=torch.tanh(z)
        log_prob=dist.log_prob(z)-torch.log(1-action.pow(2)+min_Val)
        return action,log_prob,z,batch_mu,batch_log_sigma
    
    def update(self):
        if self.num_training%500==0:
            print("Training ...{}",format(self.num_training))
        s=torch.tensor([t.s for t in self.replay_buffer]).float().to(self.device)
        a=torch.tensor([t.a for t in self.replay_buffer]).float().to(self.device)
        r=torch.tensor([t.r for t in self.replay_buffer]).float().to(self.device)
        s_=torch.tensor([t.s_ for t in self.replay_buffer]).float().to(self.device)
        d=torch.tensor([t.d for t in self.replay_buffer]).int().to(self.device)
        
        for _ in range(gradient_steps):
            index=np.random.choice(range(Buffer_Capacity),batch_size,replace=False)
            bn_s=s[index]
            bn_a=a[index].reshape(-1,1)
            bn_r=r[index].reshape(-1,1)
            bn_s_ = s_[index]
            bn_d = d[index].reshape(-1, 1)
            
            target_value=self.target_value_net(bn_s_)
            next_q_value=bn_r+(1-bn_d)*self.gamma*target_value
            
            expected_value=self.value_net(bn_s)
            print(bn_s.shape,bn_a.shape)
            expected_Q=self.Q_net(bn_s,bn_a)
            
            sample_action,log_prob,z,batch_mu,batch_log_sigma=self.get_action_log_prob(bn_s)
            expected_new_Q=self.Q_net(bn_s,sample_action)
            next_value=expected_new_Q-log_prob
            
            v_loss=self.value_criterion(expected_value,next_value.detach())
            v_loss=v_loss.mean()
            
            Q_loss=self.Q_criterion(expected_Q,next_q_value.detach())
            Q_loss = Q_loss.mean()

            log_policy_target = expected_new_Q - expected_value

            pi_loss = log_prob * (log_prob- log_policy_target).detach()
            pi_loss = pi_loss.mean()
            
            self.value_optimizer.zero_grad()
            v_loss.backward(retain_graph=True)
            nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5)
            self.value_optimizer.step()

            self.Q_optimizer.zero_grad()
            Q_loss.backward(retain_graph = True)
            nn.utils.clip_grad_norm_(self.Q_net.parameters(), 0.5)
            self.Q_optimizer.step()

            self.policy_optimizer.zero_grad()
            pi_loss.backward(retain_graph = True)
            nn.utils.clip_grad_norm_(self.policy_net.parameters(), 0.5)
            self.policy_optimizer.step()
            
            for target_param, param in zip(self.Target_value_net.parameters(), self.value_net.parameters()):
                target_param.data.copy_(target_param * (1 - self.tau) + param * self.tau)

            self.num_training += 1


5. 测试SAC模块

5.1 初始化SAC

In [56]:
state=np.random.randn(5)
print('state:',state)
sac=SAC(state_dim=5,action_dim=1,hidden_dim=32,max_action=1,device='cuda')

state: [-1.60781207  1.21274204  0.19062339  0.48187438 -1.79303331]


5.2 测试select_action

In [41]:
act=sac.select_action(state)
print(act)

[0.9627316]


5.3 测试Replay_Buffer

In [58]:
for _ in range(10000):
    s=np.random.randn(5)
    a=sac.select_action(s)
    r=np.random.randn(1)
    s_=np.random.randn(5)
    d=np.random.randint(0,1)
    sac.store(s,a,r,s_,d)

5.4.1 测试update的replay模块

In [59]:
sac.update()

torch.Size([128, 5]) torch.Size([128, 1])
<class 'torch.Tensor'>


RuntimeError: Found dtype Double but expected Float

In [44]:
index=np.random.choice(Buffer_Capacity,2,replace=False)
s=torch.tensor([t.s for t in sac.replay_buffer]).float()
a=torch.tensor([t.a for t in sac.replay_buffer]).float()

In [45]:
a[index]

tensor([[ 0.9975],
        [-1.0000]])

In [47]:
q=Q(5,1,32)
q(s[index],a[index])

tensor([[0.1830],
        [0.2313]], grad_fn=<AddmmBackward0>)