In [None]:
from custom_models.CustomViT import CustomViT
from custom_models.CustomViTMAE import CustomViTMAE
import torch
# call CustomViT
from transformers import AutoImageProcessor, ViTMAEForPreTraining, ViTMAEConfig
from PIL import Image

output_dir='/home/ubuntu/camelmera'
# trained_model_name = 'multimodal'
# output_dir='/home/ubuntu/weights/' + trained_model_name

# Initialize a new CustomViT model
model_name = "facebook/vit-mae-base"
vit_config = ViTMAEConfig.from_pretrained(model_name)
vit_config.output_hidden_states=True
vit_model = CustomViT(config=vit_config)

# Initialize a new CustomViTMAE model
model_name = "facebook/vit-mae-base"
config = ViTMAEConfig.from_pretrained(model_name)
config.output_hidden_states=True
custom_model = CustomViTMAE(config=config)
custom_model.vit = vit_model

# Load the state_dict from the saved model
state_dict = torch.load(f"{output_dir}/pytorch_model.bin")
custom_model.load_state_dict(state_dict)

# don't need decoders
vit_encoder = custom_model.vit

In [None]:
import numpy as np

def reward_function(state_embedding, goal_embedding, threshold=0.02, goal_reward=100):
    distance = np.linalg.norm(state_embedding - goal_embedding)

    if distance <= threshold:
        # Give a large positive reward when the goal is reached
        reward = goal_reward
    else:
        # Give a negative reward proportional to the distance otherwise
        reward = -distance

    return reward

In [None]:
from tem_dataloader import MultimodalDatasetPerTrajectory
import functools
import os
from torch.utils.data import Dataset, DataLoader

environment_name = 'AbandonedFactoryExposure'
environemnt_directory = f'/mnt/temp_mount/{environment_name}/Data_hard'
OBSERVATION_SIZE = 768
ACTION_SIZE = 7
BATCH_SIZE = 64

