In [1]:
import time
import itertools
import tqdm
from collections import deque
from typing import NamedTuple
import numpy as np
import random as py_random

import jax
import jax.numpy as jnp
from jax import jit, grad, random

import optax

import flax.linen as nn
from flax.training.train_state import TrainState

import gymnasium as gym
import ale_py
from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing

In [2]:
class DQN(nn.Module):
    action_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4))(x)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2))(x)
        x = nn.relu(x)
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x)
        x = nn.relu(x)
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=512)(x)
        x = nn.relu(x)
        x = nn.Dense(self.action_dim)(x)
        return x

In [3]:
class Transition(NamedTuple):
    observation: np.ndarray
    observation_next: np.ndarray
    action: int
    reward: float
    done: bool

In [4]:
max_episode_steps=1000

def create_env():
    env: gym.Env = gym.make(
        id="ALE/Pong-v5",
        max_episode_steps=max_episode_steps,
        autoreset=False
    )
    return AtariPreprocessing(
        env,
        noop_max=30,
        frame_skip=1,  # Pong-v5 has frame_skip=4 by default
        screen_size=84,
        terminal_on_life_loss=False,
        grayscale_obs=True,
        grayscale_newaxis=True,
        scale_obs=False,
    )

env = create_env()

A.L.E: Arcade Learning Environment (version 0.9.0+750d7f9)
[Powered by Stella]


In [5]:
model = DQN(env.action_space.n)
rng = random.key(0)
variables = model.init(rng, jnp.zeros((1, *env.observation_space.shape)))

tx = optax.adam(1e-4)
state = TrainState.create(
    apply_fn=model.apply,
    params=variables["params"],
    tx=tx,
)

In [6]:
def loss_fn(params, batch:Transition):
    x = jnp.array(batch.observation).astype(jnp.float32) / 255.0
    x_next = jnp.array(batch.observation_next).astype(jnp.float32) / 255.0
    rewards = jnp.array(batch.reward)
    dones = jnp.array(batch.done)
    actions = jnp.array(batch.action)
    
    state_action_values = state.apply_fn({"params": state.params}, x)
    next_state_values = state.apply_fn({"params": state.params}, x_next)
    next_state_values = jnp.max(next_state_values, axis=1)
    next_state_values = jnp.where(dones, jnp.zeros_like(next_state_values), next_state_values)
    
    expected = rewards + 0.99 * next_state_values
    actual = state_action_values[jnp.arange(x.shape[0]), actions]
    loss = jnp.mean((expected - actual) ** 2)
    return loss

@jit
def update(state:TrainState, batch):
    loss, grads = jax.value_and_grad(loss_fn, allow_int=True)(state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [9]:
epsilon = 1
epsilon_decay = 0.95

def get_action(observation):
    if np.random.rand() < epsilon:
        return env.action_space.sample()

    observation = jnp.array(observation).astype(jnp.float32) / 255.0
    q_values = state.apply_fn({"params": state.params}, observation[None, ...])
    return jnp.argmax(q_values).item()

In [10]:
num_episodes = 10
batch_size = 32

memory = []
memory_size = 10_000

for episode in range(num_episodes):
    observation, _ = env.reset()
    done = False
    rewards = 0
    losses = []
    for i in tqdm.trange(max_episode_steps):
        action = get_action(observation)
        observation_next, reward, done, truncated, info = env.step(action)
        done = done or truncated
        transition = Transition(observation, observation_next, action, reward, done)
        
        memory.append(transition)
        if len(memory) > memory_size:
            memory.pop(0)
        
        rewards += reward
        observation = observation_next

        if len(memory) >= batch_size and i % 16 == 0:
            batch = py_random.sample(memory, batch_size)
            batch = Transition(*zip(*batch))
            state, loss = update(state, batch)
            losses.append(loss)

        if done:
            break
    
    epsilon *= epsilon_decay
    print(f"Episode {episode} rewards: {rewards}, losses: {np.mean(losses)}, epsilon: {epsilon}")

 95%|█████████▍| 947/1000 [00:01<00:00, 729.59it/s]


Episode 0 rewards: -20.0, losses: 0.03066929802298546, epsilon: 0.95


 95%|█████████▌| 950/1000 [00:16<00:00, 56.69it/s] 


Episode 1 rewards: -20.0, losses: 0.030415983870625496, epsilon: 0.9025


 90%|█████████ | 905/1000 [00:18<00:01, 47.85it/s] 


Episode 2 rewards: -20.0, losses: 0.028568534180521965, epsilon: 0.8573749999999999


 93%|█████████▎| 934/1000 [00:38<00:02, 24.32it/s]


Episode 3 rewards: -21.0, losses: 0.028507616370916367, epsilon: 0.8145062499999999


 83%|████████▎ | 828/1000 [00:34<00:07, 24.25it/s]


Episode 4 rewards: -21.0, losses: 0.02430008165538311, epsilon: 0.7737809374999999


 95%|█████████▌| 952/1000 [00:45<00:02, 21.04it/s]


Episode 5 rewards: -20.0, losses: 0.02806958183646202, epsilon: 0.7350918906249998


 81%|████████  | 812/1000 [00:50<00:11, 16.23it/s]


Episode 6 rewards: -21.0, losses: 0.03618348762392998, epsilon: 0.6983372960937497


 98%|█████████▊| 976/1000 [00:59<00:01, 16.37it/s]


Episode 7 rewards: -19.0, losses: 0.02941839210689068, epsilon: 0.6634204312890623


 92%|█████████▏| 918/1000 [02:02<00:10,  7.49it/s]


Episode 8 rewards: -21.0, losses: 0.03079039230942726, epsilon: 0.6302494097246091


 87%|████████▋ | 871/1000 [02:05<00:18,  6.91it/s]


Episode 9 rewards: -21.0, losses: 0.028599319979548454, epsilon: 0.5987369392383786
