In [0]:
%load_ext autoreload
%autoreload 2

import matplotlib
from matplotlib import pyplot
pyplot.ion()
pyplot.style.use('dark_background')

import torch as th
from rl.core.algs.dqn import *

In [2]:
from rl.core.algs import experiment
from rl.core.algs import model
from rl.core.envs import grid_world

BATCH_SIZE = 32
TRAIN_START_STEP = 3 * BATCH_SIZE
NUM_TRAIN_STEPS = 10000 # 17
REPLAY_SIZE = 100000
TARGET_NETWORK_UPDATE_FREQ = 500


env = environment.GridWorldWrapped(grid_shape=(4, 4, 1))
conv_specs = {
  model.ConvSpec(32, 4, 1),
}
fc_specs = (64,)
model_params = {
  'conv_specs': conv_specs,
  'fc_specs': fc_specs,
}

EPS_DECAY = util.Schedule((0, TRAIN_START_STEP, NUM_TRAIN_STEPS / 1, NUM_TRAIN_STEPS),
                          (1, 1, 0.0, 0.0))
EPS_RANDOM = util.Schedule((0,), (1.0,))
adam_params = {
  'fn': th.optim.Adam,
  'kwargs': {
    'lr': 0.005,
  }
}
rmsprop_params = {
  'fn': th.optim.RMSprop,
  'kwargs': {
    'lr': 0.00025,
    'momentum': 0.95,
    'alpha': 0.95,
    'eps': 0.01,
  }
}

policy_params = {
    'opt_params': rmsprop_params,
    'eps_sched': EPS_DECAY,  # TODO: change
    'target_update_freq': TARGET_NETWORK_UPDATE_FREQ,
}
exp = experiment.Experiment(env, model.QNetwork, model_params,
                            Dqn, policy_params)

print(exp.policy.model)

rb = ReplayBuffer(REPLAY_SIZE)
for step in range(TRAIN_START_STEP):
  sars = env.step(exp.policy.get_action)
  rb.add(sars)

eps_rew = 0
for step in range(TRAIN_START_STEP, NUM_TRAIN_STEPS):
  sars = env.step(exp.policy.get_action)
  rb.add(sars)
    
  experience_batch = rb.sample(BATCH_SIZE)
  m = exp.policy.update(experience_batch)
  
  eps_rew += sars.r
  if sars.s1 is None:
    m['r_per_eps'] = eps_rew
    eps_rew = 0
  if step % 100 == 0 or sars.s1 is None:
    for k, v in m.items():
      exp.plt.add_data(k, step, v)
      
# Expect about r_per_eps=-50 in pure random agent.

(1, 1, 32)
QNetwork(
  (convs): Sequential(
    (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(1, 1))
    (1): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=32, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=4, bias=True)
  )
)


In [3]:
opt = grid_world.GridWorld((4, 4)).optimal_expected_reward()
print(f'optimal expected r: {opt}')

optimal expected r: 1.6666666666666665


In [4]:
env.show_model_policy(exp.policy.model)

[[ 0.          0.03255194 -1.04763699 -2.02158523]
 [ 0.00506255 -1.02143931 -2.02258563 -2.91643786]
 [-0.96860361 -1.98987949 -2.90923166 -3.95237851]
 [-2.02682257 -2.93438935 -3.96420765 -4.74878883]]
[['⟲' '←' '←' '←']
 ['↑' '←' '←' '↑']
 ['↑' '↑' '←' '↑']
 ['↑' '↑' '↑' '↑']]
[[-0.11723417  0.          0.02510482 -0.93119252]
 [-1.12678456 -0.07661006 -0.92180407 -1.81782329]
 [-1.9832226  -1.02196336 -1.95192206 -2.90122223]
 [-2.96776533 -2.01089549 -2.92617917 -3.76400995]]
[['→' '⟲' '←' '←']
 ['↑' '↑' '↑' '↑']
 ['↑' '↑' '↑' '↑']
 ['↑' '↑' '↑' '↑']]
[[-1.15663433e+00 -1.28930330e-01  0.00000000e+00  4.73703742e-02]
 [-2.13065648e+00 -9.85872924e-01  2.36055255e-03 -9.80775714e-01]
 [-2.92562771e+00 -2.08812714e+00 -1.03411794e+00 -1.91628015e+00]
 [-3.98718643e+00 -2.95133638e+00 -2.13784361e+00 -3.07564831e+00]]
[['→' '→' '⟲' '←']
 ['↑' '→' '↑' '↑']
 ['↑' '↑' '↑' '↑']
 ['↑' '↑' '↑' '↑']]
[[-2.02695680e+00 -9.63680148e-01 -4.50737476e-02  0.00000000e+00]
 [-2.92734337e+00 -2.01

In [5]:
# Debug the action-value function for a given state.
s = [[0, 0, 0, 0],
     [0, 0, 0, 0],
     [0, 0, 1, 0],
     [0, 0, 2, 0]]
s = np.expand_dims(np.expand_dims(s, 0), -1)
print(exp.policy.model(util.to_tensor(s)))
print(exp.policy.target_model(util.to_tensor(s)))

tensor([[-1.8173, -1.8327,  0.0533, -1.8684]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([[-2.0231, -1.8632, -0.0575, -1.8431]], device='cuda:0',
       grad_fn=<AddmmBackward>)


In [6]:
# Examine the replay buffer.
terminals = rb.filter_by(sars_filter(r=0))
print(len(terminals))

# action-value at 's' was wrong. Search the replay buffer for occurrences of
# the correct transition. It's likely missing, which led the NN to overfit
# on missing data.
trans_s = rb.filter_by(sars_filter(s=s))
trans_s_a = rb.filter_by(sars_filter(s=s,a=2))
print(f'Num with (s): {len(trans_s)}. Num with (s, a): {len(trans_s_a)}')


1721
Num with (s): 72. Num with (s, a): 45


In [7]:
env.visualize(exp.policy, steps=20)

array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 2, 1, 0],
       [0, 0, 0, 0]])
array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 0, 0, 0]])
array([[0, 0, 0, 0],
       [2, 0, 1, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]])
array([[0, 0, 0, 0],
       [2, 1, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]])
array([[0, 0, 0, 0],
       [1, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]])
array([[0, 2, 1, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]])
array([[0, 1, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]])
array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 2],
       [0, 0, 1, 0]])
array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 2],
       [0, 0, 0, 1]])
array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 1],
       [0, 0, 0, 0]])
array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 1, 0, 0],
       [0, 0, 0, 2]])
array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 1, 0],
  