In [1]:
from model import NSFrozenLake
from amalearn.agent import AgentBase
from draw_policy import draw_policy
import numpy as np

from CONFIG import *

In [2]:
class Agent(AgentBase):
    
    def __init__(self, id, environment, discount, theta, actions):
        
        self.environment = environment
        self.actions = actions
        self.i_limit, self.j_limit = 4, 4
        self.n_states = self.i_limit*self.j_limit
        
        self.V = {}
        self.Q = {}
        self.init_V_Q()
        
        self.policy = {}
        self.init_policy()
        
        super(Agent, self).__init__(id, environment)
        
        self.discount = discount
        self.theta = theta
        
    def init_V_Q(self):
        
        for i in range(self.i_limit):
            for j in range(self.j_limit):
                if (i, j) == (3, 3):
                    self.V[(i, j)] = 0
                else:
                    self.V[(i, j)] = np.random.rand()
                for a in self.actions:
                    self.Q[((i, j), a)] = 0
                
    def init_policy(self):
        
        for i in range(self.i_limit):
            for j in range(self.j_limit):
                act = np.random.choice([i for i in range(len(ACTIONS))])
                self.policy[(i, j)] = np.zeros(len(ACTIONS))
                self.policy[(i, j)][act] = 1
                
    def calculate_v(self, state, action):
        
        states, probs, fail_probs, dones = self.environment.possible_consequences(action, state)
        new_v = 0
        for next_state, prob, fail_prob, done in zip(states, probs, fail_probs, dones):
            new_v += prob*fail_prob*(FAIL_REWARD+MOVE_REWARD)
            new_v += prob*(1-fail_prob)*(MOVE_REWARD + self.discount*self.V[next_state])
            if done:
                new_v += prob*(1-fail_prob)*(GOAL_REWARD)

                
        return new_v
    
    def value_iteration(self):
        
        epoch = 0
        while True:
            delta = 0
            epoch += 1
            for state, v in self.V.items():
                
                temp_v = v
                max_v, max_act = 0, -1
                for act in ACTIONS:
                    new_v = self.calculate_v(state, act)
                    self.Q[(state, act)] = new_v
                    if new_v > max_v:
                        max_v = new_v
                        max_act = act
                        
                self.V[state] = max_v
                
                old_act = np.argmax(self.policy[state])
                self.policy[state][old_act] = 0
                self.policy[state][max_act] = 1
                
                delta = max(delta, abs(temp_v - max_v))
             
            if delta <= self.theta:
                print(epoch)
                break
    
    def take_action(self) -> (object, float, bool, object):
        pass

In [3]:
environment = NSFrozenLake(studentNum=STUDENT_NUM)

In [4]:
agent = Agent('1', environment, DISCOUNT, 0, ACTIONS)
agent.value_iteration()

311


In [5]:
environment.render()
draw_policy(agent.policy)


------------------------------
| [44m0.000[0m | 0.001 | 0.332 | 0.746 | 
------------------------------
| 0.696 | 0.001 | 0.143 | 0.998 | 
------------------------------
| 0.703 | 0.001 | 0.001 | 0.001 | 
------------------------------
| 0.861 | 0.401 | 0.128 | 0.000 | 
------------------------------
→|↓|↓|←
→|↓|↓|↓
→|→|→|↓
→|→|→|↻


In [6]:
for i in range(4):
    for j in range(4):
        print(f'================== state {(i, j)} ==================')
        for a in ACTIONS:
            print(f'state: {(i, j)}, action: {a}, Q-value: {agent.Q[((i, j), a)]}')

state: (0, 0), action: 0, Q-value: 194.4794221221824
state: (0, 0), action: 1, Q-value: 70.4352357425524
state: (0, 0), action: 2, Q-value: 220.23301811078974
state: (0, 0), action: 3, Q-value: 194.4794221221824
state: (0, 1), action: 0, Q-value: 198.02645484626086
state: (0, 1), action: 1, Q-value: 252.29098315395277
state: (0, 1), action: 2, Q-value: 147.41091822576334
state: (0, 1), action: 3, Q-value: 223.7800508348682
state: (0, 2), action: 0, Q-value: 219.216594896086
state: (0, 2), action: 1, Q-value: 241.58331939898954
state: (0, 2), action: 2, Q-value: 35.32279292750111
state: (0, 2), action: 3, Q-value: 142.84746228698117
state: (0, 3), action: 0, Q-value: 131.21772998055948
state: (0, 3), action: 1, Q-value: -4.822973970027828
state: (0, 3), action: 2, Q-value: 23.693060621079425
state: (0, 3), action: 3, Q-value: 23.693060621079425
state: (1, 0), action: 0, Q-value: 67.94898508737158
state: (1, 0), action: 1, Q-value: 73.97721556140814
state: (1, 0), action: 2, Q-value: 246