## MuZero

Based on MuZero pseudocode:

 * https://arxiv.org/src/1911.08265v2/anc/pseudocode.py
 
TicTacToe

In [None]:
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class TicTacToe:
    def __init__(self): # no recovery state, no discount
        self.action_space_size = 9
        self.raw_board = [0 for i in range(self.action_space_size)]
        self.to_play = 1 # alternates between 0 and 1
        self.terminated = False
        self.raw_reward = 0
        
        self.initial_board = self.board(perspective="to_play")
        
        # actual board history "before playing" would be:
        #  game.make_image(-1) + game.board_history
        self.board_history = [] # board state after playing (perspective of next player)
        # rewards will need to be back-filled
        self.rewards = [] # reward after playing (perspective of player)
        self.history = [] # action (that player) executed
        self.player_history = [] # player that executed action
        
        self.child_visits = []
        self.root_values = []
        
    def backfill_rewards(self):
        # once game is terminated (last player plays - self.played)
        # back-fill rewards based on player_history perspective
        self.rewards = [self.reward(perspective=p) for p in self.player_history]
        
    def terminal(self):
        return self.terminated
    
    def legal_actions(self):
        ret = []
        if self.terminated:
            return ret
        # non-zero locations can be valid actions
        for idx, val in enumerate(self.raw_board):
            if val == 0:
                ret.append(idx)
        return ret
        
    def apply(self, action):
        self.player_history.append(self.to_play)
        self.history.append(action)
        reward = self.step(action)
        self.rewards.append(reward)
        self.board_history.append(self.board(perspective="to_play"))
        return self
    
    def store_search_statistics(self, root):
        sum_visits = sum(child.visit_count for child in root.children.values())
        if sum_visits == 0:
            sum_visits = 1
        action_space = [index for index in range(self.action_space_size)]
        self.child_visits.append([
            root.children[a].visit_count / sum_visits if a in root.children else 0\
            for a in action_space
        ])
        self.root_values.append(root.value())
        
    def make_image(self, state_index=None):
        # state_index starts at 0, after first play
        # image is after the play, in the perspective of next_player
        # to get initial state, get self.initial_board (state_index = -1)
        # no state_index specifies the most recent board
        # batch of 1
        if state_index == -1:
            return [self.initial_board]
        if state_index is None:
            if len(self.board_history) == 0:
                return [self.initial_board]
            state_index = -1

        return [self.board_history[state_index]]
        
    def make_target(self, state_index, num_unroll_steps):
        # the value target is the discounted root value of the search tree N steps
        # into the future, plus the discounted sum of all rewards until then
        # (there is no discount for this game, end reward is applied to all steps)
        # (there is no bootstrap index, this will always be the end of the game)
        # (value is the reward in our case?)
        targets = []
        for current_index in range(state_index, state_index+num_unroll_steps+1):
            if current_index < len(self.root_values):
                targets.append((
                    self.rewards[current_index], self.rewards[current_index],
                    self.child_visits[current_index]
                ))
            #else:
            #    targets.append((0, 0, []))
        return targets

    def board(self, perspective=None):
        # perspective
        # - None: raw_board
        # - played: "played" player's perspective
        # - to_play: "to_play" player's perspective
        # - -1, 1: selected player's perspective
        if perspective is None:
            return self.raw_board
        elif perspective == "played":
            return [k * -self.to_play for k in self.raw_board]
        elif perspective == "to_play":
            return [k * self.to_play for k in self.raw_board]
        return [k * perspective for k in self.raw_board]
    
    def _eval_win(self):
        # check if player has won - enumerate all win positions
        # self.to_play is the player that just played
        win_positions = np.array([
            [0, 1, 2], [3, 4, 5], [6, 7, 8],
            [0, 3, 6], [1, 4, 7], [2, 5, 8],
            [0, 4, 8], [2, 4, 6],
        ])
        has_win = np.any(
            np.all((
                np.array(self.raw_board)[win_positions.ravel()] == self.to_play
            ).reshape(*win_positions.shape), axis=1)
        )
        return has_win
    
    def step(self, action, debug=False):
        # manual environment step
        if self.raw_board[action] == 0:
            self.raw_board[action] = self.to_play
        else:
            # invalid action
            if debug:
                print("step: invalid action")
            if self.terminated is False:
                # automatic win for other player
                self.raw_reward = -self.to_play
            self.terminated = True

        # self.to_play player just played
        if self._eval_win():
            if debug:
                print("step: eval_win")
            if self.terminated is False:
                self.raw_reward = self.to_play
            self.terminated = True

        if len(self.legal_actions()) == 0:
            if debug:
                print("step: no more valid actions")
            # tie if not already terminated?
            self.terminated = True
            
        # player played, reverse roles
        self.to_play = -self.to_play

        return self.reward(perspective="played")
    
    def reward(self, perspective=None):
        if perspective is None:
            return self.raw_reward
        elif perspective == "played":
            return self.raw_reward * -self.to_play
        elif perspective == "to_play":
            return self.raw_reward * self.to_play
        return self.raw_reward * perspective
    
    def render(self):
        for j in range(3):
            print("[", end=" ")
            for i in range(3):
                val = self.raw_board[j * 3 + i]
                s = " "
                if val == 1:
                    s = "O"
                elif val == -1:
                    s = "X"
                print(s, end=" ")
            print("]")

