Setup rendering dependencies for Google Colaboratory.

In [None]:
!apt-get install -y xvfb ffmpeg > /dev/null 2>&1
!pip install pyvirtualdisplay pygame moviepy > /dev/null 2>&1

Install d3rlpy and PyTorch with TPU support! It likely fails to install the XLA dependency for the first time. If it's the case, simply restart the runtime and retry this.

In [None]:
!pip install d3rlpy torch~=2.0.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl

Setup cartpole dataset.

In [None]:
import d3rlpy

# get CartPole dataset
dataset, env = d3rlpy.datasets.get_cartpole()

Setup data-driven deep reinforcement learning algorithm with TPU. Currently, it's super slow to train the small architecture. But, you can see it works at least.

In [None]:
# get TPU device
import torch_xla.core.xla_model as xm
device = xm.xla_device()

# setup CQL algorithm
cql = d3rlpy.algos.DiscreteCQLConfig().create(device=str(device))

# start training
cql.fit(
    dataset,
    n_steps=10000,
    n_steps_per_epoch=1000,
    evaluators={
        'environment': d3rlpy.metrics.EnvironmentEvaluator(env), # evaluate with CartPole-v1 environment
    },
)

Record video!

In [None]:
import gym
from gym.wrappers import RecordVideo

# start virtual display
d3rlpy.notebook_utils.start_virtual_display()

# wrap RecordVideo wrapper
env = RecordVideo(gym.make("CartPole-v1", render_mode="rgb_array"), './video')

# evaluate
d3rlpy.metrics.evaluate_qlearning_with_environment(cql, env)

Let's see how it works!

In [None]:
d3rlpy.notebook_utils.render_video("video/rl-video-episode-0.mp4")