In [4]:
def get_episode_actions(dataset):
    for episode in dataset.episodes:
        print(episode.actions)
        
def evaluate_mean_episode_return(model, env):
    evaluate_scorer = evaluate_on_environment(env, n_trials=1)
    mean_episode_return = evaluate_scorer(model)
    return(mean_episode_return)

def predict_actions_for_state(model):
    for i in range(5):
        obs = np.array([i])
        action = model.predict([obs])[0]
        print(action)

def save_DQN_model(model_name='Online_trained_corridor_6000.pt'):
    dqn.save_model(model_name)

def load_DQN_model(dataset, model_name='Online_trained_corridor_6000.pt'):
    dqn_loaded = DQN()
    dqn_loaded.build_with_dataset(dataset)
    dqn_loaded.load_model(model_name)
    return(dqn_loaded)

In [5]:
import gym
from gym.spaces import Discrete, Box
import numpy as np
import os
import random

class SimpleCorridor_d3rlpy(gym.Env):
    """Example of a custom env in which you have to walk down a corridor.
    Get a reward of -0.1 if you are not at the end, a random reward that is positive if you do.
    Move +1 if you move forward, -1 if you move backward. The total length is 5.
    We should want to reach the end in 5 steps in the perfectly trained world. 
    You can configure the length of the corridor via the env config."""

    def __init__(self, config):
        self.end_pos = config["corridor_length"]
        self.cur_pos = 0
        self.action_space = Discrete(5)
        self.observation_space = Box(0.0, self.end_pos, shape=(1,), dtype=np.float32)
        # Set the seed. This is only used for the final (reach goal) reward.
        self.reset()

    def reset(self, *, seed=None, options=None):
        random.seed(seed)
        self.cur_pos = 0
        return np.array([self.cur_pos])

    def step(self, action):
        assert action in [0, 1, 2, 3, 4], action
        # backward step
        if action == 0 and self.cur_pos > 0:
            self.cur_pos -= 1
            reward = -0.2
        # forward step
        elif action == 1:
            self.cur_pos += 1
            if(self.check_if_water()):
                reward = -0.5
            else:
                reward = 0.1
        # double speed
        elif action == 4:
            if(self.cur_pos <= 3):
                self.cur_pos += 2
            else:
                self.cur_pos += 1
            if(self.check_if_water()):
                reward = -0.5
            else:
                reward = 0.2
        # left or right
        else:
            self.cur_pos = self.cur_pos
            reward = -0.05
            
        done = truncated = self.cur_pos >= self.end_pos
        if(done):
            reward = 2
        # Produce a random reward when we reach the goal.
        return (
            np.array([self.cur_pos]),
            reward, # Setting to 2 instead of random reward has no real impact
            done,
            {},
        )
    
    def check_if_water(self):
        # This is water and will get a negative reward
        if(self.cur_pos == 3):
            return True
        else:
            return False

### Online learning and collection of data while learning

In [6]:
# This setup works and is used for online collection of data
from gym.wrappers import TimeLimit
from d3rlpy.algos import DQN
from d3rlpy.online.buffers import ReplayBuffer
from d3rlpy.online.explorers import LinearDecayEpsilonGreedy, ConstantEpsilonGreedy

config={"corridor_length": 5}
env_corridor = SimpleCorridor_d3rlpy(config=config)

env = TimeLimit(env_corridor, max_episode_steps=2000)

#env = gym.make('CartPole-v0')
#eval_env = gym.make('CartPole-v0')

# setup algorithm
dqn = DQN(batch_size=128,
          learning_rate=2.5e-5,
          target_update_interval=100,
          use_gpu=False)

# setup replay buffer
buffer = ReplayBuffer(maxlen=100000, env=env)

# setup explorers
explorer = LinearDecayEpsilonGreedy(start_epsilon=1.0,
                                    end_epsilon=0.1,
                                    duration=10000)

# start training
dqn.fit_online(env,
               buffer,
               explorer=explorer, # you don't need this with probablistic policy algorithms
               eval_env=env,
               n_steps=6000, # the number of total steps to train.
               n_steps_per_epoch=1000,
               update_interval=10) # update parameters every 10 steps.

# export online dataset as MDPDataset
dataset_online = buffer.to_mdp_dataset()

