In [3]:
from pathlib import Path
import sys
import numpy as np
from collections import defaultdict
import torch
from torch.utils.tensorboard import SummaryWriter
from rl_envs.gym_grid_world_env import GridWorldEnv
import numpy as np
from tools.helper import *
# from rl_envs.episodic_grid_world_env import EpisodicGridWorldEnv
# from rl_envs.grid_world_env import GridWorldEnv
from ReplayMemory import *
from agents.DQN import DeepQLearningAgent
%load_ext autoreload 
# %aimport rl_envs.grid_world_env

%autoreload 2

In [4]:

def print_actions(agent, env, get_optimal = False):
    with torch.no_grad():
        action_mapping = [" ↓ "," ↑ "," → "," ← "," ↺ "]
        for i in range(env.height):
            print("[", end=" ")
            for j in range(env.width):
                state = torch.tensor((i,j), dtype=torch.float).unsqueeze(0)
                action = agent.get_action(state)
                print(action_mapping[action.item()], end=" ")
            print("]")

def state_normalize(env, state):
    return ((state[0] - (env.height-1)/2.0)/env.height,(state[1] - (env.width-1)/2.0)/env.width)




BATCHSIZE = 100
LEARN_RATE = 1e-4 # change to 1e-4
TRUE_RANDOM_STATE_VALUE = [
    [-3.8, -3.8, -3.6, -3.1, -3.2],
    [-3.8, -3.8, -3.8, -3.1, -2.9],
    [-3.6, -3.9, -3.4, -3.2, -2.9],
    [-3.9, -3.6, -3.4, -2.9, -3.2],
    [-4.5, -4.2, -3.4, -3.4, -3.5],         
]

def calculate_state_value_error(env:GridWorldEnv,agent):
    # offline policy have 2 policies, I am using the behavior(random) policy for calculating
    with torch.no_grad():
        state_value_error = 0
        for i in range(env.height):
            for j in range(env.width):
                state = torch.tensor((i,j), dtype=torch.float).unsqueeze(0)
                output = agent.policy_net(state)
                state_value = output.sum()/env.action_n
                state_value_error += (state_value - TRUE_RANDOM_STATE_VALUE[i][j])
    return state_value_error




In [5]:

env = GridWorldEnv(fixed_map = True, forbidden_grids=[(1,1),(1,2), (2,2),(3,1),(3,3),(4,1)], target_grids=[(3,2)], forbidden_reward=-1, hit_wall_reward=-1, target_reward=10)
agent = DeepQLearningAgent(state_space_n= 2, action_space_n=env.action_n, lr = LEARN_RATE)
writer = SummaryWriter()


In [6]:
"""
generate samples to replay buffer
"""
replay_buffer = ReplayMemory(10000000)

episode_num = 1000000
for _ in range(episode_num):
    state, info = env.reset()
    for _ in range(100):
        action = random.randint(0,int(env.action_n)-1)
        # action = agent.get_behavior_acion(state)
        next_state, reward, terminated , truncated, info = env.step(action)
        # replay_buffer.push(torch.tensor(state_normalize(env,state), dtype=torch.float), torch.tensor(action, dtype=torch.int64).unsqueeze(0), torch.tensor(reward, dtype=torch.float).unsqueeze(0), torch.tensor(state_normalize(env,next_state), dtype=torch.float))
        replay_buffer.push(torch.tensor(state['agent'], dtype=torch.float), torch.tensor(action, dtype=torch.int64).unsqueeze(0), torch.tensor(reward, dtype=torch.float).unsqueeze(0), torch.tensor(next_state['agent'], dtype=torch.float))
        state = next_state
        if terminated or truncated:
            break


In [7]:



"""
perform DQN
"""
# iter_counter = 0
# q_value = target_value = loss = []
for iter_counter in range(1000000):
    transitions  = replay_buffer.sample(BATCHSIZE)
    batch = Transition(*zip(*transitions))
    state = torch.stack(batch.state)
    next_state = torch.stack(batch.next_state)
    reward = torch.cat(batch.reward)
    action_indices = torch.cat(batch.action)
    
    loss, q_value, target_value = agent.update_Q_network(state, action_indices, reward, next_state)

    if iter_counter % 50 == 0:
        # copy target network every C=50 iteration
        # state_value_estimated = output.sum(dim=1) / env.possible_actions 
        writer.add_scalar('TD error', (q_value - target_value).sum(), iter_counter)         
        writer.add_scalar('Loss', loss.sum(), iter_counter)
        writer.add_scalar('State value error', calculate_state_value_error(env,agent), iter_counter)


        # iter_counter+=1
        # agent.target_net.load_state_dict(agent.policy_net.state_dict())
        agent.sync_target_network()
    # print(loss)


In [None]:

writer.flush()
writer.close()
print(env)

print_actions(agent, env, True)

print()

for i in range(env.height):
    print("[", end=" ")
    for j in range(env.width):
        state = torch.tensor((i,j), dtype=torch.float).unsqueeze(0)
        output = agent.policy_net(state)
        state_value = output.sum()/env.action_n
        state_value_error = (state_value - TRUE_RANDOM_STATE_VALUE[i][j])
        print(state_value_error, end=" ")
    print("]")

# print()


<GridWorldEnv instance>
[  ←   ←   ←   ↓   ↓  ]
[  ←   ←   ←   ↓   ↓  ]
[  ←   ←   ↑   ↑   ↓  ]
[  ←   ←   ↑   ↑   ↑  ]
[  ←   ←   ↑   ↑   ↑  ]

[ tensor(3.7849, grad_fn=<SubBackward0>) tensor(3.8066, grad_fn=<SubBackward0>) tensor(3.6185, grad_fn=<SubBackward0>) tensor(3.1156, grad_fn=<SubBackward0>) tensor(3.2161, grad_fn=<SubBackward0>) ]
[ tensor(3.7790, grad_fn=<SubBackward0>) tensor(3.7925, grad_fn=<SubBackward0>) tensor(3.7787, grad_fn=<SubBackward0>) tensor(3.0724, grad_fn=<SubBackward0>) tensor(2.8702, grad_fn=<SubBackward0>) ]
[ tensor(3.5921, grad_fn=<SubBackward0>) tensor(3.8787, grad_fn=<SubBackward0>) tensor(3.3621, grad_fn=<SubBackward0>) tensor(3.1422, grad_fn=<SubBackward0>) tensor(2.8398, grad_fn=<SubBackward0>) ]
[ tensor(3.9038, grad_fn=<SubBackward0>) tensor(3.5691, grad_fn=<SubBackward0>) tensor(3.3587, grad_fn=<SubBackward0>) tensor(2.8253, grad_fn=<SubBackward0>) tensor(3.1060, grad_fn=<SubBackward0>) ]
[ tensor(4.5127, grad_fn=<SubBackward0>) tensor(4.1706, gra

In [None]:
Q = {}
for y in range(env.size):
    for x in range(env.size):
       state = (y,x)
       q_values = agent.policy_net(state)
       Q[state] = q_values
print_by_dict(env,Q)
                                                                                                                                                                                                                     
                                 

TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple

In [None]:
V = {}
for state in Q.keys():
    V[state] = torch.max(Q[state]).item()
print_by_dict(env, V)


In [None]:
plot_value_function(V)
