In [None]:
import gym
from gym import envs
d = envs.registry
for k, v in d.items():
    print(f"{k} -> {v}")

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

class DQN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DQNAgent:
    def __init__(self, input_dim, hidden_dim, output_dim, lr=1e-3, gamma=0.99, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995):
        self.q_net = DQN(input_dim, hidden_dim, output_dim)
        self.target_net = DQN(input_dim, hidden_dim, output_dim)
        self.target_net.load_state_dict(self.q_net.state_dict())
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=lr)
        self.memory = deque(maxlen=2000)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.output_dim = output_dim

    def get_action(self, state):
        if np.random.rand() < self.epsilon:
            return random.randrange(self.output_dim)
        state = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            q_values = self.q_net(state)
        return q_values.argmax().item()

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def update(self, batch_size=64):
        if len(self.memory) < batch_size:
            return
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions).unsqueeze(1)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones).unsqueeze(1)

        q_values = self.q_net(states).gather(1, actions)
        next_q_values = self.target_net(next_states).max(1)[0].unsqueeze(1)
        target = rewards + self.gamma * next_q_values * (1 - dones)

        loss = nn.MSELoss()(q_values, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target(self):
        self.target_net.load_state_dict(self.q_net.state_dict())
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
            
# ...existing code...
env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
hidden_dim = 64
agent = DQNAgent(input_dim, hidden_dim, output_dim)

episodes = 200
batch_size = 64
for ep in range(episodes):
    state, _ = env.reset()
    total_reward = 0
    for t in range(500):
        action = agent.get_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        agent.remember(state, action, reward, next_state, done)
        agent.update(batch_size)
        state = next_state
        total_reward += reward
        if done:
            break
    agent.update_target()
    print(f"Episode {ep}, Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")
# ...existing code...

Episode 0, Reward: 28.0, Epsilon: 0.99
Episode 1, Reward: 18.0, Epsilon: 0.99
Episode 2, Reward: 19.0, Epsilon: 0.99
Episode 3, Reward: 31.0, Epsilon: 0.98


  states = torch.FloatTensor(states)


Episode 4, Reward: 16.0, Epsilon: 0.98
Episode 5, Reward: 28.0, Epsilon: 0.97
Episode 6, Reward: 30.0, Epsilon: 0.97
Episode 7, Reward: 15.0, Epsilon: 0.96
Episode 8, Reward: 20.0, Epsilon: 0.96
Episode 9, Reward: 16.0, Epsilon: 0.95
Episode 10, Reward: 11.0, Epsilon: 0.95
Episode 11, Reward: 14.0, Epsilon: 0.94
Episode 12, Reward: 15.0, Epsilon: 0.94
Episode 13, Reward: 26.0, Epsilon: 0.93
Episode 14, Reward: 27.0, Epsilon: 0.93
Episode 15, Reward: 35.0, Epsilon: 0.92
Episode 16, Reward: 24.0, Epsilon: 0.92
Episode 17, Reward: 61.0, Epsilon: 0.91
Episode 18, Reward: 18.0, Epsilon: 0.91
Episode 19, Reward: 20.0, Epsilon: 0.90
Episode 20, Reward: 12.0, Epsilon: 0.90
Episode 21, Reward: 28.0, Epsilon: 0.90
Episode 22, Reward: 17.0, Epsilon: 0.89
Episode 23, Reward: 18.0, Epsilon: 0.89
Episode 24, Reward: 21.0, Epsilon: 0.88
Episode 25, Reward: 16.0, Epsilon: 0.88
Episode 26, Reward: 21.0, Epsilon: 0.87
Episode 27, Reward: 22.0, Epsilon: 0.87
Episode 28, Reward: 37.0, Epsilon: 0.86
Episod

In [None]:
import matplotlib.pyplot as plt
from IPython import display
env = gym.make('CartPole-v1', render_mode='rgb_array')    
state, _ = env.reset()  # 处理Gym v0.26+的返回值
frame = env.render()
plt.ion()  # 打开交互模式
fig, ax = plt.subplots()
img = ax.imshow(frame)
display.display(plt.gcf())

for _ in range(512):
    try:
        frame = env.render()
    except Exception as e:
        print(f"渲染错误: {e}")
        break
    img.set_data(frame)  # 只更新图像数据，而不是重建
    display.clear_output(wait=True)
    display.display(plt.gcf())
    action = agent.get_action(state)
    # if state[2] < 0:
    #     action = 0
    # else:
    #     action = 1
    next_state, reward, terminated, truncated, _ = env.step(action)  # Gym v0.26+的返回参数
    state = next_state
    if  truncated:
        state, _ = env.reset()  # 重置环境并获取初始状态
plt.ioff()
env.close()