In [29]:
import numpy as np
import tensorflow as tf
import gym
import time
from spinup.utils.logx import EpochLogger

class ReplayBuffer:
    def __init__(self,obs_dim,act_dim,size):
        self.obs1_buf=np.zeros((size,obs_dim),dtype=np.float32)
        self.obs2_buf=np.zeros((size,obs_dim),dtype=np.float32)
        self.acts_buf=np.zeros((size,act_dim),dtype=np.float32)
        self.rews_buf=np.zeros(size,dtype=np.float32)
        self.done_buf=np.zeros(size,dtype=np.float32)
        self.ptr,self.size,self.max_size=0,0,size
        
    # simple FIFO
    def store(self,obs,act,rew,next_obs,done):
        self.obs1_buf[self.ptr]=obs
        self.acts_buf[self.ptr]=act
        self.rews_buf[self.ptr]=rew
        self.obs2_buf[self.ptr]=next_obs
        self.done_buf[self.ptr]=done
        self.ptr=(self.ptr+1)%self.max_size
        self.size=min(self.size+1,self.max_size)
        
    def sample_batch(self,batch_size=32):
        idx=np.random.randint(0,self.size,size=batch_size)
        return dict(obs1=self.obs1_buf[idx],
                    obs2=self.obs2_buf[idx],
                    acts=self.acts_buf[idx],
                    rews=self.rews_buf[idx],
                    done=self.done_buf[idx])
 


"""

td3: twin delayed DDPG

"""
def mlp(x,hidden_sizes=[32,],activation=tf.tanh,output_activation=None):
    for i in hidden_sizes[:-1]:
        x=tf.layers.dense(x,units=i,activation=activation)
    return tf.layers.dense(x,units=hidden_sizes[-1],activation=output_activation)


def mlp_actor_critic(x,a,hidden_sizes=(400,300),activation=tf.nn.relu,
                     output_activation=tf.tanh,action_space=None):
    act_dim=a.shape.as_list()[-1]
    act_limit=action_space.high[0]
    with tf.variable_scope('pi'):
        pi=act_limit*mlp(x,list(hidden_sizes)+[act_dim],activation,output_activation)
    with tf.variable_scope('q1'):
        q1=tf.squeeze(mlp(tf.concat([x,a],axis=-1),list(hidden_sizes)+[1],activation,None),axis=1)
    with tf.variable_scope('q2'):
        q2=tf.squeeze(mlp(tf.concat([x,a],axis=-1),list(hidden_sizes)+[1],activation,None),axis=1)
    with tf.variable_scope('q1',reuse=True):
        q1_pi=tf.squeeze(mlp(tf.concat([x,pi],axis=-1),list(hidden_sizes)+[1],activation,None),axis=1)
    return pi,q1,q2,q1_pi


def get_vars(scope):
    return [v for v in tf.global_variables() if scope in v.name]


def count_vars(scope):
    v=get_vars(scope)
    return sum([np.prod(var.shape.as_list()) for var in v])
        
    
