In [13]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
import tqdm
from collections import deque

In [14]:
def moving_average(data, *, window_size = 50):
    """Smooths 1-D data array using a moving average.

    Args:
        data: 1-D numpy.array
        window_size: Size of the smoothing window

    Returns:
        smooth_data: A 1-d numpy.array with the same size as data
    """
    assert data.ndim == 1
    kernel = np.ones(window_size)
    smooth_data = np.convolve(data, kernel) / np.convolve(
        np.ones_like(data), kernel
    )
    return smooth_data[: -window_size + 1]

In [15]:
def plot_curves(arr_list, legend_list, color_list, ylabel, fig_title, smoothing = True):
    """
    Args:
        arr_list (list): List of results arrays to plot
        legend_list (list): List of legends corresponding to each result array
        color_list (list): List of color corresponding to each result array
        ylabel (string): Label of the vertical axis

        Make sure the elements in the arr_list, legend_list, and color_list
        are associated with each other correctly (in the same order).
        Do not forget to change the ylabel for different plots.
    """
    # Set the figure type
    fig, ax = plt.subplots(figsize=(12, 8))

    # PLEASE NOTE: Change the vertical labels for different plots
    ax.set_ylabel(ylabel)
    ax.set_xlabel("Time Steps")

    # Plot results
    h_list = []
    for arr, legend, color in zip(arr_list, legend_list, color_list):
        # Compute the standard error (of raw data, not smoothed)
        arr_err = arr.std(axis=0) / np.sqrt(arr.shape[0])
        # Plot the mean
        averages = moving_average(arr.mean(axis=0)) if smoothing else arr.mean(axis=0)
        h, = ax.plot(range(arr.shape[1]), averages, color=color, label=legend)
        # Plot the confidence band
        arr_err *= 1.96
        ax.fill_between(range(arr.shape[1]), averages - arr_err, averages + arr_err, alpha=0.3,
                        color=color)
        # Save the plot handle
        h_list.append(h)

    # Plot legends
    ax.set_title(f"{fig_title}")
    ax.legend(handles=h_list)

    plt.show()

### Car racing A2C NN

In [25]:
class ActorCriticNet(nn.Module):
    def __init__(self):
        super(ActorCriticNet, self).__init__()

        # Common CNN layers for feature extraction
        self.conv1 = nn.Conv2d(3, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.fc1 = nn.Linear(4096, 512)  # Adjust based on input dimensions
        self.common_activation = nn.ReLU()

        # Actor network
        self.actor_fc = nn.Linear(512, 3)  # Outputs mean of actions
        self.actor_std = nn.Parameter(torch.ones(1, 3))  # Learnable standard deviation

        # Critic network
        self.critic_fc = nn.Linear(512, 1)

    def forward(self, x):
        # Feature extraction
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.contiguous().view(x.size(0), -1)  # Ensure tensor is contiguous before flattening
        # print(x.shape)
        x = self.common_activation(self.fc1(x))

        # Actor: mean and standard deviation
        action_mean = torch.tanh(self.actor_fc(x))
        action_std = torch.exp(self.actor_std)

        # Critic: state value
        state_value = self.critic_fc(x)

        return action_mean, action_std, state_value

### Car racing A2C agent

In [17]:
class ActorCriticAgent:
    def __init__(self):
        self.policy_net = ActorCriticNet()

    def get_action(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2) / 255.0
        action_mean, action_std, state_value = self.policy_net(state_tensor)
        dist = Normal(action_mean, action_std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action.squeeze(0).detach().numpy(), log_prob, state_value

### A2C training loop

In [18]:
class ActorCriticAgentTrainer:
    def __init__(self, agent, env, params):
        self.agent = agent
        self.env = env
        self.params = params
        self.gamma = params['gamma']
        self.optimizer = torch.optim.Adam(self.agent.policy_net.parameters(), lr=params['learning_rate'])

    def rollout(self):
        self.saved_rewards = []
        self.saved_log_probs = []
        self.saved_state_values = []

        state, _ = self.env.reset()
        is_done = False

        while not is_done:
            action, log_prob, state_value = self.agent.get_action(state)
            next_state, reward, done, trunc, _ = self.env.step(action)
            is_done = done or trunc

            self.saved_rewards.append(reward)
            self.saved_log_probs.append(log_prob)
            self.saved_state_values.append(state_value)

            state = next_state

    def update_agent_policy_network(self):
        returns = deque()
        g_return = 0

        # Calculate returns
        for reward in reversed(self.saved_rewards):
            g_return = reward + self.gamma * g_return
            returns.appendleft(g_return)

        returns = torch.tensor(returns, dtype=torch.float32)
        policy_loss = []
        value_loss = []

        for log_prob, value, ret in zip(self.saved_log_probs, self.saved_state_values, returns):
            advantage = ret - value
            policy_loss.append(-log_prob * advantage)
            value_loss.append(advantage.pow(2))

        total_loss = torch.stack(policy_loss).sum() + torch.stack(value_loss).sum()

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        del self.saved_rewards[:]
        del self.saved_log_probs[:]
        del self.saved_state_values[:]

        return returns[0].item(), total_loss.item()

    def train(self):
        train_returns = []
        train_losses = []
        ep_bar = tqdm.trange(self.params['num_episodes'])

        for ep in ep_bar:
            self.rollout()
            G, loss = self.update_agent_policy_network()

            train_returns.append(G)
            train_losses.append(loss)
            ep_bar.set_description(f"Episode: {ep} | Return: {G:.2f} | Loss: {loss:.2f}")

        return train_returns, train_losses

In [26]:
if __name__ == "__main__":
    env = gym.make('CarRacing-v3', continuous=True, render_mode=None)
    params = {
        'num_episodes': 1000,
        'learning_rate': 1e-4,
        'gamma': 0.99
    }

    agent = ActorCriticAgent()
    trainer = ActorCriticAgentTrainer(agent, env, params)
    returns, losses = trainer.train()

    # Save model
    torch.save(agent.policy_net.state_dict(), "car_racing_actor_critic.pth")

Episode: 160 | Return: 3.63 | Loss: -8054.37:  16%|█▌        | 161/1000 [40:09<4:07:16, 17.68s/it]  