In [1]:
# Dataset
from d3rlpy.datasets import get_cartpole 
# Algorithm
from d3rlpy.algos import DQN
# Metrics
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment

from sklearn.model_selection import train_test_split

# Model Training

Here, we use the CartPole dataset to instantly check training results.

In [2]:
dataset, env = get_cartpole()
train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)
dqn = DQN(use_gpu=False)

Initialize neural networks with the given observation shape and action size. This is not necessary when you directly call fit or fit_online method.

In [3]:
dqn.build_with_dataset(dataset)

Calculate metrics with test dataset

In [4]:
td_error = td_error_scorer(dqn, test_episodes)

Set environment in scorer function

In [5]:
evaluate_scorer = evaluate_on_environment(env)

Evaluate algorithm on the environment

In [6]:
rewards = evaluate_scorer(dqn)

Start training

In [7]:
dqn.fit(train_episodes,
    eval_episodes=test_episodes,
    n_epochs=10,
    scorers={
        'td_error': td_error_scorer,
        'value_scale': average_value_estimation_scorer,
        'environment': evaluate_scorer
    }
)

2022-05-10 12:51.44 [debug    ] RoundIterator is selected.
2022-05-10 12:51.44 [info     ] Directory is created at d3rlpy_logs/DQN_20220510125144
2022-05-10 12:51.44 [info     ] Parameters are saved to d3rlpy_logs/DQN_20220510125144/params.json params={'action_scaler': None, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 8000, 'use_gpu': None, 'algorithm': 'DQN', 'observation_shape': (4,), 'action_size': 2}


