Skip to content
This repository has been archived by the owner on Dec 11, 2020. It is now read-only.

Confusion about the updateEdgeStats() #166

Open
neoql opened this issue May 1, 2020 · 0 comments
Open

Confusion about the updateEdgeStats() #166

neoql opened this issue May 1, 2020 · 0 comments

Comments

@neoql
Copy link

neoql commented May 1, 2020

In the updateEdgeStats function, reward is updated by edge.reward += reward, which is consistent with the formula in paper "Mastering the game of Go without human knowledge".

But in many other popular unofficial implementations, e.g. junxiaosong/AlphaZero_Gomoku, suragnair/alpha-zero-general, add v to update the edge reward when the current node belongs to the current player, but add -v when the current node belongs to other player. These implementation has achieved good results even if the reward update method is different from the description in the original paper.

I think this implementation is more intuitive than the method described in the original paper, and the Q value of each node represents the value of the current node for the player. 1. I don’t understand why Q in the original paper description is the average of all v in the subtree, no matter v is reward for which player.

And 2. Why are both methods effective? So what are the differences between them?

The following code is taken from junxiaosong/AlphaZero_Gomoku and suragnair/alpha-zero-general respectively.

https://github.com/suragnair/alpha-zero-general/blob/2b7725aeb0868253e1b9492661d78ab76c466e68/MCTS.py#L50

def search(self, canonicalBoard):
        """
        This function performs one iteration of MCTS. It is recursively called
        till a leaf node is found. The action chosen at each node is one that
        has the maximum upper confidence bound as in the paper.
        Once a leaf node is found, the neural network is called to return an
        initial policy P and a value v for the state. This value is propagated
        up the search path. In case the leaf node is a terminal state, the
        outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
        updated.
        NOTE: the return values are the negative of the value of the current
        state. This is done since v is in [-1,1] and if v is the value of a
        state for the current player, then its value is -v for the other player.
        Returns:
            v: the negative of the value of the current canonicalBoard
        """

        s = self.game.stringRepresentation(canonicalBoard)

        if s not in self.Es:
            self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
        if self.Es[s]!=0:
            # terminal node
            return -self.Es[s]

        if s not in self.Ps:
            # leaf node
            self.Ps[s], v = self.nnet.predict(canonicalBoard)
            valids = self.game.getValidMoves(canonicalBoard, 1)
            self.Ps[s] = self.Ps[s]*valids      # masking invalid moves
            sum_Ps_s = np.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s    # renormalize
            else:
                # if all valid moves were masked make all valid moves equally probable
                
                # NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else.
                # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process.   
                print("All valid moves were masked, do workaround.")
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])

            self.Vs[s] = valids
            self.Ns[s] = 0
            return -v

        valids = self.Vs[s]
        cur_best = -float('inf')
        best_act = -1

        # pick the action with the highest upper confidence bound
        for a in range(self.game.getActionSize()):
            if valids[a]:
                if (s,a) in self.Qsa:
                    u = self.Qsa[(s,a)] + self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s])/(1+self.Nsa[(s,a)])
                else:
                    u = self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s] + EPS)     # Q = 0 ?

                if u > cur_best:
                    cur_best = u
                    best_act = a

        a = best_act
        next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
        next_s = self.game.getCanonicalForm(next_s, next_player)

        v = self.search(next_s)

        if (s,a) in self.Qsa:
            self.Qsa[(s,a)] = (self.Nsa[(s,a)]*self.Qsa[(s,a)] + v)/(self.Nsa[(s,a)]+1)
            self.Nsa[(s,a)] += 1

        else:
            self.Qsa[(s,a)] = v
            self.Nsa[(s,a)] = 1

        self.Ns[s] += 1
        return -v

https://github.com/junxiaosong/AlphaZero_Gomoku/blob/a2555b26e38aaaa08270e0731c53135e6222ef46/mcts_alphaZero.py#L61

def update_recursive(self, leaf_value):
        """Like a call to update(), but applied recursively for all ancestors.
        """
        # If it is not root, this node's parent should be updated first.
        if self._parent:
            self._parent.update_recursive(-leaf_value)
        self.update(leaf_value)
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant