In [5]:
from __future__ import annotations

import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

import gymnasium as gym
import gym_2048
from gym_2048.wrappers import ConvObservation, PrintScores

In [2]:
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.conv1 = nn.Conv2d(18, 256, kernel_size=2, stride=1, padding=0)
        self.conv2 = nn.Conv2d(256, 512, kernel_size=2, stride=1, padding=0)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc_out = nn.Linear(256, 4)
    
    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h)).view(-1, 2048)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        out = self.fc_out(h)
        return out

class QNetwork(nn.Module):
    def __init__(self):
        super(QNetwork, self).__init__()
        self.conv1 = nn.Conv2d(18, 256, kernel_size=2, stride=1, padding=0)
        self.conv2 = nn.Conv2d(256, 512, kernel_size=2, stride=1, padding=0)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc_out = nn.Linear(256, 4)
    
    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h)).view(-1, 2048)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        out = self.fc_out(h)
        return out

In [None]:
class Gumbel(nn.Module):
    def __init__(
        self, 
        device: torch.device,
        critic_lr: float,
        actor_lr: float,
        n_envs: int,
    ):
        super().__init__()
        self.device = device
        self.n_envs = n_envs
        
        self.actor = PolicyNetwork().to(self.device)
        self.critic = QNetwork().to(self.device)

        self.actor_optim = optim.Adam(params=self.actor.parameters(), lr=actor_lr)
        self.critic_optim = optim.Adam(params=self.critic.parameters(), lr=critic_lr)

        self.c_visit = 50
        self.c_scale = 0.1

    def forward(self, x: np.ndarray):
        x = torch.tensor(x, dtype=torch.float32).to(self.device)
        q_values = self.critic(x)
        action_logits = self.actor(x)
        return (q_values, action_logits)

    def select_action(self, x: np.ndarray, legal_actions):

        batch_size = len(x)
        q_values, action_logits = self.forward(x)
        selected_actions = torch.zeros(size=(batch_size,), dtype=torch.int32, device=self.device)
        action_logprobs = torch.zeros(size=(batch_size,), dtype=torch.float32, device=self.device)
        selecte_q_values = torch.zeros(size=(batch_size,), dtype=torch.float32, device=self.device)

        for i in range(batch_size):
            n_legal_actions = len(legal_actions[i])
            gumbel_noise = torch.tensor(np.random.gumbel(size=(n_legal_actions,)), device=self.device)
            logits = action_logits[i][legal_actions[i]]
            qs = q_values[i][legal_actions[i]]
            normalized_q_values = qs / (torch.max(qs) - torch.min(qs))
            action_idx = torch.argmax(gumbel_noise + logits + self.c_visit * self.c_scale * normalized_q_values)
            
            selected_actions[i] = legal_actions[i][action_idx]
            action_pd = torch.distributions.Categorical(logits=action_logits[i])
            action_logprobs[i] = action_pd.log_prob(selected_actions[i])
            selecte_q_values[i] = q_values[i][selected_actions[i]]

        return selected_actions, action_logprobs, selecte_q_values
    
    def get_losses(
        self,
        rewards,
        action_log_probs,
        q_value_preds,
        masks,
        device,
    ):
        T = len(rewards)
        td_errors = torch.zeros(T, self.n_envs, device=device)
        for t in reversed(range(T-1)):
            td_error = (rewards[t] + masks[t] * q_value_preds[t+1]).detach() - q_value_preds[t]
            td_errors[t] = td_error
        
        critic_loss = td_errors.pow(2).mean()
        actor_loss = -action_log_probs.mean()

        return critic_loss, actor_loss
    
    def update_parameters(self, critic_loss, actor_loss):
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

In [None]:
n_envs = 10
n_updates = 10000
n_steps_per_update = 128

actor_lr = 0.001
critic_lr = 0.005

envs = gym.vector.make("TwentyFortyEight-v0", num_envs=n_envs, wrappers=ConvObservation)

device = torch.device("cuda")

agent = Gumbel(device, critic_lr, actor_lr, n_envs)

envs_wrapper = PrintScores(envs, deque_size=n_envs * n_updates)

critic_losses = []
actor_losses = []
entropies = []