In [None]:
class Network:
    def __init__(self, game_class, state_size=32, hidden_size=64):
        self.observation_size = len(game_class().board())
        self.action_size = game_class().action_space_size
        self.state_size = state_size
        self.hidden_size = hidden_size

        self.representation_nn = nn.Sequential(
            nn.Linear(self.observation_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.state_size)
        )

        self.dynamics_nn = nn.Sequential(
            nn.Linear(self.state_size + self.action_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.state_size)
        )
        #self.reward_nn = nn.Sequential(
        #    nn.Linear(self.state_size, self.hidden_size),
        #    nn.ReLU(),
        #    nn.Linear(self.hidden_size, 1),
        #    nn.Tanh()
        #)

        self.value_nn = nn.Sequential(
            nn.Linear(self.state_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, 1),
            nn.Tanh()
        )
        self.policy_nn = nn.Sequential(
            nn.Linear(self.state_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.action_size),
            nn.Softmax(dim=1)
        )
    def initial_inference(self, image):
        image_t = torch.tensor(image, dtype=torch.float32)
        state_t = self.representation_nn(image_t)
        value_t = self.value_nn(state_t)
        policy_t = self.policy_nn(state_t)
        return [value_t, None, policy_t, state_t]
    
    def recurrent_inference(self, hidden_state, action_onehot):
        #hidden_state_t = torch.tensor(hidden_state, dtype=torch.float32)
        hidden_state_t = hidden_state
        action_t = torch.tensor(action_onehot, dtype=torch.float32)
        input_t = torch.cat([hidden_state_t, action_t], dim=1)
        next_state_t = self.dynamics_nn(input_t)
        #reward_t = self.reward_nn(next_state_t)
        reward_t = None
        value_t = self.value_nn(next_state_t)
        policy_t = self.policy_nn(next_state_t)
        return [value_t, reward_t, policy_t, next_state_t]

In [None]:
class Node:
    def __init__(self, prior, node_id=None):
        if type(prior) == torch.Tensor:
            prior = prior.item()
        
        self.node_id = node_id # debugging purposes
        
        self.visit_count = 0
        self.to_play = -1
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.hidden_state = None
        self.reward = 0

    def expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count
    
    def __str__(self):
        return "%s: prior: %5.2f, n_child: %d, n_visit: %d, value_sum: %d" %\
            (str(self.node_id), self.prior, len(self.children), self.visit_count, self.value_sum)

    def print_tree(self, depth=999):
        # currently max depth = 4
        print(self)
        for i1 in self.children.keys():
            n1 = self.children[i1]
            print(" ", n1)
            if depth >= 2:
                for i2 in n1.children.keys():
                    n2 = n1.children[i2]
                    print("   ", n2)
                    if depth >= 3:
                        for i3 in n2.children.keys():
                            n3 = n2.children[i3]
                            print("     ", n3)
                            if depth >= 4:
                                for i4 in n3.children.keys():
                                    n4 = n3.children[i4]
                                    print("       ", n4)
                        
def add_exploration_noise(node):
    root_dirichlet_alpha = 0.5  # for chess, 0.03 for Go and 0.15 for shogi
    root_exploration_fraction = .25
    actions = node.children.keys()
    noise = np.random.gamma(root_dirichlet_alpha, 1, len(actions))
    frac = root_exploration_fraction
    for a, n in zip(actions, noise):
        node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac

def expand_node(node, to_play, actions, network_output):
    #print(network_output)
    node.to_play = to_play
    node.hidden_state = network_output[3]
    # ignore reward
    policy = {a: network_output[2][0][a] for a in actions}
    policy_sum = sum(policy.values())
    for action, p in policy.items():
        node.children[action] = Node(p / policy_sum, node.node_id + "-" + str(action))

