In [1]:
import gymnasium as gym
from gymnasium import spaces, Space

import random
import pickle
import numpy as np
from tqdm import tqdm
import time
import d3rlpy

In [2]:
np.random.seed(42)
random.seed(42)

### Discrete Actions

In [3]:
class CustomActionSpace(Space):
    def __init__(self, shape=None, dtype=None):
        super().__init__(shape, dtype)
        actions = np.arange(0.5, 5.5, 0.5)
        self.actions_map = {idx: action for idx, action in enumerate(actions)}
        self.actions = list(self.actions_map.keys())
    
class MovieLensEnv(gym.Env):
    
    def __init__(self, data, use_prev_temp_as_feature=False, van_specific_embeddings=None, pbar=None):
        # print("__init__ method")
        # with open('../gym/data/mlens/mlens-test-trajectories-v1.pkl', 'rb') as f:
        # with open(test_traj_path, 'rb') as f:
        self.dataset = data

        super(MovieLensEnv, self).__init__()
        actions = np.arange(0.5, 5.5, 0.5)
        self.actions_map = {idx: action for idx, action in enumerate(actions)}
        self.current_step = 0
        self.max_steps = sum(len(traj['observations']) for traj in self.dataset)
        self.action_space = spaces.Discrete(10)  # You need to define CustomActionSpace
        self.observation_space = spaces.Box(low=0, high=1, shape=(self.dataset[0]['observations'].shape[1],), dtype=np.float32)
        self.sampled_idx = None
        self.action = None
        self.reward = None
        self.pbar = pbar
        self.total_steps = 0
        self.use_prev_temp = use_prev_temp_as_feature
        # self.idx_of_prev_temp_feat = np.where(self.dataset[0]['features'] == 'd_prev_target_temp')[0][0]
        self.personalized_features = van_specific_embeddings

    def step(self, action):
        self.action = action
        target_rating = self.dataset[self.sampled_idx]['targets'][self.current_step]
        # print(f"action taken in step: {action}")
        # print(f"type of action: {type(action)}")
        pred_rating = self.actions_map[action]
        # print(f"pred_rating: {pred_rating}")
        
        # print(f"action: {self.action} | pred_rating: {pred_rating} | original_rating: {target_rating}")
        acc = 0
        if pred_rating == target_rating:
            acc = 1

        # Rewards scheme 5
        # -------------------------------
        # error = abs(target_rating - pred_rating)
        # self.reward = (1- (error / 4.5)) ** 2
        # # -------------------------------
        
        # Binary Rewards scheme 
        # -------------------------------
        if target_rating >= 3.5 and pred_rating >= 3.5:
            self.reward = 1
        else:
            self.reward = 0
        
        # # -------------------------------
        # # Reward for special cases
        # if target_rating != pred_rating:
        #     special_reward = reward
        # else:
        #     special_reward = 0
        
        done = False
        
        # if self.pbar is not None:
        #     self.pbar.set_description(f"(idx, step): ({self.sampled_idx}, {self.current_step}) | True rating: {target_rating} | Predicted rating: {pred_rating} | reward: {self.reward:.2f}")
        #     # time.sleep(0.25)
        self.current_step += 1
        obs, done = self._next_observation()
        self.total_steps += 1
        # return obs, self.reward, done, acc, target_rating, pred_rating, self.total_steps
        return obs, self.reward, done, None, {}
    
    def reset(self, seed=None):
        self.sampled_idx = random.randint(0, len(self.dataset) - 1)
        self.current_step = 0
        traj = self.dataset[self.sampled_idx]
        user_id = traj['user_id']


        obs = traj['observations'][self.current_step]

        if self.personalized_features is not None:
            obs = np.hstack((obs, self.personalized_features[user_id]))
        
        return obs, None
    
    def _next_observation(self):
        if self.dataset[self.sampled_idx]['terminals'][self.current_step]:
            done = True
            obs, _ = self.reset()
            return obs, done
        
        traj = self.dataset[self.sampled_idx]
        user_id = traj['user_id']
        obs = traj['observations'][self.current_step]
        if self.personalized_features is not None:
            obs = np.hstack((obs, self.personalized_features[van_id]))
        done = False
        return obs, done

    def eval(self):
        self.training = False
        
    def get_true_temperature(self):
        target_temperature = self.dataset[self.sampled_idx]['actions'][self.current_step]
        target_temperature = self.actions_map[target_temperature]
        return target_temperature
        


In [4]:
from copy import deepcopy
with open("../data/dt-datasets/movielens/processed-data/all_trajectories_with_concatenated_movname_genres_tags_userid_reward_of_scale_5.pkl", 'rb') as f:
    all_trajectories = pickle.load(f)
# all_trajs_copy = deepcopy(all_trajectories)