# save MDPDataset
dataset_online.dump("trained_corridor_policy_dataset_10000_online.h5")

2023-03-15 13:40:20 [info     ] Directory is created at d3rlpy_logs/DQN_online_20230315134020
2023-03-15 13:40:20 [debug    ] Building model...
2023-03-15 13:40:20 [debug    ] Model has been built.
2023-03-15 13:40:20 [info     ] Parameters are saved to d3rlpy_logs/DQN_online_20230315134020/params.json params={'action_scaler': None, 'batch_size': 128, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 2.5e-05, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 100, 'use_gpu': None, 'algorithm': 'DQN', 'observation_shape': (1,), 'action_size': 5}


  0%|          | 0/6000 [00:00<?, ?it/s]

2023-03-15 13:40:21 [info     ] Model parameters are saved to d3rlpy_logs/DQN_online_20230315134020/model_1000.pt
2023-03-15 13:40:21 [info     ] DQN_online_20230315134020: epoch=1 step=1000 epoch=1 metrics={'time_inference': 0.0004335432052612305, 'time_environment_step': 8.30221176147461e-06, 'time_step': 0.0008761329650878906, 'rollout_return': 1.4559139784946238, 'time_sample_batch': 0.00016294555230574176, 'time_algorithm_update': 0.004333558407696811, 'loss': 0.10008559016172182, 'evaluation': 1.6} step=1000
2023-03-15 13:40:22 [info     ] Model parameters are saved to d3rlpy_logs/DQN_online_20230315134020/model_2000.pt
2023-03-15 13:40:22 [info     ] DQN_online_20230315134020: epoch=2 step=2000 epoch=2 metrics={'time_inference': 0.0004269955158233643, 'time_environment_step': 8.482217788696289e-06, 'time_step': 0.0009129078388214111, 'rollout_return': 1.6578125000000004, 'time_sample_batch': 0.00018132448196411132, 'time_algorithm_update': 0.00417576789855957, 'loss': 0.14245763

In [7]:
get_episode_actions(dataset_online)

[4 2 4 4]
[3 4 1 1 3 1]
[3 1 3 1 2 4 2 2 2 3 3 2 0 0 0 1 1 0 2 2 2 0 3 2 4 3 4]
[1 4 3 4]
[1 2 2 3 3 0 1 4 4]
[3 1 3 4 1 4]
[0 3 3 3 0 1 3 0 1 1 4 0 0 0 2 3 2 4 3 4]
[1 1 1 2 0 1 4]
[1 2 0 2 4 4 1]
[2 2 4 0 0 0 0 0 2 3 3 1 2 0 2 2 2 2 1 0 0 3 1 3 4 1 3 2 0 4]
[1 2 1 4 1]
[3 2 0 1 1 2 0 1 3 3 4 2 0 3 2 4]
[3 2 4 2 3 3 2 3 0 3 1 1 4]
[0 0 3 4 0 1 2 4 3 2 0 4]
[1 1 2 0 0 0 3 4 0 1 1 2 0 1 4]
[2 2 2 1 4 3 1 4]
[1 1 3 4 1]
[0 3 1 0 2 4 3 1 3 2 3 3 2 4]
[2 2 1 4 4]
[1 3 3 3 4 2 0 3 0 1 0 3 2 0 4 0 3 4 2 3 1 2 3 0 2 3 0 0 4 0 1 3 0 4 3 4]
[4 0 1 2 2 2 3 3 1 3 3 2 3 2 3 2 2 1 3 0 0 2 3 2 2 0 4 2 1 0 3 0 0 2 0 3 1
 1 0 3 0 4 3 4 3 0 2 2 1 0 2 3 0 2 2 4 1]
[2 3 0 1 0 4 3 2 4 4]
[4 2 4 1]
[2 1 3 0 0 1 4 4]
[1 2 4 4]
[2 1 1 4 0 1 4]
[3 4 1 4]
[2 1 4 2 1 3 4]
[1 0 2 0 2 3 2 2 2 3 3 1 1 4 2 2 3 3 0 3 1 1]
[2 2 0 3 3 2 2 4 3 1 2 3 1 0 0 4 2 1]
[4 2 3 4 0 3 4]
[2 1 0 2 0 3 1 1 0 1 2 0 4 0 4 2 0 0 1 3 0 4 2 0 4]
[0 4 3 1 4]
[1 1 3 0 1 2 0 4 1 2 0 0 3 0 0 4 4 2 1]
[1 0 1 0 2 3 2 0 3 1 3 4 2 1 2 4]
[0 1 

In [10]:
predict_actions_for_state(dqn)

4
1
4
4
4


### Continue training

In [None]:
# This setup works and is used for online collection of data
# setup replay buffer
buffer = ReplayBuffer(maxlen=100000, env=env)

# setup explorers
explorer = LinearDecayEpsilonGreedy(start_epsilon=1.0,
                                    end_epsilon=0.1,
                                    duration=10000)

# start training
dqn.fit_online(env,
               buffer,
               explorer=explorer, # you don't need this with probablistic policy algorithms
               eval_env=env,
               n_steps=6000, # the number of total steps to train.
               n_steps_per_epoch=1000,
               update_interval=10) # update parameters every 10 steps.

# export online dataset as MDPDataset
dataset_online = buffer.to_mdp_dataset()

# save MDPDataset
dataset_online.dump("trained_corridor_policy_dataset_10000_continuation.h5")

### Collect data from a trained environment

In [11]:
import d3rlpy

# prepare experience replay buffer
buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=30000, env=env)

# start data collection
dqn.collect(env, buffer, explorer=explorer, deterministic=False, n_steps=10000)

# export as MDPDataset
dataset = buffer.to_mdp_dataset()

# save MDPDataset
dataset.dump("trained_corridor_policy_dataset_10000_collected.h5")



  0%|          | 0/10000 [00:00<?, ?it/s]

In [12]:
get_episode_actions(dataset)

[0 2 3 2 3 4 0 1 4 4]
[3 4 1 4]
[2 0 0 2 4 0 1 4 1]
[4 1 2 0 3 1 3 3 2 3 2 4]
[3 4 2 2 0 2 0 3 1 4 3 2 0 4 2 3 3 1]
[4 0 1 4 1]
[4 4 2 2 0 3 2 1 1]
[1 3 0 0 3 1 3 1 1 3 1 3 4]
[0 2 3 1 4 1 4]
[0 4 1 0 2 0 2 3 3 0 1 1 1 0 2 3 0 2 2 1 3 1 3 1 2 0 1 0 2 0 3 3 3 4 2 1]
[4 2 1 4]
[3 3 0 0 4 2 1 1 2 1]
[0 1 0 0 0 3 0 1 0 3 2 1 3 1 3 0 0 3 2 4 1 4]
[4 1 1 2 0 2 3 2 1 2 2 4]
[3 4 2 2 3 1 4]
[0 0 0 1 0 0 3 0 2 3 0 1 2 3 4 4]
[2 4 2 0 0 1 1 1 1 1]
[2 1 3 1 4 0 3 2 1 2 4]
[0 0 0 4 0 4 0 4 0 2 3 4]
[3 4 2 1 2 4]
[2 0 0 0 0 4 4 0 4]
[1 0 2 3 0 0 4 3 4 4]
[1 4 3 0 4 1]
[1 4 0 0 3 1 2 4 1]
[1 3 4 2 2 1 3 2 3 3 2 1]
[0 4 4 0 1 4]
[2 2 2 2 2 4 1 0 2 1 1 2 0 4]
[3 4 4 3 2 2 2 0 0 2 0 2 4 1 0 3 4]
[3 1 0 0 1 0 2 3 4 0 2 3 3 0 0 0 0 1 3 1 1 3 3 3 3 0 4 0 1 1]
[4 0 4 1 3 2 2 1]
[1 2 3 1 2 1 4]
[2 3 3 3 1 3 1 3 2 4 3 3 2 3 1]
[0 3 1 0 4 1 4]
[4 0 4 0 4 0 2 4]
[4 2 0 2 0 1 4 1 4]
[3 2 3 2 1 2 4 0 1 3 4]
[4 4 2 1]
[0 3 4 4 0 4]
[2 1 4 4]
[2 2 2 0 1 1 4 4]
[3 3 4 3 3 3 1 2 4]
[2 4 3 0 2 1 1 1 4]
[4 4 2 0 2 4]


### Load offline data and train

In [13]:
### Use trained_corridor_policy_dataset_10000_online.h5 which actualy has 6000 samples
### This was generated from a model that predicted the correct actions at each step
#dataset_trained = dataset.load("trained_corridor_policy_dataset_10000_online.h5")
from sklearn.model_selection import train_test_split

dataset_trained = dataset.load("trained_corridor_policy_dataset_10000_collected.h5")
train_episodes, test_episodes = train_test_split(dataset_trained, test_size=0.3)

In [17]:
from d3rlpy.algos import DQN
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer

# if you don't use GPU, set use_gpu=False instead.
dqn1 = DQN(batch_size=32,
          learning_rate=2.5e-4,
          target_update_interval=100,
          use_gpu=False)

# initialize neural networks with the given observation shape and action size.
# this is not necessary when you directly call fit or fit_online method.
dqn1.build_with_dataset(dataset_trained)

# calculate metrics with test dataset
td_error = td_error_scorer(dqn1, test_episodes)

# set environment in scorer function, returns mean return per episode
evaluate_scorer = evaluate_on_environment(env)

# evaluate algorithm on the environment
rewards = evaluate_scorer(dqn1)

# Use the output of learning, this is the mean reward and the most useful metric to measure learning
dqn1.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=10,
        scorers={
            'td_error': td_error_scorer,
            'advantage': discounted_sum_of_advantage_scorer,
            'value_scale': average_value_estimation_scorer,
            'environment': evaluate_scorer
        })

