<a href="https://colab.research.google.com/github/sktoyo/miscellaneous_codes/blob/master/grid_world_MCP_TD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
import numpy as np
from itertools import product
from tqdm.notebook import tqdm

In [None]:
class GridWorld():
  def __init__(self, grid_x, grid_y, block_list):
    self.x_lim = grid_x
    self.y_lim = grid_y
    self.block_list = block_list
    self.x = 0
    self.y = 0
    
  def step(self, a):
    # 4: left, 6: right, 8:up, 2:down
    if a == 4:
      self.move_left()
    elif a== 6:
      self.move_right()
    elif a== 8:
      self.move_up()
    elif a== 2:
      self.move_down()
    
    reward = -1 # 보상은 늘 -1
    done = self.is_done()

    return (self.x, self.y), reward, done

  def move_left(self):
    if self.x <= 0 or [self.x - 1, self.y] in self.block_list:
      pass
    else:
      self.x -= 1
  
  def move_right(self):
    if self.x >= self.x_lim - 1 or [self.x + 1, self.y] in self.block_list:
      pass
    else:
      self.x += 1
  
  def move_up(self):
    if self.y >= self.y_lim - 1 or [self.x, self.y + 1] in self.block_list:
      pass
    else:
      self.y += 1
  
  def move_down(self):
    if self.y <= 0 or [self.x, self.y - 1] in self.block_list:
      pass
    else:
      self.y -= 1

  def is_done(self):
    if self.x == self.x_lim - 1 and self.y == self.y_lim - 1:
      return True
    else:
      return False

  def get_state(self):
    return (self.x, self.y)
  
  def reset(self):
    self.x = 0
    self.y = 0

    return (self.x, self.y)

class Agent():
  def __init__(self):
    pass

  def select_action(self):
    coin = random.random()
    if coin < 0.25:
      action = 4
    elif coin < 0.5:
      action = 6
    elif coin < 0.75:
      action = 8
    else:
      action = 2

    return action


class Qagent():
  def __init__(self, grid_x, grid_y, action_counts, agent_type):
    self.x_lim = grid_x
    self.y_lim = grid_y
    self.q_table = np.zeros((self.x_lim, self.y_lim, action_counts))
    self.eps = 0.1
    self.alpha = 0.1
    self.gamma = 1
    self.action_list = [2, 4, 6, 8]
    self.agent_type = agent_type

  def select_action(self, s):
    x, y = s
    coin = random.random()
    if coin < self.eps:
      action = random.choice(self.action_list)
    else:
      action_val = self.q_table[x,y,:]
      if action_val.sum() == 0:        
        action = random.choice(self.action_list)
      else:
        action = self.action_list[np.argmax(action_val)]
        
    return action
  
  def update_table(self, history):
    cum_reward = 0
    for transition in history[::-1]:
      s, a, r, s_prime = transition
      x, y = s
      a = self.action_list.index(a)
      if self.agent_type == "mcc": # MCC update method
        self.q_table[x,y,a] = self.q_table[x,y,a] + self.alpha * (cum_reward - self.q_table[x,y,a])
        cum_reward = r + self.gamma * cum_reward
      elif self.agent_type == "sarsa": # SARSA update method      
        x_prime, y_prime = s_prime
        a_prime = self.select_action(s_prime)
        a_prime = self.action_list.index(a_prime)
        self.q_table[x,y,a] = self.q_table[x,y,a] + self.alpha * (r + self.gamma *  self.q_table[x_prime, y_prime, a_prime] - self.q_table[x, y, a])
      elif self.agent_type == "qlearning": # Q-learning update method      
        x_prime, y_prime = s_prime
        self.q_table[x,y,a] = self.q_table[x,y,a] + self.alpha * (r + self.gamma *  np.amax(self.q_table[x_prime, y_prime, :]) - self.q_table[x, y, a])

  def anneal_eps(self):
    if self.agent_type == 'qlearning':
      self.eps =  self.eps - 0.01
      self.eps = max(self.eps, 0.2)
    else:
      self.eps =  self.eps - 0.03
      self.eps = max(self.eps, 0.1)

  def show_table(self):
    q_list = self.q_table
    data = np.zeros((self.x_lim, self.y_lim))
    for x, y in product(np.arange(self.x_lim), np.arange(self.y_lim)):
      action = self.action_list[np.argmax(q_list[x, y])]
      data[x,y] = action
    print(np.transpose(data)[::-1])
    print(data)