HBox(children=(FloatProgress(value=0.0, description='Epoch 1/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:51.51 [info     ] DQN_20220510125144: epoch=1 step=2514 epoch=1 metrics={'time_sample_batch': 0.00011620724495195071, 'time_algorithm_update': 0.0022142446505805284, 'loss': 0.011194559461501734, 'time_step': 0.0024188707627848015, 'td_error': 0.9844668525807032, 'value_scale': 1.0447659955444977, 'environment': 11.2} step=2514
2022-05-10 12:51.51 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_2514.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 2/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:51.58 [info     ] DQN_20220510125144: epoch=2 step=5028 epoch=2 metrics={'time_sample_batch': 0.00012481535059550309, 'time_algorithm_update': 0.0021508126953052164, 'loss': 3.73246886335691e-05, 'time_step': 0.002360104186817494, 'td_error': 0.9862890182657684, 'value_scale': 1.0468920218224318, 'environment': 9.8} step=5028
2022-05-10 12:51.58 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_5028.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 3/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:52.05 [info     ] DQN_20220510125144: epoch=3 step=7542 epoch=3 metrics={'time_sample_batch': 0.00011057036209409922, 'time_algorithm_update': 0.002106637165673498, 'loss': 3.580718892105383e-05, 'time_step': 0.0022986553544171195, 'td_error': 0.9853889863979285, 'value_scale': 1.0489737394936889, 'environment': 9.3} step=7542
2022-05-10 12:52.05 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_7542.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 4/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:52.12 [info     ] DQN_20220510125144: epoch=4 step=10056 epoch=4 metrics={'time_sample_batch': 0.00010690493724036994, 'time_algorithm_update': 0.002202211340552962, 'loss': 0.007335323230942474, 'time_step': 0.0023907648729917636, 'td_error': 0.9779033269410278, 'value_scale': 2.0433203454677478, 'environment': 15.4} step=10056
2022-05-10 12:52.12 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_10056.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 5/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:52.19 [info     ] DQN_20220510125144: epoch=5 step=12570 epoch=5 metrics={'time_sample_batch': 0.00010073042712518878, 'time_algorithm_update': 0.002336400031284766, 'loss': 0.006751903670184446, 'time_step': 0.0025214459465530627, 'td_error': 0.984716553432055, 'value_scale': 2.031477987710578, 'environment': 15.2} step=12570
2022-05-10 12:52.19 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_12570.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 6/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:52.26 [info     ] DQN_20220510125144: epoch=6 step=15084 epoch=6 metrics={'time_sample_batch': 0.00010266537313529586, 'time_algorithm_update': 0.0022703881479959782, 'loss': 0.006578138327177328, 'time_step': 0.0024565265779942867, 'td_error': 0.9957254867581926, 'value_scale': 2.0490504242197494, 'environment': 13.6} step=15084
2022-05-10 12:52.26 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_15084.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 7/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:52.33 [info     ] DQN_20220510125144: epoch=7 step=17598 epoch=7 metrics={'time_sample_batch': 9.534410189899453e-05, 'time_algorithm_update': 0.0022269415305356897, 'loss': 0.015864678236087983, 'time_step': 0.002395195342665546, 'td_error': 1.0053853179919539, 'value_scale': 3.0196696636053826, 'environment': 14.7} step=17598
2022-05-10 12:52.33 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_17598.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 8/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:52.41 [info     ] DQN_20220510125144: epoch=8 step=20112 epoch=8 metrics={'time_sample_batch': 9.41707866004088e-05, 'time_algorithm_update': 0.002229144768574547, 'loss': 0.018684730965853845, 'time_step': 0.0023954501679328668, 'td_error': 1.0079978265937157, 'value_scale': 3.0245120766055704, 'environment': 200.0} step=20112
2022-05-10 12:52.41 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_20112.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 9/10', max=2514.0, style=ProgressStyle(description_…


2022-05-10 12:52.48 [info     ] DQN_20220510125144: epoch=9 step=22626 epoch=9 metrics={'time_sample_batch': 9.942898792033169e-05, 'time_algorithm_update': 0.002313308139215316, 'loss': 0.01835772474139758, 'time_step': 0.0024843246290014204, 'td_error': 1.0099606888588577, 'value_scale': 3.025052481606954, 'environment': 36.9} step=22626
2022-05-10 12:52.48 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_22626.pt


HBox(children=(FloatProgress(value=0.0, description='Epoch 10/10', max=2514.0, style=ProgressStyle(description…


2022-05-10 12:52.56 [info     ] DQN_20220510125144: epoch=10 step=25140 epoch=10 metrics={'time_sample_batch': 9.496636872454684e-05, 'time_algorithm_update': 0.002296350072467716, 'loss': 0.02487943174007617, 'time_step': 0.002466608155978134, 'td_error': 1.0427339407797607, 'value_scale': 3.983210044635811, 'environment': 200.0} step=25140
2022-05-10 12:52.56 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20220510125144/model_25140.pt


[(1,
  {'time_sample_batch': 0.00011620724495195071,
   'time_algorithm_update': 0.0022142446505805284,
   'loss': 0.011194559461501734,
   'time_step': 0.0024188707627848015,
   'td_error': 0.9844668525807032,
   'value_scale': 1.0447659955444977,
   'environment': 11.2}),
 (2,
  {'time_sample_batch': 0.00012481535059550309,
   'time_algorithm_update': 0.0021508126953052164,
   'loss': 3.73246886335691e-05,
   'time_step': 0.002360104186817494,
   'td_error': 0.9862890182657684,
   'value_scale': 1.0468920218224318,
   'environment': 9.8}),
 (3,
  {'time_sample_batch': 0.00011057036209409922,
   'time_algorithm_update': 0.002106637165673498,
   'loss': 3.580718892105383e-05,
   'time_step': 0.0022986553544171195,
   'td_error': 0.9853889863979285,
   'value_scale': 1.0489737394936889,
   'environment': 9.3}),
 (4,
  {'time_sample_batch': 0.00010690493724036994,
   'time_algorithm_update': 0.002202211340552962,
   'loss': 0.007335323230942474,
   'time_step': 0.0023907648729917636,
   

Make decisions

In [8]:
observation = env.reset()

In [19]:
action = dqn.predict([observation])[0]
action

1

In [14]:
value = dqn.predict_value([observation], [action])[0]
value

4.058762

In [10]:
dqn.save_policy('policy.pt')

# Policy Application

In [20]:
for i_episode in range(1):
    observation = env.reset()
    for t in range(10000):
        print(observation)
        
        # Randomly samples
        # action = env.action_space.sample()

        # Use policy
        action = dqn.predict([observation])[0]

        observation, reward, done, info = env.step(action)
        if done:
            print("Episode finished after {} timesteps".format(t+1))
            break
env.close()

[-0.03849623  0.02210191 -0.03195776 -0.04903256]
[-0.0380542   0.21766718 -0.03293841 -0.35162467]
[-0.03370085  0.02302875 -0.03997091 -0.06950752]
[-0.03324028  0.21870028 -0.04136106 -0.3745287 ]
[-0.02886627  0.02418951 -0.04885163 -0.09516876]
[-0.02838248 -0.17019947 -0.05075501  0.18171018]
[-0.03178647 -0.3645598  -0.0471208   0.45795967]
[-0.03907767 -0.55898504 -0.03796161  0.73542543]
[-0.05025737 -0.75356261 -0.0232531   1.01592348]
[-0.06532862 -0.55813844 -0.00293463  0.71603067]
[-0.07649139 -0.75321965  0.01138598  1.00778844]
[-0.09155578 -0.55825154  0.03154175  0.71870265]
[-0.10272081 -0.36357992  0.0459158   0.43611219]
[-0.10999241 -0.169137    0.05463805  0.15824972]
[-0.11337515  0.02516186  0.05780304 -0.11670805]
[-0.11287191  0.21941     0.05546888 -0.390609  ]
[-0.10848371  0.41370264  0.0476567  -0.66530023]
[-0.10020966  0.60813044  0.0343507  -0.94260505]
[-0.08804705  0.80277311  0.01549859 -1.2242997 ]
[-0.07199159  0.60745502 -0.0089874  -0.92680135]
