In [1]:
import math
import random
from collections import defaultdict
from itertools import starmap

import numpy as np
import gymnasium as gym
from gymnasium.spaces import Discrete
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

import games

In [51]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # common layers
        self.conv1 = nn.Conv2d(4, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        # action policy layers
        self.act_conv1 = nn.Conv2d(128, 4, kernel_size=1)
        self.act_fc1 = nn.Linear(4*3*3, 3*3)
        # state value layers
        self.val_conv1 = nn.Conv2d(128, 4, kernel_size=1)
        self.val_fc1 = nn.Linear(4*3*3, 1)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x_act = F.relu(self.act_conv1(x))
        x_act = x_act.view(-1, 4*3*3)
        x_act = F.log_softmax(self.act_fc1(x_act))
        x_val = F.relu(self.val_conv1(x))
        x_val = x_val.view(-1, 4*3*3)
        x_val = F.relu(self.val_fc1(x_val))
        return x_act, x_val
    
class PolicyValueNet:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = Net().to(self.device)
        self.lr = 1e-3
        self.c = 1e-4
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr, weight_decay=self.c)
    
    def policy_value(self, state):
        state = torch.as_tensor(state, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            log_action_p, value = self.net(state)
        action_p = np.exp(log_action_p.cpu().numpy())
        value = value.cpu().numpy()
        return action_p, value
    
    # def policy_value(self, state):
    #     state = torch.as_tensor(state, dtype=torch.float32, device=self.device)
    #     with torch.no_grad():
    #         log_action_p, value = self.net(state)
    #     action_p = np.exp(log_action_p.cpu().numpy())
    #     value = value.cpu().numpy()
    #     return action_p, value        
    
    def update(self, state, mcts_p, is_winner):
        state = torch.tensor(state, dtype=torch.float32, device=self.device)
        mcts_p = torch.tensor(mcts_p, dtype=torch.float32, device=self.device)
        is_winner = torch.tensor(is_winner, dtype=torch.float32, device=self.device)
        self.optimizer.zero_grad()
        log_action_p, value = self.net(state)
        # mse loss for the state value
        value_loss = F.mse_loss(value.view(-1), is_winner)
        # cross entorpy loss for the search probabilities 
        policy_loss = torch.mean(torch.sum(mcts_p * log_action_p, -1))
        # total loss
        loss = value_loss - policy_loss
        loss.backward()
        self.optimizer.step()
        return value_loss.item(), -1 * policy_loss.item()
            
    def save(self, filename):
        torch.save(self.state_dict(), filename)
    
    @classmethod
    def load(cls, filename):
        policy_value_net = cls()
        policy_value_net.net.load_state_dict(torch.load(filename))
        return policy_value_net

In [34]:
# envs = gym.vector.SyncVectorEnv(
#     [
#     lambda: gym.make('games/TicTacToe') for _ in range(10)
#     ]
# )
net = Net()
env =  gym.make('games/TicTacToe')
observation, info = env.reset()
states = torch.tensor(observation, dtype=torch.float32)
action_spaces = env.action_space

In [41]:
states = torch.zeros([10,4,3,3])

In [None]:
class AgentNode:
    def __init__(self, env, parent, prior, reward=0.0):
        self.env = env
        self.parent = parent
        self.prior = prior
        self.reward = reward
        self.terminated = False
        self.n = 0
        self.q = None
        self.u = None
        self.children = {}
        
    def select(self):
        if not self.terminated:
            return self
        else:
            pass
    
    def expand(self):
        if not self.terminated:
            self.children[action] = []
            return self.get_outcome_child(action, reward, next_state)
        else:
            return self
    
    
    
    def get_outcome_child(self, action, reward, next_state):
        # Find the corresponding state and return if this already exists.
        for child in self.children[action]:
            if next_state == child.state:
                return child
        # Create one if it is not occured from this state-action pair previously.
        new_child = AgentNode(
            self,
            next_state,
            reward
        )
        self.children[action] += [new_child]
        return child
        
        
    def bandit_select(self):
        

In [295]:
class MCTS:
    def __init__(self, envs, bandit, qfunction):
        self.envs = envs
        self.bandit = bandit
        self.qfunction = qfunction
        
    def self_play(self, root_nodes=None):
        if root_nodes is None:
            root_nodes = self.create_root_nodes()
            
        non_final_mask = np.full(self.envs.num_envs, True)
        while non_final_mask.any():
            pass
        
    def create_root_nodes(self, num_agents):
        return [AgentNode(self.bandit, self.qfunction), for _ in range(num_agents)]
        
        

SyntaxError: expected ':' (2435132113.py, line 6)