In [3]:
import copy
import numpy as np
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import games

In [4]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # common layers
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        # action policy layers
        self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
        self.act_fc1 = nn.Linear(4*3*3, 3*3)
        # state value layers
        self.val_conv1 = nn.Conv2d(128, 4, kernel_size=1)
        self.val_fc1 = nn.Linear(4*3*3, 1)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x_act = F.relu(self.act_conv1(x))
        x_act = x_act.view(-1, 4*3*3)
        x_act = F.log_softmax(self.act_fc1(x_act), dim=-1)
        x_val = F.relu(self.val_conv1(x))
        x_val = x_val.view(-1, 4*3*3)
        x_val = F.relu(self.val_fc1(x_val))
        return x_act, x_val
    
class PolicyValueNet:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = Net().to(self.device)
        self.lr = 1e-3
        self.c = 1e-4
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.c)
    
    def policy_value(self, state):
        state = torch.as_tensor(state, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            log_action_p, value = self.net(state)
        action_p = np.exp(log_action_p.cpu().numpy())
        value = value.cpu().numpy()
        return action_p, value   
    
    def update(self, state, mcts_p, is_winner):
        state = torch.tensor(state, dtype=torch.float32, device=self.device)
        mcts_p = torch.tensor(mcts_p, dtype=torch.float32, device=self.device)
        is_winner = torch.tensor(is_winner, dtype=torch.float32, device=self.device)
        self.optimizer.zero_grad()
        log_action_p, value = self.net(state)
        # mse loss for the state value
        value_loss = F.mse_loss(value.view(-1), is_winner)
        # cross entorpy loss for the search probabilities 
        policy_loss = torch.mean(torch.sum(mcts_p * log_action_p, -1))
        # total loss
        loss = value_loss - policy_loss
        loss.backward()
        self.optimizer.step()
        return value_loss.item(), -1 * policy_loss.item()
            
    def save(self, filename):
        torch.save(self.state_dict(), filename)
    
    @classmethod
    def load(cls, filename):
        policy_value_net = cls()
        policy_value_net.net.load_state_dict(torch.load(filename))
        return policy_value_net

In [5]:
def to_state(observation, info, agent_mark_mapping):
    agent_index = info['agent_index']
    # mark = agent_mark_mapping[info['agent_index']]
    mark_list = list(agent_mark_mapping.values())
    num_agents = len(mark_list)
    array_list = []
    for i in range(num_agents):
        index = (agent_index + i) % num_agents
        mark = mark_list[index]
        array_list.append(observation == mark)
    state = np.stack(array_list)
    return state

In [60]:
class AgentNode:
    """
    Each state_action pair (s, a) stores a set of statistics, {N(s, a), W(s, a), Q(s, a), P(s, a)},
    where N(s, a) is the visit count, W(s, a) is the total action-value, Q(s, a) is the mean action-value,
    and P(s, a) is the prior probability of selecting a in s.
    """
    def __init__(self, parent, num_actions, P, N=0, W=0):
        
        self.parent = parent
        self.num_actions = num_actions
        self.P = P
        self.N = N
        self.W = W
        self.children = {}
        self.child_N = np.zeros(num_actions, dtype=np.float32)
        self.child_W = np.zeros(num_actions, dtype=np.float32)
        self.child_P = None
        self.agent_index = None
        self.is_expanded = False

    def select(self, c_puct=1):
        if not self.is_expand:
            raise ValueError('Expand leaf node first.')
        
        Q = self.child_W / np.where(self.child_N > 0, self.child_N, 1)
        U = c_puct * self.child_P * np.sqrt(self.N) / (1 + self.child_N)
        action = np.argmax(Q + U)
        if action not in self.children.keys():
            self.children[action] = AgentNode(
                self,
                self.num_actions,
                self.child_P[action]
            )
        return action, self.children[action]
    
    def expand(self, agent_index, next_P):
        self.agent_index = agent_index
        self.child_P = next_P
        self.is_expand = True
        # for action in range(action_space.start, action_space.n):
        #     if action not in self.children.keys():
        #         self.children[action] = AgentNode(
        #             self,
        #             next_P[action]
        #         )

    def back_propagate(self, value):
        self.N += 1
        self.W += value
        if self.parent is not None:
            self.parent.back_propagate(-value)

In [117]:
env = gym.make('games/TicTacToe', max_episode_steps=100)
net = PolicyValueNet()

action_space = env.unwrapped.action_space
agent_mark_mapping = env.unwrapped.agent_mark_mapping

observation, info = env.reset()
root_node = AgentNode(
    parent = None,
    num_actions = action_space.n,
    P = 1
)

state = to_state(observation, info, agent_mark_mapping)
prior_p, value = net.policy_value(state)
root_node.expand(info['agent_index'], prior_p[0])
root_node.back_propagate(value[0])


num_simulations = 1000
while root_node.N < num_simulations:
    sim_env = copy.deepcopy(env)
    node = root_node
    
    is_end = False
    while not is_end:
        action, node = node.select()
        observation, reward, terminated, truncated, info = sim_env.step(action)
        state = to_state(observation, info, agent_mark_mapping)
        prior_p, value = net.policy_value(state)
        node.expand(info['agent_index'], prior_p[0])
        node.back_propagate(value[0])
        is_end = terminated or truncated
    # observation, info = env.reset()
#     is_end = False
#     while node.is_expanded
print(root_node.W)

[-0.30607694]


In [92]:
root_node.children[8].W

array([[0.]], dtype=float32)

In [65]:
env = gym.make('games/TicTacToe')
action_space = env.unwrapped.action_space
agent_index_space = env.unwrapped.agent_index_space
agent_mark_mapping = env.unwrapped.agent_mark_mapping

agent_list = [Agent(mark=agent_mark_mapping[i], action_space=action_space) for i in range(agent_index_space.start, agent_index_space.n)]

In [69]:
observation, info = env.reset()
is_end = False
while not is_end:
    print(observation)
    agent = agent_list[info['agent_index']]
    action, q_values = agent.play(observation)
    observation, reward, terminated, truncated, info = env.step(action)
    is_end = terminated or truncated

[[0 0 0]
 [0 0 0]
 [0 0 0]]
[[0 0 0]
 [0 0 0]
 [0 0 2]]
[[0 0 1]
 [0 0 0]
 [0 0 2]]
[[2 0 1]
 [0 0 0]
 [0 0 2]]
[[2 0 1]
 [0 0 1]
 [0 0 2]]
[[2 0 1]
 [0 0 1]
 [0 2 2]]
[[2 0 1]
 [1 0 1]
 [0 2 2]]
