In [1]:
from d3rlpy.datasets import get_cartpole

In [2]:
dataset, env = get_cartpole()

In [3]:
from sklearn.model_selection import train_test_split

train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

In [4]:
from d3rlpy.algos import DQN

# if you don't use GPU, set use_gpu=False instead.
dqn = DQN(use_gpu=False)

# initialize neural networks with the given observation shape and action size.
# this is not necessary when you directly call fit or fit_online method.
dqn.build_with_dataset(dataset)

In [5]:
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer

# calculate metrics with test dataset
td_error = td_error_scorer(dqn, test_episodes)

In [6]:
from d3rlpy.metrics.scorer import evaluate_on_environment

# set environment in scorer function
evaluate_scorer = evaluate_on_environment(env)

# evaluate algorithm on the environment
rewards = evaluate_scorer(dqn)

In [7]:
env

<TimeLimit<CartPoleEnv<CartPole-v0>>>

In [8]:
dqn.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=10,
        scorers={
            'td_error': td_error_scorer,
            'value_scale': average_value_estimation_scorer,
            'environment': evaluate_scorer
        })

2023-03-10 13:16:18 [debug    ] RoundIterator is selected.
2023-03-10 13:16:18 [info     ] Directory is created at d3rlpy_logs/DQN_20230310131618
2023-03-10 13:16:18 [info     ] Parameters are saved to d3rlpy_logs/DQN_20230310131618/params.json params={'action_scaler': None, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 8000, 'use_gpu': None, 'algorithm': 'DQN', 'observation_shape': (4,), 'action_size': 2}


Epoch 1/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:16:28 [info     ] DQN_20230310131618: epoch=1 step=2457 epoch=1 metrics={'time_sample_batch': 0.00013033626620052402, 'time_algorithm_update': 0.003288626913428549, 'loss': 0.011455701024676515, 'time_step': 0.003508702302590395, 'td_error': 1.0296351130590442, 'value_scale': 1.2064248184526873, 'environment': 9.8} step=2457
2023-03-10 13:16:28 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_2457.pt


Epoch 2/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:16:38 [info     ] DQN_20230310131618: epoch=2 step=4914 epoch=2 metrics={'time_sample_batch': 0.00011680200908258769, 'time_algorithm_update': 0.003217718685648526, 'loss': 0.0004052328704691272, 'time_step': 0.0034202942287334835, 'td_error': 1.0257553684281568, 'value_scale': 1.195757424426908, 'environment': 10.4} step=4914
2023-03-10 13:16:38 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_4914.pt


Epoch 3/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:16:48 [info     ] DQN_20230310131618: epoch=3 step=7371 epoch=3 metrics={'time_sample_batch': 0.00011697599545309201, 'time_algorithm_update': 0.0032616272950783754, 'loss': 0.0003899326607791523, 'time_step': 0.0034639358908653646, 'td_error': 1.014149315410479, 'value_scale': 1.1918088400457698, 'environment': 10.3} step=7371
2023-03-10 13:16:48 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_7371.pt


Epoch 4/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:16:57 [info     ] DQN_20230310131618: epoch=4 step=9828 epoch=4 metrics={'time_sample_batch': 0.00011550928937937605, 'time_algorithm_update': 0.003230897983400425, 'loss': 0.008893519954619985, 'time_step': 0.003434514339541729, 'td_error': 1.0371447571413561, 'value_scale': 2.200204801626109, 'environment': 11.5} step=9828
2023-03-10 13:16:57 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_9828.pt


Epoch 5/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:17:07 [info     ] DQN_20230310131618: epoch=5 step=12285 epoch=5 metrics={'time_sample_batch': 0.00011579574100554936, 'time_algorithm_update': 0.003362585191179399, 'loss': 0.009158847931193047, 'time_step': 0.0035680884212607234, 'td_error': 1.0039447103984074, 'value_scale': 2.2080282534753954, 'environment': 13.0} step=12285
2023-03-10 13:17:07 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_12285.pt