In [None]:
def MCP():
  env = GridWorld()
  agent = Agent()
  data = [[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]
  gamma = 1.0
  alpha = 0.001

  for k in range(50000):
    done = False
    history = []

    while not done:
      action = agent.select_action()
      (x, y), reward, done = env.step(action)
      history.append((x,y,reward))
    env.reset()

    cum_reward = 0
    for transition in history[::-1]:
      (x, y, reward) = transition
      data[x][y] = data[x][y] + alpha * (cum_reward - data[x][y])
      cum_reward = reward + gamma * cum_reward

  for row in data:
    print(row)

def TD():
  env = GridWorld()
  agent = Agent()
  data = [[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]
  gamma = 1.0
  alpha = 0.001

  for k in range(50000):
    done = False
    history = []

    while not done:
      x, y = env.get_state()
      action = agent.select_action()
      (x_prime, y_prime), reward, done = env.step(action)
      data[x][y] = data[x][y] + alpha * (reward + gamma * data[x_prime][y_prime] - data[x][y])
      history.append((x,y,reward))
    env.reset()

  for row in data:
    print(row)

In [None]:
def MCC():
  grid_x = 5
  grid_y = 7
  block_list = [[0,2], [1,2], [2,2], [2,4], [3,4], [4,4]]
  env = GridWorld(grid_x, grid_y, block_list)
  agent = Qagent(grid_x, grid_y, 4, 'mcc')

  for i in tqdm(range(10000)):
    done = False
    history = []
  
    s = env.reset()
    while not done:
      a = agent.select_action(s)
      s_prime, r, done = env.step(a)
      history.append((s,a,r,s_prime))
      s = s_prime
    agent.update_table(history)
    agent.anneal_eps()

  agent.show_table()

In [None]:
MCC()

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))


[[6. 2. 2. 6. 2.]
 [6. 6. 6. 6. 8.]
 [8. 8. 2. 2. 2.]
 [8. 8. 4. 4. 4.]
 [2. 2. 2. 6. 8.]
 [4. 8. 8. 8. 8.]
 [6. 6. 6. 6. 8.]]
[[6. 4. 2. 8. 8. 6. 6.]
 [6. 8. 2. 8. 8. 6. 2.]
 [6. 8. 2. 4. 2. 6. 2.]
 [6. 8. 6. 4. 2. 6. 6.]
 [8. 8. 8. 4. 2. 8. 2.]]


In [None]:
def SARSA():
  grid_x = 5
  grid_y = 7
  block_list = [[0,2], [1,2], [2,2], [2,4], [3,4], [4,4]]
  env = GridWorld(grid_x, grid_y, block_list)
  agent = Qagent(grid_x, grid_y, 4, 'sarsa')

  for i in tqdm(range(10000)):
    done = False
  
    s = env.reset()
    while not done:
      a = agent.select_action(s)
      s_prime, r, done = env.step(a)
      agent.update_table([(s,a,r,s_prime)])
      s = s_prime
      
    agent.anneal_eps()

  agent.show_table()

In [None]:
SARSA()

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))


[[6. 6. 6. 6. 2.]
 [6. 6. 6. 6. 8.]
 [8. 8. 2. 2. 2.]
 [8. 8. 4. 4. 4.]
 [2. 2. 2. 8. 8.]
 [6. 6. 6. 8. 8.]
 [6. 6. 6. 8. 8.]]
[[6. 6. 2. 8. 8. 6. 6.]
 [6. 6. 2. 8. 8. 6. 6.]
 [6. 6. 2. 4. 2. 6. 6.]
 [8. 8. 8. 4. 2. 6. 6.]
 [8. 8. 8. 4. 2. 8. 2.]]


In [None]:
def QLearning():
  grid_x = 5
  grid_y = 7
  block_list = [[0,2], [1,2], [2,2], [2,4], [3,4], [4,4]]
  env = GridWorld(grid_x, grid_y, block_list)
  agent = Qagent(grid_x, grid_y, 4, 'qlearning')

  for i in tqdm(range(10000)):
    done = False
    history = []
  
    s = env.reset()
    while not done:
      a = agent.select_action(s)
      s_prime, r, done = env.step(a)
      agent.update_table([(s,a,r,s_prime)])
      s = s_prime
    agent.anneal_eps()

  agent.show_table()

In [None]:
QLearning()

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))


[[6. 6. 6. 6. 2.]
 [6. 6. 6. 6. 8.]
 [6. 8. 2. 2. 2.]
 [6. 8. 4. 4. 4.]
 [2. 2. 2. 8. 4.]
 [6. 6. 6. 8. 4.]
 [6. 6. 6. 8. 8.]]
[[6. 6. 2. 6. 6. 6. 6.]
 [6. 6. 2. 8. 8. 6. 6.]
 [6. 6. 2. 4. 2. 6. 6.]
 [8. 8. 8. 4. 2. 6. 6.]
 [8. 4. 4. 4. 2. 8. 2.]]


In [None]:
MCP()

[-60.45583312994574, -58.7068948004123, -57.834527171581364, -55.27104432760525]
[-58.36313001297977, -56.12548745599916, -51.367199359634775, -46.54291342820892]
[-54.42082590854048, -52.580422628147126, -44.58003341471624, -31.965271802219128]
[-50.47593972547579, -45.35002512545413, -29.336746137071888, 0.0]


In [None]:
TD()

[-58.32395083199417, -56.3132411180893, -53.159072823067845, -50.61408950276734]
[-56.42126106096235, -53.42240022443133, -48.49326596987988, -44.215838324121776]
[-53.18629056777815, -48.47436662077926, -39.426202903500766, -28.755940968890684]
[-50.677969781872655, -43.81782057096022, -28.960638146910778, 0]
