(sec:deep-q-learinng)=
# 深層Q学習

In [1]:
"""
Google Colabの準備
"""

IN_COLAB = True
try:
    import google.colab

    print("You are running the code in Google Colab.")
except ImportError:
    IN_COLAB = False
    print("You are running the code on the local computer.")

if IN_COLAB:
    # Gymnasiumのインストール
    !pip install "gymnasium[classic-control]"
    pass

You are running the code on the local computer.


In [None]:
import random
from collections import deque

import numpy as np
import seaborn as sns
import IPython.display as display
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from matplotlib.animation import ArtistAnimation

try:
    from myst_nb import glue
except ImportError:
    glue = lambda *args, **kwargs: None

# パラメータ
n_episodes = 100
glue("n_episodes", n_episodes)

# 乱数のシードを固定
random.seed(31415)
np.random.seed(31415)

# グラフの設定
rc = {
    "figure.dpi": 150,
    "axes.linewidth": 1,
    "axes.edgecolor": "black",
    "grid.color": "gray",
    "grid.linestyle": "--",
    "grid.linewidth": 0.5,
    "xtick.major.size": 2,
    "ytick.major.size": 2,
    "legend.frameon": True,
    "legend.borderpad": 0.5,
    "legend.facecolor": "white",
    "legend.edgecolor": "black",
    "legend.framealpha": 1.0,
}
sns.set_theme(style="whitegrid", palette="colorblind", rc=rc)

100

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Network(nn.Sequential):
    def __init__(self, n_inputs, n_outputs):
        super(Network, self).__init__(
            nn.Linear(n_inputs, 128, bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64, bias=False),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Linear(64, n_outputs),
        )

In [None]:
class ReplayMemoryDataset(torch.utils.data.Dataset):
    def __init__(self, memory):
        self.memory = memory

    def __len__(self):
        return len(self.memory)

    def __getitem__(self, idx):
        m = self.memory[idx]
        return m

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
q_net_online = Network(4, 2)
q_net_target = Network(4, 2)
q_net_online.to(device)
q_net_target.to(device)

Network(
  (0): Linear(in_features=4, out_features=128, bias=False)
  (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=128, out_features=64, bias=False)
  (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=64, out_features=2, bias=True)
)

In [None]:
gamma = 0.99
batch_size = 32
steps_per_episode = 1000
memory_size = 10000

In [None]:
import gymnasium as gym

# ゲーム環境の作成
env = gym.make("CartPole-v1", render_mode="rgb_array")

In [None]:
replay_memory = deque(maxlen=memory_size)
optim = torch.optim.Adam(q_net_online.parameters(), lr=1.0e-3)

e0 = 0.5
e1 = 0.005
epsilons = np.linspace(e0, e1, n_episodes)

# エピソードのループ
pbar = tqdm(total=n_episodes * steps_per_episode)
for epi in range(n_episodes):
    # ゲーム環境のリセット
    s0, _ = env.reset()
    eps = epsilons[epi]

    # エピソード開始
    while True:
        # Q-networkを使って行動を選択
        inputs = torch.Tensor(s0)
        inputs = inputs.unsqueeze(0).float().to(device)

        # ε-greedy法
        if np.random.rand() < eps:
            a0 = env.action_space.sample()
        else:
            with torch.no_grad():
                q_net_online.eval()
                q_values = q_net_online(inputs)

            q_values = q_values.detach().squeeze().cpu().numpy()
            a0 = np.argmax(q_values)

        # 行動の選択
        s1, reward, done, _, _ = env.step(a0)

        # リプレイメモリに記録
        replay_memory.append((s0, a0, reward, s1, done))

        # 次の状態に遷移
        s0 = s1

        if done:
            break

    # データセットの用意
    memory_dataset = ReplayMemoryDataset(replay_memory)
    memory_sampler = torch.utils.data.RandomSampler(
        memory_dataset,
        replacement=True,
        num_samples=batch_size * steps_per_episode,
    )
    memory_loader = torch.utils.data.DataLoader(
        memory_dataset,
        batch_size=batch_size,
        sampler=memory_sampler,
    )

    # 学習ループ
    q_net_online.train()
    for i, memory in enumerate(memory_loader):
        s0, a0, reward, s1, done = memory

        s0 = s0.float().to(device)
        a0 = a0.long().to(device)
        reward = reward.float().to(device)
        s1 = s1.float().to(device)
        done = done.float().to(device)

        q_values = q_net_online(s0)
        q0 = torch.gather(q_values, 1, a0.unsqueeze(1)).squeeze(-1)

        with torch.no_grad():
            q_net_target.eval()
            q1 = q_net_target(s1)
            q_max = torch.max(q1, dim=1)[0]

        loss = F.smooth_l1_loss(q0, reward + gamma * q_max * (1 - done))

        optim.zero_grad()
        loss.backward()
        optim.step()

        if i % 100 == 0:
            pbar.set_description(f"Episode {epi+1}/{n_episodes}, Loss: {loss.item():.3f}")

        pbar.update()

    # Q-networkの更新
    if (epi + 1) % 5 == 0:
        q_net_target.load_state_dict(q_net_online.state_dict())

  0%|          | 0/100000 [00:00<?, ?it/s]

In [None]:
frames = []
obsrv, _ = env.reset()
while True:
    img = env.render()
    frames.append(img)

    # Q-networkを使ってQ値を計算
    inputs = torch.Tensor(obsrv)
    inputs = inputs.unsqueeze(0).float().to(device)
    with torch.no_grad():
        q_net_online.eval()
        q_values = q_net_online(inputs).detach().squeeze().cpu().numpy()

    # Q値が最大となる行動を選択
    a = np.argmax(q_values)

    obsrv, reward, done, _, _ = env.step(a)
    if done:
        break

In [None]:
# アニメーションの描画
fig, ax = plt.subplots(dpi=100)
ax.set(xticks=[], yticks=[])

# 各フレームの描画
draw = []
for i, f in enumerate(frames):
    ims = plt.imshow(f)
    txt = plt.text(20, 30, f"frame #{i+1:d}")
    draw.append([ims, txt])
    fig.tight_layout()

# アニメーションの作成
ani = ArtistAnimation(fig, draw, interval=100, blit=True)
html = display.HTML(ani.to_jshtml())
display.display(html)

# Matplotlibのウィンドウを閉じる
plt.close()

: 