for sample_phase in tqdm(range(n_updates)):
    ep_q_value_preds = torch.zeros(n_steps_per_update, n_envs, device=device)
    ep_rewards = torch.zeros(n_steps_per_update, n_envs, device=device)
    ep_action_log_probs = torch.zeros(n_steps_per_update, n_envs, device=device)
    masks = torch.zeros(n_steps_per_update, n_envs, device=device)

    if sample_phase == 0:
        states, infos = envs_wrapper.reset(seed=42)
    
    for step in range(n_steps_per_update):
        actions, action_log_probs, q_value_preds = agent.select_action(states, infos["legal actions"])

        states, rewards, terminated, truncated, infos = envs_wrapper.step(actions.cpu().numpy())

        ep_q_value_preds[step] = torch.squeeze(q_value_preds)
        ep_rewards[step] = torch.tensor(rewards, device=device)
        ep_action_log_probs[step] = action_log_probs
        masks[step] = torch.tensor([not term for term in terminated])
    
    critic_loss, actor_loss = agent.get_losses(
        ep_rewards,
        ep_action_log_probs,
        ep_q_value_preds,
        masks,
        device,
    )

    agent.update_parameters(critic_loss, actor_loss)
    critic_losses.append(critic_loss.detach().cpu().numpy())
    actor_losses.append(actor_loss.detach().cpu().numpy())

rolling_length = 20
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 5))

# episode return
axs[0][0].set_title("Episode Returns")
episode_returns_moving_average = (
    np.convolve(
        np.array(envs_wrapper.return_queue).flatten(),
        np.ones(rolling_length),
        mode="valid",
    )
    / rolling_length
)
axs[0][0].plot(
    np.arange(len(episode_returns_moving_average)) / n_envs,
    episode_returns_moving_average,
)
axs[0][0].set_xlabel("Number of episodes")

# entropy
axs[1][0].set_title("Entropy")
entropy_moving_average = (
    np.convolve(np.array(entropies), np.ones(rolling_length), mode="valid")
    / rolling_length
)
axs[1][0].plot(entropy_moving_average)
axs[1][0].set_xlabel("Number of updates")


# critic loss
axs[0][1].set_title("Critic Loss")
critic_losses_moving_average = (
    np.convolve(
        np.array(critic_losses).flatten(), np.ones(rolling_length), mode="valid"
    )
    / rolling_length
)
axs[0][1].plot(critic_losses_moving_average)
axs[0][1].set_xlabel("Number of updates")


# actor loss
axs[1][1].set_title("Actor Loss")
actor_losses_moving_average = (
    np.convolve(np.array(actor_losses).flatten(), np.ones(rolling_length), mode="valid")
    / rolling_length
)
axs[1][1].plot(actor_losses_moving_average)
axs[1][1].set_xlabel("Number of updates")

plt.tight_layout()
plt.savefig("gumbel.png")

## Afterstateモデル

In [7]:
class AfterstatePolicyNetwork(nn.Module):
    def __init__(self):
        super(AfterstatePolicyNetwork, self).__init__()
        self.conv1 = nn.Conv2d(18, 256, kernel_size=2, stride=1, padding=0)
        self.conv2 = nn.Conv2d(256, 512, kernel_size=2, stride=1, padding=0)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc_out = nn.Linear(256, 1)
    
    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h)).view(-1, 2048)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        out = self.fc_out(h)
        return out

class AfterstateValueNetwork(nn.Module):
    def __init__(self):
        super(AfterstateValueNetwork, self).__init__()
        self.conv1 = nn.Conv2d(18, 256, kernel_size=2, stride=1, padding=0)
        self.conv2 = nn.Conv2d(256, 512, kernel_size=2, stride=1, padding=0)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc_out = nn.Linear(256, 1)
    
    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = F.relu(self.conv2(h)).view(-1, 2048)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        out = self.fc_out(h)
        return out

In [11]:
def make_conv_feature(inputs):
    """
    inputs: np.ndarray of shape(n, 4, 4)
    output: np.ndarray of shape(n, 18, 4, 4)
    """
    conv_feature = np.zeros(shape=(len(inputs), 18, 4, 4))
    
    for i in range(len(inputs)):
        for x in range(4):
            for y in range(4):
                conv_feature[i][inputs[i][x][y]][x][y] = 1.0
    
    return conv_feature

