<a href="https://colab.research.google.com/github/yhk775206/2023.RL/blob/main/BE_DP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import numpy as np

In [5]:
# 1. MDP 정의
# (1) S
shape = (4, 4)
terminals = [(0, 0), (3, 3)]

# (2) A
numa = 4
actions = ['up', 'down', 'left', 'right']

# (4) R
reward = -1 * np.ones(shape)
for terminal in terminals:
  reward[terminal] = 0

# (5) gamma
gamma = 1.0

In [2]:
# 1. MDP 정의
# (3) P
def P(state, action):
  if action == 'up':
    next_state = (max(0, state[0]-1), state[1])
  elif action == 'down':
    next_state = (min(shape[0]-1, state[0]+1), state[1])
  elif action == 'left':
    next_state = (state[0], max(0, state[1]-1))
  elif action == 'right':
    next_state = (state[0], min(shape[1]-1, state[1]+1))
  return next_state


In [6]:
# 2. value iteration
# (1) Initialize the value function
V = np.zeros(shape)

# (2) Value iteration
while True:
    delta = 0
    for i in range(shape[0]):
        for j in range(shape[1]):
            if (i, j) in terminals:
                continue
            v = V[i, j]
            V[i, j] = sum((reward[i, j] + gamma * V[P((i, j), a)]) for a in actions) / numa
            delta = max(delta, abs(v - V[i, j]))
    if delta < 1e-4:
        break


In [7]:
# 2. value iteration
# (3) Extract the optimal policy
optimal_policy = {}
for i in range(shape[0]):
    for j in range(shape[1]):
        if (i, j) in terminals:
            optimal_policy[i, j] = 'terminal'
        else:
            optimal_policy[i, j] = actions[np.argmax([(reward[i, j]
			     + gamma * V[P((i, j), a)]) for a in actions])]


In [10]:
print(optimal_policy)
print(V)

{(0, 0): 'terminal', (0, 1): 'left', (0, 2): 'left', (0, 3): 'left', (1, 0): 'up', (1, 1): 'up', (1, 2): 'left', (1, 3): 'down', (2, 0): 'up', (2, 1): 'up', (2, 2): 'right', (2, 3): 'down', (3, 0): 'up', (3, 1): 'right', (3, 2): 'right', (3, 3): 'terminal'}
[[  0.         -13.99931242 -19.99901152 -21.99891199]
 [-13.99931242 -17.99915625 -19.99908389 -19.99909436]
 [-19.99901152 -19.99908389 -17.99922697 -13.99942284]
 [-21.99891199 -19.99909436 -13.99942284   0.        ]]


In [8]:
# 3. 결과 출력
print("Optimal policy is:")
for i in range(shape[0]):
    for j in range(shape[1]):
        print(f"({i},{j}): {optimal_policy[i,j]}")

print("\nOptimal value function is:")
for i in range(shape[0]):
    for j in range(shape[1]):
        print(f"V({i},{j}): {V[i,j]}")


Optimal policy is:
(0,0): terminal
(0,1): left
(0,2): left
(0,3): left
(1,0): up
(1,1): up
(1,2): left
(1,3): down
(2,0): up
(2,1): up
(2,2): right
(2,3): down
(3,0): up
(3,1): right
(3,2): right
(3,3): terminal

Optimal value function is:
V(0,0): 0.0
V(0,1): -13.999312424461952
V(0,2): -19.999011518162753
V(0,3): -21.998911992496346
V(1,0): -13.999312424461952
V(1,1): -17.999156254598965
V(1,2): -19.99908388638086
V(1,3): -19.99909436158647
V(2,0): -19.999011518162757
V(2,1): -19.99908388638086
V(2,2): -17.99922696784339
V(2,3): -13.999422844683943
V(3,0): -21.99891199249635
V(3,1): -19.999094361586472
V(3,2): -13.999422844683945
V(3,3): 0.0