Epoch 6/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:17:17 [info     ] DQN_20230310131618: epoch=6 step=14742 epoch=6 metrics={'time_sample_batch': 0.00011215046939686833, 'time_algorithm_update': 0.0032328870367314888, 'loss': 0.008866018023772635, 'time_step': 0.003432592823585942, 'td_error': 1.0100685505583245, 'value_scale': 2.222634844396501, 'environment': 12.1} step=14742
2023-03-10 13:17:17 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_14742.pt


Epoch 7/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:17:26 [info     ] DQN_20230310131618: epoch=7 step=17199 epoch=7 metrics={'time_sample_batch': 0.00010873371864849831, 'time_algorithm_update': 0.003118905586394829, 'loss': 0.0164472367380149, 'time_step': 0.0033100684774955403, 'td_error': 1.0175753088570716, 'value_scale': 3.1998974605822785, 'environment': 36.7} step=17199
2023-03-10 13:17:26 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_17199.pt


Epoch 8/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:17:36 [info     ] DQN_20230310131618: epoch=8 step=19656 epoch=8 metrics={'time_sample_batch': 0.00010406587683771217, 'time_algorithm_update': 0.0030436170358312388, 'loss': 0.021311579033812297, 'time_step': 0.003224560629317175, 'td_error': 1.0198858178035943, 'value_scale': 3.1757803345255327, 'environment': 175.4} step=19656
2023-03-10 13:17:36 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_19656.pt


Epoch 9/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:17:45 [info     ] DQN_20230310131618: epoch=9 step=22113 epoch=9 metrics={'time_sample_batch': 0.00010369810865912245, 'time_algorithm_update': 0.0030478935296039634, 'loss': 0.020873517754649407, 'time_step': 0.0032280461789147854, 'td_error': 1.0218380605179647, 'value_scale': 3.2124304441962352, 'environment': 49.7} step=22113
2023-03-10 13:17:45 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_22113.pt


Epoch 10/10:   0%|          | 0/2457 [00:00<?, ?it/s]

2023-03-10 13:17:55 [info     ] DQN_20230310131618: epoch=10 step=24570 epoch=10 metrics={'time_sample_batch': 0.00010482829229396003, 'time_algorithm_update': 0.003110367542464024, 'loss': 0.024985878717610165, 'time_step': 0.0032936076007822833, 'td_error': 1.0538374609166938, 'value_scale': 4.1673449580561615, 'environment': 198.3} step=24570
2023-03-10 13:17:55 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230310131618/model_24570.pt