In [28]:
class AfterstateGumbel(nn.Module):
    def __init__(
        self, 
        device: torch.device,
        critic_lr: float,
        actor_lr: float,
        n_envs: int,
    ):
        super().__init__()
        self.device = device
        self.n_envs = n_envs
        
        self.actor = AfterstatePolicyNetwork().to(self.device)
        self.critic = AfterstateValueNetwork().to(self.device)

        self.actor_optim = optim.Adam(params=self.actor.parameters(), lr=actor_lr)
        self.critic_optim = optim.Adam(params=self.critic.parameters(), lr=critic_lr)

        self.c_visit = 50
        self.c_scale = 0.1

    def forward(self, afterstates: np.ndarray):
        concatenated_afterstates = np.concatenate(afterstates, axis=0)
        conv_afterstates = torch.tensor(make_conv_feature(concatenated_afterstates), dtype=torch.float32, device=self.device)
        afterstate_values = self.critic(conv_afterstates)
        afterstate_policy_logits = self.actor(conv_afterstates)
        return (afterstate_values.squeeze(), afterstate_policy_logits.squeeze())

    def select_action(self, legal_actions, afterstates, afterstate_rewards):
        
        batch_size = len(legal_actions)
        afterstate_values, afterstate_policy_logits = self.forward(afterstates)
        q_values = afterstate_values + torch.tensor(np.concatenate(afterstate_rewards, axis=0), dtype=torch.float32, device=self.device)
        gumbel_noise = torch.tensor(np.random.gumbel(size=(len(afterstate_values),)), device=self.device)
      
        selected_actions = torch.zeros(size=(batch_size,), dtype=torch.int32, device=self.device)
        action_logprobs = torch.zeros(size=(batch_size,), dtype=torch.float32, device=self.device)
        selected_q_values = torch.zeros(size=(batch_size,), dtype=torch.float32, device=self.device)

        idx_cnt = 0

        for i in range(batch_size):
            n_legal_actions = len(legal_actions[i])
            noise = gumbel_noise[idx_cnt:idx_cnt+n_legal_actions]
            policy_logits = afterstate_policy_logits[idx_cnt:idx_cnt+n_legal_actions]
            qs = q_values[idx_cnt:idx_cnt+n_legal_actions]
            normalized_q_values = qs / (torch.max(qs) - torch.min(qs))

            action_idx = torch.argmax(noise + policy_logits + self.c_visit * self.c_scale * normalized_q_values)

            selected_actions[i] = legal_actions[i][action_idx]
            action_pd = torch.distributions.Categorical(logits=policy_logits)
            action_logprobs[i] = action_pd.log_prob(action_idx)
            selected_q_values[i] = qs[action_idx]

            idx_cnt += n_legal_actions

        return selected_actions, action_logprobs, selected_q_values
    
    def get_losses(
        self,
        rewards,
        action_log_probs,
        q_value_preds,
        masks,
        device,
    ):
        T = len(rewards)
        td_errors = torch.zeros(T, self.n_envs, device=device)
        for t in reversed(range(T-1)):
            td_error = (rewards[t] + masks[t] * q_value_preds[t+1]).detach() - q_value_preds[t]
            td_errors[t] = td_error
        
        critic_loss = td_errors.pow(2).mean()
        actor_loss = -action_log_probs.mean()

        return critic_loss, actor_loss
    
    def update_parameters(self, critic_loss, actor_loss):
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

In [29]:
n_envs = 10
n_updates = 10000
n_steps_per_update = 128

actor_lr = 0.001
critic_lr = 0.005

envs = gym.vector.make("TwentyFortyEight-v0", num_envs=n_envs, afterstate=True)

device = torch.device("cuda")

agent = AfterstateGumbel(device, critic_lr, actor_lr, n_envs)

envs_wrapper = PrintScores(envs)

critic_losses = []
actor_losses = []
entropies = []

for sample_phase in tqdm(range(n_updates)):
    ep_q_value_preds = torch.zeros(n_steps_per_update, n_envs, device=device)
    ep_rewards = torch.zeros(n_steps_per_update, n_envs, device=device)
    ep_action_log_probs = torch.zeros(n_steps_per_update, n_envs, device=device)
    masks = torch.zeros(n_steps_per_update, n_envs, device=device)

    if sample_phase == 0:
        states, infos = envs_wrapper.reset(seed=42)
    
    for step in range(n_steps_per_update):
        actions, action_log_probs, q_value_preds = agent.select_action(infos["legal actions"], infos["afterstates"], infos["rewards"])

        states, rewards, terminated, truncated, infos = envs_wrapper.step(actions.cpu().numpy())

        ep_q_value_preds[step] = torch.squeeze(q_value_preds)
        ep_rewards[step] = torch.tensor(rewards, device=device)
        ep_action_log_probs[step] = action_log_probs
        masks[step] = torch.tensor([not term for term in terminated])
    
    critic_loss, actor_loss = agent.get_losses(
        ep_rewards,
        ep_action_log_probs,
        ep_q_value_preds,
        masks,
        device,
    )

    agent.update_parameters(critic_loss, actor_loss)
    critic_losses.append(critic_loss.detach().cpu().numpy())
    actor_losses.append(actor_loss.detach().cpu().numpy())

rolling_length = 20
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12, 5))