def ucb1(parent, child):
    if False:
        if child.visit_count == 0:
            return np.inf
        return (-child.value_sum / child.visit_count) + 2 * np.sqrt(np.log(parent.visit_count) / child.visit_count)
    
    prior_score = child.prior * np.sqrt(parent.visit_count) / (child.visit_count + 1)
    if child.visit_count > 0:
        # The value of the child is from the perspective of the opposing player
        value_score = -child.value()
    else:
        value_score = 0
    
    return value_score + prior_score

def backpropagate(search_path, value, to_play):
    if type(value) == torch.Tensor:
        value = value.item()
    for node in search_path:
        node.value_sum += value if node.to_play == to_play else -value
        node.visit_count += 1
        #print(node)

def run_mcts(root, action_history, action_space, network, num_simulations=5):
    for sim_idx in range(num_simulations):
        #print("simulation:", sim_idx)
        history = action_history.copy()
        node = root
        search_path = [node]
        
        while node.expanded():
            #print(node.children)
            #for action, child in node.children.items():
            #    print("+", ucb1(node, child))
            v, action, child = max(
                (ucb1(node, child), action, child)\
                for action, child in node.children.items()
            )
            history.append(action)
            #print("..", action, v)
            node = child
            search_path.append(node)
        
        #print(search_path)
        parent = search_path[-2]
        action_onehot = [0 for _ in action_space]
        action_onehot[history[-1]] = 1
        network_output = network.recurrent_inference(
            parent.hidden_state, [action_onehot]
        )
        to_play = (len(history) + 1) % 2 # specific to 2 player games
        expand_node(node, to_play, action_space, network_output)
        backpropagate(search_path, network_output[0], to_play)
        #print(history, search_path)
        #root.print_tree(depth=1)
        #print(">", parent, search_path[-1])
        
def get_visit_counts(node, argmax=True):
    db = []
    g = node.restore_game()
    visit_counts = [0 for i in range(len(g.raw_board))]
    for n in node.children:
        db.append([n, node.children[n].visit_count])
        visit_counts[n] = node.children[n].visit_count
    if len(db) == 0:
        return None, visit_counts
    #next_action = sorted(db, key=lambda x: x[1], reverse=True)[0][0]
    if argmax is True:
        next_action = np.argmax(visit_counts)
    else:
        next_action = np.random.choice(len(visit_counts), p=np.array(visit_counts) / np.sum(visit_counts))
    return next_action, visit_counts

def select_action(num_moves, node, network, argmax=False):
    visit_counts = [
        (child.visit_count, action) for action, child in node.children.items()
    ]
    count_values = np.array(list(zip(*visit_counts))[0])
    if np.sum(count_values) == 0:
        count_values = count_values + 1
    if argmax is True: # argmax
        #print(visit_counts)
        idx_sel = max([(k, j) for j, k in list(enumerate(visit_counts))])[1]
    else:
        idx_sel = np.random.choice(len(visit_counts), p=count_values/np.sum(count_values))
    _, action = visit_counts[idx_sel]
    return action

In [None]:
game.history

In [None]:
def play_game(game_class, network, explore_threshold=4):
    game = game_class()
    action_space = list(range(game.action_space_size))
    nodes = []
    
    while not game.terminal():
        # At the root of the search tree we use the representation function to
        # obtain a hidden state given the current observation.
        root = Node(0, "root")
        current_observation = game.make_image()
        expand_node(root, game.to_play, game.legal_actions(),
                    network.initial_inference(current_observation))
        add_exploration_noise(root)

        run_mcts(root, game.history, action_space, network, 100)
        if len(game.history) < explore_threshold:
            action = select_action(len(game.history), root, network, argmax=False)
        else:
            action = select_action(len(game.history), root, network, argmax=True)
        #print(action)
        #root.print_tree(depth=1)
        game.apply(action)
        game.store_search_statistics(root)
        nodes.append(root)
    game.backfill_rewards()
    return game, nodes

In [None]:
game_class = TicTacToe
network = Network(game_class)
game, nodes = play_game(game_class, network)

In [None]:
game.history

In [None]:
game.render()

In [None]:
nodes

In [None]:
nodes[0].print_tree(depth=1)

In [None]:
class ReplayBuffer:
    def __init__(self, window_size=256):
        self.window_size = window_size
        self.buffer = []
        
    def save_game(self, game):
        if len(self.buffer) > self.window_size:
            self.buffer.pop(0)
        self.buffer.append(game)
    
    def sample_batch(self, num_unroll_steps, batch_size=512):
        games = [self.sample_game() for _ in range(batch_size)]
        game_pos = [(g, self.sample_position(g)) for g in games]
        # make_image -1 gives initial board, 0 gives after first action
        return [(
            g.make_image(i-1),
            g.history[i:i+num_unroll_steps],
            g.make_target(i, num_unroll_steps)
        ) for (g, i) in game_pos]
        
    def sample_game(self):
        game_idx = np.random.choice(len(self.buffer))
        return self.buffer[game_idx]
    
    def sample_position(self, game):
        position = np.random.choice(len(game.history))
        return position

