In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import numpy as np
import gymnasium as gym
from collections import deque
import tqdm

### PPO NN

In [8]:
class PPO(nn.Module):
    def __init__(self, action_dim):
        super(PPO, self).__init__()

        # Convolutional layers for feature extraction
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=2)

        # Fully connected layers
        self.fc1 = nn.Linear(128 * 9 * 9, 512)

        # Actor head
        self.actor_mean = nn.Linear(512, action_dim)
        self.actor_log_std = nn.Parameter(torch.zeros(action_dim))

        # Critic head
        self.critic = nn.Linear(512, 1)

        self.activation = nn.ReLU()

    def forward(self, x):
        # Feature extraction
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten
        x = self.activation(self.fc1(x))

        # Actor: mean and log_std
        mean = self.actor_mean(x)
        log_std = self.actor_log_std.expand_as(mean)
        std = torch.exp(log_std)

        # Critic: state value
        value = self.critic(x)

        return mean, std, value

### PPO Agent

In [12]:
# PPO Agent
class PPOAgent:
    def __init__(self, env, params):
        self.env = env
        self.gamma = params["gamma"]
        self.eps_clip = params["eps_clip"]
        self.update_epochs = params["update_epochs"]

        # Initialize actor-critic network
        self.actor_critic = PPO(env.action_space.shape[0])
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=params["lr"])

        # Storage for trajectories
        self.states = []
        self.actions = []
        self.log_probs = []
        self.rewards = []
        self.values = []
        self.dones = []

    def select_action(self, state):
        state_tensor = torch.tensor(state).permute(2, 0, 1).unsqueeze(0).float()
        with torch.no_grad():
            mean, std, value = self.actor_critic(state_tensor)
            dist = Normal(mean, std)
            action = dist.sample()
            log_prob = dist.log_prob(action).sum(dim=-1)

        # Debugging: Print selected action and value
        # print(f"Selected Action: {action.squeeze().numpy()}, Log Prob: {log_prob.item()}, Value: {value.item()}")

        return action.squeeze().numpy(), log_prob.item(), value.item()

    def compute_returns_and_advantages(self, last_value, normalize=True):
        """Compute discounted returns and advantages using GAE."""
        returns = []
        advantages = []
        g = last_value
        adv = 0
        for step in reversed(range(len(self.rewards))):
            delta = (
                self.rewards[step]
                + self.gamma * (1 - self.dones[step]) * g
                - self.values[step]
            )
            adv = delta + self.gamma * adv
            g = self.rewards[step] + self.gamma * (1 - self.dones[step]) * g
            returns.insert(0, g)
            advantages.insert(0, adv)

        returns = torch.tensor(returns)
        advantages = torch.tensor(advantages)

        if normalize:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Debugging: Print returns and advantages
        print(f"Returns: {returns}, Advantages: {advantages}")

        return returns, advantages

    def optimize(self, states, actions, log_probs, returns, advantages):
        """Optimize actor-critic network."""
        for _ in range(self.update_epochs):
            mean, std, values = self.actor_critic(states)
            dist = Normal(mean, std)

            new_log_probs = dist.log_prob(actions).sum(dim=-1)
            entropy = dist.entropy().mean()
            values = values.squeeze()

            # PPO objective
            ratio = torch.exp(new_log_probs - log_probs)
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            # Critic loss
            value_loss = ((returns - values) ** 2).mean()

            # Total loss
            loss = policy_loss + 0.5 * value_loss - 0.01 * entropy

            # Debugging: Print losses
            print(f"Policy Loss: {policy_loss.item()}, Value Loss: {value_loss.item()}, Total Loss: {loss.item()}")

            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def train(self, num_episodes):
        total_rewards = []
        for episode in tqdm.trange(num_episodes, desc="Training Progress"):
            state, _ = self.env.reset()
            episode_reward = 0
            done = False

            # Rollout
            while not done:
                action, log_prob, value = self.select_action(state)
                next_state, reward, done, trunc, _ = self.env.step(action)
                self.states.append(torch.tensor(state).permute(2, 0, 1).unsqueeze(0).float())
                self.actions.append(torch.tensor(action))
                self.log_probs.append(log_prob)
                self.rewards.append(reward)
                self.values.append(value)
                self.dones.append(done)

                state = next_state
                episode_reward += reward

                if done or trunc:
                    # Calculate returns and advantages
                    with torch.no_grad():
                        last_value = self.actor_critic(
                            torch.tensor(next_state)
                            .permute(2, 0, 1)
                            .unsqueeze(0)
                            .float()
                        )[2].item()
                    returns, advantages = self.compute_returns_and_advantages(
                        last_value
                    )

                    # Optimize policy and value networks
                    self.optimize(
                        torch.cat(self.states),
                        torch.stack(self.actions),
                        torch.tensor(self.log_probs),
                        returns,
                        advantages,
                    )

                    # Clear storage
                    self.states = []
                    self.actions = []
                    self.log_probs = []
                    self.rewards = []
                    self.values = []
                    self.dones = []

                    break

            total_rewards.append(episode_reward)
            avg_reward = np.mean(total_rewards[-100:]) if len(total_rewards) > 100 else np.mean(total_rewards)
            tqdm.tqdm.write(f"Episode {episode} | Reward: {episode_reward:.2f} | Avg Reward: {avg_reward:.2f}")

        return total_rewards