2023-03-15 13:45:54 [debug    ] RoundIterator is selected.
2023-03-15 13:45:54 [info     ] Directory is created at d3rlpy_logs/DQN_20230315134554
2023-03-15 13:45:54 [info     ] Parameters are saved to d3rlpy_logs/DQN_20230315134554/params.json params={'action_scaler': None, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 0.00025, 'n_critics': 1, 'n_frames': 1, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 100, 'use_gpu': None, 'algorithm': 'DQN', 'observation_shape': (1,), 'action_size': 5}


Epoch 1/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:03 [info     ] DQN_20230315134554: epoch=1 step=218 epoch=1 metrics={'time_sample_batch': 0.00010271684839091169, 'time_algorithm_update': 0.0027943313668627257, 'loss': 0.04616735788829061, 'time_step': 0.002965638396936819, 'td_error': 0.3232573659145911, 'advantage': -1.8548000507347877, 'value_scale': 2.4893798647661454, 'environment': -99.49999999999649} step=218
2023-03-15 13:46:03 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_218.pt


Epoch 2/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:12 [info     ] DQN_20230315134554: epoch=2 step=436 epoch=2 metrics={'time_sample_batch': 0.00010560958757313019, 'time_algorithm_update': 0.0028960573563881972, 'loss': 0.009923529355821985, 'time_step': 0.0030747170842021975, 'td_error': 0.007370942917948001, 'advantage': -0.34069460412132463, 'value_scale': 2.286115983439589, 'environment': -99.49999999999649} step=436
2023-03-15 13:46:12 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_436.pt


