In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from collections import deque

In [4]:
env = gym.make('CartPole-v1', render_mode='rgb_array')
obs_space = env.observation_space.shape[0]
action_space = env.action_space.n

class Actor(nn.Module):
    def __init__(self, obs_space, action_space):
        super(Actor, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(obs_space, 64),
            nn.ReLU(),
            nn.Linear(64, action_space),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

class Critic(nn.Module):
    def __init__(self, obs_space):
        super(Critic, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(obs_space, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.fc(x)

gamma = 0.99
clip_epsilon = 0.2
lr_actor = 3e-4
lr_critic = 1e-3
train_epochs = 10
batch_size = 256

actor = Actor(obs_space, action_space)
critic = Critic(obs_space)
optimizer_actor = optim.Adam(actor.parameters(), lr=lr_actor)
optimizer_critic = optim.Adam(critic.parameters(), lr=lr_critic)

def preprocess_state(state):
    return torch.tensor(state, dtype=torch.float32).unsqueeze(0)

def collect_trajectories(env, policy, num_steps):
    states, actions, rewards, dones, log_probs = [], [], [], [], []
    state, _ = env.reset()
    for _ in range(num_steps):
        state_tensor = preprocess_state(state)
        with torch.no_grad():
            probs = policy(state_tensor).squeeze(0)
        dist = Categorical(probs)
        action = dist.sample()

        next_state, reward, done, _, _ = env.step(action.item())

        states.append(state)
        actions.append(action.item())
        rewards.append(reward)
        dones.append(done)
        log_probs.append(dist.log_prob(action).item())

        state = next_state if not done else env.reset()[0]

    return states, actions, rewards, dones, log_probs

def compute_returns(rewards, dones, gamma=0.99):
    returns = []
    R = 0
    for reward, done in zip(reversed(rewards), reversed(dones)):
        R = reward + gamma * R * (1 - done)
        returns.insert(0, R)
    return returns

def ppo_update(states, actions, returns, old_log_probs):
    states = torch.tensor(states, dtype=torch.float32)
    actions = torch.tensor(actions, dtype=torch.int64)
    returns = torch.tensor(returns, dtype=torch.float32)
    old_log_probs = torch.tensor(old_log_probs, dtype=torch.float32)

    for _ in range(train_epochs):
        probs = actor(states)
        dist = Categorical(probs)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy().mean()

        values = critic(states).squeeze()
        advantage = returns - values.detach()

        ratio = torch.exp(log_probs - old_log_probs)
        surrogate1 = ratio * advantage
        surrogate2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
        actor_loss = -torch.min(surrogate1, surrogate2).mean() - 0.01 * entropy
        critic_loss = nn.MSELoss()(values, returns)

        optimizer_actor.zero_grad()
        actor_loss.backward()
        optimizer_actor.step()

        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()

for iteration in range(3):
    states, actions, rewards, dones, log_probs = collect_trajectories(env, actor, batch_size)
    returns = compute_returns(rewards, dones)

    ppo_update(states, actions, returns, log_probs)
    print(f"Iteration {iteration + 1} completed.")

lr_actor /= 2
optimizer_actor = optim.Adam(actor.parameters(), lr=lr_actor)

def evaluate_policy(env, policy, episodes=3):
    episode_rewards = []
    for _ in range(episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False
        while not done:
            state_tensor = preprocess_state(state)
            with torch.no_grad():
                probs = policy(state_tensor).squeeze(0)
            action = torch.argmax(probs).item()
            state, reward, done, _, _ = env.step(action)
            total_reward += reward
        episode_rewards.append(total_reward)
    return np.mean(episode_rewards)

mean_reward = evaluate_policy(env, actor)
print(f"Mean evaluation reward: {mean_reward}")

torch.save(actor.state_dict(), "ppo_actor.pth")
print("Policy deployed and saved.")

Iteration 1 completed.
Iteration 2 completed.
Iteration 3 completed.
Mean evaluation reward: 8.666666666666666
Policy deployed and saved.