[(1,
  {'time_sample_batch': 0.00013033626620052402,
   'time_algorithm_update': 0.003288626913428549,
   'loss': 0.011455701024676515,
   'time_step': 0.003508702302590395,
   'td_error': 1.0296351130590442,
   'value_scale': 1.2064248184526873,
   'environment': 9.8}),
 (2,
  {'time_sample_batch': 0.00011680200908258769,
   'time_algorithm_update': 0.003217718685648526,
   'loss': 0.0004052328704691272,
   'time_step': 0.0034202942287334835,
   'td_error': 1.0257553684281568,
   'value_scale': 1.195757424426908,
   'environment': 10.4}),
 (3,
  {'time_sample_batch': 0.00011697599545309201,
   'time_algorithm_update': 0.0032616272950783754,
   'loss': 0.0003899326607791523,
   'time_step': 0.0034639358908653646,
   'td_error': 1.014149315410479,
   'value_scale': 1.1918088400457698,
   'environment': 10.3}),
 (4,
  {'time_sample_batch': 0.00011550928937937605,
   'time_algorithm_update': 0.003230897983400425,
   'loss': 0.008893519954619985,
   'time_step': 0.003434514339541729,
   't

In [11]:
observation = env.reset()

# return actions based on the greedy-policy
action = dqn.predict([observation])[0]

# estimate action-values
value = dqn.predict_value([observation], [action])

In [12]:
value

array([4.2053056], dtype=float32)

In [13]:
dqn.predict([observation])

array([1])

### Generate offline data

In [14]:
import d3rlpy

# setup algorithm
random_policy = d3rlpy.algos.DiscreteRandomPolicy()

# prepare experience replay buffer
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=10000, env=env)

# start data collection
random_policy.collect(env, buffer, n_steps=10000)

# export as MDPDataset
dataset = buffer.to_mdp_dataset()

# save MDPDataset
dataset.dump("random_policy_dataset.h5")

2023-03-10 13:28:43 [debug    ] Building model...
2023-03-10 13:28:43 [debug    ] Model has been built.


  0%|          | 0/100000 [00:00<?, ?it/s]

In [15]:
ls

Intro_d3rlpy.ipynb        [34md3rlpy_logs[m[m/              pyproject.toml
[34md3rlpy_data[m[m/              poetry.lock               random_policy_dataset.h5


In [16]:
# start data collection
dqn.collect(env, buffer, n_steps=100000)

# export as MDPDataset
dataset = buffer.to_mdp_dataset()



  0%|          | 0/100000 [00:00<?, ?it/s]

In [17]:
# save MDPDataset
dataset.dump("trained_policy_dataset.h5")

In [41]:
import numpy as np

dataset2 = d3rlpy.online.buffers.ReplayBuffer(maxlen=100000, env=env)
dataset2 = buffer.to_mdp_dataset()
dataset2.load("trained_policy_dataset.h5")

<d3rlpy.dataset.MDPDataset at 0x14fb21c40>

In [42]:
dataset2

<d3rlpy.dataset.MDPDataset at 0x14f2ca0a0>

In [43]:
dataset2.actions

array([1, 1, 1, ..., 1, 1, 1], dtype=int32)

In [44]:
dataset2.rewards

array([1., 1., 1., ..., 1., 1., 1.], dtype=float32)

In [45]:
dataset2.observations

array([[ 0.0028865 ,  0.02860445,  0.03013965,  0.03935649],
       [ 0.00345859,  0.22328153,  0.03092678, -0.24366678],
       [ 0.00792422,  0.4179484 ,  0.02605345, -0.52643645],
       ...,
       [-0.04982188,  0.18751903,  0.0407802 , -0.24097829],
       [-0.0460715 ,  0.38203543,  0.03596063, -0.52052426],
       [-0.0384308 ,  0.5766332 ,  0.02555015, -0.801662  ]],
      dtype=float32)

### Train on offline data

In [46]:
from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import DiscreteCQL

# setup CQL algorithm (discrete version)
cql = DiscreteCQL(n_frames=4, scaler='pixel', use_gpu=False)

# split train and test episodes
train_episodes, test_episodes = train_test_split(dataset2, test_size=0.2)

In [47]:
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=10,
        scorers={
            'environment': evaluate_on_environment(env), # Cartpole environment
            'advantage': discounted_sum_of_advantage_scorer, # smaller is better
            'td_error': td_error_scorer, # smaller is better
            'value_scale': average_value_estimation_scorer # smaller is better
        })

2023-03-10 13:58:19 [debug    ] RoundIterator is selected.
2023-03-10 13:58:19 [info     ] Directory is created at d3rlpy_logs/DiscreteCQL_20230310135819
2023-03-10 13:58:19 [debug    ] Fitting scaler...              scaler=pixel
2023-03-10 13:58:19 [debug    ] Building models...
2023-03-10 13:58:19 [debug    ] Models have been built.
2023-03-10 13:58:19 [info     ] Parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/params.json params={'action_scaler': None, 'alpha': 1.0, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 4, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': {'type': 'pixel', '

Epoch 1/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 13:58:33 [info     ] DiscreteCQL_20230310135819: epoch=1 step=2475 epoch=1 metrics={'time_sample_batch': 0.00011373096042209202, 'time_algorithm_update': 0.0048319306999746, 'loss': 0.7044372973056755, 'time_step': 0.005021619122437757, 'environment': 9.7, 'advantage': -0.7184256813428528, 'td_error': 1.003873027455517, 'value_scale': 1.0141254830004633} step=2475
2023-03-10 13:58:33 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_2475.pt


Epoch 2/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 13:58:47 [info     ] DiscreteCQL_20230310135819: epoch=2 step=4950 epoch=2 metrics={'time_sample_batch': 0.00010732737454501066, 'time_algorithm_update': 0.004808726936879784, 'loss': 0.6926807059904542, 'time_step': 0.004982993386008523, 'environment': 9.6, 'advantage': -0.8237623272107137, 'td_error': 1.0083632118101098, 'value_scale': 1.016894909017989} step=4950
2023-03-10 13:58:47 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_4950.pt


Epoch 3/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 13:58:59 [info     ] DiscreteCQL_20230310135819: epoch=3 step=7425 epoch=3 metrics={'time_sample_batch': 0.0001019903626104798, 'time_algorithm_update': 0.004526491646814828, 'loss': 0.692025963369042, 'time_step': 0.0046929629162104445, 'environment': 9.2, 'advantage': -1.4125308991548584, 'td_error': 1.0304378637299039, 'value_scale': 1.0262357572453116} step=7425
2023-03-10 13:58:59 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_7425.pt


Epoch 4/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 13:59:13 [info     ] DiscreteCQL_20230310135819: epoch=4 step=9900 epoch=4 metrics={'time_sample_batch': 0.00010997589188392716, 'time_algorithm_update': 0.004905999308884746, 'loss': 0.6936927360236043, 'time_step': 0.005086273906206844, 'environment': 22.6, 'advantage': -0.3006639996273257, 'td_error': 0.9743382934822494, 'value_scale': 2.0003800274591748} step=9900
2023-03-10 13:59:13 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_9900.pt


Epoch 5/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 13:59:27 [info     ] DiscreteCQL_20230310135819: epoch=5 step=12375 epoch=5 metrics={'time_sample_batch': 0.0001084217880711411, 'time_algorithm_update': 0.004895942187068438, 'loss': 0.6897469291542515, 'time_step': 0.00507350459243312, 'environment': 11.9, 'advantage': -1.7485878148839962, 'td_error': 1.0283506860403937, 'value_scale': 2.0208384595941564} step=12375
2023-03-10 13:59:27 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_12375.pt


Epoch 6/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 13:59:40 [info     ] DiscreteCQL_20230310135819: epoch=6 step=14850 epoch=6 metrics={'time_sample_batch': 0.00010091974277688999, 'time_algorithm_update': 0.004574159757055418, 'loss': 0.686699037840872, 'time_step': 0.004739959119546293, 'environment': 25.9, 'advantage': -0.8160245403421296, 'td_error': 1.0020592841492546, 'value_scale': 2.0106483820014636} step=14850
2023-03-10 13:59:40 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_14850.pt


Epoch 7/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 13:59:52 [info     ] DiscreteCQL_20230310135819: epoch=7 step=17325 epoch=7 metrics={'time_sample_batch': 0.00010110190420439749, 'time_algorithm_update': 0.004474106316614632, 'loss': 0.6857365346195722, 'time_step': 0.004640489154391818, 'environment': 17.5, 'advantage': -2.433293339287986, 'td_error': 1.0552107648111715, 'value_scale': 3.0459035404557544} step=17325
2023-03-10 13:59:52 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_17325.pt


Epoch 8/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 14:00:06 [info     ] DiscreteCQL_20230310135819: epoch=8 step=19800 epoch=8 metrics={'time_sample_batch': 0.00010647542548902106, 'time_algorithm_update': 0.004866572293368253, 'loss': 0.6805609355791651, 'time_step': 0.005041583186448223, 'environment': 58.7, 'advantage': -2.9189308941082848, 'td_error': 1.061670664012514, 'value_scale': 3.046021322604366} step=19800
2023-03-10 14:00:06 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_19800.pt


Epoch 9/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 14:00:20 [info     ] DiscreteCQL_20230310135819: epoch=9 step=22275 epoch=9 metrics={'time_sample_batch': 0.00010994246511748343, 'time_algorithm_update': 0.0049418311648898655, 'loss': 0.6768818291991648, 'time_step': 0.005123888940522165, 'environment': 25.4, 'advantage': -2.5767650129123534, 'td_error': 1.087821161304596, 'value_scale': 3.061811562537585} step=22275
2023-03-10 14:00:20 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_22275.pt


Epoch 10/10:   0%|          | 0/2475 [00:00<?, ?it/s]

2023-03-10 14:00:33 [info     ] DiscreteCQL_20230310135819: epoch=10 step=24750 epoch=10 metrics={'time_sample_batch': 0.00010402534947250829, 'time_algorithm_update': 0.004621366539386788, 'loss': 0.678041718006134, 'time_step': 0.004792209490381106, 'environment': 105.7, 'advantage': -2.5476583313030563, 'td_error': 1.066795705757112, 'value_scale': 4.074616679151892} step=24750
2023-03-10 14:00:33 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230310135819/model_24750.pt


[(1,
  {'time_sample_batch': 0.00011373096042209202,
   'time_algorithm_update': 0.0048319306999746,
   'loss': 0.7044372973056755,
   'time_step': 0.005021619122437757,
   'environment': 9.7,
   'advantage': -0.7184256813428528,
   'td_error': 1.003873027455517,
   'value_scale': 1.0141254830004633}),
 (2,
  {'time_sample_batch': 0.00010732737454501066,
   'time_algorithm_update': 0.004808726936879784,
   'loss': 0.6926807059904542,
   'time_step': 0.004982993386008523,
   'environment': 9.6,
   'advantage': -0.8237623272107137,
   'td_error': 1.0083632118101098,
   'value_scale': 1.016894909017989}),
 (3,
  {'time_sample_batch': 0.0001019903626104798,
   'time_algorithm_update': 0.004526491646814828,
   'loss': 0.692025963369042,
   'time_step': 0.0046929629162104445,
   'environment': 9.2,
   'advantage': -1.4125308991548584,
   'td_error': 1.0304378637299039,
   'value_scale': 1.0262357572453116}),
 (4,
  {'time_sample_batch': 0.00010997589188392716,
   'time_algorithm_update': 0.0

### Custom environment

In [104]:
import gym
from gym.spaces import Discrete, Box
import numpy as np
import os
import random

class SimpleCorridor_d3rlpy(gym.Env):
    """Example of a custom env in which you have to walk down a corridor.
    Get a reward of -0.1 if you are not at the end, a random reward that is positive if you do.
    Move +1 if you move forward, -1 if you move backward. The total length is 5.
    We should want to reach the end in 5 steps in the perfectly trained world. 
    You can configure the length of the corridor via the env config."""

    def __init__(self, config):
        self.end_pos = config["corridor_length"]
        self.cur_pos = 0
        self.action_space = Discrete(5)
        self.observation_space = Box(0.0, self.end_pos, shape=(1,), dtype=np.float32)
        # Set the seed. This is only used for the final (reach goal) reward.
        self.reset()

    def reset(self, *, seed=None, options=None):
        random.seed(seed)
        self.cur_pos = 0
        return np.array([self.cur_pos])

    def step(self, action):
        assert action in [0, 1, 2, 3, 4], action
        # backward step
        if action == 0 and self.cur_pos > 0:
            self.cur_pos -= 1
            reward = -0.2
        # forward step
        elif action == 1:
            self.cur_pos += 1
            if(self.check_if_water()):
                reward = -0.5
            else:
                reward = 0.1
        # double speed
        elif action == 4:
            if(self.cur_pos <= 3):
                self.cur_pos += 2
            else:
                self.cur_pos += 1
            if(self.check_if_water()):
                reward = -0.5
            else:
                reward = 0.2
        # left or right
        else:
            self.cur_pos = self.cur_pos
            reward = -0.05
            
        done = truncated = self.cur_pos >= self.end_pos
        if(done):
            reward = 2
        # Produce a random reward when we reach the goal.
        return (
            np.array([self.cur_pos]),
            reward, # Setting to 2 instead of random reward has no real impact
            done,
            {},
        )
    
    def check_if_water(self):
        # This is water and will get a negative reward
        if(self.cur_pos == 3):
            return True
        else:
            return False

config={"corridor_length": 5}
env = SimpleCorridor_d3rlpy(config=config)

In [105]:
env

<__main__.SimpleCorridor_d3rlpy at 0x136a18c40>

In [106]:
env.reset()

array([0])

In [108]:
action = env.action_space.sample()
obs, reward, done, info = env.step(action)

In [109]:
reward

0.2

In [102]:
obs.astype

<function ndarray.astype>

In [138]:
# setup algorithm
random_policy = d3rlpy.algos.DiscreteRandomPolicy()

# prepare experience replay buffer
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=100000, env=env)

# start data collection
random_policy.collect(env, buffer, n_steps=100000)

# export as MDPDataset
dataset = buffer.to_mdp_dataset()

# save MDPDataset
dataset.dump("random_policy_corridor_dataset.h5")

2023-03-10 15:47:33 [debug    ] Building model...
2023-03-10 15:47:33 [debug    ] Model has been built.


  0%|          | 0/100000 [00:00<?, ?it/s]

### Fit online with custom environment

In [113]:
# setup algorithm
dqn = d3rlpy.algos.DQN()

# prepare experience replay buffer
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=100000, env=env)

# prepare exploration strategy if necessary
explorer = d3rlpy.online.explorers.ConstantEpsilonGreedy(0.3)

# start data collection
dqn.fit_online(env, buffer, n_steps=100000)

2023-03-10 14:47:32 [info     ] Directory is created at d3rlpy_logs/DQN_online_20230310144732
2023-03-10 14:47:32 [debug    ] Building model...
2023-03-10 14:47:32 [debug    ] Model has been built.
2023-03-10 14:47:32 [info     ] Parameters are saved to d3rlpy_logs/DQN_online_20230310144732/params.json params={'action_scaler': None, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 8000, 'use_gpu': None, 'algorithm': 'DQN', 'observation_shape': (1,), 'action_size': 5}


  0%|          | 0/100000 [00:00<?, ?it/s]

2023-03-10 14:48:12 [info     ] Model parameters are saved to d3rlpy_logs/DQN_online_20230310144732/model_10000.pt
2023-03-10 14:48:12 [info     ] DQN_online_20230310144732: epoch=1 step=10000 epoch=1 metrics={'time_inference': 0.0005066035032272339, 'time_environment_step': 1.2087106704711915e-05, 'time_step': 0.003947368121147155, 'time_sample_batch': 0.0001312981424975613, 'time_algorithm_update': 0.00323518305445813, 'loss': 1.49070639037667e-06} step=10000
2023-03-10 14:48:50 [info     ] Model parameters are saved to d3rlpy_logs/DQN_online_20230310144732/model_20000.pt
2023-03-10 14:48:50 [info     ] DQN_online_20230310144732: epoch=2 step=20000 epoch=2 metrics={'time_inference': 0.0004898247003555298, 'time_environment_step': 1.1702489852905273e-05, 'time_sample_batch': 0.00012943384647369385, 'time_algorithm_update': 0.003080199909210205, 'loss': 7.911302969334136e-07, 'time_step': 0.0037821149587631228} step=20000
2023-03-10 14:49:30 [info     ] Model parameters are saved to d3

In [146]:
for i in range(5):
    obs = np.array([i])
    action = dqn.predict([obs])[0]
    print(action)

### Fit on offline data generated from custom environment

In [147]:
#initialize the algorithm
dqn = d3rlpy.algos.DQN()

dataset2 = d3rlpy.online.buffers.ReplayBuffer(maxlen=100000, env=env)
dataset2 = buffer.to_mdp_dataset()
dataset2.load("random_policy_corridor_dataset.h5")

train_episodes, test_episodes = train_test_split(dataset2, test_size=0.2)

In [None]:
# start training
dqn.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=5,
        scorers={
            'environment': evaluate_on_environment(env), # Cartpole environment
            'advantage': discounted_sum_of_advantage_scorer, # smaller is better
            'td_error': td_error_scorer, # smaller is better
            'value_scale': average_value_estimation_scorer # smaller is better
        })

In [148]:
dqn.fit(
    train_episodes,
    n_epochs=10,
    #scorers=metrics,
    eval_episodes=test_episodes,
)