# Deep Q-learning using backprop on small grdiworld

In [1]:
import os
os.chdir("..")
from src.gym_kalman.env_Gridworld import GridworldEnv
from pytagi.nn import Linear, OutputUpdater, ReLU, Sequential, EvenExp

In [2]:
# Initialize value function
import numpy as np

# initialize the environment
grid_size = 4
env = GridworldEnv(grid_size=grid_size, reward_std=0.2)
num_states = env.observation_space.n
actions = np.arange(env.action_space.n)

In [3]:
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    # "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [4]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [5]:
class TAGI_Net():
    def __init__(self, n_observations, n_actions):
        super(TAGI_Net, self).__init__()
        self.net = Sequential(
                    Linear(n_observations, 128),
                    ReLU(),
                    Linear(128, 128),
                    ReLU(),
                    Linear(128, n_actions * 2),
                    EvenExp()
                    )
        self.n_actions = n_actions
        self.n_observations = n_observations
    def forward(self, mu_x, var_x):
        return self.net.forward(mu_x, var_x)

In [6]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 1
EPS_START = 0.9
EPS_END = 0.0001
EPS_DECAY = 1000

TAU = 0.005
LR = 1e-2

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = 1

policy_net = TAGI_Net(n_observations, n_actions)
target_net = TAGI_Net(n_observations, n_actions)
target_net.net.load_state_dict(policy_net.net.get_state_dict())

memory = ReplayMemory(10000)


steps_done = 0


def select_action(state):
    global steps_done
    steps_done += 1
    policy_net.net.eval()
    state_np = state.numpy()

    state_temp = np.array(state_np)
    state_np = np.repeat(state_temp, BATCH_SIZE, axis=0)

    ma, Sa = policy_net.net(state_np)
    print('?',type(state_np))
    ma = ma.reshape(BATCH_SIZE, policy_net.n_actions*2)
    print('?',ma)
    ma = ma[0]
    action_mean = ma[::2]
    Sa = Sa.reshape(BATCH_SIZE, policy_net.n_actions*2)[0]
    action_var = Sa[::2] + ma[1::2]

    a_sample = np.zeros_like(action_mean)
    # print(state)
    for i in range(len(action_mean)):
        # print(action_mean[i])
        # print(np.sqrt(action_var[i]))
        a_sample[i] = np.random.normal(action_mean[i], np.sqrt(action_var[i]))
    # print('======================')

    action = np.argmax(a_sample, axis=0)

    steps_done += 1
    return torch.tensor([[action]],device=device)

def select_greedy_action(state):
    global steps_done
    steps_done += 1
    policy_net.net.eval()
    state_np = state.numpy()

    state_temp = np.array(state_np)
    state_np = np.repeat(state_temp, BATCH_SIZE, axis=0)

    ma, _ = policy_net.net(state_np)
    ma = ma.reshape(BATCH_SIZE, policy_net.n_actions*2)[0]
    action_mean = ma[::2]

    action = np.argmax(action_mean, axis=0)

    return torch.tensor([[action]],device=device)

In [7]:
# def infer_model():
#     if len(memory) < BATCH_SIZE:
#         return
#     transitions = memory.sample(BATCH_SIZE)
#     # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
#     # detailed explanation). This converts batch-array of Transitions
#     # to Transition of batch-arrays.
#     batch = Transition(*zip(*transitions))

#     policy_net.net.train()
#     target_net.net.eval()

#     final_mask = torch.tensor(tuple(map(lambda s: s is None,
#                                 batch.next_state)), device=device, dtype=torch.bool)

#     state_mean_batch = torch.cat(batch.state).numpy()
#     action_batch = torch.cat(batch.action).numpy()
#     reward_batch = torch.cat(batch.reward).numpy()
#     next_state_mean_batch = torch.cat([s if s is not None
#                                         else torch.tensor([100.0])
#                                         for s in batch.next_state]).numpy()

#     # Add uncertainty to the state
#     state_batch = {'mu': state_mean_batch, 'var': np.zeros_like(state_mean_batch)}
#     next_state_batch = {'mu': next_state_mean_batch, 'var': np.zeros_like(next_state_mean_batch)}