In [13]:
if __name__ == "__main__":
    # Initialize environment
    env = gym.make("CarRacing-v3", continuous=True)

    # Training parameters
    params = {
        "lr": 3e-4,
        "gamma": 0.99,
        "eps_clip": 0.2,
        "update_epochs": 10,
        "num_episodes": 500,
    }

    # Initialize and train agent
    agent = PPOAgent(env, params)
    rewards = agent.train(num_episodes=params["num_episodes"])

    # Save the trained model
    torch.save(agent.actor_critic.state_dict(), "ppo_car_racing.pth")

    print("Training complete. Evaluate the agent in the environment.")

Training Progress:   0%|          | 0/500 [00:00<?, ?it/s]

Returns: tensor([ 0.2150, -7.1365, -7.1076, -7.0783, -7.0488, -7.0190, -6.9889, -6.9585,
        -6.9278, -6.8967, -6.8654, -6.8337, -6.8017, -6.7694, -6.7368, -6.7038,
        -6.6705, -6.6369, -6.6029, -6.5686, -6.5340, -6.4990, -6.4636, -6.4279,
        -6.3918, -6.3553, -6.3185, -6.2813, -6.2438, -6.2058, -6.1675, -6.1288,
        -6.0897, -6.0502, -6.0103, -5.9700, -5.9293, -5.8882, -5.8467, -5.8047,
        -5.7623, -5.7195, -5.6763, -5.6326, -5.5885, -5.5439, -5.4989, -5.4535,
        -5.4075, -5.3611, -5.3143, -5.2670, -5.2191, -5.1709, -5.1221, -5.0728,
        -8.7503, -8.7377, -8.7250, -8.7121, -8.6991, -8.6859, -8.6727, -8.6593,
        -8.6457, -8.6320, -8.6182, -8.6043, -8.5902, -8.5759, -8.5615, -8.5470,
        -8.5323, -8.5175, -8.5025, -8.4874, -8.4721, -8.4567, -8.4411, -8.4254,
        -8.4095, -8.3934, -8.3772, -8.3608, -8.3442, -8.3275, -8.3106, -8.2935,
        -8.2763, -8.2589, -8.2413, -8.2235, -8.2056, -8.1875, -8.1691, -8.1507,
        -8.1320, -8.1131, -8.09

Training Progress:   0%|          | 1/500 [00:32<4:31:15, 32.62s/it]

Episode 0 | Reward: -74.17 | Avg Reward: -74.17
Returns: tensor([-0.0312, -7.2237, -7.1957, -7.1673, -7.1387, -7.1098, -7.0806, -7.0511,
        -7.0213, -6.9913, -6.9609, -6.9302, -6.8992, -6.8678, -6.8362, -6.8042,
        -6.7720, -6.7394, -6.7064, -6.6732, -6.6395, -6.6056, -6.5713, -6.5367,
        -6.5017, -6.4664, -6.4307, -6.3946, -6.3582, -6.3214, -6.2843, -6.2467,
        -6.2088, -6.1705, -6.1318, -6.0928, -6.0533, -6.0134, -5.9732, -5.9325,
        -5.8914, -5.8499, -5.8080, -5.7656, -5.7229, -5.6797, -5.6360, -5.5919,
        -5.5474, -5.5024, -5.4570, -5.4111, -5.3648, -5.3179, -5.2707, -5.2229,
        -5.1746, -5.1259, -8.7232, -8.7103, -8.6973, -8.6841, -8.6709, -8.6574,
        -8.6439, -8.6302, -8.6163, -8.6024, -8.5882, -8.5740, -8.5596, -8.5450,
        -8.5303, -8.5155, -8.5005, -8.4853, -8.4700, -8.4546, -8.4390, -8.4232,
        -8.4073, -8.3912, -8.3749, -8.3585, -8.3420, -8.3252, -8.3083, -8.2912,
        -8.2739, -8.2565, -8.2389, -8.2211, -8.2031, -8.1850, -

Training Progress:   0%|          | 2/500 [01:03<4:24:41, 31.89s/it]

Episode 1 | Reward: -71.12 | Avg Reward: -72.64
Returns: tensor([-0.3416, -6.5179, -6.4828, -6.4472, -6.4113, -6.3751, -6.3385, -6.3015,
        -6.2641, -6.2264, -6.1883, -6.1498, -6.1109, -6.0716, -6.0319, -5.9918,
        -5.9514, -5.9105, -5.8691, -5.8274, -5.7853, -5.7427, -5.6997, -5.6563,
        -5.6124, -5.5681, -5.5233, -5.4781, -8.5694, -8.5549, -8.5403, -8.5256,
        -8.5107, -8.4956, -8.4804, -8.4651, -8.4496, -8.4339, -8.4181, -8.4021,
        -8.3860, -8.3697, -8.3532, -8.3366, -8.3198, -8.3028, -8.2857, -8.2683,
        -8.2509, -8.2332, -8.2153, -8.1973, -8.1791, -8.1607, -8.1421, -8.1234,
        -8.1044, -8.0853, -8.0659, -8.0464, -8.0267, -8.0067, -7.9866, -7.9662,
        -7.9457, -7.9250, -7.9040, -7.8828, -7.8614, -7.8398, -7.8180, -7.7960,
        -7.7737, -7.7512, -7.7285, -7.7056, -7.6824, -7.6590, -7.6353, -7.6114,
        -7.5873, -7.5629, -7.5383, -7.5135, -7.4884, -7.4630, -7.4374, -7.4115,
        -7.3853, -7.3589, -7.3322, -7.3053, -7.2781, -7.2506, -

Training Progress:   1%|          | 3/500 [01:43<4:51:43, 35.22s/it]

Episode 2 | Reward: -75.16 | Avg Reward: -73.48
Returns: tensor([-0.7280, -7.2364, -7.2084, -7.1802, -7.1518, -7.1230, -7.0939, -7.0646,
        -7.0349, -7.0050, -6.9747, -6.9442, -6.9133, -6.8821, -6.8506, -6.8188,
        -6.7867, -6.7542, -6.7214, -6.6883, -6.6549, -6.6211, -6.5870, -6.5525,
        -6.5177, -6.4825, -6.4469, -6.4111, -6.3748, -6.3382, -6.3012, -6.2638,
        -6.2261, -6.1880, -6.1495, -6.1106, -6.0713, -6.0316, -5.9915, -5.9510,
        -5.9101, -5.8688, -5.8271, -5.7849, -5.7424, -9.0003, -8.9902, -8.9800,
        -8.9697, -8.9593, -8.9488, -8.9382, -8.9275, -8.9166, -8.9057, -8.8947,
        -8.8835, -8.8722, -8.8608, -8.8493, -8.8377, -8.8259, -8.8141, -8.8021,
        -8.7900, -8.7778, -8.7654, -8.7530, -8.7404, -8.7277, -8.7148, -8.7018,
        -8.6887, -8.6755, -8.6621, -8.6486, -8.6349, -8.6211, -8.6072, -8.5931,
        -8.5789, -8.5646, -8.5501, -8.5354, -8.5206, -8.5057, -8.4906, -8.4753,
        -8.4599, -8.4444, -8.4287, -8.4128, -8.3968, -8.3806, -

Training Progress:   1%|          | 4/500 [02:14<4:36:58, 33.50s/it]

Episode 3 | Reward: -70.59 | Avg Reward: -72.76
Returns: tensor([-0.4276, -7.1329, -7.1039, -7.0747, -7.0451, -7.0153, -6.9851, -6.9547,
        -6.9239, -6.8928, -6.8615, -6.8298, -6.7977, -6.7654, -6.7327, -6.6997,
        -6.6664, -6.6327, -6.5987, -6.5643, -6.5296, -6.4946, -6.4592, -6.4234,
        -6.3873, -6.3508, -6.3139, -6.2767, -6.2391, -6.2011, -6.1627, -6.1240,
        -6.0848, -6.0453, -6.0053, -5.9650, -5.9242, -5.8830, -5.8415, -5.7994,
        -5.7570, -5.7142, -5.6709, -5.6271, -5.5830, -5.5384, -5.4933, -5.4478,
        -5.4018, -5.3553, -8.7094, -8.6964, -8.6832, -8.6699, -8.6565, -8.6429,
        -8.6292, -8.6154, -8.6014, -8.5873, -8.5730, -8.5586, -8.5440, -8.5293,
        -8.5144, -8.4994, -8.4843, -8.4690, -8.4535, -8.4379, -8.4221, -8.4062,
        -8.3901, -8.3738, -8.3574, -8.3408, -8.3240, -8.3071, -8.2900, -8.2727,
        -8.2553, -8.2377, -8.2199, -8.2019, -8.1837, -8.1654, -8.1468, -8.1281,
        -8.1092, -8.0901, -8.0708, -8.0513, -8.0316, -8.0118, -

Training Progress:   1%|          | 5/500 [02:44<4:27:11, 32.39s/it]

Episode 4 | Reward: -69.70 | Avg Reward: -72.15
Returns: tensor([-0.1343, -6.8367, -6.8048, -6.7725, -6.7399, -6.7070, -6.6737, -6.6401,
        -6.6062, -6.5719, -6.5373, -6.5023, -6.4669, -6.4313, -6.3952, -6.3588,
        -6.3220, -6.2849, -6.2473, -6.2094, -6.1711, -6.1325, -6.0934, -6.0539,
        -6.0141, -5.9738, -5.9332, -5.8921, -5.8506, -5.8087, -5.7663, -5.7236,
        -5.6804, -5.6367, -5.5927, -5.5481, -5.5032, -5.4578, -5.4119, -5.3655,
        -5.3187, -5.2714, -5.2237, -8.5764, -8.5621, -8.5475, -8.5329, -8.5180,
        -8.5031, -8.4879, -8.4727, -8.4572, -8.4417, -8.4259, -8.4100, -8.3940,
        -8.3777, -8.3614, -8.3448, -8.3281, -8.3112, -8.2941, -8.2769, -8.2595,
        -8.2419, -8.2242, -8.2062, -8.1881, -8.1698, -8.1513, -8.1326, -8.1138,
        -8.0947, -8.0755, -8.0560, -8.0364, -8.0166, -7.9965, -7.9763, -7.9559,
        -7.9352, -7.9144, -7.8933, -7.8720, -7.8505, -7.8288, -7.8069, -7.7847,
        -7.7623, -7.7397, -7.7169, -7.6938, -7.6705, -7.6470, -

Training Progress:   1%|          | 6/500 [03:18<4:30:00, 32.79s/it]

Episode 5 | Reward: -69.70 | Avg Reward: -71.74


Training Progress:   1%|          | 6/500 [03:21<4:36:38, 33.60s/it]


KeyboardInterrupt: 