# Deep Q-learning using backprop on small grdiworld

In [1]:
import os
os.chdir("..")
from src.gym_kalman.env_Gridworld import GridworldEnv

In [2]:
# Initialize value function
import numpy as np

# initialize the environment
grid_size = 4
env = GridworldEnv(grid_size=grid_size, reward_std=0.)
num_states = env.observation_space.n
actions = np.arange(env.action_space.n)

In [3]:
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# set seed
np.random.seed(0)
random.seed(0)


In [4]:
Q_table = np.zeros((num_states, len(actions)))

GAMMA = 1
EPS_START = 0.9
EPS_END = 0.0001
EPS_DECAY = 1000

steps_done = 0

def select_action(state):
    global Q_table
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample < eps_threshold:
        return random.choice(actions)
    else:
        return np.argmax(Q_table[state])

def select_greedy_action(state):
    global Q_table
    return np.argmax(Q_table[state])

In [5]:
# Extract current policy
def extract_policy(num_states, episode_i):
    policy = np.zeros(num_states)
    for state in range(num_states):
        if state == 15:  # Terminal state
            policy[state] = 10
            continue
        suggested_action = select_greedy_action(state)
        policy[state] = suggested_action

    policy_grid = np.array(policy).reshape((grid_size, grid_size))
    policy_grid[0,2]=1
    # Print title of the plot
    print(f"Episode {episode_i}'s policy")
    print(policy_grid)
    return

In [6]:
num_episodes = 30

for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state, info = env.reset()
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = observation

        # Update Q-table
        if next_state is not None:
            Q_table[state][action] = reward + GAMMA * np.max(Q_table[next_state])
        else:
            Q_table[state][action] = reward

        # Move to the next state
        state = next_state

        # print(np.round(Q_table.T.reshape((grid_size, grid_size, 4)),2))
        # print('-------------------------')

        if done:
            break

    print(np.round(np.max(Q_table,axis=1).reshape((grid_size, grid_size)),2))
    print('-------------------------')

# Extract policy
extract_policy(num_states, i_episode)
print('=========================')

print('Complete')

[[-2. -1.  0.  0.]
 [-1.  0.  0.  0.]
 [-1. -1.  0.  0.]
 [-1.  0.  0.  0.]]
-------------------------
[[-2. -2. -1. -2.]
 [-2. -2. -2. -1.]
 [-2. -1. -1.  0.]
 [-1. -1. -1.  0.]]
-------------------------
[[-2. -2. -2. -2.]
 [-3. -2. -2. -1.]
 [-2. -1. -1. -1.]
 [-1. -1. -1.  0.]]
-------------------------
[[-3. -3. -3. -3.]
 [-3. -2. -2. -2.]
 [-2. -1. -1. -1.]
 [-1. -1. -1.  0.]]
-------------------------
[[-3. -3. -3. -3.]
 [-3. -2. -2. -2.]
 [-2. -1. -2. -1.]
 [-1. -1. -1.  0.]]
-------------------------
[[-4. -3. -3. -3.]
 [-3. -2. -2. -2.]
 [-2. -2. -2. -1.]
 [-2. -2. -1.  0.]]
-------------------------
[[-4. -3. -3. -3.]
 [-3. -2. -2. -2.]
 [-3. -2. -2. -1.]
 [-2. -2. -1.  0.]]
-------------------------
[[-4. -3. -3. -3.]
 [-3. -2. -2. -2.]
 [-3. -3. -2. -1.]
 [-3. -2. -1.  0.]]
-------------------------
[[-4. -3. -3. -3.]
 [-4. -3. -3. -2.]
 [-3. -3. -2. -1.]
 [-3. -2. -1.  0.]]
-------------------------
[[-4. -3. -3. -3.]
 [-4. -3. -3. -2.]
 [-4. -3. -2. -1.]
 [-3. -2. -1.  0

In [7]:
values = np.zeros(num_states)
for state in range(num_states):
    if state == 15:  # Terminal state
        continue
    suggested_action = select_greedy_action(state)
    values[state] = Q_table[state][suggested_action]
    # values[state] = np.mean(Q_table[state])

value_grid = np.array(values).reshape((grid_size, grid_size))
value_grid = np.round(value_grid, 2)
print("\nState values:")
print(value_grid)


State values:
[[-6. -5. -4. -3.]
 [-5. -4. -3. -2.]
 [-4. -3. -2. -1.]
 [-3. -2. -1.  0.]]