#     # Get the next state values from target net
    # next_state_values_mu_f, next_state_values_var_f = target_net.net(next_state_batch['mu'])

    # # Reshape to 2D
    # next_state_values_mu_f = next_state_values_mu_f.reshape(BATCH_SIZE, target_net.n_actions*2)
    # next_state_values_var_f = next_state_values_var_f.reshape(BATCH_SIZE, target_net.n_actions*2)

    # # Along the first axis, select the first and the third columns of the 2D array next_state_values_mu
    # next_state_values_mu = next_state_values_mu_f[:, [0, 2, 4, 6]]
    # next_state_values_var = next_state_values_var_f[:, [0, 2, 4, 6]] + next_state_values_mu_f[:, [1, 3, 5, 7]]

    # next_state_values_samples = np.zeros((BATCH_SIZE, target_net.n_actions))
    # for i in range(BATCH_SIZE):
    #     for j in range(target_net.n_actions):
    #         next_state_values_samples[i, j] = np.random.normal(next_state_values_mu[i, j], np.sqrt(next_state_values_var[i, j]))

    # # Keep the maximum next state value according to the samples
    # max_indices = np.argmax(next_state_values_samples, axis=1)
    # next_state_values_mu = next_state_values_mu[np.arange(BATCH_SIZE), max_indices]
    # next_state_values_var = next_state_values_var[np.arange(BATCH_SIZE), max_indices]

    # # Set the next state values of final states to 0 if next state is final
    # next_state_values_mu_tensor = torch.tensor(next_state_values_mu, device=device)
    # next_state_values_var_tensor = torch.tensor(next_state_values_var, device=device)
    # next_state_values_mu_tensor[final_mask] = 0.0
    # next_state_values_var_tensor[final_mask] = 1e-4
    # next_state_values_mu = next_state_values_mu_tensor.numpy()
    # next_state_values_var = next_state_values_var_tensor.numpy()

    # # Compute the expected Q values
    # # Scale the reward so that the Q value is bounded between 0 and 1
    # expected_state_values_mu = np.array((next_state_values_mu * GAMMA) + reward_batch)
    # expected_state_values_var = np.array((next_state_values_var * GAMMA**2))

    # # Infer the policy network using the expected Q values
    # expected_state_action_values_mu_f, expected_state_action_values_var_f = policy_net.net(state_batch['mu'])

    # # # Only change the expected Q values where actions are taken
    # expected_state_action_values_mu_f = expected_state_action_values_mu_f.reshape(BATCH_SIZE, policy_net.n_actions*2)
    # expected_state_action_values_var_f = expected_state_action_values_var_f.reshape(BATCH_SIZE, policy_net.n_actions*2)
    # expected_state_action_values_mu = expected_state_action_values_mu_f[:, [0,2,4,6]]
    # expected_state_action_values_var = expected_state_action_values_var_f[:, [0,2,4,6]] + expected_state_action_values_mu_f[:, [1,3,5,7]]
    # expected_state_action_values_mu[np.arange(BATCH_SIZE), action_batch.flatten()] = expected_state_values_mu
    # expected_state_action_values_var[np.arange(BATCH_SIZE), action_batch.flatten()] = expected_state_values_var
    # # expected_state_action_values_mu[np.arange(self.batchsize), 1-action_batch.flatten()] = np.nan
    # expected_state_action_values_var[np.arange(BATCH_SIZE), 1-action_batch.flatten()] = 1e8

    # print(expected_state_action_values_mu_f)
    # print(expected_state_action_values_var_f)
    # print(expected_state_action_values_mu)
    # print(expected_state_action_values_var)

    # expected_state_action_values_mu = expected_state_action_values_mu.flatten()
    # expected_state_action_values_var = expected_state_action_values_var.flatten()

    # # Update output layer
    # out_updater = OutputUpdater(policy_net.net.device)
    # out_updater.update_heteros(
    #     output_states = policy_net.net.output_z_buffer,
    #     mu_obs = expected_state_action_values_mu,
    #     var_obs = expected_state_action_values_var,
    #     delta_states = policy_net.net.input_delta_z_buffer,
    # )

    # # Feed backward
    # policy_net.net.backward()
    # policy_net.net.step()

    # policy_net.net.eval()
    # dummy_mean, dummy_var = policy_net.net(state_batch['mu'])
    # print(dummy_mean)
    # print(dummy_var)
    # print('======================')

    # # For numerical stability: clip the variance of the parameters to 1e-8
    # policy_net_param_temp = policy_net.net.get_state_dict()
    # for key in policy_net_param_temp:
    #     policy_net_param_temp[key]['var_w']=np.clip(policy_net_param_temp[key]['var_w'], 1e-8, None).tolist()
    #     policy_net_param_temp[key]['var_b']=np.clip(policy_net_param_temp[key]['var_b'], 1e-8, None).tolist()
    #     # Clip the policy_net_param_temp[key]['mu_w'] at 1e8 if it is possitive and -1e8 if it is negative
    #     # policy_net_param_temp[key]['mu_w']=np.sign(policy_net_param_temp[key]['mu_w'])*np.clip(np.abs(policy_net_param_temp[key]['mu_w']), 1e-8, None).tolist()
    #     # policy_net_param_temp[key]['mu_b']=np.sign(policy_net_param_temp[key]['mu_b'])*np.clip(np.abs(policy_net_param_temp[key]['mu_b']), 1e-8, None).tolist()
    # policy_net.net.load_state_dict(policy_net_param_temp)