In [None]:
def train_network(network, replay_buffer):
    for i in range(50):
        num_unroll_steps = 5
        batch = replay_buffer.sample_batch(num_unroll_steps)
        update_weights(optimizer, network, batch)
    return network

def update_weights(optimizer, network, batch):
    for image, actions, targets in batch:
        # reward is invalid
        value, reward, policy_output, hidden_state = network.initial_inference(image)
        predictions = [(1., value, reward, policy_output)]
        
        for action in actions:
            value, reward, policy_output, hidden_state = network.recurrent_inference(hidden_state, action)
            predictions.append((1./len(actions), value, reward, policy_output))
        
        for prediction, target in zip(predictions, targets):
            gradient_scale, value, reward, policy_output = prediction
            target_value, target_reward, target_policy = target
            
            l = (
                scalar_loss(value, target_value) +\
                scalar_loss(reward, target_reward) +\
                tf.nn.softmax_cross_entropy_with_logits(
                    logits=policy_logits, labels=target_policy
                )
            )

In [None]:
game_class = TicTacToe
network = Network(game_class)
replay_buffer = ReplayBuffer(192)

optimizer = optim.Adam(
    list(network.representation_nn.parameters()) +\
    list(network.dynamics_nn.parameters()) +\
    list(network.policy_nn.parameters()) +\
    list(network.value_nn.parameters()),
    lr=.005
)

In [None]:
value_losses = []
policy_losses = []
for j in range(1200):
    print(j, end=" ")
    for i in range(128):
        print(".", end="")
        game, nodes = play_game(game_class, network)
        replay_buffer.save_game(game)
    print()

    batch = replay_buffer.sample_batch(5, 256)

    target_value_batch = []
    value_batch = []
    target_policy_batch = []
    policy_batch = []

    for image, actions, targets in batch:
        value, reward, policy_output, hidden_state = network.initial_inference(image)
        predictions = [(1., value, reward, policy_output)]

        for action in actions:
            action_onehot = [0 for i in range(9)]
            action_onehot[action] = 1
            value, reward, policy_output, hidden_state = network.recurrent_inference(hidden_state, [action_onehot])
            predictions.append((1./len(actions), value, reward, policy_output))


        for prediction, target in zip(predictions, targets):
            gradient_scale, value, reward, policy_output = prediction
            target_value, target_reward, target_policy = target
            target_value_ = torch.tensor([target_value], dtype=torch.float32)
            target_policy_ = torch.tensor(target_policy, dtype=torch.float32)

            target_value_batch.append(target_value_)
            value_batch.append(value[0])
            target_policy_batch.append(target_policy_)
            policy_batch.append(policy_output[0])

    target_value_batch = torch.stack(target_value_batch)
    value_batch = torch.stack(value_batch)
    target_policy_batch = torch.stack(target_policy_batch)
    policy_batch = torch.stack(policy_batch)

    optimizer.zero_grad()

    loss_value = F.mse_loss(
        value_batch, target_value_batch
    )
    loss_policy = torch.mean(-torch.sum(
        target_policy_batch *\
        torch.log(policy_batch),
    dim=1))
    print(loss_value, loss_policy)
    value_losses.append(loss_value.item())
    policy_losses.append(loss_policy.item())
    loss = loss_value + loss_policy

    loss.backward()
    #print(loss)
    optimizer.step()

In [None]:
plt.plot(value_losses)
plt.plot(policy_losses)

In [None]:
action_sequence = []
game = game_class()
action_space = list(range(game.action_space_size))
while True:
    game.render()
    print("avail:", game.legal_actions())
    if game.terminal():
        print("reward:", game.reward())
        break
        
    print()
    action = input()
    action = int(action)
    action_sequence.append(action)
    game.apply(action)
    
    game.render()
    print("avail:", game.legal_actions())
    if game.terminal():
        print("reward:", game.reward())
        break
    
    print()
    
    root = Node(0, "root")
    current_observation = game.make_image()
    expand_node(root, game.to_play, game.legal_actions(),
                network.initial_inference(current_observation))
    run_mcts(root, game.history, action_space, network, 100)
    action = select_action(len(game.history), root, network, argmax=True)

    action_sequence.append(action)
    print("comp:", action)
    game.apply(action)