# MCTS Agent #

Toy environment.

In [None]:
from tqdm import tqdm
import gym
import itertools
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sys
from gym import error, spaces, utils
from gym.utils import seeding
from enum import Enum
import collections
from itertools import combinations
import random


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import gym
from gym import error, spaces, utils
from gym.utils import seeding
from enum import Enum

class Plant:
  # changed maturity to 10 from 110 to make comparable to q-learning
    def __init__(self, species, maturity=1):
        self.species = species
        self.maturity = maturity         # consider 'days_to_maturity'
        self.age = 0
        
    def __repr__(self):
        return "{}".format(self.species)
    
class Field(gym.Env):

    def __init__(self, size=5, sow_limit=200, season=120, calendar=0):
        # added to define action and observation spaces
        vals = [range(3), range(360), range(360)]

        # mapping possible planting formations into discrete action space
        self.action_list = list(itertools.product(*vals))
        self.num_actions = np.array(self.action_list).shape[0]
        self.action_space = spaces.Discrete(self.num_actions)

        # reduced observation space to plant count 1) standardize output and 2) remain consistent with genetic algo
        self.observation_space = spaces.Box(np.array([0,0,0]),np.array([360,360,360]),dtype=np.int64)

        # parameters for overall field character
        self.size = size
        self.sow_limit = sow_limit
        self.season = season
        self.calendar = calendar
        
        # constants for computing end-of-season reward---distances represent meters
        self.crowding_dist = .02
        self.maize_maize_dist = .1
        self.bean_support_dist = .1
        self.crowding_penalty = .1
        self.maize_maize_penalty = .9
        self.bean_support_bonus = .6
        
        # field is initialized by calling reset()
        self.field = None

    def is_terminal(self):
        return self.calendar == self.season
    def step(self, action):
        #
        action = self.action_list[action]
        if int(action[0]) == 0:
                self.field = np.append(self.field, [[self.size * truncate(action[1]), 
                                                 self.size * truncate(action[2]), 
                                                 Plant('Maize')]], axis=0)
        elif int(action[0]) == 1:
            self.field = np.append(self.field, [[self.size * truncate(action[1]), 
                                              self.size * truncate(action[2]), 
                                              Plant('Bean')]], axis=0)
        elif int(action[0]) == 2:
            self.field = np.append(self.field, [[self.size * truncate(action[1]), 
                                              self.size * truncate(action[2]), 
                                              Plant('Squash')]], axis=0)
        self.calendar +=1
        for plant in self.field:
            plant[2].age += 1
            
        done = self.calendar == self.season
            
        if not done:
            reward = 0
        else:
            reward = self.get_reward()
        return self.field, reward, done, {}
    
    def reset(self):
        # field is initialized with one random corn plant in order to make sowing (by np.append) work
        self.field = np.array([[self.size * np.random.random(), 
                                self.size * np.random.random(), 
                                Plant('Maize')]])
        # timekeeping is reset
        self.calendar = 0
        reward=0

        # added to avoid returning none type
        return self.field
        
    def render(self, mode='human'):
        # initialize plant type arrays so that pyplot won't break if any is empty
        maize = np.array([[None, None]])
        bean = np.array([[None, None]])
        squash = np.array([[None, None]])
        maize_imm = np.array([[None, None]])
        bean_imm = np.array([[None, None]])
        squash_imm = np.array([[None, None]])
        # replace initial arrays with coordinates for each plant type; imm are plants that haven't matured
        maize = np.array([row for row in self.field 
                            if row[2].__repr__() == 'Maize' and row[2].age >= row[2].maturity])
        if maize.size==0:
          maize = np.array([[None, None]])
        bean = np.array([row for row in self.field 
                            if row[2].__repr__() == 'Bean' and row[2].age >= row[2].maturity])
        if bean.size==0:
          bean = np.array([[None, None]])
        squash = np.array([row for row in self.field 
                              if row[2].__repr__() == 'Squash' and row[2].age >= row[2].maturity])
        if squash.size==0:
          squash = np.array([[None, None]])
        maize_imm = np.array([row for row in self.field 
                            if row[2].__repr__() == 'Maize' and row[2].age < row[2].maturity])
        if maize_imm.size==0:
          maize_imm = np.array([[None, None]])
        bean_imm = np.array([row for row in self.field 
                            if row[2].__repr__() == 'Bean' and row[2].age < row[2].maturity])
        if bean_imm.size==0:
          bean_imm = np.array([[None, None]])
        squash_imm = np.array([row for row in self.field 
                            if row[2].__repr__() == 'Squash' and row[2].age < row[2].maturity])
        if squash_imm.size==0:
          squash_imm = np.array([[None, None]])
        # plot the field---currently breaks if any plant type is absent
        plt.figure(figsize=(10, 10))
        plt.scatter(maize[:,0], maize[:,1], c='green', s=200, marker = 'o', alpha=.5, edgecolor='#303030')
        plt.scatter(bean[:,0], bean[:,1], c='brown', s=150, marker = 'o', alpha=.5, edgecolor='#303030')
        plt.scatter(squash[:,0], squash[:,1], c='orange', s=400, marker = 'o', alpha=.5, edgecolor='#303030')
        plt.scatter(maize_imm[:,0], maize_imm[:,1], c='green', s=200, marker = 'o', alpha=.1, edgecolor='#303030')
        plt.scatter(bean_imm[:,0], bean_imm[:,1], c='brown', s=200, marker = 'o', alpha=.1, edgecolor='#303030')
        plt.scatter(squash_imm[:,0], squash_imm[:,1], c='orange', s=200, marker = 'o', alpha=.1, edgecolor='#303030')

        plt.show()
        
        print("Total yield in Calories is {}.\n---\n".format(round(self.get_reward(), 1)))
    
    def close(self):
        # unneeded right now? AFAICT this is only used to shut down realtime movie visualizations
        pass
    
    def get_reward(self):
        # array of plant coordinates for computing distances
        xy_array = np.array([[row[0], row[1]] for row in self.field])

        # distances[m,n] is distance from mth to nth plant in field
        distances = np.linalg.norm(xy_array - xy_array[:,None], axis=-1)
        
        reward = 0
        i = 0
        while i < len(self.field):
            if self.field[i,2].age < self.field[i,2].maturity:
                reward += 0
            elif self.field[i,2].__repr__() == 'Maize':
                cal = 1
                j = 0
                while j < len(distances[0]):
                    if (self.field[j,2].__repr__() == 'Bean' 
                            and distances[i,j] < self.bean_support_dist):
                        cal += self.bean_support_bonus
                    if (self.field[j,2].__repr__() == 'Maize' 
                            and i !=j 
                            and distances[i,j] < self.maize_maize_dist):
                        cal *= self.maize_maize_penalty
                    if 0 < distances[i,j] < self.crowding_dist:
                        cal *= self.crowding_penalty
                    j += 1
                reward += cal
            elif self.field[i,2].__repr__() == 'Bean':
                reward += .25
            elif self.field[i,2].__repr__() == 'Squash':
                reward += 3
            i += 1        
        return reward