In [5]:
# Calculate the size for the training set
np.random.seed(42)
trajectories = all_trajectories
indices = {i for i in range(len(trajectories))}
train_indices = list(np.random.choice(list(indices), size=round(0.7*len(indices)), replace=False))
remaining_indices = indices.difference(train_indices)
test_indices = remaining_indices

print(f"total train users: {len(train_indices)}")
print(f"total test users: {len(test_indices)}")

train_trajectories = [trajectories[idx]for idx in train_indices]
test_trajectories = [trajectories[idx]for idx in test_indices]

print("Train set:", len(train_trajectories))
print("Test set:", len(test_trajectories))

total train users: 427
total test users: 183
Train set: 427
Test set: 183


#### Reward scheme 3: Binary reward: 1 or 0; 1 if liked 0 if not

In [6]:
train_trajectories_copy = deepcopy(train_trajectories)
threshold = 3.5
for traj in train_trajectories_copy:
    name_and_genre_embeds = traj['observations'][:, 0:768] + traj['observations'][:, 768:2*768]
    traj['observations'] = np.concatenate((name_and_genre_embeds, traj['observations'][:, 3*768:]), axis=1)
    # print(traj['observations'].shape)
    # errors = abs(highest_rating - traj['targets'])
    traj['rewards'] = (traj['targets'] >= threshold).astype(int)


test_trajectories_copy = deepcopy(test_trajectories)
for traj in test_trajectories_copy:
    name_and_genre_embeds = traj['observations'][:, 0:768] + traj['observations'][:, 768:2*768]
    traj['observations'] = np.concatenate((name_and_genre_embeds, traj['observations'][:, 3*768:]), axis=1)
    # print(traj['observations'].shape)
    # errors = abs(highest_rating - traj['targets'])
    traj['rewards'] = (traj['targets'] >= threshold).astype(int)

train_data_with_binary_rewards = train_trajectories_copy
test_data_with_binary_rewards = test_trajectories_copy

with open('data/train_data_with_binary_rewards.pkl', 'wb') as f:
    pickle.dump(train_data_with_binary_rewards, f)

with open('data/test_data_with_binary_rewards.pkl', 'wb') as f:
    pickle.dump(test_data_with_binary_rewards, f)

In [7]:
# Data preparation for Discrete CQL
import numpy as np
observations_mlens = []
observations_mlens = np.concatenate([ep['observations'] for ep in train_data_with_binary_rewards])
actions_mlens = np.concatenate([ep['actions'] for ep in train_data_with_binary_rewards])
rewards_mlens = np.concatenate([ep['rewards'] for ep in train_data_with_binary_rewards])
terminals_mlens = np.concatenate([ep['terminals'] for ep in train_data_with_binary_rewards])

timeouts = None

In [8]:
from d3rlpy.dataset import EpisodeGenerator

episode_generator = EpisodeGenerator(
    observations=observations_mlens,
    actions=actions_mlens,
    rewards=rewards_mlens,
    terminals=terminals_mlens,
    timeouts=timeouts,
)

episodes_generated_mlens = episode_generator()

In [9]:
from d3rlpy.dataset import ReplayBuffer, InfiniteBuffer

dataset = ReplayBuffer(
    InfiniteBuffer(),
    episodes=episodes_generated_mlens,
    transition_picker=None,
    trajectory_slicer=None,
)

