In [None]:
!pip install swig
!pip install gym[box2d]

In [73]:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_probability as tfp

In [74]:
import gym
from tqdm import tqdm_notebook
import numpy as np
from collections import deque

In [75]:
class PolicyNetwork(tf.keras.Model):
    def __init__(self, observation_space, action_space):
        super(PolicyNetwork, self).__init__()
        # Initialize layers
        self.input_layer = tf.keras.layers.Dense(128, input_dim=observation_space, activation='relu')
        self.output_layer = tf.keras.layers.Dense(action_space, activation='softmax')

    def call(self, inputs):
        # Forward pass
        x = self.input_layer(inputs)
        action_probs = self.output_layer(x)
        return action_probs

In [76]:
class StateValueNetwork(tf.keras.Model):
    def __init__(self, observation_space):
        super(StateValueNetwork, self).__init__()
        self.input_layer = layers.Dense(128, activation='relu', input_shape=(observation_space,))
        self.output_layer = layers.Dense(1, activation=None)  # typically no activation for value output

    def call(self, inputs):
        x = self.input_layer(inputs)
        state_value = self.output_layer(x)
        return state_value

In [77]:
def select_action(network, state):
    ''' Selects an action given current state in TensorFlow
    Args:
    - network (tf.keras.Model): network to process state
    - state (Array): Array of action space in an environment

    Return:
    - (int): action that is selected
    - (float): log probability of selecting that action given state and network
    '''
    # convert state to float tensor, add 1 dimension
    state = tf.convert_to_tensor(state, dtype=tf.float32)
    state = tf.expand_dims(state, 0)

    # use network to predict action probabilities
    action_probs = network(state)

    # sample an action using the probability distribution
    action_dist = tfp.distributions.Categorical(probs=action_probs)
    action = action_dist.sample()

    # compute log probability of the selected action
    log_prob = action_dist.log_prob(action)

    # return action and log probability
    return action.numpy()[0], log_prob.numpy()[0]

In [78]:
env = gym.make(
    "LunarLander-v2"
)

# Instantiate the networks
policy_network = PolicyNetwork(env.observation_space.shape[0], env.action_space.n)
stateval_network = StateValueNetwork(env.observation_space.shape[0])

