In [None]:
# Reference #

#https://web.stanford.edu/~surag/posts/alphazero.html




In [None]:
from random import choice
from math import sqrt

In [None]:
class Mcts:

    def __init__(self):
        self.c_punct = 0.5 # exploration rate, figure out a good value for this
        self.visited = [] # states that have already been visited
        self.Q = {} # figure out data structure for Q and N
        self.N = {}
        self.P = {}

        
    """
    Search refers to each simulation of MCTS
    """
        
    def search(self,s,game,nnet): 

        if game.gameEnded(s): return  -game.gameReward(s) # returns the negative of the game reward, why negative?

        if s not in self.visited # if a state has not been visited then you must find the predictions made by the model
            self.visited.add(s) # will mark as visisted
            self.P[s],v = nnet.predict(s) # this is like the rollout phase but using the nueral network to build policy vector
            return -v # still don't understand why negative?
        
        max_u, best_a = -float("inf"), -1  

        for a in game.getvalidActions(s): # finds the valid actions from current state, we will be using pyChess or stockfish for this

            u = self.Q[s][a] + self.c_puct*self.P[s][a]*sqrt(sum(self.N[s]))/(1+self.N[s][a]) # formula for evaluating action (we can come up with an evaluation function)

            if u>max_u:
                max_u = u  # assigns best u
                best_a = a  # assigns best a
        
            a = best_a 
            sp = game.nextState(s,a) # want to take the the best action
            v = self.search(sp,game,nnet) # performs a search on the best action and this will update the values accordingly

            self.Q[s][a] = (self.N[s][a] * self.Q[s][a] +v)/self.N[s][a]+1 # calculates Q
            self.N[s][a] += 1  # adds 1 to represent the node was visited
            
            return -v # why negative???

    def policy(self,s):

         # I believe this should build a policy vector based on the state
         # This example considers the policy vecotr as a Normalized Count of N, 

         return self.N[s] # this is not the actual return, but showing htat it has to do with N     

 


In [None]:
   """"
    1. Initialize nn with random weights, starting with a random policy and value network
    2. Play a number of games of self play
    3. In each turn of the game perform a fixed number of MCTS simulations from the current state
    4. Pick a move by sampling the improved policy 
    """"

class Train:

    def __init__(self,numIters,numEps,threshold,numMCTSsims):
        self.numIters = numIters
        self.numEps = numEps
        self.threshold = threshold
        self.numMCTSsim = numMCTSsims

    def policyIterSp(self,game):

        nnet = initNet() # initializes nueral network, we need to work on creating this
        trainingData = []
        for i in range(self.numIters):
            for e in range(self.numEps):
                trainingData += self.executeEpisode(game,nnet) # recieves training trainingData
            new_nnet = self.trainNNet(trainingData) # trains new nnet on new training trainingData
            frac_win = self.pit(new_nnet,nnet) # play the two nueral networks against each other
            if frac_win > self.threshold: 
                nnet = new_nnet # pick the winning model
        return nnet


    def executeEpisde(self,game,nnet):

        trainingData = []
        s = game.startState() # get the start state of the game
       
        mcts = Mcts() # instantiate the MCTS class
 
        while True: 

            for _ in range(self.numMCTSsims):
                mcts.search(s,game,nnet) #performs numMCTSsims monte carlo simulations

            trainingData.append([s,mcts.policy(s),None]) # append the state, and improved policy, None refers to not knowing the values yet

            a = choice(len(mcts.pi(s)), p=mcts.policy(s)) # choose a random move from the improved policy
            s= game.nextState(s,a) # try that random move

            if game.gameEnded(s): # if the game is over then you need to assign rewards to all the trainingData
                trainingData = self.assignRewards(trainingData,game.gameReward(s))
                return trainingData

    def assignRewards(self,trainingData,reward):

        for i in trainingData:
            # assign to each training example here