"""
main function
"""
def td3(env_fn,ac_kwargs=dict(),act_noise=0.1, target_noise=0.2, noise_clip=0.5,gamma=0.99,pi_lr=1e-3, 
        q_lr=1e-3, polyak=0.995,replay_size=int(1e6),steps_per_epoch=5000,epochs=100,start_steps=10000,
        max_ep_len=1000,batch_size=100,policy_delay=2,logger_kwargs=dict()):

    logger=EpochLogger(**logger_kwargs)
    # 以dict保存在该行之前出现的局部变量（即参数列表中的超参数）和传入的logger_kwargs信息，并打印
    logger.save_config(locals())
    
    tf.set_random_seed(0)
    np.random.seed(0)
    
    env,test_env=env_fn(),env_fn()
    ac_kwargs['action_space']=env.action_space
    obs_dim=env.observation_space.shape[0]
    # 连续动作
    act_dim=env.action_space.shape[0]
    # 连续动作最大值
    act_limit=env.action_space.high[0]
    
    
    # input to computation graph
    x_ph=tf.placeholder(tf.float32,[None,obs_dim])
    a_ph=tf.placeholder(tf.float32,[None,act_dim])
    x2_ph=tf.placeholder(tf.float32,[None,obs_dim])
    r_ph=tf.placeholder(tf.float32,[None,])
    d_ph=tf.placeholder(tf.float32,[None,])
    
    # main output from computation graph
    with tf.variable_scope('main'):
        pi,q1,q2,q1_pi=mlp_actor_critic(x_ph,a_ph,**ac_kwargs)
    
    # target policy net
    with tf.variable_scope('target'):
        pi_targ,_,_,_,=mlp_actor_critic(x2_ph,a_ph,**ac_kwargs)
        
    # target Q net
    with tf.variable_scope('target',reuse=True):
        epsilon=tf.random_normal(tf.shape(pi_targ),stddev=target_noise)
        epsilon=tf.clip_by_value(epsilon,-noise_clip,noise_clip)
        a2=pi_targ+epsilon
        a2=tf.clip_by_value(a2,-act_limit,act_limit)
        _,q1_targ,q2_targ,_=mlp_actor_critic(x2_ph,a2,**ac_kwargs)
    
    
    # 需要在定义完网络马上计算，防止后面Adam对网络参数进行复制，重复计算
    # print parameter number of each net
    var_counts=tuple(count_vars(scope) for scope in ['main/pi','main/q1','main/q2','main'])
    print('\nNumber of parameters: pi %d, q1: %d, q2: %d, total: %d\n'%var_counts)
    
    #Bellman backup for Q-learning
    min_q_targ=tf.minimum(q1_targ,q2_targ)
    backup=tf.stop_gradient(r_ph+gamma*(1-d_ph)*min_q_targ)
    
    # TD3 losses
    pi_loss=-tf.reduce_mean(q1_pi)
    q1_loss=tf.reduce_mean((q1-backup)**2)
    q2_loss=tf.reduce_mean((q2-backup)**2)
    q_loss=q1+q2
    
    # train ops for pi and q
    pi_optimizer=tf.train.AdamOptimizer(learning_rate=pi_lr)
    q_optimizer=tf.train.AdamOptimizer(learning_rate=q_lr)
    train_pi_op=pi_optimizer.minimize(pi_loss,var_list=get_vars('main/pi'))
    train_q_op=q_optimizer.minimize(q_loss,var_list=get_vars('main/q'))
    
    # Polyak averaging for target variables
    target_update=tf.group([tf.assign(v_targ,polyak*v_targ+(1-polyak)*v_main) 
                            for v_main,v_targ in zip(get_vars('main'),get_vars('target'))])
    
    # initializing target variables to match the main
    target_init=tf.group([tf.assign(v_targ,v_main) 
                          for v_main, v_targ in zip(get_vars('main'),get_vars('target'))])
    
    
    sess=tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(target_init)
    
    # setup model saving
    logger.setup_tf_saver(sess,inputs={'x':x_ph,'a':a_ph},outputs={'pi':pi,'q1':q1,'q2':q2})
    
    replay_buffer=ReplayBuffer(obs_dim,act_dim,replay_size)
    
    def get_action(o,act_noise=0):
        a=sess.run(pi,feed_dict={x_ph:o.reshape(1,-1)})[0]
        a+=act_noise*np.random.randn(act_dim)
        return np.clip(a,-act_limit,act_limit)
    
    def test_agent(n=10):
        for j in range(n):
            o,r,d,ep_ret,ep_len=test_env.reset(),0,False,0,0
            while not(d or (ep_len==max_ep_len)):
                o,r,d,_=test_env.step(get_action(o))
                ep_ret+=r
                ep_len+=1
            logger.store(TestEpRet=ep_ret,TestEpLen=ep_len)
            
    
    start_time=time.time()
    o,r,d,ep_ret,ep_len=env.reset(),0,False,0,0
    total_steps=steps_per_epoch*epochs
    for t in range(total_steps):
        
        if t>start_steps:
            a=get_action(o,act_noise)
        else:
            a=env.action_space.sample()
         
        o2,r,d,_=env.step(a)
        ep_ret+=r
        ep_len+=1
        
        d=False if ep_len==max_ep_len else d
        replay_buffer.store(o,a,r,o2,d)
        o=o2
        
        # after each episode(trajectory) begin to train
        if d or (ep_len==max_ep_len):
            
            for j in range(ep_len):
                batch=replay_buffer.sample_batch(batch_size)
                feed_dict={x_ph:batch['obs1'],
                           a_ph:batch['acts'],
                           r_ph:batch['rews'],
                           x2_ph:batch['obs2'],
                           d_ph:batch['done']}
                
                # Q nets update
                q_step_ops=[q_loss,q1,q2,train_q_op]
                outs=sess.run(q_step_ops,feed_dict)
                logger.store(LossQ=outs[0],Q1Vals=outs[1],Q2Vals=outs[2])
                
                # delayed policy and all target nets update
                if j%policy_delay==0:
                    pi_step_ops=[pi_loss,train_pi_op,target_update]
                    outs=sess.run(pi_step_ops,feed_dict)
                    logger.store(LossPi=outs[0])
                    
            logger.store(EpRet=ep_ret,EpLen=ep_len)        
            o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0
            
        # log after each epoch
        if t>0 and t%steps_per_epoch==0:
            epoch=t//steps_per_epoch
            
            # # Save model
            # if (epoch % save_freq == 0) or (epoch == epochs-1):
            #    logger.save_state({'env': env}, None)
            
            # test performance
            test_agent()
            
            # log info about this epoch
            logger.log_tabular('Epoch',epoch)
            logger.log_tabular('EpRet',with_min_and_max=True)
            logger.log_tabular('TestEpRet',with_min_and_max=True)
            logger.log_tabular('EpLen',average_only=True)
            logger.log_tabular('TestEpLen',average_only=True)
            logger.log_tabular('TotalEnvInteracts',t)
            logger.log_tabular('Q1Vals',with_min_and_max=True)
            logger.log_tabular('Q2Vals',with_min_and_max=True)
            logger.log_tabular('LossPi',average_only=True)
            logger.log_tabular('LossQ',average_only=True)
            logger.log_tabular('Time',time.time()-start_time)
            logger.dump_tabular()
            
            
            
   
tf.reset_default_graph()
td3(lambda:gym.make('HalfCheetah-v2'),steps_per_epoch=5000,epochs=10,logger_kwargs=dict(exp_name='td3'))
    

[32;1mLogging data to /tmp/experiments/1559656756/progress.txt[0m
[36;1mSaving config:
[0m
{
    "ac_kwargs":	{},
    "act_noise":	0.1,
    "batch_size":	100,
    "env_fn":	"<function <lambda> at 0x131bf3f28>",
    "epochs":	10,
    "exp_name":	"td3",
    "gamma":	0.99,
    "logger":	{
        "<spinup.utils.logx.EpochLogger object at 0x13388eb00>":	{
            "epoch_dict":	{},
            "exp_name":	"td3",
            "first_row":	true,
            "log_current_row":	{},
            "log_headers":	[],
            "output_dir":	"/tmp/experiments/1559656756",
            "output_file":	{
                "<_io.TextIOWrapper name='/tmp/experiments/1559656756/progress.txt' mode='w' encoding='UTF-8'>":	{
                    "mode":	"w"
                }
            }
        }
    },
    "logger_kwargs":	{
        "exp_name":	"td3"
    },
    "max_ep_len":	1000,
    "noise_clip":	0.5,
    "pi_lr":	0.001,
    "policy_delay":	2,
    "polyak":	0.995,
    "q_lr":	0.001,
    "replay_size

---------------------------------------
|             Epoch |               9 |
|      AverageEpRet |            -555 |
|          StdEpRet |            1.06 |
|          MaxEpRet |            -553 |
|          MinEpRet |            -556 |
|  AverageTestEpRet |            -600 |
|      StdTestEpRet |           0.651 |
|      MaxTestEpRet |            -599 |
|      MinTestEpRet |            -602 |
|             EpLen |           1e+03 |
|         TestEpLen |           1e+03 |
| TotalEnvInteracts |         4.5e+04 |
|     AverageQ1Vals |       -4.84e+10 |
|         StdQ1Vals |        1.05e+10 |
|         MaxQ1Vals |       -3.28e+05 |
|         MinQ1Vals |       -1.51e+11 |
|     AverageQ2Vals |       -5.27e+10 |
|         StdQ2Vals |        1.13e+10 |
|         MaxQ2Vals |       -3.58e+05 |
|         MinQ2Vals |       -1.58e+11 |
|            LossPi |        5.23e+10 |
|             LossQ |       -1.01e+11 |
|              Time |             281 |
---------------------------------------