# Initialize the optimizers
policy_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
stateval_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)

  deprecation(
  deprecation(


In [79]:
#discount factor for future utilities
DISCOUNT_FACTOR = 0.99

#number of episodes to run
NUM_EPISODES = 500

#max steps per episode
MAX_STEPS = 10000

#score agent needs for environment to be solved
SOLVED_SCORE = 195

In [None]:
policy_network.trainable_variables

[<tf.Variable 'policy_network_5/dense_18/kernel:0' shape=(8, 128) dtype=float32, numpy=
 array([[-0.04808192, -0.1171983 ,  0.06812976, ..., -0.15145896,
          0.00117312,  0.03061894],
        [ 0.17905344,  0.07582764, -0.15012689, ...,  0.08952861,
          0.15179394, -0.00740643],
        [ 0.08402084,  0.00120749,  0.08395649, ..., -0.05509317,
          0.08067052, -0.17857836],
        ...,
        [-0.03884403, -0.10525813,  0.14588685, ..., -0.19521789,
         -0.06740102,  0.192192  ],
        [-0.0200502 ,  0.1033379 ,  0.11397909, ...,  0.13289063,
         -0.14636038,  0.10663916],
        [ 0.05271553,  0.19778956,  0.06622379, ...,  0.11798503,
          0.11922942,  0.03053285]], dtype=float32)>,
 <tf.Variable 'policy_network_5/dense_18/bias:0' shape=(128,) dtype=float32, numpy=
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0

In [80]:
import numpy as np
from collections import deque
from tqdm import tqdm

scores = []
recent_scores = deque(maxlen=100)

for episode in tqdm(range(NUM_EPISODES)):
    state = env.reset()
    done = False
    score = 0
    I = 1.0

    for step in range(MAX_STEPS):
        action, lp = select_action(policy_network, state)
        new_state, reward, done, _ = env.step(action)
        score += reward


        with tf.GradientTape() as policy_tape, tf.GradientTape() as value_tape:
            state_val = stateval_network(tf.convert_to_tensor([state], dtype=tf.float32))
            new_state_val = stateval_network(tf.convert_to_tensor([new_state], dtype=tf.float32))

            if done:
                new_state_val = tf.constant([[0.0]])

            # 计算损失
            val_loss = tf.reduce_mean(tf.square(reward + DISCOUNT_FACTOR * new_state_val - state_val))
            advantage = reward + DISCOUNT_FACTOR * tf.squeeze(new_state_val) - tf.squeeze(state_val)
            #advantage = tf.convert_to_tensor(advantage)
            lp_tensor = tf.convert_to_tensor(lp, dtype=tf.float32)
            policy_loss = -lp_tensor * advantage
            policy_loss *= I

        # 反向传播和优化
        gradients = policy_tape.gradient(policy_loss, policy_network.trainable_variables)
        policy_optimizer.apply_gradients(zip(gradients, policy_network.trainable_variables))

        gradients = value_tape.gradient(val_loss, stateval_network.trainable_variables)
        stateval_optimizer.apply_gradients(zip(gradients, stateval_network.trainable_variables))

        if done:
            break

        state = new_state
        I *= DISCOUNT_FACTOR

    scores.append(score)
    recent_scores.append(score)

    if np.mean(recent_scores) >= SOLVED_SCORE:
        break


  if not isinstance(terminated, (bool, np.bool8)):
  0%|          | 0/500 [00:00<?, ?it/s]


ValueError: No gradients provided for any variable: (['policy_network_6/dense_22/kernel:0', 'policy_network_6/dense_22/bias:0', 'policy_network_6/dense_23/kernel:0', 'policy_network_6/dense_23/bias:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'policy_network_6/dense_22/kernel:0' shape=(8, 128) dtype=float32, numpy=
array([[ 0.19794314, -0.0693457 , -0.19295913, ...,  0.15815432,
        -0.05465674,  0.07426123],
       [ 0.1701522 , -0.03206649,  0.04759721, ...,  0.04679726,
        -0.10077876, -0.10581645],
       [ 0.17518122, -0.0837068 ,  0.19845034, ...,  0.06641267,
        -0.1842982 , -0.17843804],
       ...,
       [-0.11206261, -0.11143719, -0.04583035, ..., -0.12017128,
        -0.01050779,  0.14847769],
       [ 0.15997095, -0.02252905, -0.03360023, ...,  0.16476949,
        -0.10409091,  0.16908588],
       [-0.00198098,  0.02651316, -0.13287775, ..., -0.07062525,
         0.10262774, -0.20046265]], dtype=float32)>), (None, <tf.Variable 'policy_network_6/dense_22/bias:0' shape=(128,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>), (None, <tf.Variable 'policy_network_6/dense_23/kernel:0' shape=(128, 4) dtype=float32, numpy=
array([[ 1.06330603e-01, -7.62804747e-02,  8.71686339e-02,
        -3.06594819e-02],
       [ 4.23246026e-02, -7.03360438e-02, -2.00807154e-01,
         1.69835448e-01],
       [ 1.64523661e-01,  5.05516529e-02,  1.86014622e-01,
         1.39171332e-01],
       [ 1.99945509e-01,  1.34308040e-01, -1.03957653e-01,
        -1.43412173e-01],
       [-1.70135513e-01, -1.85653150e-01, -2.60870159e-03,
        -7.92016834e-02],
       [-1.80171192e-01,  3.73846292e-03, -7.71536976e-02,
        -2.40430683e-02],
       [-1.97584257e-01, -8.12276006e-02, -1.36408329e-01,
        -1.33610368e-02],
       [-1.83209807e-01, -1.48252204e-01,  1.66556507e-01,
         4.59938347e-02],
       [-1.11017525e-01, -1.86573043e-01,  1.22067213e-01,
         1.56150490e-01],
       [-9.45316032e-02,  1.81597084e-01,  5.29719889e-03,
         3.69188786e-02],
       [-9.19125974e-03,  1.23067260e-01, -9.56045091e-03,
        -1.36263207e-01],
       [-7.22934306e-03, -1.07087165e-01,  1.63225889e-01,
         1.64157122e-01],
       [-4.02531773e-02,  1.16891235e-01, -1.45043954e-01,
        -1.06571943e-01],
       [ 5.09708524e-02, -1.94732040e-01,  9.35523808e-02,
        -1.94493532e-01],
       [ 2.11138606e-01,  1.02280140e-01, -1.52217075e-01,
         5.03168702e-02],
       [-8.97288844e-02,  7.14983940e-02, -9.10276175e-03,
         1.98337466e-01],
       [ 7.81753063e-02, -1.79849774e-01,  1.45690233e-01,
         9.75645781e-02],
       [-2.16089189e-02, -3.82962823e-02, -1.91314161e-01,
         1.47616774e-01],
       [-1.79549575e-01,  1.73332334e-01, -1.81627437e-01,
        -1.51302159e-01],
       [-1.61707982e-01, -1.88915849e-01,  1.72510505e-01,
         4.80477512e-02],
       [-1.92128420e-01, -6.60528243e-03,  1.76554978e-01,
        -1.89498320e-01],
       [-1.17404237e-01,  9.01598930e-02,  2.82741934e-02,
        -1.34737760e-01],
       [ 2.01421380e-02, -9.91848260e-02,  7.11948872e-02,
         8.65353644e-02],
       [-2.10304424e-01,  9.51729715e-03, -1.50332302e-01,
         3.99405658e-02],
       [ 1.57243192e-01,  1.71725899e-02, -1.85346752e-01,
         6.40893579e-02],
       [ 1.23017550e-01, -8.94089043e-02,  6.13142848e-02,
        -2.24143863e-02],
       [ 1.50053561e-01, -8.13153386e-03, -1.69534236e-01,
        -6.78670257e-02],
       [-1.39053106e-01, -3.67056280e-02,  2.10406840e-01,
         4.42417860e-02],
       [ 1.49842411e-01,  1.60798311e-01,  5.60019910e-03,
         2.09390640e-01],
       [ 1.03222132e-01,  1.24323189e-01,  1.77772105e-01,
        -6.23135865e-02],
       [-1.15693979e-01, -1.94578618e-01, -1.28901199e-01,
         5.80451488e-02],
       [ 2.11376458e-01,  1.59913957e-01, -6.07385933e-02,
        -1.96273074e-01],
       [ 2.76094228e-02, -1.66658774e-01,  6.43445253e-02,
        -2.04343870e-01],
       [-1.90600947e-01,  1.10443950e-01, -8.67892802e-02,
         8.83081555e-02],
       [-3.82841974e-02, -4.00606841e-02, -6.12005591e-03,
        -1.07019715e-01],
       [ 3.65974605e-02, -8.83988887e-02, -1.65737927e-01,
         1.21758521e-01],
       [ 1.61015868e-01, -3.02245319e-02,  1.13287181e-01,
         1.38407111e-01],
       [ 1.18250966e-01,  1.86343700e-01,  1.98873222e-01,
         1.36401862e-01],
       [ 1.44938737e-01,  7.39272535e-02,  4.14124429e-02,
        -1.77236661e-01],
       [-1.48533702e-01,  1.08207464e-02,  1.46355182e-02,
         1.72323346e-01],
       [-1.38949782e-01, -1.12002224e-01, -1.08594894e-02,
        -7.52977133e-02],
       [-1.49808750e-01, -1.63368538e-01, -1.64799631e-01,
         1.05948657e-01],
       [-1.28717482e-01,  1.04395717e-01,  3.52273583e-02,
         1.77274197e-02],
       [ 1.51114136e-01,  1.33358300e-01,  3.85555327e-02,
        -6.00708723e-02],
       [ 1.28580153e-01, -4.14508134e-02, -4.47023213e-02,
        -1.98886141e-01],
       [ 2.88581848e-02, -6.99032694e-02,  6.37893081e-02,
        -4.77938652e-02],
       [ 2.71584541e-02, -9.37154070e-02, -1.96547106e-01,
        -9.72436368e-02],
       [-3.35771888e-02,  1.30341291e-01,  7.88375139e-02,
        -8.32007527e-02],
       [-1.00751132e-01,  1.80264980e-01, -2.36517191e-03,
        -2.01647788e-01],
       [-1.38182834e-01,  9.05860960e-02, -9.16926414e-02,
         7.34997094e-02],
       [-3.98940593e-02, -1.58407629e-01, -7.71466792e-02,
         4.27599549e-02],
       [ 5.26261032e-02, -1.19656049e-01, -4.57668751e-02,
         1.34795457e-01],
       [-1.79478198e-01, -9.24279615e-02,  8.54624212e-02,
         3.62902433e-02],
       [-1.77766681e-02,  1.18063360e-01,  1.29462510e-01,
        -1.78109631e-01],
       [ 1.38703585e-02, -1.71708986e-01,  2.00262815e-01,
        -1.48620009e-01],
       [-5.82385659e-02,  1.38480186e-01, -2.01725319e-01,
         8.41534734e-02],
       [ 8.69933069e-02, -1.56721979e-01,  1.35901630e-01,
         1.93520516e-01],
       [ 2.84231752e-02, -5.01396060e-02,  7.23111331e-02,
         2.09705770e-01],
       [-1.31066903e-01, -8.65731984e-02, -1.33068413e-02,
        -1.78021953e-01],
       [-1.15318745e-02,  8.17998052e-02, -1.64541557e-01,
         1.65297359e-01],
       [ 2.11355954e-01, -6.60710186e-02,  1.82593077e-01,
        -1.37107968e-01],
       [-1.65402234e-01, -2.03320950e-01, -7.22186267e-03,
        -1.40096068e-01],
       [ 1.27942622e-01, -4.78529781e-02,  1.10772520e-01,
        -3.04414779e-02],
       [-1.25009879e-01, -7.90849328e-03, -4.47642803e-03,
        -1.35223299e-01],
       [ 1.84497058e-01,  1.09852880e-01,  1.91296458e-01,
        -9.60217118e-02],
       [-6.73592091e-03,  1.30437315e-03, -7.76143372e-03,
        -7.50616193e-03],
       [-7.36220181e-02,  5.35337925e-02,  1.10164881e-01,
         1.92249089e-01],
       [-2.07921207e-01,  1.77693635e-01,  9.63269174e-02,
         1.54321641e-01],
       [ 1.43261760e-01,  1.79507434e-01,  1.97145283e-01,
         1.97301835e-01],
       [-5.02301753e-03,  1.51138902e-01, -1.03302903e-01,
         9.75871980e-02],
       [ 7.83579350e-02,  7.28254914e-02,  1.81492090e-01,
        -1.76267520e-01],
       [-1.88843012e-01,  1.01403207e-01, -1.09601371e-01,
         1.55131638e-01],
       [-1.41087472e-01, -1.81775570e-01,  3.63094062e-02,
         1.11919075e-01],
       [ 1.24859214e-02, -2.06852838e-01,  1.40403539e-01,
         3.57478261e-02],
       [-1.15933865e-02,  7.23766387e-02,  4.02269661e-02,
         4.84882593e-02],
       [ 6.27103150e-02,  1.56307846e-01,  1.96936727e-01,
         6.58337474e-02],
       [ 1.16093725e-01, -1.43838853e-01,  1.01455092e-01,
         1.05234832e-01],
       [-1.73065573e-01, -1.37539014e-01,  2.06036419e-02,
        -1.02131903e-01],
       [ 2.02448845e-01, -2.70310193e-02, -1.28366098e-01,
        -5.43069392e-02],
       [-4.44632024e-02, -1.14134535e-01, -1.91864043e-01,
        -8.17696601e-02],
       [-4.21836376e-02,  1.23209149e-01,  1.70508862e-01,
        -2.93294489e-02],
       [ 9.32331383e-02, -1.17717534e-02, -1.21456385e-01,
         3.00101191e-02],
       [-7.11470991e-02,  1.43404752e-01,  3.26597840e-02,
        -8.50632638e-02],
       [ 1.50284529e-01, -1.59732893e-01,  1.31208360e-01,
        -1.24481536e-01],
       [-7.56400973e-02,  1.77500486e-01,  9.05575454e-02,
         1.74867332e-01],
       [ 1.68252945e-01, -1.09682977e-03,  8.75681639e-02,
         1.53979391e-01],
       [ 6.71943724e-02,  1.01254910e-01, -2.02452600e-01,
         9.03354883e-02],
       [ 1.43554449e-01, -2.54893601e-02, -3.46149504e-02,
        -7.50625134e-02],
       [ 2.06687778e-01, -8.90859962e-03,  9.56845582e-02,
         1.93929553e-01],
       [ 7.33094513e-02,  1.09358341e-01,  9.20066833e-02,
         1.84794635e-01],
       [ 1.74333155e-01, -1.41464800e-01,  1.86514854e-01,
         1.19621068e-01],
       [ 6.53931499e-03,  1.41471922e-02,  4.34999466e-02,
         1.93931907e-01],
       [ 6.03076816e-02,  6.03216887e-03, -1.45713359e-01,
         8.42989683e-02],
       [-1.92899674e-01,  1.61642462e-01, -1.17075972e-01,
         1.40048802e-01],
       [ 1.81811899e-02, -1.87134773e-01,  1.68589890e-01,
         2.04959035e-01],
       [ 5.84782362e-02, -1.19051211e-01, -1.03262663e-02,
         2.00820565e-02],
       [ 1.16196036e-01,  1.10352486e-01, -2.98827887e-02,
         1.53115690e-02],
       [-1.66746467e-01,  1.89272523e-01, -1.44556582e-01,
        -1.31497532e-01],
       [-9.62741897e-02,  9.40487981e-02, -3.49715799e-02,
         8.60926509e-02],
       [ 7.56886899e-02,  2.44906843e-02,  1.99107200e-01,
         1.32660091e-01],
       [-4.68522161e-02,  7.35274255e-02,  1.28319830e-01,
        -2.42428482e-03],
       [-1.07012749e-01,  1.75431132e-01, -5.05063534e-02,
        -1.78925514e-01],
       [ 2.99513340e-06, -1.23999611e-01,  1.81973338e-01,
        -1.57221183e-01],
       [-1.24859065e-01,  8.51984620e-02,  1.40994161e-01,
         1.94356233e-01],
       [ 1.07270777e-01,  2.11702585e-01,  1.67596757e-01,
        -1.41149729e-01],
       [ 3.30773145e-02, -1.30993143e-01, -2.04981998e-01,
         1.19993359e-01],
       [-1.14761636e-01, -1.54987276e-02, -4.82487977e-02,
        -9.42703336e-02],
       [ 4.02080417e-02,  1.45393580e-01, -2.35997736e-02,
        -1.45435154e-01],
       [ 1.97052062e-02,  4.72647548e-02,  1.69342816e-01,
        -1.15585625e-02],
       [ 9.36512649e-02,  6.70159161e-02,  7.50075579e-02,
        -1.87226176e-01],
       [ 5.39800525e-03, -3.11894715e-03,  1.60317451e-01,
        -1.39260769e-01],
       [-5.99365681e-02,  9.17688012e-03,  1.16972774e-01,
        -6.62216842e-02],
       [ 1.03269160e-01,  1.79090917e-01, -1.96864337e-01,
         5.39747179e-02],
       [ 1.19186878e-01,  5.26302755e-02, -2.06176937e-01,
        -1.34198695e-01],
       [ 8.72676373e-02, -4.25116569e-02, -1.28264874e-02,
         1.05408877e-01],
       [ 4.60210741e-02, -1.60351723e-01, -5.27526736e-02,
        -3.27737480e-02],
       [-1.70659781e-01,  6.62590861e-02,  7.94691443e-02,
         9.18562710e-02],
       [ 1.22167706e-01, -1.83340549e-01,  1.88295186e-01,
        -2.68437117e-02],
       [ 1.18517131e-02,  1.87831223e-01,  7.75187314e-02,
         4.32993770e-02],
       [-1.72440112e-02, -1.83416605e-02, -1.36687592e-01,
        -2.94106752e-02],
       [ 5.29761910e-02, -8.40286911e-02, -4.86193001e-02,
        -8.58268440e-02],
       [-1.18614778e-01, -5.67149073e-02,  1.28274858e-01,
         1.80760264e-01],
       [ 1.91902876e-01, -9.28475708e-02, -1.87610000e-01,
        -3.47717106e-02],
       [ 2.67806798e-02, -1.83247074e-01,  1.94213092e-02,
        -1.65749788e-02],
       [ 9.24155712e-02,  1.74851865e-01, -1.52744547e-01,
        -7.55293965e-02],
       [-8.27703178e-02,  1.71894014e-01, -1.37792498e-01,
        -9.40673649e-02],
       [-1.36231586e-01, -1.15223691e-01,  9.92234051e-02,
        -8.06782246e-02],
       [-1.56964332e-01, -8.03944319e-02,  7.51347840e-03,
        -1.08092859e-01]], dtype=float32)>), (None, <tf.Variable 'policy_network_6/dense_23/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)).

In [None]:
policy_network.trainable_variables

[<tf.Variable 'policy_network_5/dense_18/kernel:0' shape=(8, 128) dtype=float32, numpy=
 array([[-0.04808192, -0.1171983 ,  0.06812976, ..., -0.15145896,
          0.00117312,  0.03061894],
        [ 0.17905344,  0.07582764, -0.15012689, ...,  0.08952861,
          0.15179394, -0.00740643],
        [ 0.08402084,  0.00120749,  0.08395649, ..., -0.05509317,
          0.08067052, -0.17857836],
        ...,
        [-0.03884403, -0.10525813,  0.14588685, ..., -0.19521789,
         -0.06740102,  0.192192  ],
        [-0.0200502 ,  0.1033379 ,  0.11397909, ...,  0.13289063,
         -0.14636038,  0.10663916],
        [ 0.05271553,  0.19778956,  0.06622379, ...,  0.11798503,
          0.11922942,  0.03053285]], dtype=float32)>,
 <tf.Variable 'policy_network_5/dense_18/bias:0' shape=(128,) dtype=float32, numpy=
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0