In [5]:
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import os
# import wandb
from transformers import ViTMAEConfig

from custom_models.CustomViT import CustomViT
from custom_models.CustomViTMAE import CustomViTMAE
from transformers.models.vit_mae.modeling_vit_mae import ViTMAEModel
# from tem_dataloader import MultimodalDatasetPerTrajectory
from torch.utils.data import DataLoader

from d3rlpy.algos import CQL
from d3rlpy.dataset import Episode, MDPDataset, Transition
# wandb.login() 

ModuleNotFoundError: No module named 'wandb'

In [7]:
from custom_models.CustomViT import CustomViT
from custom_models.CustomViTMAE import CustomViTMAE
import torch
# call CustomViT
from transformers import AutoImageProcessor, ViTMAEForPreTraining, ViTMAEConfig
from PIL import Image

output_dir='/home/ubuntu/camelmera'
# trained_model_name = 'multimodal'
# output_dir='/home/ubuntu/weights/' + trained_model_name

# Initialize a new CustomViT model
model_name = "facebook/vit-mae-base"
vit_config = ViTMAEConfig.from_pretrained(model_name)
vit_config.output_hidden_states=True
vit_model = CustomViT(config=vit_config)

# Initialize a new CustomViTMAE model
model_name = "facebook/vit-mae-base"
config = ViTMAEConfig.from_pretrained(model_name)
config.output_hidden_states=True
custom_model = CustomViTMAE(config=config)
custom_model.vit = vit_model

# Load the state_dict from the saved model
state_dict = torch.load(f"{output_dir}/pytorch_model.bin")
custom_model.load_state_dict(state_dict)

# don't need decoders
vit_encoder = custom_model.vit

In [2]:
from tem_dataloader import MultimodalDataset
import functools

environment_name = 'AbandonedFactoryExposure'
environemnt_directory = f'/mnt/temp_mount/{environment_name}/Data_easy'
dataset = MultimodalDataset(environemnt_directory)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P001
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P004
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P006
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P005
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P008
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P002
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P009
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P003
Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P000
Number of images: 8185
Number 

In [13]:
import numpy as np

def reward_function(state_embedding, goal_embedding, threshold=0.05, goal_reward=100):
    distance = np.linalg.norm(state_embedding - goal_embedding)

    if distance <= threshold:
        # Give a large positive reward when the goal is reached
        reward = goal_reward
    else:
        # Give a negative reward proportional to the distance otherwise
        reward = -distance

    return reward

In [14]:
from tem_dataloader import MultimodalDatasetPerTrajectory
import functools
import os
from torch.utils.data import Dataset, DataLoader

environment_name = 'AbandonedFactoryExposure'
environemnt_directory = f'/mnt/temp_mount/{environment_name}/Data_easy'
OBSERVATION_SIZE = 768
ACTION_SIZE = 7
BATCH_SIZE = 64

