In [55]:

%load_ext autoreload 
# %aimport rl_envs.grid_world_env

%autoreload 2
import torch
import math
from torch.utils.tensorboard import SummaryWriter # type: ignore

from rl_envs.gym_grid_world_env import GridWorldEnv
from agents.policy_gradient import PGAgent
from tools.helper import *
import  gymnasium  as gym


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
LEARN_RATE = 1e-2
DISCOUNTED_FACTOR = 0.99

FORBIDDEN_REWARD = -10
HITWALL_REWARD = -10
TARGET_REWARD = 1

In [57]:
# env = GridWorldEnv(size=3,fixed_map = True, forbidden_grids=[(1,1)], target_grids=[(2,2)], forbidden_reward=FORBIDDEN_REWARD, hit_wall_reward=HITWALL_REWARD, target_reward=TARGET_REWARD)
env = GridWorldEnv(fixed_map = True, forbidden_grids=[(1,1),(1,2), (2,2),(3,1),(3,3),(4,1)], target_grids=[(3,2)], forbidden_reward=FORBIDDEN_REWARD, hit_wall_reward=HITWALL_REWARD, target_reward=TARGET_REWARD)

# env = gym.make("CartPole-v1")


In [58]:
episode_rewards = []
episode_lengths = []


In [60]:
# agent = PGAgent(4, 2, lr = LEARN_RATE, discounted_factor=DISCOUNTED_FACTOR)
agent = PGAgent(2, env.action_n, lr = LEARN_RATE, discounted_factor=DISCOUNTED_FACTOR)
writer = SummaryWriter()
num_episodes = 20000
episode_len = 10000
eps = np.finfo(np.float32).eps.item()
# 第一次收集改为随机收集
trajectory = []
obs, _ = env.reset()
# for _ in range(1000):
#     state = tuple(obs['agent'])
#     action = agent.get_behavior_action(state)
#     obs, reward, terminated, truncated, info = env.step(action)
#     trajectory.append((state, action, reward+10))
running_reward = -10
for episode in range(num_episodes):
    # 首先, 根据 policy 生成 episode
    obs, _ = env.reset()
    ep_reward = 0.
    real_episode_len = 0
    rewards = []
    # 初始策略是不是有比较大的影响? 
    for real_episode_len in range(episode_len):
        # state = obs
        state = obs['agent']
        action = agent.get_action(state) # action 这里也有随机性
        obs, reward, terminated, truncated, info = env.step(action)
        rewards.append(reward)
        # trajectory.append((state, action, reward))
        ep_reward += float(reward)
        if terminated or truncated:
            break
    running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
    policy_loss = []
    discounted_reward = 0
    returns = []
    for t, reward in enumerate(rewards[::-1]):
        discounted_reward = discounted_reward * agent.discounted_factor + reward
        returns.insert(0, discounted_reward)
        # discounted_reward = sum(DISCOUNTED_FACTOR**i * t[2] for i, t in enumerate(trajectory[t:]))
    toh_returns = torch.tensor(returns)
    toh_returns = (toh_returns- toh_returns.mean()) / (toh_returns.std() + eps)
    for log_prob, R in zip(agent.saved_log_probs, returns):
        policy_loss.append(-log_prob * R)
    # policy update
    """
    特别注意: 这里 log π 中的 π(a|s) 是选择 a 的概率, policy network 得输出一个概率, 而不是什么 a 的值
    当然我们可以用输出的值, 归一化一下作为 action 的概率
    """
    # action_probs = agent.policy_net(torch.tensor(state, dtype=torch.float))
    # agent.q[state][action] = discounted_reward 
    # agent.v[state] = sum([agent.q[state][a] * action_probs[a] for a in agent.q[state].keys()])
    # agent.v[state] = np.mean(agent.q[state])

    # action_probs = actions_val/actions_val.sum()
    agent.optimizer.zero_grad()
    """
    当 discounted reward < 0 时, loss < 0. 若是 action 选择错误, 则 discounted_reward 小, 使得 loss 小 (或者说负地厉害) 
    梯度下降会将 loss 减地更小, 也就使得对应错误 action 的 action_probs[action] 减小

    相反, 当选择正确 action 时, discounted_reward 理想下应该更大, 则 loss 也更大, 梯度下降同样降低 loss,
    使得对应正确 action 的 action_probs[action] 减小. 
    
    关键就在于, 要使得 loss 小的时候梯度下降地比 loss 大的时候要更快.

    (若是训练地成功 下一轮时 discounted_reward 就会变大, 那么 loss 也就是越来越大, 自然就是向上走,
    至于为什么 loss 会趋近于 0, 我猜测是因为 discounted_reward 有一个由负变正的过程, 而在其中当 loss 变为 0 时
    ) 
    """
    loss = torch.cat(policy_loss).sum()
    # loss = -torch.log(action_probs[action]) * (discounted_reward)
    # loss = -torch.log(action_probs[action]) * (discounted_reward - agent.v[state]) # add baselline advantage
    # [parms.grad for name, parms in agent.policy_net.named_parameters()]
    # loss = abs(loss)
    loss.backward()
    # torch.nn.utils.clip_grad.clip_grad_norm_(agent.policy_net.parameters(), 100)
    agent.optimizer.step()

    del agent.saved_log_probs[:]

    writer.add_scalar('Loss', loss, episode)
    writer.add_scalar('episodeReward', discounted_reward, episode)
    writer.add_scalar('ep_reward', ep_reward, episode)
    writer.add_scalar('running_reward', running_reward, episode)


    if episode % 10 == 0:
        print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}\tloss: {:.2f}'.format(
                episode, ep_reward, running_reward, loss))
    if running_reward > 0.9:
        print("Solved! Running reward is now {} and "
                "the last episode runs to {} time steps!".format(running_reward, real_episode_len))
        break