Epoch 3/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:14 [info     ] DQN_20230315134554: epoch=3 step=654 epoch=3 metrics={'time_sample_batch': 9.955725538621255e-05, 'time_algorithm_update': 0.002808462589158924, 'loss': 0.0022977970143616706, 'time_step': 0.0029764262908095612, 'td_error': 0.0037947744771046213, 'advantage': -0.344099419051331, 'value_scale': 2.2781946193221887, 'environment': 2.3999999999999995} step=654
2023-03-15 13:46:14 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_654.pt


Epoch 4/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:17 [info     ] DQN_20230315134554: epoch=4 step=872 epoch=4 metrics={'time_sample_batch': 9.886934122908006e-05, 'time_algorithm_update': 0.0029269697469308836, 'loss': 0.001359577327894598, 'time_step': 0.0030917016738051667, 'td_error': 0.0032065449443952725, 'advantage': -0.41444467465847895, 'value_scale': 2.2627561591632688, 'environment': 2.3999999999999995} step=872
2023-03-15 13:46:17 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_872.pt


Epoch 5/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:19 [info     ] DQN_20230315134554: epoch=5 step=1090 epoch=5 metrics={'time_sample_batch': 9.465655055614787e-05, 'time_algorithm_update': 0.002856420814444166, 'loss': 0.0010565040684681613, 'time_step': 0.0030117811412986267, 'td_error': 0.0049247810423745375, 'advantage': -0.4667467900519489, 'value_scale': 2.2817097056822284, 'environment': 2.3999999999999995} step=1090
2023-03-15 13:46:19 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_1090.pt


