In [32]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym

In [36]:
# Environment Setup
env = gym.make("CartPole-v1", render_mode="rgb_array")

# Preprocessing
def preprocess_state(state):
    return torch.tensor(np.array(state), dtype=torch.float32)

# Train-Test Split
num_episodes = 1000
train_size = int(num_episodes * 0.8)
all_indices = np.arange(num_episodes)
np.random.shuffle(all_indices)
train_episodes = all_indices[:train_size]
test_episodes = all_indices[train_size:]

# Actor-Critic Model
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.shared_layer = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU()
        )
        self.actor = nn.Linear(128, action_dim)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = self.shared_layer(x)
        action_probs = torch.softmax(self.actor(x), dim=-1)
        state_value = self.critic(x)
        return action_probs, state_value

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
model = ActorCritic(state_dim, action_dim)
optimizer = optim.Adam(model.parameters(), lr=0.01)

def compute_returns(rewards, gamma=0.99):
    returns = []
    R = 0
    for reward in reversed(rewards):
        R = reward + gamma * R
        returns.insert(0, R)
    return torch.tensor(returns, dtype=torch.float32)

# Train Base Model (REINFORCE)
def reinforce_train(env, model, optimizer, episodes=500):
    for _ in range(episodes):
        state, _ = env.reset()
        state = preprocess_state(state)
        log_probs, rewards = [], []
        done = False

        while not done:
            probs, _ = model(state)
            action_dist = torch.distributions.Categorical(probs)
            action = action_dist.sample()
            log_probs.append(action_dist.log_prob(action))
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            rewards.append(reward)
            state = preprocess_state(next_state)

        returns = compute_returns(rewards)
        loss = -torch.sum(torch.stack(log_probs) * returns)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Train A2C Model
def a2c_train(env, model, optimizer, episodes=500, gamma=0.99):
    for _ in range(episodes):
        state, _ = env.reset()
        state = preprocess_state(state)
        log_probs, values, rewards = [], [], []
        done = False

        while not done:
            probs, value = model(state)
            action_dist = torch.distributions.Categorical(probs)
            action = action_dist.sample()
            log_probs.append(action_dist.log_prob(action))
            values.append(value.squeeze())
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            rewards.append(reward)
            state = preprocess_state(next_state)

        returns = compute_returns(rewards, gamma)
        values = torch.stack(values)
        advantage = returns - values

        actor_loss = -(torch.stack(log_probs) * advantage.detach()).sum()
        critic_loss = advantage.pow(2).mean()
        loss = actor_loss + critic_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluation
def evaluate_policy(env, model, episodes=100):
    total_rewards = []
    for _ in range(episodes):
        state, _ = env.reset()
        state = preprocess_state(state)
        episode_reward = 0
        done = False

        while not done:
            probs, _ = model(state)
            action = torch.argmax(probs).item()
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            state = preprocess_state(next_state)

        total_rewards.append(episode_reward)
    return np.mean(total_rewards)

# Deploy Policy
def deploy_policy(env, model):
    state, _ = env.reset()
    state = preprocess_state(state)
    done = False

    while not done:
        env.render()
        probs, _ = model(state)
        action = torch.argmax(probs).item()
        next_state, _, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        state = preprocess_state(next_state)
    env.close()

reinforce_train(env, model, optimizer, episodes=500)
a2c_train(env, model, optimizer, episodes=500)
print("Average Reward After Training:", evaluate_policy(env, model, episodes=100))
deploy_policy(env, model)

Average Reward After Training: 9.45
