In [1]:
import os
from typing import Dict, List, Tuple

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython.display import clear_output
from torch.distributions import Categorical
import math
import torch.nn as nn

In [2]:
env = gym.make("CartPole-v1", render_mode="rgb_array")

device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
print(env.observation_space.shape, obs_dim, action_dim)

(4,) 4 2


In [5]:
class ActionNetwork(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        """Initialization."""
        super(ActionNetwork, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(in_dim, 128), 
            nn.ReLU(),
            nn.Linear(128, 128), 
            nn.ReLU(), 
            nn.Linear(128, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        return F.softmax(self.layers(x))
    
class QValueNetwork(nn.Module):
    def __init__(self, in_dim: int):
        '''状态价值，用来评估动作的好坏程度'''
        super(QValueNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 128), 
            nn.ReLU(),
            nn.Linear(128, 128), 
            nn.ReLU(), 
            nn.Linear(128, 1)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        return self.layers(x)

In [8]:
actor = ActionNetwork(obs_dim, action_dim)
critic = QValueNetwork(obs_dim)

actor_optimizer = optim.Adam(actor.parameters(), lr=0.0001)
critic_optimizer = optim.Adam(critic.parameters(), lr=0.0001)

In [9]:
def select_action(state):
    '''策略网络做出决策，给出一个动作，并让智能体执行'''
    action_probs = actor(torch.FloatTensor(state).to(device))
    m = Categorical(action_probs)
    action = m.sample()
    loss = m.log_prob(action)
    return action.item(), loss

In [None]:
gamma = 0.98

steps = []
U_s = []
view_losses = []
max_epoch = 1000
for i in range(max_epoch):
    score = 0
    step = 0
    
    state, _ = env.reset(seed=3)
    trajectories = []
    while True:
        action, loss = select_action(state)
        '''从环境中观测到奖励和新的状态'''
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        
        step += 1
        score += reward
        
        trajectories.append([state, action, reward, loss])
        if done or step > 200:
            steps.append(step)
            break
        
        '''根据策略网络做决策,但不让智能体执行动作'''
        '''让价值网络打分'''
        '''计算 TD 目标和 TD 误差'''
        '''更新价值网络'''
        '''更新策略网络'''
        state = next_state