for i in range(0,10):
    if i==7:
        continue
    trajectory_folder_path = os.path.join(environemnt_directory, f'P00{i}')
    my_dataset = MultimodalDatasetPerTrajectory(trajectory_folder_path)
    train_dataloader = DataLoader(my_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize empty arrays for observations, actions, rewards, and terminals
    all_observations = np.empty((0, OBSERVATION_SIZE))
    all_actions = np.empty((0, ACTION_SIZE))
    all_rewards = np.empty(0)
    all_terminals = np.empty(0, dtype=bool)

    for batch_idx, data in enumerate(train_dataloader):
        # get embedding
        vit_encoder.cuda()
        vit_encoder.eval()
        pixel_values = data["pixel_values"].cuda()
        pixel_values1 = data["pixel_values1"].cuda()
        pixel_values2 = data["pixel_values2"].cuda()
        outputs = vit_encoder(pixel_values,pixel_values1,pixel_values2,noise=None)
        embedding = outputs.last_hidden_state[:,0,:]
        observation = embedding.cpu().detach().numpy()
        # get action
        pose = data["pose_values"]
        action = torch.diff(pose,axis = 0).numpy()
        action = np.concatenate((action, np.zeros((1,7))), axis=0)
        # get reward
        goal = observation[-1]
        partial_function = functools.partial(reward_function, goal_embedding=goal)
        reward = np.apply_along_axis(partial_function, 1, observation)
        # get terminals
        terminals = np.zeros_like(reward, dtype=int)
        terminals[reward == 100] = 1

        # Concatenate observations, actions, rewards, and terminals
        all_observations = np.vstack((all_observations, observation))
        all_actions = np.vstack((all_actions, action))
        all_rewards = np.hstack((all_rewards, reward))
        all_terminals = np.hstack((all_terminals, terminals))

    print("All observations shape:", all_observations.shape)
    print("All actions shape:", all_actions.shape)
    print("All rewards shape:", all_rewards.shape)
    print("All terminals shape:", all_terminals.shape)
    
    np.save(f'trajectories/all_observations_P00{i}.npy', all_observations)
    np.save(f'trajectories/all_actions_P00{i}.npy', all_actions)
    np.save(f'trajectories/all_rewards_P00{i}.npy', all_rewards)
    np.save(f'trajectories/all_terminals_P00{i}.npy', all_terminals)

Processing folder: /mnt/temp_mount/AbandonedFactoryExposure/Data_easy/P000
Number of images: 1480
Number of depth: 1480
Number of lidar: 1480
Number of pose: 1480
All observations shape: (1480, 768)
All actions shape: (1480, 7)
All rewards shape: (1480,)
All terminals shape: (1480,)
Processing folder: /mnt/temp_mount/AbandonedFactoryExposure/Data_easy/P001
Number of images: 1055
Number of depth: 1055
Number of lidar: 1055
Number of pose: 1055
All observations shape: (1055, 768)
All actions shape: (1055, 7)
All rewards shape: (1055,)
All terminals shape: (1055,)
Processing folder: /mnt/temp_mount/AbandonedFactoryExposure/Data_easy/P002
Number of images: 698
Number of depth: 698
Number of lidar: 698
Number of pose: 698
All observations shape: (698, 768)
All actions shape: (698, 7)
All rewards shape: (698,)
All terminals shape: (698,)
Processing folder: /mnt/temp_mount/AbandonedFactoryExposure/Data_easy/P003
Number of images: 1009
Number of depth: 1009
Number of lidar: 1009
Number of pose

In [22]:
from tem_dataloader import MultimodalDatasetPerTrajectory
# environment_name = 'AmericanDinerExposure'
# environemnt_directory = f'/media/tyz/3B6FFE7354FF3296/11_777/tartanairv2filtered/{environment_name}/Data_easy'
# my_dataset = MultimodalDatasetPerTrajectory(environemnt_directory)
environment_name = 'AbandonedFactoryExposure'
environemnt_directory = f'/mnt/data/tartanairv2filtered/{environment_name}/Data_easy'
OBSERVATION_SIZE = 768
ACTION_SIZE = 7
BATCH_SIZE = 64

# for folder in os.listdir(environemnt_directory):
trajectory_folder_path = os.path.join(environemnt_directory, 'P000')
my_dataset = MultimodalDatasetPerTrajectory(trajectory_folder_path)
train_dataloader = DataLoader(my_dataset, batch_size=BATCH_SIZE, shuffle=False)

Processing folder: /mnt/data/tartanairv2filtered/AbandonedFactoryExposure/Data_easy/P000
Number of images: 1480
Number of depth: 1480
Number of lidar: 1480
Number of pose: 1480


In [24]:
import functools
import numpy as np
import torch

# Initialize empty arrays for observations, actions, rewards, and terminals
all_observations = np.empty((0, OBSERVATION_SIZE))
all_actions = np.empty((0, ACTION_SIZE))
all_rewards = np.empty(0)
all_terminals = np.empty(0, dtype=bool)

for batch_idx, data in enumerate(train_dataloader):
    # get embedding
    vit_encoder.cuda()
    vit_encoder.eval()
    pixel_values = data["pixel_values"].cuda()
    pixel_values1 = data["pixel_values1"].cuda()
    pixel_values2 = data["pixel_values2"].cuda()
    outputs = vit_encoder(pixel_values,pixel_values1,pixel_values2,noise=None)
    embedding = outputs.last_hidden_state[:,0,:]
    observation = embedding.cpu().detach().numpy()
    # get action
    pose = data["pose_values"]
    action = torch.diff(pose,axis = 0).numpy()
    action = np.concatenate((action, np.zeros((1,7))), axis=0)
    # get reward
    goal = observation[-1]
    partial_function = functools.partial(reward_function, goal_embedding=goal)
    reward = np.apply_along_axis(partial_function, 1, observation)
    # get terminals
    terminals = [False]*BATCH_SIZE
    terminals[-1]=True
    terminals = np.array(terminals)
    if batch_idx==0:
        print(observation.shape)
        print(action.shape)
        print(reward.shape)
        print(terminals.shape)

(64, 768)
(64, 7)
(64,)
(64,)


In [25]:
print("All observations shape:", all_observations.shape)
print("All actions shape:", all_actions.shape)
print("All rewards shape:", all_rewards.shape)
print("All terminals shape:", all_terminals.shape)

All observations shape: (1480, 768)
All actions shape: (1480, 7)
All rewards shape: (1480,)
All terminals shape: (1536,)


In [17]:
np.save('all_observations.npy', all_observations)
np.save('all_actions.npy', all_actions)
np.save('all_rewards.npy', all_rewards)
np.save('all_terminals.npy', all_terminals)

In [15]:
'''
Args:
        observations (numpy.ndarray): N-D array. If the
            observation is a vector, the shape should be
            `(N, dim_observation)`. If the observations is an image, the shape
            should be `(N, C, H, W)`.
        actions (numpy.ndarray): N-D array. If the actions-space is
            continuous, the shape should be `(N, dim_action)`. If the
            action-space is discrete, the shape should be `(N,)`.
        rewards (numpy.ndarray): array of scalar rewards. The reward function
            should be defined as :math:`r_t = r(s_t, a_t)`.
        terminals (numpy.ndarray): array of binary terminal flags.
        episode_terminals (numpy.ndarray): array of binary episode terminal
            flags. The given data will be splitted based on this flag.
            This is useful if you want to specify the non-environment
            terminations (e.g. timeout). If ``None``, the episode terminations
            match the environment terminations.
        discrete_action (bool): flag to use the given actions as discrete
            action-space actions. If ``None``, the action type is automatically
            determined.
    '''
cql_dataset = MDPDataset(observations=all_observations,actions=all_actions,rewards=all_rewards,terminals=all_terminals,episode_terminals=all_terminals)

In [16]:
from d3rlpy.algos import CQL

# setup CQL algorithm
cql = CQL(use_gpu=True)

# split train and test episodes
# train_episodes, test_episodes = train_test_split(cql_dataset, test_size=0.25)

# start training
cql.fit(cql_dataset,
        eval_episodes=None,
        n_epochs=10,
        scorers=None)

[2m2023-05-04 01:27:46[0m [[32m[1mdebug    [0m] [1mRoundIterator is selected.[0m
[2m2023-05-04 01:27:46[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/CQL_20230504012746[0m
[2m2023-05-04 01:27:46[0m [[32m[1mdebug    [0m] [1mBuilding models...[0m
[2m2023-05-04 01:27:46[0m [[32m[1mdebug    [0m] [1mModels have been built.[0m
[2m2023-05-04 01:27:46[0m [[32m[1minfo     [0m] [1mParameters are saved to d3rlpy_logs/CQL_20230504012746/params.json[0m [36mparams[0m=[35m{'action_scaler': None, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 0.0001, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'bat

Epoch 1/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:48[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=1 step=31[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009980278630410471, 'time_algorithm_update': 0.08748865127563477, 'temp_loss': 11.744194215343844, 'temp': 0.9984005458893315, 'alpha_loss': -31.277802190473004, 'alpha': 1.0016018421419206, 'critic_loss': 3721.7289094002017, 'actor_loss': -10.316245063658684, 'time_step': 0.0886292073034471}[0m [36mstep[0m=[35m31[0m
[2m2023-05-04 01:27:48[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_31.pt[0m


Epoch 2/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:50[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=2 step=62[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000903790996920678, 'time_algorithm_update': 0.03586419936149351, 'temp_loss': 11.724647952664283, 'temp': 0.995309462470393, 'alpha_loss': -31.375866490025675, 'alpha': 1.0047140698279104, 'critic_loss': 3234.1929419732865, 'actor_loss': -25.97842062673261, 'time_step': 0.03689156809160786}[0m [36mstep[0m=[35m62[0m
[2m2023-05-04 01:27:50[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_62.pt[0m


Epoch 3/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:51[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=3 step=93[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009151427976546749, 'time_algorithm_update': 0.03396520306987147, 'temp_loss': 11.691720900997039, 'temp': 0.9922334763311571, 'alpha_loss': -31.483734069331998, 'alpha': 1.0078407295288578, 'critic_loss': 3119.4801301033267, 'actor_loss': -26.33940456759545, 'time_step': 0.03499096439730737}[0m [36mstep[0m=[35m93[0m
[2m2023-05-04 01:27:51[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_93.pt[0m


Epoch 4/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:52[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=4 step=124[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0008699817042196951, 'time_algorithm_update': 0.03454456790801017, 'temp_loss': 11.659830677893854, 'temp': 0.9891713492331966, 'alpha_loss': -31.584238790696666, 'alpha': 1.0109822057908582, 'critic_loss': 3020.9976530997983, 'actor_loss': -27.604707779422885, 'time_step': 0.03553482024900375}[0m [36mstep[0m=[35m124[0m
[2m2023-05-04 01:27:52[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_124.pt[0m


Epoch 5/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:53[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=5 step=155[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0008454092087284211, 'time_algorithm_update': 0.03654890675698557, 'temp_loss': 11.622486237556704, 'temp': 0.9861233984270403, 'alpha_loss': -31.666741217336348, 'alpha': 1.0141376026215092, 'critic_loss': 2904.5417086693546, 'actor_loss': -29.150966029013357, 'time_step': 0.0375160324958063}[0m [36mstep[0m=[35m155[0m
[2m2023-05-04 01:27:53[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_155.pt[0m


Epoch 6/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:55[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=6 step=186[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00135833986343876, 'time_algorithm_update': 0.04469979193902785, 'temp_loss': 11.579890897197108, 'temp': 0.983090054604315, 'alpha_loss': -31.782677435105846, 'alpha': 1.0173068546479749, 'critic_loss': 2793.4649776335687, 'actor_loss': -30.826529349050215, 'time_step': 0.04619258449923608}[0m [36mstep[0m=[35m186[0m
[2m2023-05-04 01:27:55[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_186.pt[0m


Epoch 7/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:56[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=7 step=217[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012105972536148564, 'time_algorithm_update': 0.04299592202709567, 'temp_loss': 11.538329616669685, 'temp': 0.9800722118346922, 'alpha_loss': -31.878970792216638, 'alpha': 1.0204919538190287, 'critic_loss': 2683.6039645287296, 'actor_loss': -32.308014162125126, 'time_step': 0.04436281419569446}[0m [36mstep[0m=[35m217[0m
[2m2023-05-04 01:27:56[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_217.pt[0m


Epoch 8/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:57[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=8 step=248[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0015095664608863093, 'time_algorithm_update': 0.041245768147130164, 'temp_loss': 11.50996232801868, 'temp': 0.9770682415654582, 'alpha_loss': -31.977023093931138, 'alpha': 1.023691254277383, 'critic_loss': 2558.870573966734, 'actor_loss': -33.564017018964215, 'time_step': 0.043018541028422695}[0m [36mstep[0m=[35m248[0m
[2m2023-05-04 01:27:57[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_248.pt[0m


Epoch 9/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:27:59[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=9 step=279[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009979432629000757, 'time_algorithm_update': 0.036409478033742594, 'temp_loss': 11.470276709525816, 'temp': 0.974076553698509, 'alpha_loss': -32.07833819235525, 'alpha': 1.0269050905781407, 'critic_loss': 2458.1851609753026, 'actor_loss': -35.75754042594664, 'time_step': 0.037541420229019656}[0m [36mstep[0m=[35m279[0m
[2m2023-05-04 01:27:59[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_279.pt[0m


Epoch 10/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:28:00[0m [[32m[1minfo     [0m] [1mCQL_20230504012746: epoch=10 step=310[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0007801055908203125, 'time_algorithm_update': 0.030945539474487305, 'temp_loss': 11.448118209838867, 'temp': 0.9710988998413086, 'alpha_loss': -32.18758958385837, 'alpha': 1.0301338165037093, 'critic_loss': 2336.2837937878026, 'actor_loss': -37.29167212209394, 'time_step': 0.031834071682345484}[0m [36mstep[0m=[35m310[0m
[2m2023-05-04 01:28:00[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504012746/model_310.pt[0m


[(1,
  {'time_sample_batch': 0.0009980278630410471,
   'time_algorithm_update': 0.08748865127563477,
   'temp_loss': 11.744194215343844,
   'temp': 0.9984005458893315,
   'alpha_loss': -31.277802190473004,
   'alpha': 1.0016018421419206,
   'critic_loss': 3721.7289094002017,
   'actor_loss': -10.316245063658684,
   'time_step': 0.0886292073034471}),
 (2,
  {'time_sample_batch': 0.000903790996920678,
   'time_algorithm_update': 0.03586419936149351,
   'temp_loss': 11.724647952664283,
   'temp': 0.995309462470393,
   'alpha_loss': -31.375866490025675,
   'alpha': 1.0047140698279104,
   'critic_loss': 3234.1929419732865,
   'actor_loss': -25.97842062673261,
   'time_step': 0.03689156809160786}),
 (3,
  {'time_sample_batch': 0.0009151427976546749,
   'time_algorithm_update': 0.03396520306987147,
   'temp_loss': 11.691720900997039,
   'temp': 0.9922334763311571,
   'alpha_loss': -31.483734069331998,
   'alpha': 1.0078407295288578,
   'critic_loss': 3119.4801301033267,
   'actor_loss': -26.3

In [15]:
import numpy as np
from d3rlpy.dataset import Episode, MDPDataset, Transition

OBSERVATION_SIZE = 768
ACTION_SIZE = 7
BATCH_SIZE = 64

all_observations = np.empty((0, OBSERVATION_SIZE))
all_actions = np.empty((0, ACTION_SIZE))
all_rewards = np.empty(0)
all_terminals = np.empty(0, dtype=bool)

for i in range(0,10):
    if i==7:
        continue
    observation = np.load(f'trajectories/all_observations_P00{i}.npy')
    action = np.load(f'trajectories/all_actions_P00{i}.npy')
    reward = np.load(f'trajectories/all_rewards_P00{i}.npy')
    terminals = np.load(f'trajectories/all_terminals_P00{i}.npy')

    all_observations = np.vstack((all_observations, observation))
    all_actions = np.vstack((all_actions, action))
    all_rewards = np.hstack((all_rewards, reward))
    all_terminals = np.hstack((all_terminals, terminals))

    print("All observations shape:", all_observations.shape)
    print("All actions shape:", all_actions.shape)
    print("All rewards shape:", all_rewards.shape)
    print("All terminals shape:", all_terminals.shape)
cql_dataset = MDPDataset(observations=all_observations,actions=all_actions,rewards=all_rewards,terminals=all_terminals,episode_terminals=all_terminals)

All observations shape: (1480, 768)
All actions shape: (1480, 7)
All rewards shape: (1480,)
All terminals shape: (1480,)
All observations shape: (2535, 768)
All actions shape: (2535, 7)
All rewards shape: (2535,)
All terminals shape: (2535,)
All observations shape: (3233, 768)
All actions shape: (3233, 7)
All rewards shape: (3233,)
All terminals shape: (3233,)
All observations shape: (4242, 768)
All actions shape: (4242, 7)
All rewards shape: (4242,)
All terminals shape: (4242,)
All observations shape: (5049, 768)
All actions shape: (5049, 7)
All rewards shape: (5049,)
All terminals shape: (5049,)
All observations shape: (5743, 768)
All actions shape: (5743, 7)
All rewards shape: (5743,)
All terminals shape: (5743,)
All observations shape: (6744, 768)
All actions shape: (6744, 7)
All rewards shape: (6744,)
All terminals shape: (6744,)
All observations shape: (7469, 768)
All actions shape: (7469, 7)
All rewards shape: (7469,)
All terminals shape: (7469,)
All observations shape: (8185, 7

In [4]:
from d3rlpy.algos import CQL
cql01 = CQL(use_gpu=False)
cql01.build_with_dataset(cql_dataset)
cql01.load_model('/home/ubuntu/camelmera/models/gym/multimodal/d3rlpy_logs/CQL_20230503011241/model_40.pt')

In [18]:
from d3rlpy.algos import CQL

cql = CQL.from_json('d3rlpy_logs/CQL_20230504012746/params.json')

# ready to load
cql.load_model('d3rlpy_logs/CQL_20230504012746/model_310.pt')



In [19]:
# start training
cql.fit(cql_dataset,
        eval_episodes=None,
        n_epochs=30,
        scorers=None)

[2m2023-05-04 01:35:10[0m [[32m[1mdebug    [0m] [1mRoundIterator is selected.[0m
[2m2023-05-04 01:35:10[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/CQL_20230504013510[0m
[2m2023-05-04 01:35:10[0m [[32m[1minfo     [0m] [1mParameters are saved to d3rlpy_logs/CQL_20230504013510/params.json[0m [36mparams[0m=[35m{'action_scaler': None, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 0.0001, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conservative_weight': 5.0, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': No

Epoch 1/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:35:20[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=1 step=31[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010161707478184854, 'time_algorithm_update': 0.3290915950652092, 'temp_loss': 11.405332565307617, 'temp': 0.9681331226902623, 'alpha_loss': -32.28081783171623, 'alpha': 1.0333774628177765, 'critic_loss': 2233.6551868069555, 'actor_loss': -38.560919607839274, 'time_step': 0.33023206649288056}[0m [36mstep[0m=[35m31[0m
[2m2023-05-04 01:35:20[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_31.pt[0m


Epoch 2/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:35:27[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=2 step=62[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000981999981787897, 'time_algorithm_update': 0.2366661871633222, 'temp_loss': 11.369608602216166, 'temp': 0.9651807642752125, 'alpha_loss': -32.39259682932207, 'alpha': 1.0366356103650984, 'critic_loss': 2119.0245991368447, 'actor_loss': -40.40380416377898, 'time_step': 0.23776971909307665}[0m [36mstep[0m=[35m62[0m
[2m2023-05-04 01:35:27[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_62.pt[0m


Epoch 3/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:35:35[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=3 step=93[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009399690935688634, 'time_algorithm_update': 0.2364673614501953, 'temp_loss': 11.33581853681995, 'temp': 0.9622422341377505, 'alpha_loss': -32.4925417746267, 'alpha': 1.039909182056304, 'critic_loss': 2022.7026839717741, 'actor_loss': -41.9716670128607, 'time_step': 0.2375329617531069}[0m [36mstep[0m=[35m93[0m
[2m2023-05-04 01:35:35[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_93.pt[0m


Epoch 4/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:35:43[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=4 step=124[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009722478928104524, 'time_algorithm_update': 0.24814252699575118, 'temp_loss': 11.292844926157306, 'temp': 0.9593167708766076, 'alpha_loss': -32.59616494947864, 'alpha': 1.0431973511172878, 'critic_loss': 1931.2279643397178, 'actor_loss': -43.57937314433436, 'time_step': 0.24923853720388106}[0m [36mstep[0m=[35m124[0m
[2m2023-05-04 01:35:43[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_124.pt[0m


Epoch 5/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:35:50[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=5 step=155[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009500057466568485, 'time_algorithm_update': 0.24592191173184302, 'temp_loss': 11.26609605358493, 'temp': 0.9564051378157831, 'alpha_loss': -32.70433019822644, 'alpha': 1.0465002175300353, 'critic_loss': 1825.3979098412299, 'actor_loss': -45.37621467344223, 'time_step': 0.24699524141127063}[0m [36mstep[0m=[35m155[0m
[2m2023-05-04 01:35:50[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_155.pt[0m


Epoch 6/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:35:58[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=6 step=186[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009497904008434664, 'time_algorithm_update': 0.23711489092919133, 'temp_loss': 11.233154296875, 'temp': 0.9535041989818696, 'alpha_loss': -32.813171509773504, 'alpha': 1.0498186465232604, 'critic_loss': 1737.4400713520665, 'actor_loss': -46.349911228302986, 'time_step': 0.2381860363867975}[0m [36mstep[0m=[35m186[0m
[2m2023-05-04 01:35:58[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_186.pt[0m


Epoch 7/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:36:05[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=7 step=217[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010021501971829321, 'time_algorithm_update': 0.23692472519413119, 'temp_loss': 11.194594352476058, 'temp': 0.9506164077789553, 'alpha_loss': -32.91840017995527, 'alpha': 1.0531521228051954, 'critic_loss': 1650.4468049080142, 'actor_loss': -48.454989525579634, 'time_step': 0.23804723831915087}[0m [36mstep[0m=[35m217[0m
[2m2023-05-04 01:36:05[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_217.pt[0m


Epoch 8/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:36:13[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=8 step=248[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009970434250370149, 'time_algorithm_update': 0.23665251270417245, 'temp_loss': 11.160284565341088, 'temp': 0.9477410316467285, 'alpha_loss': -33.02992433117282, 'alpha': 1.0565006617576844, 'critic_loss': 1565.895255796371, 'actor_loss': -49.37596105760144, 'time_step': 0.23776928071052797}[0m [36mstep[0m=[35m248[0m
[2m2023-05-04 01:36:13[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_248.pt[0m


Epoch 9/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:36:20[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=9 step=279[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001005272711476972, 'time_algorithm_update': 0.23492783884848317, 'temp_loss': 11.128739110885128, 'temp': 0.9448788839001809, 'alpha_loss': -33.12550181727256, 'alpha': 1.059864221080657, 'critic_loss': 1480.8202298072076, 'actor_loss': -51.20974423808436, 'time_step': 0.23605735071243777}[0m [36mstep[0m=[35m279[0m
[2m2023-05-04 01:36:20[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_279.pt[0m


Epoch 10/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:36:27[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=10 step=310[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000932870372649162, 'time_algorithm_update': 0.23629223146746237, 'temp_loss': 11.103324582499843, 'temp': 0.9420276822582367, 'alpha_loss': -33.22306676064768, 'alpha': 1.0632420586001488, 'critic_loss': 1411.4829455960182, 'actor_loss': -52.45648464079826, 'time_step': 0.23734592622326267}[0m [36mstep[0m=[35m310[0m
[2m2023-05-04 01:36:27[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_310.pt[0m


Epoch 11/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:36:35[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=11 step=341[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009609422376078944, 'time_algorithm_update': 0.23752570152282715, 'temp_loss': 11.062899466483824, 'temp': 0.9391875939984475, 'alpha_loss': -33.32871972360918, 'alpha': 1.0666343704346688, 'critic_loss': 1333.3225885206652, 'actor_loss': -53.656602551860196, 'time_step': 0.2386100753661125}[0m [36mstep[0m=[35m341[0m
[2m2023-05-04 01:36:35[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_341.pt[0m


Epoch 12/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:36:42[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=12 step=372[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009121510290330456, 'time_algorithm_update': 0.23629838420498755, 'temp_loss': 11.03089480246267, 'temp': 0.9363601496142726, 'alpha_loss': -33.447303279753655, 'alpha': 1.0700419679764779, 'critic_loss': 1266.8433286605343, 'actor_loss': -55.27484709216702, 'time_step': 0.23733001370583812}[0m [36mstep[0m=[35m372[0m
[2m2023-05-04 01:36:42[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_372.pt[0m


Epoch 13/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:36:50[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=13 step=403[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010057880032447078, 'time_algorithm_update': 0.23721263485570107, 'temp_loss': 10.995390830501433, 'temp': 0.9335450126278785, 'alpha_loss': -33.562891437161355, 'alpha': 1.0734655895540792, 'critic_loss': 1205.1854504000755, 'actor_loss': -56.87439210953251, 'time_step': 0.23833936260592553}[0m [36mstep[0m=[35m403[0m
[2m2023-05-04 01:36:50[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_403.pt[0m


Epoch 14/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:36:57[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=14 step=434[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009629957137569304, 'time_algorithm_update': 0.24336472634346254, 'temp_loss': 10.958394142889208, 'temp': 0.9307423868486958, 'alpha_loss': -33.66192860757151, 'alpha': 1.0769043891660628, 'critic_loss': 1140.2056372857862, 'actor_loss': -57.696526189004224, 'time_step': 0.24444963086035945}[0m [36mstep[0m=[35m434[0m
[2m2023-05-04 01:36:57[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_434.pt[0m


Epoch 15/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:37:05[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=15 step=465[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009717018373550907, 'time_algorithm_update': 0.23802389637116464, 'temp_loss': 10.929554200941517, 'temp': 0.927951701225773, 'alpha_loss': -33.76268952892673, 'alpha': 1.0803579476571852, 'critic_loss': 1073.052716655116, 'actor_loss': -59.173760567941976, 'time_step': 0.23911464598871046}[0m [36mstep[0m=[35m465[0m
[2m2023-05-04 01:37:05[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_465.pt[0m


Epoch 16/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:37:12[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=16 step=496[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009324704447100239, 'time_algorithm_update': 0.23868568481937533, 'temp_loss': 10.894349498133506, 'temp': 0.9251717732798669, 'alpha_loss': -33.869208674277026, 'alpha': 1.0838258304903585, 'critic_loss': 1022.2078424269154, 'actor_loss': -60.36362506497291, 'time_step': 0.23973795675462292}[0m [36mstep[0m=[35m496[0m
[2m2023-05-04 01:37:12[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_496.pt[0m


Epoch 17/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:37:20[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=17 step=527[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009100514073525706, 'time_algorithm_update': 0.23907737578115157, 'temp_loss': 10.871522226641256, 'temp': 0.9224030856163271, 'alpha_loss': -33.98112770818895, 'alpha': 1.0873085606482722, 'critic_loss': 967.6675454416583, 'actor_loss': -61.4046996331984, 'time_step': 0.24010705178783787}[0m [36mstep[0m=[35m527[0m
[2m2023-05-04 01:37:20[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_527.pt[0m


Epoch 18/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:37:27[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=18 step=558[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009086516595655872, 'time_algorithm_update': 0.23862931805272256, 'temp_loss': 10.830902868701566, 'temp': 0.9196449806613307, 'alpha_loss': -34.09441511092648, 'alpha': 1.0908071648690008, 'critic_loss': 920.6823966733871, 'actor_loss': -62.730490284581336, 'time_step': 0.23965970931514616}[0m [36mstep[0m=[35m558[0m
[2m2023-05-04 01:37:27[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_558.pt[0m


Epoch 19/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:37:35[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=19 step=589[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009056445091001449, 'time_algorithm_update': 0.23877423809420678, 'temp_loss': 10.80337542872275, 'temp': 0.9168986351259293, 'alpha_loss': -34.206383489793346, 'alpha': 1.094320958660495, 'critic_loss': 872.3691268428679, 'actor_loss': -63.63703081684728, 'time_step': 0.23980018400376843}[0m [36mstep[0m=[35m589[0m
[2m2023-05-04 01:37:35[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_589.pt[0m


Epoch 20/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:37:42[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=20 step=620[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009058983095230595, 'time_algorithm_update': 0.23777965576417984, 'temp_loss': 10.768212103074596, 'temp': 0.9141628857581846, 'alpha_loss': -34.31017008135396, 'alpha': 1.0978503111870057, 'critic_loss': 834.5302045268397, 'actor_loss': -65.17218718990203, 'time_step': 0.23880695527599705}[0m [36mstep[0m=[35m620[0m
[2m2023-05-04 01:37:42[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_620.pt[0m


Epoch 21/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:37:50[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=21 step=651[0m [36mepoch[0m=[35m21[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0008913163215883317, 'time_algorithm_update': 0.24057940513856949, 'temp_loss': 10.738499672182146, 'temp': 0.9114385862504283, 'alpha_loss': -34.41144832488029, 'alpha': 1.1013942495469125, 'critic_loss': 787.6799572360131, 'actor_loss': -65.89005710232642, 'time_step': 0.24159186886202905}[0m [36mstep[0m=[35m651[0m
[2m2023-05-04 01:37:50[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_651.pt[0m


Epoch 22/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:37:57[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=22 step=682[0m [36mepoch[0m=[35m22[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009052753448486328, 'time_algorithm_update': 0.2391151382077125, 'temp_loss': 10.70498844885057, 'temp': 0.9087259211847859, 'alpha_loss': -34.52637198663527, 'alpha': 1.1049530236951766, 'critic_loss': 742.9421760805192, 'actor_loss': -67.03671215426537, 'time_step': 0.24013933827800135}[0m [36mstep[0m=[35m682[0m
[2m2023-05-04 01:37:57[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_682.pt[0m


Epoch 23/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:38:05[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=23 step=713[0m [36mepoch[0m=[35m23[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009247103045063634, 'time_algorithm_update': 0.23759798849782637, 'temp_loss': 10.671110307016681, 'temp': 0.9060236311727955, 'alpha_loss': -34.63786291307019, 'alpha': 1.1085272758237776, 'critic_loss': 707.0968096333165, 'actor_loss': -67.92995969710812, 'time_step': 0.2386554979508923}[0m [36mstep[0m=[35m713[0m
[2m2023-05-04 01:38:05[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_713.pt[0m


Epoch 24/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:38:12[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=24 step=744[0m [36mepoch[0m=[35m24[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000959404053226594, 'time_algorithm_update': 0.23914167188828991, 'temp_loss': 10.641364097595215, 'temp': 0.9033328986937, 'alpha_loss': -34.74183236399004, 'alpha': 1.1121167713595974, 'critic_loss': 667.648201234879, 'actor_loss': -69.10765814012096, 'time_step': 0.24022455369272538}[0m [36mstep[0m=[35m744[0m
[2m2023-05-04 01:38:12[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_744.pt[0m


Epoch 25/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:38:20[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=25 step=775[0m [36mepoch[0m=[35m25[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009265022893105784, 'time_algorithm_update': 0.23794101899670017, 'temp_loss': 10.612345541677167, 'temp': 0.9006518586989372, 'alpha_loss': -34.86379352692635, 'alpha': 1.1157212411203692, 'critic_loss': 633.1039448399698, 'actor_loss': -69.79104122038811, 'time_step': 0.2389881457051923}[0m [36mstep[0m=[35m775[0m
[2m2023-05-04 01:38:20[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_775.pt[0m


Epoch 26/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:38:27[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=26 step=806[0m [36mepoch[0m=[35m26[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0008840791640743132, 'time_algorithm_update': 0.236156617441485, 'temp_loss': 10.5820494005757, 'temp': 0.8979811610714081, 'alpha_loss': -34.96720455538842, 'alpha': 1.1193415964803388, 'critic_loss': 602.1876860587828, 'actor_loss': -70.81548850767074, 'time_step': 0.23716093647864558}[0m [36mstep[0m=[35m806[0m
[2m2023-05-04 01:38:27[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_806.pt[0m


Epoch 27/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:38:35[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=27 step=837[0m [36mepoch[0m=[35m27[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009458141942178049, 'time_algorithm_update': 0.23853624251581007, 'temp_loss': 10.5531617441485, 'temp': 0.8953207788928863, 'alpha_loss': -35.081697648571385, 'alpha': 1.1229768107014317, 'critic_loss': 574.0574390042212, 'actor_loss': -71.66199813350555, 'time_step': 0.23960494226024998}[0m [36mstep[0m=[35m837[0m
[2m2023-05-04 01:38:35[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_837.pt[0m


Epoch 28/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:38:42[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=28 step=868[0m [36mepoch[0m=[35m28[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009072211480909778, 'time_algorithm_update': 0.2457528498864943, 'temp_loss': 10.515092357512444, 'temp': 0.8926707083179105, 'alpha_loss': -35.19337340324156, 'alpha': 1.1266276067303074, 'critic_loss': 546.7645037251134, 'actor_loss': -72.57306720364478, 'time_step': 0.2467866482273225}[0m [36mstep[0m=[35m868[0m
[2m2023-05-04 01:38:42[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_868.pt[0m


Epoch 29/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:38:50[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=29 step=899[0m [36mepoch[0m=[35m29[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009510055665046938, 'time_algorithm_update': 0.24219419879298057, 'temp_loss': 10.488471923335906, 'temp': 0.8900320722210792, 'alpha_loss': -35.31450283911921, 'alpha': 1.130293846130371, 'critic_loss': 518.6090343844506, 'actor_loss': -73.55164829377205, 'time_step': 0.24326179873558781}[0m [36mstep[0m=[35m899[0m
[2m2023-05-04 01:38:50[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_899.pt[0m


Epoch 30/30:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:38:57[0m [[32m[1minfo     [0m] [1mCQL_20230504013510: epoch=30 step=930[0m [36mepoch[0m=[35m30[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010590476374472341, 'time_algorithm_update': 0.23783430745524745, 'temp_loss': 10.458904973922238, 'temp': 0.8874029421037243, 'alpha_loss': -35.43197853334488, 'alpha': 1.1339762172391337, 'critic_loss': 490.8682034400202, 'actor_loss': -74.2077405375819, 'time_step': 0.23901430253059633}[0m [36mstep[0m=[35m930[0m
[2m2023-05-04 01:38:57[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504013510/model_930.pt[0m


[(1,
  {'time_sample_batch': 0.0010161707478184854,
   'time_algorithm_update': 0.3290915950652092,
   'temp_loss': 11.405332565307617,
   'temp': 0.9681331226902623,
   'alpha_loss': -32.28081783171623,
   'alpha': 1.0333774628177765,
   'critic_loss': 2233.6551868069555,
   'actor_loss': -38.560919607839274,
   'time_step': 0.33023206649288056}),
 (2,
  {'time_sample_batch': 0.000981999981787897,
   'time_algorithm_update': 0.2366661871633222,
   'temp_loss': 11.369608602216166,
   'temp': 0.9651807642752125,
   'alpha_loss': -32.39259682932207,
   'alpha': 1.0366356103650984,
   'critic_loss': 2119.0245991368447,
   'actor_loss': -40.40380416377898,
   'time_step': 0.23776971909307665}),
 (3,
  {'time_sample_batch': 0.0009399690935688634,
   'time_algorithm_update': 0.2364673614501953,
   'temp_loss': 11.33581853681995,
   'temp': 0.9622422341377505,
   'alpha_loss': -32.4925417746267,
   'alpha': 1.039909182056304,
   'critic_loss': 2022.7026839717741,
   'actor_loss': -41.97166701

In [21]:
# start training
cql.fit(cql_dataset,
        eval_episodes=None,
        n_epochs=10,
        scorers=None)

[2m2023-05-04 01:43:38[0m [[32m[1mdebug    [0m] [1mRoundIterator is selected.[0m
[2m2023-05-04 01:43:38[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/CQL_20230504014338[0m
[2m2023-05-04 01:43:38[0m [[32m[1minfo     [0m] [1mParameters are saved to d3rlpy_logs/CQL_20230504014338/params.json[0m [36mparams[0m=[35m{'action_scaler': None, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 0.0001, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conservative_weight': 5.0, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': No

Epoch 1/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:43:46[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=1 step=31[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010034115083755986, 'time_algorithm_update': 0.26138419489706716, 'temp_loss': 9.549976348876953, 'temp': 0.8105977639075248, 'alpha_loss': -39.20918446202432, 'alpha': 1.255960222213499, 'critic_loss': 134.8759064212922, 'actor_loss': -90.1291038759293, 'time_step': 0.26251200706728045}[0m [36mstep[0m=[35m31[0m
[2m2023-05-04 01:43:46[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_31.pt[0m


Epoch 2/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:43:53[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=2 step=62[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00093654663332047, 'time_algorithm_update': 0.24153282565455284, 'temp_loss': 9.522798476680633, 'temp': 0.8082602850852474, 'alpha_loss': -39.35825692453692, 'alpha': 1.2601512670516968, 'critic_loss': 130.87871649957472, 'actor_loss': -90.3958499047064, 'time_step': 0.24259078887201124}[0m [36mstep[0m=[35m62[0m
[2m2023-05-04 01:43:53[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_62.pt[0m


Epoch 3/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:44:01[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=3 step=93[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009963973875968687, 'time_algorithm_update': 0.23778229375039378, 'temp_loss': 9.501792630841654, 'temp': 0.8059299838158392, 'alpha_loss': -39.50162998322518, 'alpha': 1.2643601317559519, 'critic_loss': 127.34921116982737, 'actor_loss': -90.75865074896043, 'time_step': 0.23890026923148863}[0m [36mstep[0m=[35m93[0m
[2m2023-05-04 01:44:01[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_93.pt[0m


Epoch 4/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:44:08[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=4 step=124[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009777622838174143, 'time_algorithm_update': 0.23983653130069857, 'temp_loss': 9.464464956714261, 'temp': 0.8036083098380796, 'alpha_loss': -39.62295667586788, 'alpha': 1.268586197207051, 'critic_loss': 123.86523043724799, 'actor_loss': -90.94298479633946, 'time_step': 0.24093523333149572}[0m [36mstep[0m=[35m124[0m
[2m2023-05-04 01:44:08[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_124.pt[0m


Epoch 5/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:44:16[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=5 step=155[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009543895721435547, 'time_algorithm_update': 0.24044697515426144, 'temp_loss': 9.448508324161653, 'temp': 0.8012937864949626, 'alpha_loss': -39.784160737068426, 'alpha': 1.2728292519046414, 'critic_loss': 120.35653194304436, 'actor_loss': -91.22422126031691, 'time_step': 0.24152227370969712}[0m [36mstep[0m=[35m155[0m
[2m2023-05-04 01:44:16[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_155.pt[0m


Epoch 6/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:44:23[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=6 step=186[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009299401314027848, 'time_algorithm_update': 0.23765828532557334, 'temp_loss': 9.41686648707236, 'temp': 0.7989864310910625, 'alpha_loss': -39.8959968320785, 'alpha': 1.2770898765133274, 'critic_loss': 117.2592534711284, 'actor_loss': -91.54512836087135, 'time_step': 0.2387126261188138}[0m [36mstep[0m=[35m186[0m
[2m2023-05-04 01:44:23[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_186.pt[0m


Epoch 7/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:44:31[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=7 step=217[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009609883831393334, 'time_algorithm_update': 0.23920016134938887, 'temp_loss': 9.390970107047789, 'temp': 0.796687412646509, 'alpha_loss': -40.013798006119266, 'alpha': 1.2813646370364773, 'critic_loss': 113.86450490643901, 'actor_loss': -91.7711651709772, 'time_step': 0.24028462748373708}[0m [36mstep[0m=[35m217[0m
[2m2023-05-04 01:44:31[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_217.pt[0m


Epoch 8/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:44:38[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=8 step=248[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009419764241864604, 'time_algorithm_update': 0.24076934014597245, 'temp_loss': 9.365213978675104, 'temp': 0.7943963331560935, 'alpha_loss': -40.15049177600491, 'alpha': 1.2856551101130824, 'critic_loss': 111.25508166897681, 'actor_loss': -92.0129881828062, 'time_step': 0.2418307642782888}[0m [36mstep[0m=[35m248[0m
[2m2023-05-04 01:44:38[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_248.pt[0m


Epoch 9/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:44:46[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=9 step=279[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009593809804608745, 'time_algorithm_update': 0.23928150823039393, 'temp_loss': 9.333456808520902, 'temp': 0.7921131945425465, 'alpha_loss': -40.298756630189956, 'alpha': 1.2899631184916343, 'critic_loss': 108.35731334071005, 'actor_loss': -92.19059408864668, 'time_step': 0.2403572836229878}[0m [36mstep[0m=[35m279[0m
[2m2023-05-04 01:44:46[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_279.pt[0m


Epoch 10/10:   0%|          | 0/31 [00:00<?, ?it/s]

[2m2023-05-04 01:44:54[0m [[32m[1minfo     [0m] [1mCQL_20230504014338: epoch=10 step=310[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009909829785746912, 'time_algorithm_update': 0.24242529561442713, 'temp_loss': 9.30837578927317, 'temp': 0.7898380525650517, 'alpha_loss': -40.429276251023815, 'alpha': 1.2942882160986624, 'critic_loss': 106.09569254229146, 'actor_loss': -92.38607049757435, 'time_step': 0.2435398793989612}[0m [36mstep[0m=[35m310[0m
[2m2023-05-04 01:44:54[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/CQL_20230504014338/model_310.pt[0m


[(1,
  {'time_sample_batch': 0.0010034115083755986,
   'time_algorithm_update': 0.26138419489706716,
   'temp_loss': 9.549976348876953,
   'temp': 0.8105977639075248,
   'alpha_loss': -39.20918446202432,
   'alpha': 1.255960222213499,
   'critic_loss': 134.8759064212922,
   'actor_loss': -90.1291038759293,
   'time_step': 0.26251200706728045}),
 (2,
  {'time_sample_batch': 0.00093654663332047,
   'time_algorithm_update': 0.24153282565455284,
   'temp_loss': 9.522798476680633,
   'temp': 0.8082602850852474,
   'alpha_loss': -39.35825692453692,
   'alpha': 1.2601512670516968,
   'critic_loss': 130.87871649957472,
   'actor_loss': -90.3958499047064,
   'time_step': 0.24259078887201124}),
 (3,
  {'time_sample_batch': 0.0009963973875968687,
   'time_algorithm_update': 0.23778229375039378,
   'temp_loss': 9.501792630841654,
   'temp': 0.8059299838158392,
   'alpha_loss': -39.50162998322518,
   'alpha': 1.2643601317559519,
   'critic_loss': 127.34921116982737,
   'actor_loss': -90.75865074896