Epoch 6/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:22 [info     ] DQN_20230315134554: epoch=6 step=1308 epoch=6 metrics={'time_sample_batch': 9.603128520720596e-05, 'time_algorithm_update': 0.0028364220890430137, 'loss': 0.0008738680189541368, 'time_step': 0.002999278383517484, 'td_error': 0.006125795937536988, 'advantage': -0.41815833966817456, 'value_scale': 2.2392286623903574, 'environment': 2.3999999999999995} step=1308
2023-03-15 13:46:22 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_1308.pt


Epoch 7/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:25 [info     ] DQN_20230315134554: epoch=7 step=1526 epoch=7 metrics={'time_sample_batch': 0.00010264138562963643, 'time_algorithm_update': 0.002951959951208272, 'loss': 0.00048234424151597947, 'time_step': 0.0031232844798936756, 'td_error': 0.0021500588008478293, 'advantage': -0.3710322117498041, 'value_scale': 2.231515239739108, 'environment': 2.3999999999999995} step=1526
2023-03-15 13:46:25 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_1526.pt


Epoch 8/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:33 [info     ] DQN_20230315134554: epoch=8 step=1744 epoch=8 metrics={'time_sample_batch': 9.90377653629408e-05, 'time_algorithm_update': 0.0028923432761376057, 'loss': 0.0005771722624041279, 'time_step': 0.003059131289840838, 'td_error': 0.001669326893372673, 'advantage': -0.3372361128440092, 'value_scale': 2.2373080023847924, 'environment': -99.74999999999648} step=1744
2023-03-15 13:46:33 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_1744.pt


Epoch 9/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:36 [info     ] DQN_20230315134554: epoch=9 step=1962 epoch=9 metrics={'time_sample_batch': 9.855655355191012e-05, 'time_algorithm_update': 0.0029158187568734547, 'loss': 0.00022660944185173913, 'time_step': 0.003082862687767099, 'td_error': 0.0002133706158833212, 'advantage': -0.31592696386394836, 'value_scale': 2.21144193273737, 'environment': 2.3999999999999995} step=1962
2023-03-15 13:46:36 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_1962.pt


Epoch 10/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:46:39 [info     ] DQN_20230315134554: epoch=10 step=2180 epoch=10 metrics={'time_sample_batch': 0.00010371208190917969, 'time_algorithm_update': 0.0031040130405251038, 'loss': 0.0001034649562157988, 'time_step': 0.0032796509768984734, 'td_error': 0.00030494066015425273, 'advantage': -0.3019286309389186, 'value_scale': 2.207078338742614, 'environment': 2.3999999999999995} step=2180
2023-03-15 13:46:39 [info     ] Model parameters are saved to d3rlpy_logs/DQN_20230315134554/model_2180.pt