In [8]:
# Extract current policy
def extract_policy(num_states, episode_i):
    policy = np.zeros(num_states)
    for state in range(num_states):
        if state == 15:  # Terminal state
            policy[state] = 10
            continue
        state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        suggested_action = select_greedy_action(state_tensor)
        policy[state] = suggested_action

    policy_grid = np.array(policy).reshape((grid_size, grid_size))
    # Print title of the plot
    print(f"Episode {episode_i}'s policy")
    print(policy_grid)
    return

In [9]:
if torch.cuda.is_available() or torch.backends.mps.is_available():
    num_episodes = 500
else:
    num_episodes = 500

for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        # NORMALIZED YOUR STATE!!!
        # NORMALIZED YOUR Q!!!!!!!
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        # infer_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.net.get_state_dict()
        policy_net_state_dict = policy_net.net.get_state_dict()
        for key in policy_net_state_dict:
            for key2 in policy_net_state_dict[key]:
                target_net_state_dict[key][key2] = (np.asarray(policy_net_state_dict[key][key2])*TAU +
                                                    np.asarray(target_net_state_dict[key][key2])*(1-TAU)).tolist()
        target_net.net.load_state_dict(target_net_state_dict)

        if done:
            break

    # Extract policy
    extract_policy(num_states, i_episode)

print('Complete')

? <class 'numpy.ndarray'>
? [[ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 ...
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]]
? <class 'numpy.ndarray'>
? [[ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 ...
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122  -0.6074352
   1.2426939 ]
 [ 0.12924881  2.8269396   0.01303672 ...  0.6467122 

KeyboardInterrupt: 

In [10]:
values = np.zeros(num_states)
for state in range(num_states):
    if state == 15:  # Terminal state
        continue
    state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    suggested_action = select_greedy_action(state_tensor)
    with torch.no_grad():
        values[state] = policy_net(state_tensor.view(-1, 1))[0].tolist()[suggested_action]

value_grid = np.array(values).reshape((grid_size, grid_size))
value_grid = np.round(value_grid, 2)
print("\nState values:")
print(value_grid)


State values:
[[-5.82 -5.57 -5.34 -5.12]
 [-4.92 -4.72 -4.52 -4.32]
 [-3.94 -3.26 -2.64 -3.03]
 [-3.01 -1.99 -1.    0.  ]]


In [11]:
for state in range(num_states):
    if state == 15:  # Terminal state
        continue
    state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    suggested_action = select_greedy_action(state_tensor)
    with torch.no_grad():
        print(state)
        print(policy_net(state_tensor.view(-1, 1))[0].tolist())

0
[-6.869808197021484, -5.8179931640625, -6.8171772956848145, -6.659472942352295]
1
[-6.687108993530273, -5.568826675415039, -6.579085350036621, -6.490907669067383]
2
[-6.621176719665527, -5.342931270599365, -6.384941101074219, -6.26676607131958]
3
[-6.552700042724609, -5.1189680099487305, -6.18874979019165, -6.032923221588135]
4
[-6.453420639038086, -4.918376445770264, -5.967729568481445, -5.681536674499512]
5
[-6.3541412353515625, -4.717785358428955, -5.746709823608398, -5.330151081085205]
6
[-6.254861831665039, -4.5171942710876465, -5.525690078735352, -4.978765487670898]
7
[-6.155582427978516, -4.316601753234863, -5.304670333862305, -4.627379417419434]
8
[-6.016765594482422, -3.942415237426758, -5.034659385681152, -4.341019153594971]
9
[-5.807121276855469, -3.2572550773620605, -4.676888465881348, -4.171145915985107]
10
[-5.583076477050781, -2.63578200340271, -4.333233833312988, -3.9852073192596436]
11
[-5.129392623901367, -3.0299558639526367, -4.2146897315979, -3.5430755615234375]
1