# episode return
axs[0][0].set_title("Episode Returns")
episode_returns_moving_average = (
    np.convolve(
        np.array(envs_wrapper.return_queue).flatten(),
        np.ones(rolling_length),
        mode="valid",
    )
    / rolling_length
)
axs[0][0].plot(
    np.arange(len(episode_returns_moving_average)) / n_envs,
    episode_returns_moving_average,
)
axs[0][0].set_xlabel("Number of episodes")

# entropy
axs[1][0].set_title("Entropy")
entropy_moving_average = (
    np.convolve(np.array(entropies), np.ones(rolling_length), mode="valid")
    / rolling_length
)
axs[1][0].plot(entropy_moving_average)
axs[1][0].set_xlabel("Number of updates")


# critic loss
axs[0][1].set_title("Critic Loss")
critic_losses_moving_average = (
    np.convolve(
        np.array(critic_losses).flatten(), np.ones(rolling_length), mode="valid"
    )
    / rolling_length
)
axs[0][1].plot(critic_losses_moving_average)
axs[0][1].set_xlabel("Number of updates")


# actor loss
axs[1][1].set_title("Actor Loss")
actor_losses_moving_average = (
    np.convolve(np.array(actor_losses).flatten(), np.ones(rolling_length), mode="valid")
    / rolling_length
)
axs[1][1].plot(actor_losses_moving_average)
axs[1][1].set_xlabel("Number of updates")

plt.tight_layout()
plt.savefig("gumbel.png")

  0%|          | 0/10000 [00:00<?, ?it/s]

[592.]
[848.]


  0%|          | 1/10000 [00:01<4:09:55,  1.50s/it]

[1388.]
[1396.]
[1432.]
[668.]
[1964.]
[2204.]
[2756.]
[1328.]


  0%|          | 2/10000 [00:02<3:47:28,  1.37s/it]

[3140.]
[3484.]
[1372.]
[1340.]
[1768.]
[616.]
[544.]
[1220.]
[1660.]
[ 580. 1232.]
[2720.]
[2388.]


  0%|          | 3/10000 [00:04<3:41:57,  1.33s/it]

[712.]
[720.]
[1364.]
[1316. 2336.]
[1356.]
[1264.]
[1408.]


  0%|          | 4/10000 [00:05<3:39:09,  1.32s/it]

[888.]
[580.]
[704.]
[532.]
[1128.]
[2364.]
[2816.]
[712.]
[1060.]
[1432.]
[304.]


  0%|          | 5/10000 [00:06<3:37:42,  1.31s/it]

[392.]
[1028.]
[1252.]
[1384.]
[1344.]
[1640.]
[1868.]
[656.]


  0%|          | 6/10000 [00:07<3:36:23,  1.30s/it]

[1352.]
[2432.]
[580.]
[1180.]
[984.]
[2676.]
[2756.]


  0%|          | 7/10000 [00:09<3:36:04,  1.30s/it]

[2316.]
[2612.]
[2104.]
[1572.]
[1324.]
[3332.]
[2640.]
[1636.]


  0%|          | 9/10000 [00:11<3:36:34,  1.30s/it]

[3212.]
[3604.]
[3092.]


  0%|          | 10/10000 [00:13<3:36:40,  1.30s/it]

[7016.]
[1504.]


  0%|          | 11/10000 [00:14<3:35:42,  1.30s/it]

[1780.]
[6528.]
[6544.]
[7344.]
[7924.]
[7992.]


  0%|          | 12/10000 [00:15<3:34:39,  1.29s/it]

[3560.]
[3336.]
[12384.]


  0%|          | 13/10000 [00:16<3:34:25,  1.29s/it]

[2644.]
[1620.]
[3840.]
[3244.]
[3180.]
[3184.]


  0%|          | 14/10000 [00:18<3:34:03,  1.29s/it]

[4420.]
[3048.]
[1768.]
[2560.]
[3420.]


  0%|          | 15/10000 [00:19<3:33:47,  1.28s/it]

[4760.]
[1528.]
[3560.]
[2840.]


  0%|          | 16/10000 [00:20<3:34:11,  1.29s/it]

[1576.]
[1052.]
[2092.]
[4224.]
[1476.]
[4096.]
[3436. 2400.]


  0%|          | 17/10000 [00:22<3:34:31,  1.29s/it]

[1748.]
[2488.]
[2624.]


  0%|          | 18/10000 [00:23<3:34:54,  1.29s/it]

[3460.]
[3716.]
[2984.]
[3832.]


  0%|          | 19/10000 [00:24<3:34:03,  1.29s/it]

[3780.]
[748.]
[3748.]


  0%|          | 19/10000 [00:25<3:43:39,  1.34s/it]

[5116.]
[4012.]





KeyboardInterrupt: 