[(1,
  {'time_sample_batch': 0.00010271684839091169,
   'time_algorithm_update': 0.0027943313668627257,
   'loss': 0.04616735788829061,
   'time_step': 0.002965638396936819,
   'td_error': 0.3232573659145911,
   'advantage': -1.8548000507347877,
   'value_scale': 2.4893798647661454,
   'environment': -99.49999999999649}),
 (2,
  {'time_sample_batch': 0.00010560958757313019,
   'time_algorithm_update': 0.0028960573563881972,
   'loss': 0.009923529355821985,
   'time_step': 0.0030747170842021975,
   'td_error': 0.007370942917948001,
   'advantage': -0.34069460412132463,
   'value_scale': 2.286115983439589,
   'environment': -99.49999999999649}),
 (3,
  {'time_sample_batch': 9.955725538621255e-05,
   'time_algorithm_update': 0.002808462589158924,
   'loss': 0.0022977970143616706,
   'time_step': 0.0029764262908095612,
   'td_error': 0.0037947744771046213,
   'advantage': -0.344099419051331,
   'value_scale': 2.2781946193221887,
   'environment': 2.3999999999999995}),
 (4,
  {'time_sample_

In [20]:
evaluate_mean_episode_return(dqn1, env)

2.4

In [18]:
predict_actions_for_state(dqn1)

4
1
4
1
1


### Use CQL

In [19]:
from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import DiscreteCQL
from d3rlpy.datasets import get_cartpole
from sklearn.model_selection import train_test_split
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer

# setup CQL algorithm (discrete version)
cql = DiscreteCQL(n_frames=4, use_gpu=False)

# start training
cql.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=10,
        scorers={
            'environment': evaluate_on_environment(env), # Cartpole environment
            'advantage': discounted_sum_of_advantage_scorer, # smaller is better
            'td_error': td_error_scorer, # smaller is better
            'value_scale': average_value_estimation_scorer # smaller is better
        })

2023-03-15 13:56:44 [debug    ] RoundIterator is selected.
2023-03-15 13:56:44 [info     ] Directory is created at d3rlpy_logs/DiscreteCQL_20230315135644
2023-03-15 13:56:44 [debug    ] Building models...
2023-03-15 13:56:44 [debug    ] Models have been built.
2023-03-15 13:56:44 [info     ] Parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/params.json params={'action_scaler': None, 'alpha': 1.0, 'batch_size': 32, 'encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'gamma': 0.99, 'generated_maxlen': 100000, 'learning_rate': 6.25e-05, 'n_critics': 1, 'n_frames': 4, 'n_steps': 1, 'optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'real_ratio': 1.0, 'reward_scaler': None, 'scaler': None, 'target_update_interval': 8000, 'use_gpu': None, 'algorithm': 'DiscreteCQL', 'observation

Epoch 1/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:56:47 [info     ] DiscreteCQL_20230315135644: epoch=1 step=218 epoch=1 metrics={'time_sample_batch': 0.00010322868277173524, 'time_algorithm_update': 0.00437508373085512, 'loss': 1.544770058141936, 'time_step': 0.004544807136605639, 'environment': 2.3999999999999995, 'advantage': -2.5090578793870866, 'td_error': 0.856239937198643, 'value_scale': 0.880070686648392} step=218
2023-03-15 13:56:47 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_218.pt


Epoch 2/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:56:51 [info     ] DiscreteCQL_20230315135644: epoch=2 step=436 epoch=2 metrics={'time_sample_batch': 0.00010435624953803667, 'time_algorithm_update': 0.004549850017652599, 'loss': 1.4747750578670327, 'time_step': 0.004713061752669308, 'environment': 2.3999999999999995, 'advantage': -2.5003623050301313, 'td_error': 0.9662334391609376, 'value_scale': 0.8541801982137234} step=436
2023-03-15 13:56:51 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_436.pt


Epoch 3/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:56:54 [info     ] DiscreteCQL_20230315135644: epoch=3 step=654 epoch=3 metrics={'time_sample_batch': 9.817595875591314e-05, 'time_algorithm_update': 0.004445455489902321, 'loss': 1.4605473405724272, 'time_step': 0.004599369993997277, 'environment': 2.3999999999999995, 'advantage': -2.3171698987098166, 'td_error': 0.9785533055632676, 'value_scale': 0.8253534857001095} step=654
2023-03-15 13:56:54 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_654.pt


Epoch 4/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:56:57 [info     ] DiscreteCQL_20230315135644: epoch=4 step=872 epoch=4 metrics={'time_sample_batch': 9.516072929452319e-05, 'time_algorithm_update': 0.0042372438885750026, 'loss': 1.4572687313097332, 'time_step': 0.004385384944600797, 'environment': 2.3999999999999995, 'advantage': -2.3267922089023205, 'td_error': 1.0172885564684113, 'value_scale': 0.7888513813121353} step=872
2023-03-15 13:56:57 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_872.pt


Epoch 5/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:56:59 [info     ] DiscreteCQL_20230315135644: epoch=5 step=1090 epoch=5 metrics={'time_sample_batch': 9.36044465511217e-05, 'time_algorithm_update': 0.004180771495224139, 'loss': 1.4517310835899564, 'time_step': 0.004331515469682326, 'environment': 2.3999999999999995, 'advantage': -2.3650106837628333, 'td_error': 1.0795315006971744, 'value_scale': 0.8329517482697626} step=1090
2023-03-15 13:56:59 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_1090.pt


Epoch 6/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:57:02 [info     ] DiscreteCQL_20230315135644: epoch=6 step=1308 epoch=6 metrics={'time_sample_batch': 9.772427585146843e-05, 'time_algorithm_update': 0.004191802182328811, 'loss': 1.4481568823166944, 'time_step': 0.0043426303688539275, 'environment': 2.3999999999999995, 'advantage': -2.480090078419778, 'td_error': 1.1618752905794008, 'value_scale': 0.8556614027666962} step=1308
2023-03-15 13:57:02 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_1308.pt


Epoch 7/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:57:06 [info     ] DiscreteCQL_20230315135644: epoch=7 step=1526 epoch=7 metrics={'time_sample_batch': 0.00010789862466514657, 'time_algorithm_update': 0.0047841039272623325, 'loss': 1.4426602015801526, 'time_step': 0.004953367994465959, 'environment': 2.3999999999999995, 'advantage': -2.4403183947149647, 'td_error': 1.1244752856071452, 'value_scale': 0.835108272974771} step=1526
2023-03-15 13:57:06 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_1526.pt


Epoch 8/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:57:09 [info     ] DiscreteCQL_20230315135644: epoch=8 step=1744 epoch=8 metrics={'time_sample_batch': 0.0001051119708139962, 'time_algorithm_update': 0.004752359259019204, 'loss': 1.439144038279122, 'time_step': 0.004920841357029906, 'environment': 2.3999999999999995, 'advantage': -2.34811731371349, 'td_error': 1.0744722281786758, 'value_scale': 0.8117563516800902} step=1744
2023-03-15 13:57:09 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_1744.pt


Epoch 9/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:57:12 [info     ] DiscreteCQL_20230315135644: epoch=9 step=1962 epoch=9 metrics={'time_sample_batch': 9.022065258901054e-05, 'time_algorithm_update': 0.004113776968159807, 'loss': 1.4366861133400453, 'time_step': 0.004258794522066729, 'environment': 2.3999999999999995, 'advantage': -2.4919205157610853, 'td_error': 1.1656133821392485, 'value_scale': 0.8270141491582589} step=1962
2023-03-15 13:57:12 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_1962.pt


Epoch 10/10:   0%|          | 0/218 [00:00<?, ?it/s]

2023-03-15 13:57:15 [info     ] DiscreteCQL_20230315135644: epoch=10 step=2180 epoch=10 metrics={'time_sample_batch': 9.795175779850111e-05, 'time_algorithm_update': 0.004405923939626151, 'loss': 1.4329861752483823, 'time_step': 0.004562744306861807, 'environment': 2.3999999999999995, 'advantage': -2.4479490878155725, 'td_error': 1.1254472196108043, 'value_scale': 0.8244722468449935} step=2180
2023-03-15 13:57:15 [info     ] Model parameters are saved to d3rlpy_logs/DiscreteCQL_20230315135644/model_2180.pt


[(1,
  {'time_sample_batch': 0.00010322868277173524,
   'time_algorithm_update': 0.00437508373085512,
   'loss': 1.544770058141936,
   'time_step': 0.004544807136605639,
   'environment': 2.3999999999999995,
   'advantage': -2.5090578793870866,
   'td_error': 0.856239937198643,
   'value_scale': 0.880070686648392}),
 (2,
  {'time_sample_batch': 0.00010435624953803667,
   'time_algorithm_update': 0.004549850017652599,
   'loss': 1.4747750578670327,
   'time_step': 0.004713061752669308,
   'environment': 2.3999999999999995,
   'advantage': -2.5003623050301313,
   'td_error': 0.9662334391609376,
   'value_scale': 0.8541801982137234}),
 (3,
  {'time_sample_batch': 9.817595875591314e-05,
   'time_algorithm_update': 0.004445455489902321,
   'loss': 1.4605473405724272,
   'time_step': 0.004599369993997277,
   'environment': 2.3999999999999995,
   'advantage': -2.3171698987098166,
   'td_error': 0.9785533055632676,
   'value_scale': 0.8253534857001095}),
 (4,
  {'time_sample_batch': 9.51607292

In [22]:
predict_actions_for_state(cql)

4
1
4
4
4
