# Monte Carlo Tree Search for Teaching

Monte Carlo Tree Search (MCTS) is a popular method for determining good decision strategies in large decision spaces (i.e., those with a large branching factor).  This notebook is a little example code for learning the basics of the algorithm.  It draws heavily from the tutorial by Jeff Bradberry published here: https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/

In [288]:
import numpy as np
import random
import datetime # for limiting calculation to wall clock time
from math import log, sqrt

---
**How to turn this into a teaching situation?**

The `Student` class below gives the basic structure of this teacher-student interaction.  The teacher model provides instructions to the student which influences the state of the student.  The state is directly observable.  The `Student` class is just a model in the sense that it is a fiction in the mind of the teacher.  The student might not exactly behave this way.  However, if the student *did* act this way then the MCTS planning algorithm should be able to decide what the best sequence of teaching actions is.

1. **Uncooperative student in a grid world**: The student is an agent in a grid world.  Objective is to reach a particular goal state for a large positive reward.  What instruction should the teacher give?  Assume additionally the student isn't perfect so might not follow directions perfectly... The teacher should anticipate this. (Does the teacher understand the maze or no? No means give instructions and simply observes the outcomes and decides the policy directly during planning.)

1. **Growing a plant**: The student is a plant.  Objective is to keep alive until it reaches fruiting stage.  One action is ignore the plant (which can actually be good).  Other's are like apply water, apply fertilizer, apply sunlight, etc...  It is kind of like the two player game because you don't know what the other agent will do.  No partial observability issues here, yet, technically because you can just look at the plant height (let's say).

1. **Transparent sequence learner**:  The student is a person learning a digit string.  Objective is to get the student to know the sequence.  Actions are <?>.  Agent can forget.  No partial observability because each time you do a teaching episode the student gives a complete read-out of their current memory (as the state) so you always know where you are in terms of the student's knowledge.



In [None]:
class Student(object):
    
    def start(self):
        # Returns a representation of the starting state of the game.
        pass

    def next_state(self, state, instruction):
        # Takes the game state, and the move to be applied.
        # Returns the new game state.
        pass

    def teaching_actions(self, state_history):
        # Takes a sequence of learners states representing the full
        # teaching history, and returns the full list of actions that
        # are legal teaching actions
        # Question: why does the full teaching histoy influence the actions?
        # Answer: in games like checkers the possible moves limited by past plays.  not an issue
        # in all problems.
        pass

    def reward(self, state_history):
        # Takes a sequence of learner states representing the full
        # teaching history.  If the "game" is now won, return a large
        #positive reward. If the game is still ongoing, return zero.
        pass

### The uncooperative student

The uncooperative student starts at a random spot on a $N$x$M$ maze (see sepecific maze below).  The goal is to get to a goal state while avoiding obstacles.  Actions are expressed as instructions we give to the student (and the student therefore interpret the actions and may not always follow them).  The action space is "move up", "move down", "move left", and "move right".  The actions are available in all states at all times BUT if you run into an obsticle such as a wall or another block then your action simply returns you to the same state.

<img src="images/gridworld.png" width="300">

In this particular instance of the uncooperative student, the student has a tendency to move down no matter what advice it gets ($\epsilon$ probability of following impulse and ignoring advice).  I tried to abstract this definition somewhat in the code so you could change this to be more of an intolerable student.

In [231]:
class GridWorld():
    
    def __init__(self, gridmap):
        self.gridmap = gridmap
        self.gridmap_flat = [item for sublist in gridmap for item in sublist]
        self.nrows = len(self.gridmap)
        self.ncols = len(self.gridmap[0])
        self.all_states = [] # includes all states, indexable
        self.all_states_rev = {}
        self.valid_states = {} # don't include impossible start states
        idx = 0
        for i in range(self.nrows):
            for j in range(self.ncols):
                self.all_states.append((i,j))
                self.all_states_rev[(i,j)]=idx
                idx += 1
                if self.gridmap[i][j] == 'o':
                    self.valid_states[idx]=(i,j)

    def coord_to_index(self, coord):
        return self.all_states_rev[coord]
    
    def index_to_coord(self, index):
        return self.all_states[index]
    
    def raw_print(self):
        for i in range(self.nrows):
            for j in range(self.ncols):
                print("%s"%self.gridmap[i][j],end='\t')
            print ("\n")

    def index_print(self):
        for i in range(self.nrows):
            for j in range(self.ncols):
                print("(%s,%s)"%(i,j),end=' ')
            print ("\n")

    def coord_print(self):
        for i in range(self.nrows):
            for j in range(self.ncols):
                print("%s"%self.coord_to_index((i,j)),end=' ')
            print ("\n")


    def up(self, state):
        i,j = self.index_to_coord(state)
        # if in top row just stay where you are
        # OR if you'll hit a wall
        # OR if you are a wall or goal state (do nothing)
        if (i==0) or self.gridmap[i-1][j]=='x' or self.gridmap[i][j]=='x' or self.gridmap[i][j]=='g':
            return self.coord_to_index((i,j))
        else:
            return self.coord_to_index((i-1,j))

    def down(self, state):
        i,j = self.index_to_coord(state)
        # if in bottom row just stay where you are
        # OR if you'll hit a wall
        # OR if you are a wall or goal state(do nothing)
        if (i==self.nrows-1) or self.gridmap[i+1][j]=='x' or self.gridmap[i][j]=='x' or self.gridmap[i][j]=='g':
            return self.coord_to_index((i,j))
        else:
            return self.coord_to_index((i+1,j))
    
    def left(self, state):
        i,j = self.index_to_coord(state)
        # if in left-most column just stay where you are
        # OR if you'll hit a wall
        # OR if you are a wall or goal state (do nothing)
        if (j==0) or self.gridmap[i][j-1]=='x' or self.gridmap[i][j]=='x' or self.gridmap[i][j]=='g':
            return self.coord_to_index((i,j))
        else:
            return self.coord_to_index((i,j-1))

    def right(self, state):
        i,j = self.index_to_coord(state)
        # if in right-most column just stay where you are
        # OR if you'll hit a wall
        # OR if you are a wall or goal state (do nothing)
        if (j==self.ncols-1) or self.gridmap[i][j+1]=='x' or self.gridmap[i][j]=='x' or self.gridmap[i][j]=='g':
            return self.coord_to_index((i,j))
        else:
            return self.coord_to_index((i,j+1))

In [232]:
gridworld = [
       [ 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'x', 'g'],
       [ 'o', 'o', 'x', 'o', 'o', 'o', 'o', 'x', 'o'],
       [ 'o', 'o', 'x', 'o', 'o', 'o', 'o', 'x', 'o'],
       [ 'o', 'o', 'x', 'o', 'o', 'o', 'o', 'o', 'o'],
       [ 'o', 'o', 'o', 'o', 'o', 'x', 'o', 'o', 'o'],
       [ 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o']
    ]

mygrid = GridWorld(gridworld)
mygrid.raw_print()  # print out the grid world
mygrid.index_print() # print out the indicies of each state
mygrid.coord_print() # print out the coordinates

print(mygrid.all_states) # all tuples states as a flat list
print(mygrid.all_states_rev) # maps from tuples to indicies
mygrid.valid_states # hash from indicies to tuples of non terminal, non-barrier states

o	o	o	o	o	o	o	x	g	

o	o	x	o	o	o	o	x	o	

o	o	x	o	o	o	o	x	o	

o	o	x	o	o	o	o	o	o	

o	o	o	o	o	x	o	o	o	

o	o	o	o	o	o	o	o	o	

(0,0) (0,1) (0,2) (0,3) (0,4) (0,5) (0,6) (0,7) (0,8) 

(1,0) (1,1) (1,2) (1,3) (1,4) (1,5) (1,6) (1,7) (1,8) 

(2,0) (2,1) (2,2) (2,3) (2,4) (2,5) (2,6) (2,7) (2,8) 

(3,0) (3,1) (3,2) (3,3) (3,4) (3,5) (3,6) (3,7) (3,8) 

(4,0) (4,1) (4,2) (4,3) (4,4) (4,5) (4,6) (4,7) (4,8) 

(5,0) (5,1) (5,2) (5,3) (5,4) (5,5) (5,6) (5,7) (5,8) 

0 1 2 3 4 5 6 7 8 

9 10 11 12 13 14 15 16 17 

18 19 20 21 22 23 24 25 26 

27 28 29 30 31 32 33 34 35 

36 37 38 39 40 41 42 43 44 

45 46 47 48 49 50 51 52 53 

[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6), (4, 7), (4, 8), (5, 0), (5, 1), (5,

{1: (0, 0),
 2: (0, 1),
 3: (0, 2),
 4: (0, 3),
 5: (0, 4),
 6: (0, 5),
 7: (0, 6),
 10: (1, 0),
 11: (1, 1),
 13: (1, 3),
 14: (1, 4),
 15: (1, 5),
 16: (1, 6),
 18: (1, 8),
 19: (2, 0),
 20: (2, 1),
 22: (2, 3),
 23: (2, 4),
 24: (2, 5),
 25: (2, 6),
 27: (2, 8),
 28: (3, 0),
 29: (3, 1),
 31: (3, 3),
 32: (3, 4),
 33: (3, 5),
 34: (3, 6),
 35: (3, 7),
 36: (3, 8),
 37: (4, 0),
 38: (4, 1),
 39: (4, 2),
 40: (4, 3),
 41: (4, 4),
 43: (4, 6),
 44: (4, 7),
 45: (4, 8),
 46: (5, 0),
 47: (5, 1),
 48: (5, 2),
 49: (5, 3),
 50: (5, 4),
 51: (5, 5),
 52: (5, 6),
 53: (5, 7),
 54: (5, 8)}

In [351]:
class UncooperativeStudent(object):

    def __init__(self, gridworld, epsilon):
        self.world = GridWorld(gridworld)
        self.EPS = epsilon

    def start(self):
        # choose an initially random state
        start = random.choice(list(self.world.valid_states.keys()))
        start = 45
        print("Starting at state %s:%s"%(start,self.world.all_states[start]))
        return start

    def next_state_agent(self, state, action):
        # Takes the game state, and the move to be applied.
        # Returns the new game state.
        if action=='up':
            return self.world.up(state)
        elif action=='down':
            return self.world.down(state)
        elif action=='left':
            return self.world.left(state)
        elif action=='right':
            return self.world.right(state)
        else:
            raise Exception("Invalid instruction")


    def next_state(self, state, instruction):
        # Takes the game state, and the move to be applied.
        # Returns the new game state.
        action = self.choice(instruction)
        return self.next_state_agent(state, action)

    def choice(self, instruction):
        if random.random() < self.EPS:
            return 'up'
        else:
            return instruction
        
    def teaching_actions(self, state_history):
        # Takes a sequence of learners states representing the full
        # teaching history, and returns the full list of actions that
        # are legal teaching actions
        # Question: why does the full teaching history influence the actions in an MDP?
        # this applies to games where the actions might change based on the game
        # state
        return ['up', 'down', 'left', 'right']

    def reward(self, state_history):
        # Takes a sequence of learner states representing the full
        # teaching history.  If the "game" is now won, return a large
        #positive reward. If the game is still ongoing, return zero.
        reward = 0
        done = False
        for h in state_history[1:]:
            if self.world.gridmap_flat[h]!='g':
                reward += -1
            else:
                reward += 25
                done = True
        return reward, done

In [281]:
bill = UncooperativeStudent(gridworld, 0.1)
bill.world.coord_print() 
s = bill.start()
print(s)
print(bill.next_state(s, 'left'))

bill.reward([45, 46, 47])

0 1 2 3 4 5 6 7 8 

9 10 11 12 13 14 15 16 17 

18 19 20 21 22 23 24 25 26 

27 28 29 30 31 32 33 34 35 

36 37 38 39 40 41 42 43 44 

45 46 47 48 49 50 51 52 53 

Starting at state 31:(3, 4)
31
30


(-2, False)

## Monte-carlo Tree Search

In [352]:
class MCTS(object):
    
    def __init__(self, student, **kwargs):
        # takes an instance of a Board and optionally some keyword
        # arguments. initializes the list of game states and the
        # statistics tables

        self.student = student
        self.states = []
        seconds = kwargs.get('time', 30)
        self.max_moves = kwargs.get('max_moves', 100)
        self.C = kwargs.get('C', 1.4)
        self.calculation_time = datetime.timedelta(seconds=seconds)
        self.rewards = {}
        self.plays = {}
    
    def update(self, state):
        # takes a game state, and appends it to the history
        self.states.append(state)
    
    def get_play(self):
        # causes the AI to calculate the best move from the 
        # current game state and return it
        self.max_depth = 0
        state = self.states[-1]
        legal = self.student.teaching_actions(self.states[:])
        
        if not legal:  # none of this is needed now because actions size is fixed
            return
        if len(legal)==1:
            return legal[0]
        
        games = 0
        begin = datetime.datetime.utcnow()
        while (datetime.datetime.utcnow() - begin) < self.calculation_time:
            self.run_simulation()
            games+=1
        
        
        # statistics here in terms of next states assume this is 
        # deterministic and not a q-value.
        moves_states = [(state, a) for a in legal]
        
        # display the number of calls of `run_simulation` and the time elapsed
        print(games, datetime.datetime.utcnow() - begin)
        
        # pick the move with the highest average reward
        percent_wins, move = max(
            (self.rewards.get((s,a), 0) / self.plays.get((s,a), 1), a) 
            for s, a in moves_states
        )
        
        # display the stats for each possible play
        for x in sorted(
            ((100 * self.rewards.get((s,a), 0) / self.plays.get((s,a), 1),
             self.rewards.get((s,a), 0), self.plays.get((s,a), 0), 
             a)
             for s,a in moves_states),
            reverse=True
        ):
            print("{3}: {0:.2f}% ({1} / {2})".format(*x))
    
        print("Maximum depth search:", self.max_depth)
        return move
    
    def run_simulation(self):
        # plays out a "random" game from the current position,
        # then updates the statistics tables with the result.
        plays, rewards = self.plays, self.rewards # for speed
        
        visited_qs = set()
        states_copy = self.states[:]
        state = states_copy[-1]

        expand = True  # you only expand once
        for t in range(self.max_moves):
            legal = self.student.teaching_actions(states_copy) # get a valid action
            
            moves = [(state, a) for a in legal]
            
            if all(plays.get((s,a)) for s,a in moves):
                # if we have stats on all the legal move, use them
                log_total = 2.0*log(
                    sum(plays[(s,a)] for s,a in moves)
                )
                
                # value of best
                value, ins = max(
                    ((rewards[(s,a)] / plays[(s,a)]) +
                     self.C * sqrt(log_total / plays[(s,a)]),a)
                    for s,a in moves
                )
            else:
                ins = random.choice(legal) # choose one
                            
            if expand and (state,ins) not in self.plays: # if expanding and this is new
                expand = False # stop the expansion in this run
                self.plays[(state,ins)]=0 # initialize
                self.rewards[(state,ins)]=0
                if t> self.max_depth:
                    self.max_depth = t
                    
            visited_qs.add((state,ins)) # add this state as visited
            
            state = self.student.next_state(state, ins) # get next state
            states_copy.append(state) # record

            reward, done = self.student.reward(states_copy) # compute reward if any
            #print(states_copy, reward, done)
            if done: # if done then top this simulation
                break
                
        #print(visited_states, reward)
        for q in visited_qs: # for each visited state
            if q not in self.plays: # if state not in the table of statistics yet
                continue
            self.plays[q]+=1 #  increase places
            self.rewards[q]+=reward # add up the reward you got
        

In [356]:
gridworld = [
       [ 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'x', 'g'],
       [ 'o', 'o', 'x', 'o', 'o', 'o', 'o', 'x', 'o'],
       [ 'o', 'o', 'x', 'o', 'o', 'o', 'o', 'x', 'o'],
       [ 'o', 'o', 'x', 'o', 'o', 'o', 'o', 'o', 'o'],
       [ 'o', 'o', 'o', 'o', 'o', 'x', 'o', 'o', 'o'],
       [ 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o']
    ]

bill = UncooperativeStudent(gridworld, 0.0)
bill.world.coord_print()
tree = MCTS(bill, time=3., C=6.0, max_moves = 100)

action_seq = []
state = bill.start()
print("Starting at state %s"%state)

while bill.world.gridmap_flat[state]!='g':
    tree.update(state)
    action = tree.get_play()
    action_seq.append(action)
    state = bill.next_state(state, action)
    print("Taking action %s now at %s"% (action, state))
    
print(action_seq)

0 1 2 3 4 5 6 7 8 

9 10 11 12 13 14 15 16 17 

18 19 20 21 22 23 24 25 26 

27 28 29 30 31 32 33 34 35 

36 37 38 39 40 41 42 43 44 

45 46 47 48 49 50 51 52 53 

Starting at state 45:(5, 0)
Starting at state 45
1926 0:00:03.000960
up: -9963.57% (-69745 / 700)
right: -9976.89% (-52678 / 528)
left: -10000.00% (-35100 / 351)
down: -10000.00% (-35100 / 351)
Maximum depth search: 28
Taking action up now at 36
11251 0:00:03.000051
right: 199.13% (22251 / 11174)
up: -10051.84% (-27341 / 272)
left: -10051.84% (-27341 / 272)
down: -10051.84% (-27341 / 272)
Maximum depth search: 24
Taking action right now at 37
18498 0:00:03.000065
right: 827.95% (241313 / 29146)
down: -9967.21% (-24619 / 247)
left: -9981.37% (-20362 / 204)
up: -9983.00% (-19966 / 200)
Maximum depth search: 0
Taking action right now at 38
19935 0:00:03.000079
right: 862.03% (422911 / 49060)
down: -9977.61% (-6685 / 67)
up: -10028.89% (-4513 / 45)
left: -10028.89% (-4513 / 45)
Maximum depth search: 0
Taking action right now at 