for i in range(0,10):
    if i==7:
        continue
    trajectory_folder_path = os.path.join(environemnt_directory, f'P00{i}')
    my_dataset = MultimodalDatasetPerTrajectory(trajectory_folder_path)
    train_dataloader = DataLoader(my_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize empty arrays for observations, actions, rewards, and terminals
    all_observations = np.empty((0, OBSERVATION_SIZE))
    all_actions = np.empty((0, ACTION_SIZE))
    all_rewards = np.empty(0)
    all_terminals = np.empty(0, dtype=bool)

    for batch_idx, data in enumerate(train_dataloader):
        # get embedding
        vit_encoder.cuda()
        vit_encoder.eval()
        pixel_values = data["pixel_values"].cuda()
        pixel_values1 = data["pixel_values1"].cuda()
        pixel_values2 = data["pixel_values2"].cuda()
        outputs = vit_encoder(pixel_values,pixel_values1,pixel_values2,noise=None)
        embedding = outputs.last_hidden_state[:,0,:]
        observation = embedding.cpu().detach().numpy()
        # get action
        pose = data["pose_values"]
        action = torch.diff(pose,axis = 0).numpy()
        action = np.concatenate((action, np.zeros((1,7))), axis=0)
        # get reward
        goal = observation[-1]
        partial_function = functools.partial(reward_function, goal_embedding=goal)
        reward = np.apply_along_axis(partial_function, 1, observation)
        # get terminals
        terminals = np.zeros_like(reward, dtype=int)
        terminals[reward == 100] = 1

        # Concatenate observations, actions, rewards, and terminals
        all_observations = np.vstack((all_observations, observation))
        all_actions = np.vstack((all_actions, action))
        all_rewards = np.hstack((all_rewards, reward))
        all_terminals = np.hstack((all_terminals, terminals))

    print("All observations shape:", all_observations.shape)
    print("All actions shape:", all_actions.shape)
    print("All rewards shape:", all_rewards.shape)
    print("All terminals shape:", all_terminals.shape)
    
    np.save(f'hard/all_observations_P00{i}.npy', all_observations)
    np.save(f'hard/all_actions_P00{i}.npy', all_actions)
    np.save(f'hard/all_rewards_P00{i}.npy', all_rewards)
    np.save(f'hard/all_terminals_P00{i}.npy', all_terminals)

In [None]:
'''
Args:
        observations (numpy.ndarray): N-D array. If the
            observation is a vector, the shape should be
            `(N, dim_observation)`. If the observations is an image, the shape
            should be `(N, C, H, W)`.
        actions (numpy.ndarray): N-D array. If the actions-space is
            continuous, the shape should be `(N, dim_action)`. If the
            action-space is discrete, the shape should be `(N,)`.
        rewards (numpy.ndarray): array of scalar rewards. The reward function
            should be defined as :math:`r_t = r(s_t, a_t)`.
        terminals (numpy.ndarray): array of binary terminal flags.
        episode_terminals (numpy.ndarray): array of binary episode terminal
            flags. The given data will be splitted based on this flag.
            This is useful if you want to specify the non-environment
            terminations (e.g. timeout). If ``None``, the episode terminations
            match the environment terminations.
        discrete_action (bool): flag to use the given actions as discrete
            action-space actions. If ``None``, the action type is automatically
            determined.
    '''
hard_all_observations = np.load('hard/all_observations.npy')
hard_all_actions = np.load('hard/all_actions.npy')
hard_all_rewards = np.load('hard/all_rewards.npy')
hard_all_terminals = np.load('hard/all_terminals.npy')
print(np.count_nonzero(hard_all_terminals == 1))
# cql_dataset = MDPDataset(observations=all_observations,actions=all_actions,rewards=all_rewards,terminals=all_terminals,episode_terminals=all_terminals)

In [2]:
import numpy as np
from d3rlpy.dataset import Episode, MDPDataset, Transition

all_observations = np.load('all_observations.npy')
all_actions = np.load('all_actions.npy')
all_rewards = np.load('all_rewards.npy')
all_terminals = np.load('all_terminals.npy')
print(np.count_nonzero(all_terminals == 1))
# all_observations = np.vstack((all_observations, hard_all_observations))
# all_actions = np.vstack((all_actions, hard_all_actions))
# all_rewards = np.hstack((all_rewards, hard_all_rewards))
# all_terminals = np.hstack((all_terminals, hard_all_terminals))
# print("All observations shape:", all_observations.shape)
# print("All actions shape:", all_actions.shape)
# print("All rewards shape:", all_rewards.shape)
# print("All terminals shape:", all_terminals.shape)
cql_dataset = MDPDataset(observations=all_observations,actions=all_actions,rewards=all_rewards,terminals=all_terminals,episode_terminals=all_terminals)

132


In [3]:
print(cql_dataset.actions.shape)

(8185, 7)


In [9]:
from d3rlpy.algos import BEAR,CQL,TD3PlusBC,AWAC,BCQ

# setup CQL algorithm
model = BCQ(use_gpu=True)

# split train and test episodes
# train_episodes, test_episodes = train_test_split(cql_dataset, test_size=0.25)

# start training
model.fit(cql_dataset,
        eval_episodes=None,
        n_epochs=80,
        scorers=None)

[2m2023-05-04 21:01:18[0m [[32m[1mdebug    [0m] [1mRoundIterator is selected.[0m
[2m2023-05-04 21:01:18[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/BCQ_20230504210118[0m
[2m2023-05-04 21:01:18[0m [[32m[1mdebug    [0m] [1mBuilding models...[0m
[2m2023-05-04 21:01:18[0m [[32m[1mdebug    [0m] [1mModels have been built.[0m
[2m2023-05-04 21:01:18[0m [[32m[1minfo     [0m] [1mParameters are saved to d3rlpy_logs/BCQ_20230504210118/params.json[0m [36mparams[0m=[35m{'action_flexibility': 0.05, 'action_scaler': None, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 0.001, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'batch_size': 100, 'beta': 0.5, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': N

Epoch 1/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:20[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=1 step=81[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032325732855149257, 'time_algorithm_update': 0.02111275107772262, 'imitator_loss': 0.010488384989676651, 'critic_loss': 319.94853836518746, 'actor_loss': -1.7687759609134108, 'time_step': 0.021529871740458922}[0m [36mstep[0m=[35m81[0m
[2m2023-05-04 21:01:20[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_81.pt[0m


Epoch 2/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:22[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=2 step=162[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003158074838143808, 'time_algorithm_update': 0.019842662929016867, 'imitator_loss': 0.008717886540164917, 'critic_loss': 311.5140074432632, 'actor_loss': -2.2175454239786405, 'time_step': 0.02024473378687729}[0m [36mstep[0m=[35m162[0m
[2m2023-05-04 21:01:22[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_162.pt[0m


Epoch 3/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:24[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=3 step=243[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029529171225465374, 'time_algorithm_update': 0.01951718036039376, 'imitator_loss': 0.008638597133764882, 'critic_loss': 310.9405100698824, 'actor_loss': -2.7947657858883894, 'time_step': 0.019894932523185825}[0m [36mstep[0m=[35m243[0m
[2m2023-05-04 21:01:24[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_243.pt[0m


Epoch 4/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:25[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=4 step=324[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00028861010516131365, 'time_algorithm_update': 0.01956101405767747, 'imitator_loss': 0.008628414179209941, 'critic_loss': 309.75574048065846, 'actor_loss': -3.3226466590975536, 'time_step': 0.019930300889191805}[0m [36mstep[0m=[35m324[0m
[2m2023-05-04 21:01:25[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_324.pt[0m


Epoch 5/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:27[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=5 step=405[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002907293814199942, 'time_algorithm_update': 0.019490059511161145, 'imitator_loss': 0.008613380488514163, 'critic_loss': 304.17909417329014, 'actor_loss': -3.8598726472736877, 'time_step': 0.019863099227716893}[0m [36mstep[0m=[35m405[0m
[2m2023-05-04 21:01:27[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_405.pt[0m


Epoch 6/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:29[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=6 step=486[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002993477715386285, 'time_algorithm_update': 0.01999244866547761, 'imitator_loss': 0.008602776737124831, 'critic_loss': 304.97822382190714, 'actor_loss': -4.202215241797177, 'time_step': 0.020375519622991115}[0m [36mstep[0m=[35m486[0m
[2m2023-05-04 21:01:29[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_486.pt[0m


Epoch 7/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:30[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=7 step=567[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002915976959982036, 'time_algorithm_update': 0.01950434696527175, 'imitator_loss': 0.008586331184401556, 'critic_loss': 298.02980420766056, 'actor_loss': -4.683914084493378, 'time_step': 0.019874596301420234}[0m [36mstep[0m=[35m567[0m
[2m2023-05-04 21:01:30[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_567.pt[0m


Epoch 8/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:32[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=8 step=648[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000288000813236943, 'time_algorithm_update': 0.01953754601655183, 'imitator_loss': 0.0085996125974221, 'critic_loss': 296.0382679730286, 'actor_loss': -5.281346303445321, 'time_step': 0.019906997680664062}[0m [36mstep[0m=[35m648[0m
[2m2023-05-04 21:01:32[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_648.pt[0m


Epoch 9/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:34[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=9 step=729[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002993978100058473, 'time_algorithm_update': 0.01964838710832007, 'imitator_loss': 0.008577175978801133, 'critic_loss': 290.7807633711232, 'actor_loss': -5.784067424727075, 'time_step': 0.020030207104153104}[0m [36mstep[0m=[35m729[0m
[2m2023-05-04 21:01:34[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_729.pt[0m


Epoch 10/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:35[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=10 step=810[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029605111958068094, 'time_algorithm_update': 0.01965213116304374, 'imitator_loss': 0.008585055324214476, 'critic_loss': 292.4338002782545, 'actor_loss': -6.121489583710094, 'time_step': 0.020029733210434147}[0m [36mstep[0m=[35m810[0m
[2m2023-05-04 21:01:35[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_810.pt[0m


Epoch 11/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:37[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=11 step=891[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029386414421929256, 'time_algorithm_update': 0.019658736240716627, 'imitator_loss': 0.008587525789568454, 'critic_loss': 285.46332990313755, 'actor_loss': -6.631961869604794, 'time_step': 0.020036720935209296}[0m [36mstep[0m=[35m891[0m
[2m2023-05-04 21:01:37[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_891.pt[0m


Epoch 12/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:39[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=12 step=972[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029506212399329666, 'time_algorithm_update': 0.019657779622960975, 'imitator_loss': 0.008583633781031327, 'critic_loss': 281.0507535934448, 'actor_loss': -6.929077442781425, 'time_step': 0.020035764317453644}[0m [36mstep[0m=[35m972[0m
[2m2023-05-04 21:01:39[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_972.pt[0m


Epoch 13/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:40[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=13 step=1053[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002831676859914521, 'time_algorithm_update': 0.01958253354202082, 'imitator_loss': 0.008586852259382053, 'critic_loss': 281.0065495122086, 'actor_loss': -7.586777080724269, 'time_step': 0.019947163852644556}[0m [36mstep[0m=[35m1053[0m
[2m2023-05-04 21:01:40[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1053.pt[0m


Epoch 14/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:42[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=14 step=1134[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003364880879720052, 'time_algorithm_update': 0.019871847129162446, 'imitator_loss': 0.008585193907313141, 'critic_loss': 275.7055416938699, 'actor_loss': -7.786989783063347, 'time_step': 0.020289712482028537}[0m [36mstep[0m=[35m1134[0m
[2m2023-05-04 21:01:42[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1134.pt[0m


Epoch 15/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:44[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=15 step=1215[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002932725129304109, 'time_algorithm_update': 0.01957367084644459, 'imitator_loss': 0.008613189429412653, 'critic_loss': 278.78586706629505, 'actor_loss': -8.38448565683247, 'time_step': 0.019948561986287434}[0m [36mstep[0m=[35m1215[0m
[2m2023-05-04 21:01:44[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1215.pt[0m


Epoch 16/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:45[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=16 step=1296[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029707837987829136, 'time_algorithm_update': 0.01953971238783848, 'imitator_loss': 0.00858488255812798, 'critic_loss': 272.0738640273059, 'actor_loss': -8.874909271428615, 'time_step': 0.019919913492085023}[0m [36mstep[0m=[35m1296[0m
[2m2023-05-04 21:01:45[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1296.pt[0m


Epoch 17/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:47[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=17 step=1377[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030937312561788675, 'time_algorithm_update': 0.01971698984687711, 'imitator_loss': 0.008574377624662939, 'critic_loss': 268.9036940821895, 'actor_loss': -9.32814915974935, 'time_step': 0.02011249389177487}[0m [36mstep[0m=[35m1377[0m
[2m2023-05-04 21:01:47[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1377.pt[0m


Epoch 18/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:49[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=18 step=1458[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002893106437023775, 'time_algorithm_update': 0.019582763130282177, 'imitator_loss': 0.008603044486616128, 'critic_loss': 265.49002343423115, 'actor_loss': -9.618602258187753, 'time_step': 0.019952824086318783}[0m [36mstep[0m=[35m1458[0m
[2m2023-05-04 21:01:49[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1458.pt[0m


Epoch 19/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:50[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=19 step=1539[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00028756224078896607, 'time_algorithm_update': 0.01956742192492073, 'imitator_loss': 0.008574987479979977, 'critic_loss': 267.6051225691666, 'actor_loss': -10.32958807768645, 'time_step': 0.019937032534752364}[0m [36mstep[0m=[35m1539[0m
[2m2023-05-04 21:01:50[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1539.pt[0m


Epoch 20/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:52[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=20 step=1620[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029066756919578267, 'time_algorithm_update': 0.019503499254768276, 'imitator_loss': 0.008585109317928184, 'critic_loss': 262.6185732204467, 'actor_loss': -10.673474335376127, 'time_step': 0.01987500543947573}[0m [36mstep[0m=[35m1620[0m
[2m2023-05-04 21:01:52[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1620.pt[0m


Epoch 21/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:54[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=21 step=1701[0m [36mepoch[0m=[35m21[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00028769175211588543, 'time_algorithm_update': 0.01948656564877357, 'imitator_loss': 0.008581367551268619, 'critic_loss': 262.568152303313, 'actor_loss': -11.19072115862811, 'time_step': 0.01985523730148504}[0m [36mstep[0m=[35m1701[0m
[2m2023-05-04 21:01:54[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1701.pt[0m


Epoch 22/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:55[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=22 step=1782[0m [36mepoch[0m=[35m22[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002863848650896991, 'time_algorithm_update': 0.019428412119547527, 'imitator_loss': 0.008553540868753637, 'critic_loss': 259.82380175112206, 'actor_loss': -11.612115695152754, 'time_step': 0.019796677577642748}[0m [36mstep[0m=[35m1782[0m
[2m2023-05-04 21:01:55[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1782.pt[0m


Epoch 23/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:57[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=23 step=1863[0m [36mepoch[0m=[35m23[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000301993923422731, 'time_algorithm_update': 0.019750600979651933, 'imitator_loss': 0.008592207065243045, 'critic_loss': 256.29457344666673, 'actor_loss': -12.062135190139582, 'time_step': 0.020136947985048646}[0m [36mstep[0m=[35m1863[0m
[2m2023-05-04 21:01:57[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1863.pt[0m


Epoch 24/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:01:59[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=24 step=1944[0m [36mepoch[0m=[35m24[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002901053722993827, 'time_algorithm_update': 0.019522331379078054, 'imitator_loss': 0.008576232859473905, 'critic_loss': 253.16603427904624, 'actor_loss': -12.534609300118905, 'time_step': 0.01989228931474097}[0m [36mstep[0m=[35m1944[0m
[2m2023-05-04 21:01:59[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_1944.pt[0m


Epoch 25/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:00[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=25 step=2025[0m [36mepoch[0m=[35m25[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002924571802586685, 'time_algorithm_update': 0.01956528498802656, 'imitator_loss': 0.008589022608910814, 'critic_loss': 249.79497300401147, 'actor_loss': -13.018383650132167, 'time_step': 0.01993870440824532}[0m [36mstep[0m=[35m2025[0m
[2m2023-05-04 21:02:00[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2025.pt[0m


Epoch 26/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:02[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=26 step=2106[0m [36mepoch[0m=[35m26[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002986148551658348, 'time_algorithm_update': 0.019821973494541498, 'imitator_loss': 0.008588179146847974, 'critic_loss': 246.74216432924624, 'actor_loss': -13.383858786688911, 'time_step': 0.020204267384093484}[0m [36mstep[0m=[35m2106[0m
[2m2023-05-04 21:02:02[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2106.pt[0m


Epoch 27/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:04[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=27 step=2187[0m [36mepoch[0m=[35m27[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029372874601387686, 'time_algorithm_update': 0.01947074466281467, 'imitator_loss': 0.008583230397629517, 'critic_loss': 244.9531644071326, 'actor_loss': -13.919462039146895, 'time_step': 0.01984746073499138}[0m [36mstep[0m=[35m2187[0m
[2m2023-05-04 21:02:04[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2187.pt[0m


Epoch 28/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:05[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=28 step=2268[0m [36mepoch[0m=[35m28[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003183506153247975, 'time_algorithm_update': 0.019655348342141988, 'imitator_loss': 0.008590785981972276, 'critic_loss': 241.60206060792194, 'actor_loss': -14.306414462901929, 'time_step': 0.020060854193605024}[0m [36mstep[0m=[35m2268[0m
[2m2023-05-04 21:02:05[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2268.pt[0m


Epoch 29/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:07[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=29 step=2349[0m [36mepoch[0m=[35m29[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031783551345636817, 'time_algorithm_update': 0.01996451837045175, 'imitator_loss': 0.008577923963053359, 'critic_loss': 241.47752985836547, 'actor_loss': -14.900261337374463, 'time_step': 0.020370171393877194}[0m [36mstep[0m=[35m2349[0m
[2m2023-05-04 21:02:07[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2349.pt[0m


Epoch 30/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:09[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=30 step=2430[0m [36mepoch[0m=[35m30[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003113187389609254, 'time_algorithm_update': 0.01984747839562687, 'imitator_loss': 0.008579155240483858, 'critic_loss': 237.2097368174129, 'actor_loss': -15.226341129821023, 'time_step': 0.020245943540408286}[0m [36mstep[0m=[35m2430[0m
[2m2023-05-04 21:02:09[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2430.pt[0m


Epoch 31/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:11[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=31 step=2511[0m [36mepoch[0m=[35m31[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003238548467188706, 'time_algorithm_update': 0.020016785021181458, 'imitator_loss': 0.00856896513429137, 'critic_loss': 235.6596167664745, 'actor_loss': -15.68218226491669, 'time_step': 0.02042752136418849}[0m [36mstep[0m=[35m2511[0m
[2m2023-05-04 21:02:11[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2511.pt[0m


Epoch 32/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:12[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=32 step=2592[0m [36mepoch[0m=[35m32[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003244670820824894, 'time_algorithm_update': 0.019992269115683473, 'imitator_loss': 0.008584673637179312, 'critic_loss': 230.6327829596437, 'actor_loss': -16.056468857659233, 'time_step': 0.020405107074313693}[0m [36mstep[0m=[35m2592[0m
[2m2023-05-04 21:02:12[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2592.pt[0m


Epoch 33/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:14[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=33 step=2673[0m [36mepoch[0m=[35m33[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003041190865599079, 'time_algorithm_update': 0.01950350808508602, 'imitator_loss': 0.00857972527888638, 'critic_loss': 229.69744522409675, 'actor_loss': -16.512379328409832, 'time_step': 0.01989323710217888}[0m [36mstep[0m=[35m2673[0m
[2m2023-05-04 21:02:14[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2673.pt[0m


Epoch 34/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:16[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=34 step=2754[0m [36mepoch[0m=[35m34[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002971931740089699, 'time_algorithm_update': 0.0194277321850812, 'imitator_loss': 0.008593928856476222, 'critic_loss': 227.87622442289634, 'actor_loss': -17.00969763155337, 'time_step': 0.019806017110377182}[0m [36mstep[0m=[35m2754[0m
[2m2023-05-04 21:02:16[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2754.pt[0m


Epoch 35/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:17[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=35 step=2835[0m [36mepoch[0m=[35m35[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031328495637870127, 'time_algorithm_update': 0.02002512084113227, 'imitator_loss': 0.00858172970927424, 'critic_loss': 225.1070935372953, 'actor_loss': -17.381909405743635, 'time_step': 0.020423812630735796}[0m [36mstep[0m=[35m2835[0m
[2m2023-05-04 21:02:17[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2835.pt[0m


Epoch 36/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:19[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=36 step=2916[0m [36mepoch[0m=[35m36[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030391304581253617, 'time_algorithm_update': 0.0196403103110231, 'imitator_loss': 0.008579459202326374, 'critic_loss': 225.44697609635782, 'actor_loss': -17.786373774210613, 'time_step': 0.02002823794329608}[0m [36mstep[0m=[35m2916[0m
[2m2023-05-04 21:02:19[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2916.pt[0m


Epoch 37/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:21[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=37 step=2997[0m [36mepoch[0m=[35m37[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003177060021294488, 'time_algorithm_update': 0.019822441501381956, 'imitator_loss': 0.008581828854886103, 'critic_loss': 220.60348973110501, 'actor_loss': -18.27165963914659, 'time_step': 0.020225533732661494}[0m [36mstep[0m=[35m2997[0m
[2m2023-05-04 21:02:21[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_2997.pt[0m


Epoch 38/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:22[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=38 step=3078[0m [36mepoch[0m=[35m38[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002910472728587963, 'time_algorithm_update': 0.019359429677327473, 'imitator_loss': 0.008593121330817173, 'critic_loss': 217.06708073100927, 'actor_loss': -18.635358975257404, 'time_step': 0.01973062680091387}[0m [36mstep[0m=[35m3078[0m
[2m2023-05-04 21:02:22[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3078.pt[0m


Epoch 39/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:24[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=39 step=3159[0m [36mepoch[0m=[35m39[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003031506950472608, 'time_algorithm_update': 0.01977776598047327, 'imitator_loss': 0.008561335383328024, 'critic_loss': 217.6804795226105, 'actor_loss': -19.11231971081392, 'time_step': 0.02016476054250458}[0m [36mstep[0m=[35m3159[0m
[2m2023-05-04 21:02:24[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3159.pt[0m


Epoch 40/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:26[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=40 step=3240[0m [36mepoch[0m=[35m40[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030450173366216964, 'time_algorithm_update': 0.019567498454341182, 'imitator_loss': 0.008567064672846484, 'critic_loss': 207.93111621367711, 'actor_loss': -19.40344695103021, 'time_step': 0.01995435467472783}[0m [36mstep[0m=[35m3240[0m
[2m2023-05-04 21:02:26[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3240.pt[0m


Epoch 41/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:27[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=41 step=3321[0m [36mepoch[0m=[35m41[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0004039517155400029, 'time_algorithm_update': 0.01981258981021834, 'imitator_loss': 0.008580409530780198, 'critic_loss': 204.1167601597162, 'actor_loss': -19.861642672691815, 'time_step': 0.020297812826839495}[0m [36mstep[0m=[35m3321[0m
[2m2023-05-04 21:02:27[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3321.pt[0m


Epoch 42/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:29[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=42 step=3402[0m [36mepoch[0m=[35m42[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029250721872588735, 'time_algorithm_update': 0.01954695913526747, 'imitator_loss': 0.008572966032833964, 'critic_loss': 212.1454207389443, 'actor_loss': -20.22589820108296, 'time_step': 0.01992030202606578}[0m [36mstep[0m=[35m3402[0m
[2m2023-05-04 21:02:29[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3402.pt[0m


Epoch 43/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:31[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=43 step=3483[0m [36mepoch[0m=[35m43[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030634727007077065, 'time_algorithm_update': 0.019667431160255714, 'imitator_loss': 0.00855419520334697, 'critic_loss': 207.344197915798, 'actor_loss': -20.76913390924901, 'time_step': 0.020058525933159724}[0m [36mstep[0m=[35m3483[0m
[2m2023-05-04 21:02:31[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3483.pt[0m


Epoch 44/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:32[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=44 step=3564[0m [36mepoch[0m=[35m44[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030345092585057386, 'time_algorithm_update': 0.019676379215570142, 'imitator_loss': 0.008593746001061834, 'critic_loss': 208.2647953795982, 'actor_loss': -21.002663388664338, 'time_step': 0.020062690899695878}[0m [36mstep[0m=[35m3564[0m
[2m2023-05-04 21:02:33[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3564.pt[0m


Epoch 45/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:34[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=45 step=3645[0m [36mepoch[0m=[35m45[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030498151425962096, 'time_algorithm_update': 0.019568967230526018, 'imitator_loss': 0.008574647300037336, 'critic_loss': 204.33430000219815, 'actor_loss': -21.427218707991234, 'time_step': 0.01995798099188157}[0m [36mstep[0m=[35m3645[0m
[2m2023-05-04 21:02:34[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3645.pt[0m


Epoch 46/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:36[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=46 step=3726[0m [36mepoch[0m=[35m46[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031056521851339456, 'time_algorithm_update': 0.019859711329142254, 'imitator_loss': 0.008579743911086776, 'critic_loss': 203.23920614134383, 'actor_loss': -21.895208971000013, 'time_step': 0.020256433957888755}[0m [36mstep[0m=[35m3726[0m
[2m2023-05-04 21:02:36[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3726.pt[0m


Epoch 47/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:38[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=47 step=3807[0m [36mepoch[0m=[35m47[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030515223373601464, 'time_algorithm_update': 0.019550223409393688, 'imitator_loss': 0.008588346807907024, 'critic_loss': 200.49648917380168, 'actor_loss': -22.212132936642494, 'time_step': 0.019940903157363705}[0m [36mstep[0m=[35m3807[0m
[2m2023-05-04 21:02:38[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3807.pt[0m


Epoch 48/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:39[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=48 step=3888[0m [36mepoch[0m=[35m48[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003171202577190635, 'time_algorithm_update': 0.019863102171156142, 'imitator_loss': 0.008580992389240382, 'critic_loss': 197.15230824623578, 'actor_loss': -22.61043404944149, 'time_step': 0.020267039169499904}[0m [36mstep[0m=[35m3888[0m
[2m2023-05-04 21:02:39[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3888.pt[0m


Epoch 49/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:41[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=49 step=3969[0m [36mepoch[0m=[35m49[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003302568270836347, 'time_algorithm_update': 0.02022744991161205, 'imitator_loss': 0.00859149586248361, 'critic_loss': 196.63896854424183, 'actor_loss': -23.018530504203138, 'time_step': 0.02064745514481156}[0m [36mstep[0m=[35m3969[0m
[2m2023-05-04 21:02:41[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_3969.pt[0m


Epoch 50/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:43[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=50 step=4050[0m [36mepoch[0m=[35m50[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002751085493299696, 'time_algorithm_update': 0.019288419205465434, 'imitator_loss': 0.008578642863596294, 'critic_loss': 194.3654245366285, 'actor_loss': -23.395609820330584, 'time_step': 0.01964350982948586}[0m [36mstep[0m=[35m4050[0m
[2m2023-05-04 21:02:43[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4050.pt[0m


Epoch 51/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:44[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=51 step=4131[0m [36mepoch[0m=[35m51[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002807216879762249, 'time_algorithm_update': 0.019557487817458166, 'imitator_loss': 0.008567929371363588, 'critic_loss': 192.8705237272805, 'actor_loss': -23.737825723341953, 'time_step': 0.019916340156837745}[0m [36mstep[0m=[35m4131[0m
[2m2023-05-04 21:02:44[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4131.pt[0m


Epoch 52/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:46[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=52 step=4212[0m [36mepoch[0m=[35m52[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002744021239104094, 'time_algorithm_update': 0.01943147035292637, 'imitator_loss': 0.008563689174659458, 'critic_loss': 192.4923261663428, 'actor_loss': -24.069495730929905, 'time_step': 0.019786743470180182}[0m [36mstep[0m=[35m4212[0m
[2m2023-05-04 21:02:46[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4212.pt[0m


Epoch 53/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:48[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=53 step=4293[0m [36mepoch[0m=[35m53[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002870530257990331, 'time_algorithm_update': 0.019594195448322062, 'imitator_loss': 0.008632757091595803, 'critic_loss': 187.36710734978135, 'actor_loss': -24.482150631186403, 'time_step': 0.019959888340514383}[0m [36mstep[0m=[35m4293[0m
[2m2023-05-04 21:02:48[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4293.pt[0m


Epoch 54/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:49[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=54 step=4374[0m [36mepoch[0m=[35m54[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002793412149688344, 'time_algorithm_update': 0.019350210825602215, 'imitator_loss': 0.008576859472848383, 'critic_loss': 185.3161191697014, 'actor_loss': -24.87325272736726, 'time_step': 0.019709590040607218}[0m [36mstep[0m=[35m4374[0m
[2m2023-05-04 21:02:49[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4374.pt[0m


Epoch 55/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:51[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=55 step=4455[0m [36mepoch[0m=[35m55[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002891575848614728, 'time_algorithm_update': 0.0195412900712755, 'imitator_loss': 0.008551271632313728, 'critic_loss': 187.54402527728197, 'actor_loss': -25.242511419602383, 'time_step': 0.019911857298862787}[0m [36mstep[0m=[35m4455[0m
[2m2023-05-04 21:02:51[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4455.pt[0m


Epoch 56/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:53[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=56 step=4536[0m [36mepoch[0m=[35m56[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002896197048234351, 'time_algorithm_update': 0.019960156193485967, 'imitator_loss': 0.008584547448719356, 'critic_loss': 183.81578587308343, 'actor_loss': -25.6595144154113, 'time_step': 0.020333346025443372}[0m [36mstep[0m=[35m4536[0m
[2m2023-05-04 21:02:53[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4536.pt[0m


Epoch 57/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:54[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=57 step=4617[0m [36mepoch[0m=[35m57[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029779657905484424, 'time_algorithm_update': 0.01975553712727111, 'imitator_loss': 0.00857699627953547, 'critic_loss': 182.74270579108486, 'actor_loss': -25.890892758781526, 'time_step': 0.020137336519029405}[0m [36mstep[0m=[35m4617[0m
[2m2023-05-04 21:02:54[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4617.pt[0m


Epoch 58/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:56[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=58 step=4698[0m [36mepoch[0m=[35m58[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003050668739978178, 'time_algorithm_update': 0.019717749254203137, 'imitator_loss': 0.008564690576383361, 'critic_loss': 178.842926175575, 'actor_loss': -26.2914646760917, 'time_step': 0.020108458436565634}[0m [36mstep[0m=[35m4698[0m
[2m2023-05-04 21:02:56[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4698.pt[0m


Epoch 59/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:58[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=59 step=4779[0m [36mepoch[0m=[35m59[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030622070218309943, 'time_algorithm_update': 0.019873636740225332, 'imitator_loss': 0.00856865888438475, 'critic_loss': 178.17652978466387, 'actor_loss': -26.740089416503906, 'time_step': 0.02026771910396623}[0m [36mstep[0m=[35m4779[0m
[2m2023-05-04 21:02:58[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4779.pt[0m


Epoch 60/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:02:59[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=60 step=4860[0m [36mepoch[0m=[35m60[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00033512821903935184, 'time_algorithm_update': 0.020614824177306375, 'imitator_loss': 0.008545585844757748, 'critic_loss': 176.15818910420307, 'actor_loss': -27.04816578052662, 'time_step': 0.02104204378010314}[0m [36mstep[0m=[35m4860[0m
[2m2023-05-04 21:03:00[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4860.pt[0m


Epoch 61/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:01[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=61 step=4941[0m [36mepoch[0m=[35m61[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00033481915791829425, 'time_algorithm_update': 0.020199946415277174, 'imitator_loss': 0.008580900699958021, 'critic_loss': 168.20369695016632, 'actor_loss': -27.328501713128738, 'time_step': 0.020628475848539374}[0m [36mstep[0m=[35m4941[0m
[2m2023-05-04 21:03:01[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_4941.pt[0m


Epoch 62/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:03[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=62 step=5022[0m [36mepoch[0m=[35m62[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032446119520399306, 'time_algorithm_update': 0.02027632866376712, 'imitator_loss': 0.008568738690680928, 'critic_loss': 174.49320276552973, 'actor_loss': -27.708910200330948, 'time_step': 0.020690882647479023}[0m [36mstep[0m=[35m5022[0m
[2m2023-05-04 21:03:03[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5022.pt[0m


Epoch 63/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:05[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=63 step=5103[0m [36mepoch[0m=[35m63[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003046871703348042, 'time_algorithm_update': 0.019693613052368164, 'imitator_loss': 0.00858806240958748, 'critic_loss': 172.9761761188231, 'actor_loss': -28.065712210572798, 'time_step': 0.02008332735226478}[0m [36mstep[0m=[35m5103[0m
[2m2023-05-04 21:03:05[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5103.pt[0m


Epoch 64/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:06[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=64 step=5184[0m [36mepoch[0m=[35m64[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002975375564010055, 'time_algorithm_update': 0.019607682286957164, 'imitator_loss': 0.008581663441648821, 'critic_loss': 170.60141326652632, 'actor_loss': -28.43981476771979, 'time_step': 0.019987079832288954}[0m [36mstep[0m=[35m5184[0m
[2m2023-05-04 21:03:06[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5184.pt[0m


Epoch 65/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:08[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=65 step=5265[0m [36mepoch[0m=[35m65[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002978054093725887, 'time_algorithm_update': 0.019540345227276837, 'imitator_loss': 0.0085637679748973, 'critic_loss': 167.77948281136744, 'actor_loss': -28.78329116915479, 'time_step': 0.019920516897130897}[0m [36mstep[0m=[35m5265[0m
[2m2023-05-04 21:03:08[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5265.pt[0m


Epoch 66/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:10[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=66 step=5346[0m [36mepoch[0m=[35m66[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002961129318048925, 'time_algorithm_update': 0.019559536451174888, 'imitator_loss': 0.008569841135155272, 'critic_loss': 169.29991422901736, 'actor_loss': -29.003870763896423, 'time_step': 0.019939419663982626}[0m [36mstep[0m=[35m5346[0m
[2m2023-05-04 21:03:10[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5346.pt[0m


Epoch 67/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:11[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=67 step=5427[0m [36mepoch[0m=[35m67[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003034744733645592, 'time_algorithm_update': 0.019954696113680614, 'imitator_loss': 0.008584248212476572, 'critic_loss': 165.37565677079522, 'actor_loss': -29.41000516915027, 'time_step': 0.02034114319601177}[0m [36mstep[0m=[35m5427[0m
[2m2023-05-04 21:03:11[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5427.pt[0m


Epoch 68/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:13[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=68 step=5508[0m [36mepoch[0m=[35m68[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002893489084126037, 'time_algorithm_update': 0.019410321741928287, 'imitator_loss': 0.008565288267017883, 'critic_loss': 164.758679613655, 'actor_loss': -29.77650296246564, 'time_step': 0.01978248725702733}[0m [36mstep[0m=[35m5508[0m
[2m2023-05-04 21:03:13[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5508.pt[0m


Epoch 69/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:15[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=69 step=5589[0m [36mepoch[0m=[35m69[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002899552568977262, 'time_algorithm_update': 0.019475860360227984, 'imitator_loss': 0.008582294631151505, 'critic_loss': 163.19943907378632, 'actor_loss': -30.03664621894742, 'time_step': 0.019844081666734483}[0m [36mstep[0m=[35m5589[0m
[2m2023-05-04 21:03:15[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5589.pt[0m


Epoch 70/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:16[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=70 step=5670[0m [36mepoch[0m=[35m70[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00028898980882432725, 'time_algorithm_update': 0.019495292946144386, 'imitator_loss': 0.008591349236667156, 'critic_loss': 162.63384576105042, 'actor_loss': -30.380650296623323, 'time_step': 0.019865436318479937}[0m [36mstep[0m=[35m5670[0m
[2m2023-05-04 21:03:16[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5670.pt[0m


Epoch 71/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:18[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=71 step=5751[0m [36mepoch[0m=[35m71[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000293151831921236, 'time_algorithm_update': 0.019521969336050528, 'imitator_loss': 0.008584323827821164, 'critic_loss': 159.00899995182766, 'actor_loss': -30.693816714816624, 'time_step': 0.019896860475893372}[0m [36mstep[0m=[35m5751[0m
[2m2023-05-04 21:03:18[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5751.pt[0m


Epoch 72/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:20[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=72 step=5832[0m [36mepoch[0m=[35m72[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029590394761827256, 'time_algorithm_update': 0.019475201029836395, 'imitator_loss': 0.008572669918246475, 'critic_loss': 158.92796539736014, 'actor_loss': -31.099993058192876, 'time_step': 0.019852305635993862}[0m [36mstep[0m=[35m5832[0m
[2m2023-05-04 21:03:20[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5832.pt[0m


Epoch 73/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:21[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=73 step=5913[0m [36mepoch[0m=[35m73[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030042507030345777, 'time_algorithm_update': 0.019603923515037255, 'imitator_loss': 0.008566719445365446, 'critic_loss': 155.08595515603636, 'actor_loss': -31.39128225821036, 'time_step': 0.019984265904367707}[0m [36mstep[0m=[35m5913[0m
[2m2023-05-04 21:03:21[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5913.pt[0m


Epoch 74/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:23[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=74 step=5994[0m [36mepoch[0m=[35m74[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029909463576328606, 'time_algorithm_update': 0.019706364031191224, 'imitator_loss': 0.008587786755352109, 'critic_loss': 155.74564448661275, 'actor_loss': -31.685066411524645, 'time_step': 0.02009045636212384}[0m [36mstep[0m=[35m5994[0m
[2m2023-05-04 21:03:23[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_5994.pt[0m


Epoch 75/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:25[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=75 step=6075[0m [36mepoch[0m=[35m75[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029222464855806326, 'time_algorithm_update': 0.01958937409483356, 'imitator_loss': 0.00854798013712337, 'critic_loss': 153.80123550849564, 'actor_loss': -31.95303803903085, 'time_step': 0.019965863522188165}[0m [36mstep[0m=[35m6075[0m
[2m2023-05-04 21:03:25[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_6075.pt[0m


Epoch 76/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:26[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=76 step=6156[0m [36mepoch[0m=[35m76[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029075292893397957, 'time_algorithm_update': 0.019562385700367117, 'imitator_loss': 0.008582773940338765, 'critic_loss': 151.0496697206004, 'actor_loss': -32.28977172757372, 'time_step': 0.01993482201187699}[0m [36mstep[0m=[35m6156[0m
[2m2023-05-04 21:03:26[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_6156.pt[0m


Epoch 77/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:28[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=77 step=6237[0m [36mepoch[0m=[35m77[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00030041918342496143, 'time_algorithm_update': 0.019622661449291087, 'imitator_loss': 0.00859365806006539, 'critic_loss': 152.12897086755177, 'actor_loss': -32.61534712049696, 'time_step': 0.0200056411601879}[0m [36mstep[0m=[35m6237[0m
[2m2023-05-04 21:03:28[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_6237.pt[0m


Epoch 78/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:30[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=78 step=6318[0m [36mepoch[0m=[35m78[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029383470982681086, 'time_algorithm_update': 0.019603661548944166, 'imitator_loss': 0.008567686204189136, 'critic_loss': 148.91914911001322, 'actor_loss': -32.933667830478996, 'time_step': 0.019979382738654995}[0m [36mstep[0m=[35m6318[0m
[2m2023-05-04 21:03:30[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_6318.pt[0m


Epoch 79/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:31[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=79 step=6399[0m [36mepoch[0m=[35m79[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029874142305350596, 'time_algorithm_update': 0.019537351749561453, 'imitator_loss': 0.008567059326365038, 'critic_loss': 148.79886917599742, 'actor_loss': -33.278686947292755, 'time_step': 0.019919569109692985}[0m [36mstep[0m=[35m6399[0m
[2m2023-05-04 21:03:31[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_6399.pt[0m


Epoch 80/80:   0%|          | 0/81 [00:00<?, ?it/s]

[2m2023-05-04 21:03:33[0m [[32m[1minfo     [0m] [1mBCQ_20230504210118: epoch=80 step=6480[0m [36mepoch[0m=[35m80[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0002937199156961323, 'time_algorithm_update': 0.019468310438556437, 'imitator_loss': 0.008571585755289338, 'critic_loss': 145.1495082878772, 'actor_loss': -33.54675942880136, 'time_step': 0.019845320854657962}[0m [36mstep[0m=[35m6480[0m
[2m2023-05-04 21:03:33[0m [[32m[1minfo     [0m] [1mModel parameters are saved to d3rlpy_logs/BCQ_20230504210118/model_6480.pt[0m


[(1,
  {'time_sample_batch': 0.00032325732855149257,
   'time_algorithm_update': 0.02111275107772262,
   'imitator_loss': 0.010488384989676651,
   'critic_loss': 319.94853836518746,
   'actor_loss': -1.7687759609134108,
   'time_step': 0.021529871740458922}),
 (2,
  {'time_sample_batch': 0.0003158074838143808,
   'time_algorithm_update': 0.019842662929016867,
   'imitator_loss': 0.008717886540164917,
   'critic_loss': 311.5140074432632,
   'actor_loss': -2.2175454239786405,
   'time_step': 0.02024473378687729}),
 (3,
  {'time_sample_batch': 0.00029529171225465374,
   'time_algorithm_update': 0.01951718036039376,
   'imitator_loss': 0.008638597133764882,
   'critic_loss': 310.9405100698824,
   'actor_loss': -2.7947657858883894,
   'time_step': 0.019894932523185825}),
 (4,
  {'time_sample_batch': 0.00028861010516131365,
   'time_algorithm_update': 0.01956101405767747,
   'imitator_loss': 0.008628414179209941,
   'critic_loss': 309.75574048065846,
   'actor_loss': -3.3226466590975536,
   

In [None]:
print(np.sum(all_rewards[0:64]))

In [None]:
import numpy as np
from d3rlpy.dataset import Episode, MDPDataset, Transition

OBSERVATION_SIZE = 768
ACTION_SIZE = 7
BATCH_SIZE = 64

all_observations = np.empty((0, OBSERVATION_SIZE))
all_actions = np.empty((0, ACTION_SIZE))
all_rewards = np.empty(0)
all_terminals = np.empty(0, dtype=bool)

for i in range(0,10):
    if i==7:
        continue
    observation = np.load(f'hard/all_observations_P00{i}.npy')
    action = np.load(f'hard/all_actions_P00{i}.npy')
    reward = np.load(f'hard/all_rewards_P00{i}.npy')
    terminals = np.load(f'hard/all_terminals_P00{i}.npy')

    all_observations = np.vstack((all_observations, observation))
    all_actions = np.vstack((all_actions, action))
    all_rewards = np.hstack((all_rewards, reward))
    all_terminals = np.hstack((all_terminals, terminals))

    print("All observations shape:", all_observations.shape)
    print("All actions shape:", all_actions.shape)
    print("All rewards shape:", all_rewards.shape)
    print("All terminals shape:", all_terminals.shape)
cql_dataset = MDPDataset(observations=all_observations,actions=all_actions,rewards=all_rewards,terminals=all_terminals,episode_terminals=all_terminals)

In [None]:
from d3rlpy.algos import CQL
cql01 = CQL(use_gpu=False)
cql01.build_with_dataset(cql_dataset)
cql01.load_model('/home/ubuntu/camelmera/models/gym/multimodal/d3rlpy_logs/CQL_20230503011241/model_40.pt')

In [None]:
from d3rlpy.algos import CQL

cql = CQL.from_json('d3rlpy_logs/CQL_20230504012746/params.json')

# ready to load
cql.load_model('d3rlpy_logs/CQL_20230504012746/model_310.pt')

In [None]:
# start training
cql.fit(cql_dataset,
        eval_episodes=None,
        n_epochs=30,
        scorers=None)

In [None]:
# start training
cql.fit(cql_dataset,
        eval_episodes=None,
        n_epochs=10,
        scorers=None)

In [None]:
import numpy as np
all_rewards = np.load('all_rewards.npy')
print(all_rewards[0:64])

In [None]:
import functools
import numpy as np

all_observations = np.load('all_observations.npy')
# observation = np.load('hard/all_observations_P000.npy')
print(all_observations.shape)
# get reward
all_rewards = np.empty(0)
all_terminals = np.empty(0)
for i in range(0,len(all_observations),64): 
    goal = None
    if i+63 < len(all_observations):
        goal = all_observations[i+63]
        observation = all_observations[i:i+64]
    else:
        goal = all_observations[-1]
        observation = all_observations[i:]
    partial_function = functools.partial(reward_function, goal_embedding=goal,threshold=0.01)
    reward = np.apply_along_axis(partial_function, 1, observation)
    all_rewards = np.hstack((all_rewards, reward))
    # get terminals
    terminal = np.zeros_like(reward, dtype=int)
    terminal[reward == 100] = 1
    # print(np.count_nonzero(terminal == 1))
    all_terminals = np.hstack((all_terminals,terminal))
print(terminal)

In [None]:
count_100 = np.count_nonzero(all_terminals == 1)
print("Number of terminals:", count_100)
print(len(all_terminals)/64)

In [None]:
np.save('all_terminals.npy',all_terminals)
np.save('all_rewards.npy',all_rewards)

In [None]:
original_terminals = np.load('hard/all_terminals_P000.npy')
print(original_terminals)
print(np.count_nonzero(original_terminals == 1))
print(len(original_terminals)/64)

In [None]:
np.sum(reward[0:100])

In [None]:
print(len(reward))

In [None]:
# check difference around goal = observation[-100]
downsampled_observation = observation[::10,:]
print(downsampled_observation.shape)

In [None]:
# get reward
goal = downsampled_observation[-23]
partial_function = functools.partial(reward_function, goal_embedding=goal,threshold=0.02)
reward = np.apply_along_axis(partial_function, 1, downsampled_observation)
# get terminals
terminals = np.zeros_like(reward, dtype=int)
terminals[reward == 100] = 1
count_100 = np.count_nonzero(reward == 100)
print("Number of elements with the value 100:", count_100)

In [None]:
easy_actions = np.load('all_actions.npy')
hard_actions = np.load('hard/all_actions.npy')
print(np.min(easy_actions,axis=0)[0],np.max(easy_actions,axis=0)[0],np.mean(easy_actions,axis=0)[0],np.std(easy_actions,axis=0)[0])
print(np.min(easy_actions,axis=0)[3],np.max(easy_actions,axis=0)[3],np.mean(easy_actions,axis=0)[3],np.std(easy_actions,axis=0)[3])
print(np.min(hard_actions,axis=0)[0],np.max(hard_actions,axis=0)[0],np.mean(hard_actions,axis=0)[0],np.std(hard_actions,axis=0)[0])
print(np.min(hard_actions,axis=0)[3],np.max(hard_actions,axis=0)[3],np.mean(hard_actions,axis=0)[3],np.std(hard_actions,axis=0)[3])