[2m2024-02-11 13:27.15[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('float32')], shape=[(800,)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m
[2m2024-02-11 13:27.15[0m [[32m[1minfo     [0m] [1mAction-space has been automatically determined.[0m [36maction_space[0m=[35m<ActionSpace.DISCRETE: 2>[0m
[2m2024-02-11 13:27.15[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m10[0m


In [10]:
dataset.episodes[0].observations.shape

(227, 800)

In [11]:
env = MovieLensEnv(test_data_with_binary_rewards)

#### Discrete CQL

In [12]:
# # start training
# cql = d3rlpy.algos.DiscreteCQLConfig().create(device='cuda')
# cql.fit(
#     dataset,
#     n_steps=10000,
#     n_steps_per_epoch=1000,
#     evaluators={
#         'environment': d3rlpy.metrics.EnvironmentEvaluator(env),
#     },
# )

# # evaluate
# rewards = []
# for _ in range(10):
#     reward = d3rlpy.metrics.evaluate_qlearning_with_environment(cql, env)
#     rewards.append(reward)
# # print(np.round(rewards, 2))
    
# for r in np.round(rewards, 2):
#     print(r)

#### DQN

In [15]:
# start training
dqn = d3rlpy.algos.DQNConfig().create(device='cuda')
dqn.fit(
    dataset,
    n_steps=10000,
    n_steps_per_epoch=1000,
    evaluators={
        'environment': d3rlpy.metrics.EnvironmentEvaluator(env),
    },
)

# evaluate
rewards = []
for _ in range(10):
    reward = d3rlpy.metrics.evaluate_qlearning_with_environment(dqn, env)
    rewards.append(reward)
# print(np.round(rewards, 2))
    
for r in np.round(rewards, 2):
    print(r)

[2m2024-02-11 13:28.55[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('float32')], shape=[(800,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=10)[0m
[2m2024-02-11 13:28.55[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DQN_20240211132855[0m
[2m2024-02-11 13:28.55[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-02-11 13:28.55[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-02-11 13:28.55[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [800], 'action_size': 10, 'config': {'type': 'dqn', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': 

Epoch 1/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.02[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001644219398498535, 'time_algorithm_update': 0.003269291639328003, 'loss': 0.4139505652189255, 'time_step': 0.004974134922027588, 'environment': 54.770023420049824}[0m [36mstep[0m=[35m1000[0m
[2m2024-02-11 13:29.02[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_1000.d3[0m


Epoch 2/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.07[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016619021892547608, 'time_algorithm_update': 0.003310084581375122, 'loss': 0.35927419032156466, 'time_step': 0.005024294853210449, 'environment': 66.85420451869669}[0m [36mstep[0m=[35m2000[0m
[2m2024-02-11 13:29.07[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_2000.d3[0m


Epoch 3/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.12[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016223554611206055, 'time_algorithm_update': 0.003258695125579834, 'loss': 0.34989781664311886, 'time_step': 0.004934022903442383, 'environment': 72.180401167175}[0m [36mstep[0m=[35m3000[0m
[2m2024-02-11 13:29.12[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_3000.d3[0m


Epoch 4/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.18[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016711032390594482, 'time_algorithm_update': 0.0033316402435302735, 'loss': 0.34295177841186525, 'time_step': 0.005059114694595337, 'environment': 76.15577479369787}[0m [36mstep[0m=[35m4000[0m
[2m2024-02-11 13:29.18[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_4000.d3[0m


Epoch 5/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.23[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001614142656326294, 'time_algorithm_update': 0.003193758726119995, 'loss': 0.337283948764205, 'time_step': 0.004860852956771851, 'environment': 56.21985578277819}[0m [36mstep[0m=[35m5000[0m
[2m2024-02-11 13:29.23[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_5000.d3[0m


Epoch 6/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.29[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016687748432159424, 'time_algorithm_update': 0.0032659831047058105, 'loss': 0.3398749498873949, 'time_step': 0.004987427473068237, 'environment': 48.57794643683288}[0m [36mstep[0m=[35m6000[0m
[2m2024-02-11 13:29.29[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_6000.d3[0m


Epoch 7/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.35[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015975892543792724, 'time_algorithm_update': 0.0032206614017486574, 'loss': 0.33341740249097346, 'time_step': 0.004870854139328003, 'environment': 59.759058723358315}[0m [36mstep[0m=[35m7000[0m
[2m2024-02-11 13:29.35[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_7000.d3[0m


Epoch 8/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.40[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016170566082000732, 'time_algorithm_update': 0.003234290361404419, 'loss': 0.3350099585056305, 'time_step': 0.004902865171432495, 'environment': 65.86207723587265}[0m [36mstep[0m=[35m8000[0m
[2m2024-02-11 13:29.40[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_8000.d3[0m


Epoch 9/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.46[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016836612224578858, 'time_algorithm_update': 0.003324442148208618, 'loss': 0.06315795130934566, 'time_step': 0.005070282220840454, 'environment': 63.933878117131826}[0m [36mstep[0m=[35m9000[0m
[2m2024-02-11 13:29.46[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_9000.d3[0m


Epoch 10/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:29.51[0m [[32m[1minfo     [0m] [1mDQN_20240211132855: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0017283413410186768, 'time_algorithm_update': 0.003416623830795288, 'loss': 0.05479653418343514, 'time_step': 0.005207204103469849, 'environment': 67.17517251775607}[0m [36mstep[0m=[35m10000[0m
[2m2024-02-11 13:29.51[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DQN_20240211132855/model_10000.d3[0m
66.35
68.9
67.68
69.06
50.98
71.11
46.64
62.26
69.63
70.77


#### DDQN

In [14]:
# start training
ddqn = d3rlpy.algos.DoubleDQNConfig().create(device='cuda')
ddqn.fit(
    dataset,
    n_steps=10000,
    n_steps_per_epoch=1000,
    evaluators={
        'environment': d3rlpy.metrics.EnvironmentEvaluator(env),
    },
)

# evaluate
rewards = []
for _ in range(10):
    reward = d3rlpy.metrics.evaluate_qlearning_with_environment(ddqn, env)
    rewards.append(reward)
# print(np.round(rewards, 2))
    
for r in np.round(rewards, 2):
    print(r)

[2m2024-02-11 13:27.17[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('float32')], shape=[(800,)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=10)[0m
[2m2024-02-11 13:27.17[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DoubleDQN_20240211132717[0m
[2m2024-02-11 13:27.17[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-02-11 13:27.17[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-02-11 13:27.17[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [800], 'action_size': 10, 'config': {'type': 'double_dqn', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'act

Epoch 1/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:27.23[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016401112079620361, 'time_algorithm_update': 0.0034022080898284913, 'loss': 0.45806796592473986, 'time_step': 0.005103318452835083, 'environment': 61.12494848765058}[0m [36mstep[0m=[35m1000[0m
[2m2024-02-11 13:27.23[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_1000.d3[0m


Epoch 2/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:27.29[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016365442276000976, 'time_algorithm_update': 0.003261301517486572, 'loss': 0.34124353030323984, 'time_step': 0.0049538564682006835, 'environment': 64.23304194764127}[0m [36mstep[0m=[35m2000[0m
[2m2024-02-11 13:27.29[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_2000.d3[0m


Epoch 3/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:27.34[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016311161518096924, 'time_algorithm_update': 0.0032949364185333252, 'loss': 0.3580922721847892, 'time_step': 0.004983624219894409, 'environment': 66.25463734213118}[0m [36mstep[0m=[35m3000[0m
[2m2024-02-11 13:27.34[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_3000.d3[0m


Epoch 4/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:27.40[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016907517910003662, 'time_algorithm_update': 0.0033993911743164062, 'loss': 0.3473753787204623, 'time_step': 0.005151326179504394, 'environment': 50.23706123935607}[0m [36mstep[0m=[35m4000[0m
[2m2024-02-11 13:27.40[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_4000.d3[0m


Epoch 5/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:27.46[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0017340548038482667, 'time_algorithm_update': 0.0034393789768218993, 'loss': 0.31005658972263334, 'time_step': 0.0052343418598175045, 'environment': 73.60489603758148}[0m [36mstep[0m=[35m5000[0m
[2m2024-02-11 13:27.46[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_5000.d3[0m


Epoch 6/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:27.51[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016828913688659668, 'time_algorithm_update': 0.003341737508773804, 'loss': 0.28209242078661917, 'time_step': 0.005082767248153686, 'environment': 68.28980869199984}[0m [36mstep[0m=[35m6000[0m
[2m2024-02-11 13:27.51[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_6000.d3[0m


Epoch 7/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:27.57[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0016415410041809083, 'time_algorithm_update': 0.0032884111404418944, 'loss': 0.2926566868647933, 'time_step': 0.0049905295372009275, 'environment': 73.46979155534862}[0m [36mstep[0m=[35m7000[0m
[2m2024-02-11 13:27.57[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_7000.d3[0m


Epoch 8/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:28.02[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001708214521408081, 'time_algorithm_update': 0.003382113456726074, 'loss': 0.28276054468005896, 'time_step': 0.005149126529693604, 'environment': 66.55570588755165}[0m [36mstep[0m=[35m8000[0m
[2m2024-02-11 13:28.02[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_8000.d3[0m


Epoch 9/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:28.08[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0017132432460784913, 'time_algorithm_update': 0.003343533992767334, 'loss': 0.019949583027279005, 'time_step': 0.005121784687042236, 'environment': 68.69245364081135}[0m [36mstep[0m=[35m9000[0m
[2m2024-02-11 13:28.08[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_9000.d3[0m


Epoch 10/10:   0%|          | 0/1000 [00:00<?, ?it/s]

[2m2024-02-11 13:28.15[0m [[32m[1minfo     [0m] [1mDoubleDQN_20240211132717: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0017888224124908447, 'time_algorithm_update': 0.0034689466953277586, 'loss': 0.00654171628144104, 'time_step': 0.005318193912506103, 'environment': 66.24656360124017}[0m [36mstep[0m=[35m10000[0m
[2m2024-02-11 13:28.15[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/DoubleDQN_20240211132717/model_10000.d3[0m
63.05
69.88
67.38
55.7
56.58
73.0
72.5
57.91
60.85
69.05


#### Continuous CQL

In [None]:
import d3rlpy

# prepare dataset
dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')

# # prepare algorithm
# cql = d3rlpy.algos.CQLConfig().create(device='cuda:0')

# # train
# cql.fit(
#     dataset,
#     n_steps=100000,
#     evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
# )

In [None]:
env.step(0.45)

In [None]:
dataset.episodes[0].observations.shape

In [None]:
import d4rl

In [None]:
a = [1, 2, 3]


In [None]:
a[-2:]