writer.flush()
writer.close()

Episode 0	Last reward: -370.00	Average reward: -28.00	loss: -18408.95
Episode 10	Last reward: -39.00	Average reward: -75.29	loss: -272.46
Episode 20	Last reward: -339.00	Average reward: -116.05	loss: -6955.47
Episode 30	Last reward: -119.00	Average reward: -146.61	loss: -834.45
Episode 40	Last reward: -9.00	Average reward: -161.51	loss: -1.10
Episode 50	Last reward: -399.00	Average reward: -182.39	loss: -11357.44
Episode 60	Last reward: -9.00	Average reward: -143.13	loss: -32.51
Episode 70	Last reward: -9.00	Average reward: -116.92	loss: -32.04
Episode 80	Last reward: -159.00	Average reward: -103.92	loss: -1218.61
Episode 90	Last reward: -99.00	Average reward: -87.43	loss: -297.33
Episode 100	Last reward: -19.00	Average reward: -72.44	loss: -16.63
Episode 110	Last reward: -890.00	Average reward: -175.66	loss: -6340.11
Episode 120	Last reward: -610.00	Average reward: -244.46	loss: -10037.96
Episode 130	Last reward: -280.00	Average reward: -217.66	loss: -10515.49
Episode 140	Last reward:

KeyboardInterrupt: 

In [61]:
# visualize_in_gym(agent, "CartPole-v1")
policy = agent.generate_policy_table(env.height, env.width)

print_by_dict(env, policy)

for i in range(env.height):
    print("[", end=" ")
    for j in range(env.width):
        state = (i,j)
        action = np.argmax(policy[state])
        print(env.action_mappings[action], end=" ")
    print("]")

[ [[1.2380129e-32 1.2891946e-43 1.0000000e+00 1.4328711e-30 6.8152844e-35]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] ]
[ [[2.31e-43 0.00e+00 1.00e+00 8.41e-45 1.49e-43]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] ]
[ [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] ]
[ [[0. 0. 1. 0. 0.]] [[1.1259211e-25 6.2022967e-32 1.0000000e+00 7.4105486e-27 3.6603810e-19]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] ]
[ [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] [[0. 0. 1. 0. 0.]] ]
[  ↑   ↑   ↑   ↑   ↑  ]
[  ↑   ↑   ↑   ↑   ↑  ]
[  ↑   ↑   ↑   ↑   ↑  ]
[  ↑   ↑   ↑   ↑   ↑  ]
[  ↑   ↑   ↑   ↑   ↑  ]


In [None]:
# env.max_steps = 10
gridworld_demo(agent, env, repeat_times=500)
# gridworld_demo(agent, forbidden_reward=FORBIDDEN_REWARD, hit_wall_reward=HITWALL_REWARD, target_reward=TARGET_REWARD)
# visualize_in_gym(agent, "CartPole-v1")


reward: -20, distance: [array([3, 2]), array([2, 2]), array([1, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), a