In [None]:
env = Field()

In [None]:
print(env.action_space)
print(env.observation_space)

In [None]:
# testing empty environment
env.reset()
env.render()

Brute Force Search

In [None]:
for episode in range(1, 10):
      state = env.reset()
      done = False
      score = 0 
      while not done:
          action = env.action_space.sample()
          n_state, reward, done, info = env.step(action)
          score+=reward
      env.render()


MCTS 

In [None]:
class MCTSNode:
  def __init__(self, state, terminal=False, parent=None, parent_action=None):
      self.state = state
      self.parent = parent
      self.parent_action = parent_action
      self.children = []
      self._number_of_visits = 0
      self.terminal = terminal

  # determines if node is at the end of the season
  def is_terminal_node(self):
    return self.state.is_terminal()
  
  # gets number of visits
  def n(self):
    return self._number_of_visits
  # expands tree
  def expand(self):
    state = self.state
    action=random.randint(0,len(self.state.action_list)-1)
    next_state, reward, done, info =state.step(action)
    child_node = MCTSNode(state, terminal=done, parent=self, parent_action=action)
    self.children.append(child_node)
    return child_node
  
  # recieved average rewards
  def calc_avg(self): 
    return self.state.get_reward()/self.n()
  
  # simulates one monte carlo rollout
  def rollout(self):
    state = self.state
    while not state.is_terminal():
        action = env.action_space.sample()
        n_state, reward, done, info = state.step(action)
    state.reset()
    path=[]
    curr=self
    while curr.parent_action:
      path.append(curr.parent_action)
      curr=curr.parent
    for i in range(len(path)):
      state.step(path[len(path)-i-1])
    return state.get_reward()
  
  # updates tree values
  def backpropagate(self, result):
    #print(self)
    self._number_of_visits += 1.
    if self.parent:
        self.parent.backpropagate(result)
  
  # determines if node is fully expanded -- change value to determine expansion rate
  def is_fully_expanded(self):
    return len(self.children)>5

  # returns best child
  def best_child(self, c_param=0.1):
    print(self.state.calendar)
    choices_weights = [ c.calc_avg() + c_param * np.sqrt((2 * np.log(self.state.calendar) / c.n())) for c in self.children]
    print(choices_weights )
    return self.children[np.argmax(choices_weights)]
  
  # selects node to expand
  def _tree_policy(self):
      current_node = self
      #print(current_node.is_terminal_node())
      i=0
      while not current_node.is_terminal_node():
          #print(str(i)+ ": "+str(current_node.is_fully_expanded())+", "+ str(len(current_node.children)))
          if not current_node.is_fully_expanded():
              return current_node.expand()
          else:
              current_node = current_node.best_child()
          i+=1
      return current_node

  # fill outs tree -- adjust simulation_no to increase number of nodes added
  def best_action(self):
      simulation_no = 1000
      for i in tqdm(range(simulation_no)):
          v = self._tree_policy()
          reward = v.rollout()
          v.backpropagate(reward)
      child = self.best_child()
      while len(child.children) > 0 or child.is_terminal_node():
        child = child.best_child()
      return child
  
# reference code: https://ai-boson.github.io/mcts/

In [None]:
# Running search
env.reset()
root = MCTSNode(state = env)
selected_node = root.best_action()

In [None]:
# Determining optimal path
best_path=[]
curr=selected_node
while curr.parent_action:
  best_path.append(curr.parent_action)
  curr=curr.parent

In [None]:
# Determining depth of tree
len(best_path)

In [None]:
# Rendering best path
env.reset()
for i in range(len(best_path)):
  env.step(best_path[len(best_path)-i-1])
env.render()