In [1]:
!apt-get update -qq
!apt-get install -y swig > /dev/null
!pip install pygame opencv-python imageio gymnasium[box2d] torch > /dev/null

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque
import cv2, os, base64, datetime, imageio
from IPython.display import HTML, display

class FrameStackWrapper(gym.ObservationWrapper):
    def __init__(self, env, k):
        super().__init__(env)
        self.k = k
        self.frames = deque([], maxlen=k)
        shp = self.observation(env.reset(seed=0)[0]).shape
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(k, 96, 96), dtype=np.uint8)

    def observation(self, observation):
        gray = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
        resized = cv2.resize(gray, (96, 96), interpolation=cv2.INTER_AREA)
        return resized

    def reset(self, **kwargs):
        obs, _ = self.env.reset(**kwargs)
        processed = self.observation(obs)
        for _ in range(self.k):
            self.frames.append(processed)
        return np.stack(self.frames, axis=0), _

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        processed = self.observation(obs)
        self.frames.append(processed)
        return np.stack(self.frames, axis=0), reward, terminated, truncated, info

class CNNPolicy(nn.Module):
    def __init__(self, obs_shape, n_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU()
        )
        conv_out_size = self._get_conv_out(obs_shape)
        self.actor = nn.Sequential(nn.Linear(conv_out_size, 256), nn.ReLU(), nn.Linear(256, n_actions))
        self.critic = nn.Sequential(nn.Linear(conv_out_size, 256), nn.ReLU(), nn.Linear(256, 1))

    def _get_conv_out(self, shape):
        o = torch.zeros(1, *shape)
        return int(np.prod(self.conv(o).shape))

    def forward(self, x):
        x = x / 255.0
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.actor(conv_out), self.critic(conv_out)

    def act(self, x):
        logits, value = self.forward(x)
        probs = torch.softmax(logits, dim=-1)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), value

def make_env():
    env = gym.make("CarRacing-v3", render_mode="rgb_array")
    env = FrameStackWrapper(env, k=4)
    return env

def show_video(path):
    mp4 = open(path, 'rb').read()
    data_url = "data:video/mp4;base64," + base64.b64encode(mp4).decode()
    display(HTML(f'<video width=480 controls><source src="{data_url}" type="video/mp4"></video>'))

env = make_env()
obs_shape = env.observation_space.shape
n_actions = 5
policy = CNNPolicy(obs_shape, n_actions).cuda()
policy.eval()

video_filename = "car_racing_run.mp4"
frames = []
obs, _ = env.reset(seed=42)
obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0).cuda()

max_steps = 900
for step in range(max_steps):
    with torch.no_grad():
        action, _, _, _ = policy.act(obs)

    action_mapped = {
        0: np.array([-1.0, 0.0, 0.0]),
        1: np.array([1.0, 0.0, 0.0]),
        2: np.array([0.0, 1.0, 0.0]),
        3: np.array([0.0, 0.0, 0.8]),
        4: np.array([0.0, 0.0, 0.0]),
    }[action.item()]

    obs_raw = env.env.render()
    frames.append(obs_raw)

    next_obs, reward, terminated, truncated, _ = env.step(action_mapped)
    obs = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0).cuda()

    if terminated or truncated:
        break

env.close()


imageio.mimsave(video_filename, frames, fps=30)
show_video(video_filename)


W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)


