In [1]:
import numpy as np
import gym
from gym.spaces import Discrete

In [2]:
class SnakeEnv(gym.Env):
    SIZE=100
  
    def __init__(self, ladder_num, dices):
        self.ladder_num = ladder_num
        self.dices = dices
        self.ladders = dict(np.random.randint(1, self.SIZE, size=(self.ladder_num, 2)))
        self.observation_space=Discrete(self.SIZE+1)
        self.action_space=Discrete(len(dices))

        for k,v in list(self.ladders.items()):
            self.ladders[v] = k
        self.pos = 1

    def reset(self):
        self.pos = 1
        return self.pos

    def step(self, a):
        step = np.random.randint(1, self.dices[a] + 1)
        self.pos += step
        if self.pos == 100:
            return 100, 100, 1, {}
        elif self.pos > 100:
            self.pos = 200 - self.pos

        if self.pos in self.ladders:
            self.pos = self.ladders[self.pos]
        return self.pos, -1, 0, {}

    def reward(self, s):
        if s == 100:
            return 100
        else:
            return -1

    def render(self):
        pass

class ModelFreeAgent(object):
    def __init__(self, env):
        # no num_ladders, since we assume that the transition probabilities are unknown
        
        self.state_size = env.observation_space.n
        self.action_size = env.action_space.n
        
        self.pi = np.zeros((self.state_size), dtype=np.int)
        self.value_sa = np.zeros((self.state_size, self.action_size), dtype=np.float)
        self.value_n = np.zeros((self.state_size, self.action_size), dtype=np.int) # N: accumulated quantity
        
        self.gamma = 0.8
    
    def play(self, state, epsilon=0):
        if np.random.rand() < epsilon:
            return np.random.randint(self.action_size)
        else:
            return self.pi[state]

def eval_game(env, policy):
    state = env.reset()
    return_val = 0
    while True:
        if isinstance(policy, ModelFreeAgent):
            act = policy.play(state)
        elif isinstance(policy, list):
            act = policy[state]
        else:
            raise Error('Illegal policy')
        state, reward, terminate, _ = env.step(act)
        # print(state)
        return_val += reward
        if terminate:
            break
    return return_val



In [3]:
class MonteCarlo(object):
    @staticmethod
    def _monte_carlo_eval(agent, env, epsilon):
        '''update action-value function'''
        # get (state, action, reward) list
        s = env.reset()
        sample = list()
        while True:
            a = agent.play(s, epsilon)
            s, r, done, _ = env.step(a)
            sample.append((s, a, r))
            if done:
                break
        
        # discounted reward
        new_sample = list()
        return_val = 0
        for s, a, r in reversed(sample):
            return_val = return_val * agent.gamma + r
            new_sample.append((s, a, return_val))
        sample = reversed(new_sample)
        
        # update update action-value function
        for s, a, r in sample:
            agent.value_n[s, a] += 1 # increase by 1
            N = agent.value_n[s, a]
            agent.value_sa[s, a] += (r - agent.value_sa[s, a]) / N
        
    @staticmethod
    def _policy_improvement(agent):
        '''update policy'''
        new_policy = np.zeros_like(agent.pi) # (action_size, state_size, state_size)
        for s in range(1, agent.state_size):
            new_policy[s] = np.argmax(agent.value_sa[s, :]) # select the max action !!! not [s, a], but [s, :]
        if np.all(np.equal(new_policy, agent.pi)):
            return True # converge
        else:
            agent.pi = new_policy
            return False # not converge
    
    @staticmethod
    def monte_carlo_opt(agent, env, epsilon=0.0):
        for i in range(10):
            for j in range(100):
                MonteCarlo._monte_carlo_eval(agent, env, epsilon=epsilon)
            MonteCarlo._policy_improvement(agent)

In [4]:
np.random.seed(3)
env = SnakeEnv(0, [3, 6])
agent = ModelFreeAgent(env)
MonteCarlo.monte_carlo_opt(agent, env)
print('return:', eval_game(env, agent))
print(agent.pi)

return: 72
[0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0]
