<a href="https://colab.research.google.com/github/zoujiulong/Reinforcement-Learning/blob/main/PPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import gymnasium as gym
import torch
import torch.nn as nn
from torch.distributions import Categorical
import numpy as np
from torch.optim import Adam
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Softmax(dim=-1)
        )
    def forward(self, x):
        return self.model(x)

class Critic(nn.Module):
    def __init__(self, state_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    def forward(self, x):
        return self.model(x)

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
    values = values + [0]
    gae = 0
    returns = []
    for i in reversed(range(len(rewards))):
        delta = rewards[i] + gamma * values[i + 1] * (1 - dones[i]) - values[i]
        gae = delta + gamma * lam * (1 - dones[i]) * gae
        returns.insert(0, gae + values[i])
    return returns

def train():
    env = gym.make("LunarLander-v3")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    actor = Actor(state_dim, action_dim).to(device)
    critic = Critic(state_dim).to(device)
    actor_opt = Adam(actor.parameters(), lr=3e-4)
    critic_opt = Adam(critic.parameters(), lr=1e-3)

    max_episodes = 1000
    update_interval = 2048
    eps_clip = 0.2
    epochs = 10

    memory = []
    beta=0.01
    episode_rewards = []
    epoch=list(range(max_episodes))
    avg=[]
    for episode in range(max_episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False

        while not done:
            state_tensor = torch.FloatTensor(state).to(device)
            probs = actor(state_tensor)
            dist = Categorical(probs)
            action = dist.sample()
            next_state, reward, done, _, _ = env.step(action.item())

            memory.append({
                'state': state,
                'action': action.item(),
                'reward': reward,
                'done': done,
                'log_prob': dist.log_prob(action).item()
            })

            total_reward += reward
            state = next_state

            if len(memory) >= update_interval:
                # 计算所有的 return 和 advantage
                states = torch.FloatTensor([m['state'] for m in memory]).to(device)
                actions = torch.LongTensor([m['action'] for m in memory]).to(device)
                rewards = [m['reward'] for m in memory]
                dones = [m['done'] for m in memory]
                old_log_probs = torch.FloatTensor([m['log_prob'] for m in memory]).to(device)

                with torch.no_grad():
                    values = critic(states).squeeze().cpu().numpy().tolist()

                returns = compute_gae(rewards, values, dones)
                returns = torch.FloatTensor(returns).to(device)
                advantages = returns - critic(states).squeeze().detach()

                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                for _ in range(epochs):
                    probs = actor(states)
                    dist = Categorical(probs)
                    new_log_probs = dist.log_prob(actions)
                    entropy=dist.entropy()
                    ratio = torch.exp(new_log_probs - old_log_probs)

                    surr1 = ratio * advantages
                    surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantages
                    actor_loss = -torch.min(surr1, surr2).mean()
                    critic_loss = nn.MSELoss()(critic(states).squeeze(), returns)

                    actor_opt.zero_grad()
                    critic_opt.zero_grad()
                    (actor_loss + 0.5 * critic_loss-beta*entropy.mean()).backward()
                    actor_opt.step()
                    critic_opt.step()

                memory.clear()

        episode_rewards.append(total_reward)
        avg_reward = np.mean(episode_rewards[-10:])
        avg.append(avg_reward)
        print(f"Episode {episode}: reward={total_reward:.1f}, avg10={avg_reward:.1f}")

        if avg_reward > 200:
            print("🎉 Solved!")
            break
    plt.plot(epoch,episode_rewards,label='reward')
    plt.plot(epoch,avg,label='avg_reward')
    plt.legend()
    plt.show()
    env.close()

if __name__ == '__main__':
    train()
