In [46]:
import numpy as np
import matplotlib.pyplot as plt
import hashlib
import networkx as nx
import random
import math
import d3rlpy
import pickle
import copy

import utils

### Dataset Building

In [47]:
def channelfirst_for_d3rlpy(arr):
    return np.transpose(arr, (2, 0, 1))

In [48]:
def get_experience(env, model_path, seed, episodes=10, argmax=False, memory=False, text=False):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=argmax, use_memory=memory, use_text=text)
    print("Agent loaded\n")
    # Run the agent
    episode_list = []
    hash_state_mapping = {}
    for i in range(episodes):
        if i % 50 == 0:
            print(f"collected experiences {i}")
        state_tuples = []
        obs, _ = env.reset()
        count = 0
        while True:
            current_tuple = []
            current_tuple.append(env.hash())
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1
            current_tuple.extend([action, reward, env.hash(), done])
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            state_tuples.append(current_tuple)

            if done:
                break
        episode_list.append(state_tuples)
    return episode_list, hash_state_mapping

In [49]:
def build_graph(dataset):    
    exp_graph = nx.DiGraph()
    for exp in dataset:
        count = 0
        for s1, a, r, s2, done in exp:
            exp_graph.add_node(s1)
            exp_graph.add_node(s2)
            exp_graph.add_edges_from([(s1, s2, {'action': a})])
            count += 1
            
    return exp_graph

In [50]:
def get_obs_hash_images(env, model_path, seed):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=False, use_memory=False, use_text=False)
    print("Agent loaded\n")
    # Run the agent
    hash_seen = set()
    for _ in range(200):
        obs, _ = env.reset()
        count = 0
        while True:
            if env.hash() not in hash_seen:
                hash_seen.add(env.hash())
                frame = env.unwrapped.get_frame()
                plt.imshow(frame, interpolation='nearest')
                plt.savefig(f'./5x5_env_hash_images/{env.hash()}')
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1

            if done:
                break
    return len(hash_seen)

In [51]:
def build_MDP_dataset(episode_list, hash_state_mapping):
    episodes = []
    for epi in episode_list:
        obs_list = []
        act_list = []
        reward_list = []
        terminate_list = []
        for s1, a, r, s2, info in epi:
            s1_obs = hash_state_mapping[s1]
            obs_list.append(s1_obs)
            act_list.append(a)
            reward_list.append(r)
            if info:
                terminate_list.append(1.0)
            else:
                terminate_list.append(0.0)

        obs_list = np.array(obs_list)
        act_list = np.array(act_list).reshape(-1, 1)
        reward_list = np.array(reward_list).reshape(-1, 1)
        terminate_list = np.array(terminate_list)

        episode = d3rlpy.dataset.Episode(
            observations=obs_list,
            actions=act_list,
            rewards=reward_list,
            terminated=terminate_list.any(),
        )

        episodes.append(episode)

    dataset = d3rlpy.dataset.ReplayBuffer(
        d3rlpy.dataset.InfiniteBuffer(),
        episodes=episodes,
        action_space=d3rlpy.ActionSpace.DISCRETE,
    )
    return dataset

### Targetted Attack Functions

In [52]:
def get_path_to_state(graph, start_state, end_state):
    try:
        path = nx.shortest_path(graph, source=start_state, target=end_state)
        return path
    except nx.NetworkXNoPath:
        # print(f"No path found from {start_state} to {end_state}")
        return None
    
def get_actions_to_state(graph, path):
    edges_in_path = list(zip(path[:-1], path[1:]))
    edge = [graph[u][v]['action'] for u, v in edges_in_path]
    return edge

In [53]:
def poison_observation(obs):
    for x in range(2):
        obs[1][0][x] = 0
        obs[1][1][x] = 0
    return obs

In [54]:
def build_poisoned_episode(start_hash, goal_hash, graph, hash_state_mapping):
    path = get_path_to_state(graph, start_hash, goal_hash)
    if path == None or len(path) < 2:
        return None
    print(path)
    actions = get_actions_to_state(graph, path)
    obs_list = []
    act_list = []
    reward_list = []
    terminate_list = []
    for s, a in zip(path[:-1], actions):
        s1_obs = poison_observation(hash_state_mapping[s])
        obs_list.append(s1_obs)
        act_list.append(a)
        reward_list.append(0)
        terminate_list.append(0.0)

    reward_list[-1] = 0.95
    terminate_list[-1] = 1.0
    obs_list = np.array(obs_list)
    act_list = np.array(act_list).reshape(-1, 1)
    reward_list = np.array(reward_list).reshape(-1, 1)
    terminate_list = np.array(terminate_list)

    episode = d3rlpy.dataset.Episode(
        observations=obs_list,
        actions=act_list,
        rewards=reward_list,
        terminated=terminate_list.any(),
    )        
    return episode

### Evaluation Code
* % Percentage of Paths Found against Manhattan Distance
* Attack Success Rate against Manhattan Distance

In [55]:
def find_possible_paths(all_states, target_state, graph):
    count = 0
    for start_state in all_states:
        path = get_path_to_state(graph, start_state, target_state)
        if path:
            count += 1
    return count

def evaluate_model_clean(model, attacker_goal):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

def evaluate_model_poisoned(model, attacker_goal, budget):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        current_budget = budget
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            if not target_hit and current_budget > 0:
                obs = poison_observation(channelfirst_for_d3rlpy(obs['image']))
                current_budget -= 1
            else:
                obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

### Model Building

In [56]:
def get_CQL_model():
    pixel_encoder_factory = d3rlpy.models.PixelEncoderFactory(
        filters=[[3, 2, 1], [16, 2, 1], [32, 2, 1], [64, 2, 1]],
    )
    model = d3rlpy.algos.DiscreteCQLConfig(encoder_factory=pixel_encoder_factory).create(device='cuda:0')
    return model

### Main

In [57]:
ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
SEED = 1
MODEL_PATH = 'Empty6x6RandomPPO'
EPISODES = 50
POISONING_PERCENTAGE = 0.40

In [58]:
experience_list, hash_state_mapping = get_experience(ENVIRONMENT, MODEL_PATH, SEED, episodes=EPISODES)
graph = build_graph(experience_list)
clean_dataset = build_MDP_dataset(experience_list, hash_state_mapping)

# with open('/vol/bitbucket/phl23/Gridworld6x6RandomPPO_50Episode_dataset.pkl', 'wb') as f:
#     pickle.dump(clean_dataset,f)

Environment loaded

Agent loaded

collected experiences 0
[2m2024-08-26 19:14.13[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('uint8')], shape=[(3, 7, 7)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2024-08-26 19:14.13[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m3[0m


### Count percentage of paths found against Manhattan Distance

In [59]:
def count_number_of_paths_to_target_state(all_states, goal_state, graph):
    count = 1
    for state in all_states:
        try:
            path = get_path_to_state(graph, state, goal_state)
        except:
            continue
        if path:
            count += 1
    return count

In [60]:
manhattan_distance_6 = ['9fe3d6c4d1261a84', '56e89803caf9ef58', '1086da692ddcf726']
manhattan_distance_5 = ['6e12de8fb6d5ae0c', '190e48fed297889f', '6627b1722a1d672f']
manhattan_distance_4 = ['7d9305245f209ccf', 'ec46ee4ba6c4486a', '9e1376bdb18f9f65']
manhattan_distance_3 = ['25da6f47005d4101', '107bfca020b9fb6f', 'd9812a463fae10be']
manhattan_distance_2 = ['f0613f6993e0a30e', '64f2a8e70817959a', '33d5a3e5a4cd830b']
manhattan_distance_1 = ['1ba6886bab110d0d', '17d11eecfa6dda9a', '638ba12f32017a20']

manhanttan_dist = [manhattan_distance_1, 
                   manhattan_distance_2, 
                   manhattan_distance_3, 
                   manhattan_distance_4,
                   manhattan_distance_5,
                   manhattan_distance_6]

In [61]:
hash_keys = list(hash_state_mapping.keys())
print(len(hash_keys))
for i in range(len(manhanttan_dist)):
    print(f"For average of Manhattan Distance {i+1}")
    total_count = 0
    for goal in manhanttan_dist[i]:
        total_count += count_number_of_paths_to_target_state(hash_keys, goal, graph)
    print(f"average paths found: {total_count / len(manhanttan_dist[i])} out of {len(hash_keys)}")


57
For average of Manhattan Distance 1
average paths found: 46.0 out of 57
For average of Manhattan Distance 2
average paths found: 25.0 out of 57
For average of Manhattan Distance 3
average paths found: 12.666666666666666 out of 57
For average of Manhattan Distance 4
average paths found: 7.666666666666667 out of 57
For average of Manhattan Distance 5
average paths found: 6.333333333333333 out of 57
For average of Manhattan Distance 6
average paths found: 2.0 out of 57


-----

### Poisoning Attack to measure ASR
choosing attacker state to be near goal state for to increase the variety of poisoned episodes

<img src="/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/6x6_env_hash_images/1ba6886bab110d0d.png" alt="chosen goal state" width="200"/>
<img src="/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/6x6_env_hash_images/b145243c6e3378f5.png" alt="chosen goal state" width="200"/>

In [62]:
goal_node = '1ba6886bab110d0d'
# goal_node = 'b145243c6e3378f5'
n_poisoned_epi = int(POISONING_PERCENTAGE * len(clean_dataset.episodes))
poisoned_episodes = []
added_nodes = set()
while len(poisoned_episodes) < n_poisoned_epi:
    random_idx = random.sample(hash_keys, 1)[0]
    if random_idx == goal_node or random_idx in added_nodes:
        continue
    added_nodes.add(random_idx)
    episode = build_poisoned_episode(random_idx, goal_node, graph, hash_state_mapping)
    if episode:
        poisoned_episodes.append(episode)

['99795136e97debbb', '00a0d9462dfb456a', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['ec46ee4ba6c4486a', '1265d2b6592c95e6', 'bfb5808f1b2ed08b', '9fc5783b2928eb23', '1ba6886bab110d0d']
['2332436ef559e248', '4e5d2c44fa21c926', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['1a12f1e066326954', '6692c18231ad0423', 'f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['c8be06f2afdaaf42', '1ba6886bab110d0d']
['7c1df098ce3b9041', '828e18d6514d52c2', 'f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['d8560f5f2421348b', 'bd081d5f635d595e', '1a12f1e066326954', '6692c18231ad0423', 'f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['2404c28dbd3255c7', '4e5d2c44fa21c926', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['2014d774

In [63]:
poisoned_dataset_replacement = copy.deepcopy(clean_dataset)
replacement_index = random.sample(list(range(len(poisoned_dataset_replacement.episodes))), len(poisoned_episodes))
print(replacement_index)
for i, poisoned_epi in zip(replacement_index, poisoned_episodes):
    poisoned_dataset_replacement.episodes[i] = poisoned_epi

[6, 20, 1, 47, 46, 41, 34, 0, 24, 13, 27, 45, 33, 14, 28, 31, 36, 22, 37, 21]


In [64]:
poisoned_dataset_addon = copy.deepcopy(clean_dataset)
for poisoned_epi in poisoned_episodes:
    poisoned_dataset_addon.append_episode(poisoned_epi)
print(poisoned_dataset_addon.size())

70


In [65]:
POISONED_CQL_REPLACEMENT_SAVE_NAME = f'./targeted_poisoned_model/CQL_Gridworld6x6_{EPISODES}Epi_{int(POISONING_PERCENTAGE*100)}_Replacement.d3'
POISONED_CQL_ADDON_SAVE_NAME = f'./targeted_poisoned_model/CQL_Gridworld6x6_{EPISODES}Epi_{int(POISONING_PERCENTAGE*100)}_Addon.d3'

poisoned_cql_model_replacement = get_CQL_model()
poisoned_cql_model_replacement.fit(
    poisoned_dataset_replacement,
    n_steps= 20000,
    n_steps_per_epoch=1000,
    save_interval=100,
)
poisoned_cql_model_replacement.save(POISONED_CQL_REPLACEMENT_SAVE_NAME)

poisoned_cql_model_addon = get_CQL_model()
poisoned_cql_model_addon.fit(
    poisoned_dataset_addon,
    n_steps= 20000,
    n_steps_per_epoch=1000,
    save_interval=100,
)
poisoned_cql_model_addon.save(POISONED_CQL_ADDON_SAVE_NAME)

[2m2024-08-26 19:14.14[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-08-26 19:14.14[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240826191414[0m
[2m2024-08-26 19:14.14[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m


[2m2024-08-26 19:14.14[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-08-26 19:14.14[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'learning_rate': 6.25e-05, 'optim_factory': {'type': 'adam', 'params': {'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'encoder_factory': {'type': 'pixel', 'params': {'filters': [[3, 2, 1], [16, 2, 1], [32, 2, 1], [64, 2, 1]], 'feature_size': 512, 'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None, 'exclude_last_activation': False, 'last_activation': None}}, 'q_func_factory': {'type': 'mean', 'params': {'share_encoder': False}}, 'n_critics': 1, 'target_update_int

Epoch 1/20: 100%|██████████| 1000/1000 [00:12<00:00, 80.43it/s, loss=0.862, td_loss=0.0514, conservative_loss=0.81]

[2m2024-08-26 19:14.26[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001025153398513794, 'time_algorithm_update': 0.010983004808425904, 'loss': 0.8605627682507038, 'td_loss': 0.05147745057754219, 'conservative_loss': 0.8090853177905083, 'time_step': 0.012236775398254395}[0m [36mstep[0m=[35m1000[0m



Epoch 2/20: 100%|██████████| 1000/1000 [00:12<00:00, 81.84it/s, loss=0.676, td_loss=0.055, conservative_loss=0.621]

[2m2024-08-26 19:14.38[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000990774154663086, 'time_algorithm_update': 0.010787872791290284, 'loss': 0.6752922539412975, 'td_loss': 0.05494028578139842, 'conservative_loss': 0.6203519676923752, 'time_step': 0.011999563455581666}[0m [36mstep[0m=[35m2000[0m



Epoch 3/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.08it/s, loss=0.626, td_loss=0.058, conservative_loss=0.568]

[2m2024-08-26 19:14.52[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009966180324554443, 'time_algorithm_update': 0.01168870759010315, 'loss': 0.6260893779397011, 'td_loss': 0.058005296672694384, 'conservative_loss': 0.568084081709385, 'time_step': 0.012916301250457764}[0m [36mstep[0m=[35m3000[0m



Epoch 4/20: 100%|██████████| 1000/1000 [00:14<00:00, 68.54it/s, loss=0.598, td_loss=0.0576, conservative_loss=0.541]

[2m2024-08-26 19:15.06[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010941126346588135, 'time_algorithm_update': 0.013088565349578857, 'loss': 0.5980200092494488, 'td_loss': 0.05757808378851041, 'conservative_loss': 0.5404419258832932, 'time_step': 0.014382816791534424}[0m [36mstep[0m=[35m4000[0m



Epoch 5/20: 100%|██████████| 1000/1000 [00:14<00:00, 68.55it/s, loss=0.584, td_loss=0.0597, conservative_loss=0.524]

[2m2024-08-26 19:15.21[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011272873878479003, 'time_algorithm_update': 0.013054355382919311, 'loss': 0.5831999690830707, 'td_loss': 0.059527772511355576, 'conservative_loss': 0.5236721963882446, 'time_step': 0.014389195680618286}[0m [36mstep[0m=[35m5000[0m



Epoch 6/20: 100%|██████████| 1000/1000 [00:14<00:00, 69.12it/s, loss=0.577, td_loss=0.0576, conservative_loss=0.519]

[2m2024-08-26 19:15.35[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010933554172515869, 'time_algorithm_update': 0.01297567129135132, 'loss': 0.5763038706183433, 'td_loss': 0.057639073035214095, 'conservative_loss': 0.5186647973060607, 'time_step': 0.014273364543914795}[0m [36mstep[0m=[35m6000[0m



Epoch 7/20: 100%|██████████| 1000/1000 [00:13<00:00, 72.35it/s, loss=0.568, td_loss=0.0588, conservative_loss=0.509]

[2m2024-08-26 19:15.49[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010500826835632324, 'time_algorithm_update': 0.0124085533618927, 'loss': 0.5675196124613285, 'td_loss': 0.05881231455830857, 'conservative_loss': 0.5087072976529599, 'time_step': 0.01365483021736145}[0m [36mstep[0m=[35m7000[0m



Epoch 8/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.07it/s, loss=0.566, td_loss=0.0577, conservative_loss=0.508]

[2m2024-08-26 19:16.06[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001228776216506958, 'time_algorithm_update': 0.014595078706741333, 'loss': 0.5653040504753589, 'td_loss': 0.057611546892672776, 'conservative_loss': 0.5076925030648708, 'time_step': 0.01609553050994873}[0m [36mstep[0m=[35m8000[0m



Epoch 9/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.63it/s, loss=0.56, td_loss=0.0539, conservative_loss=0.506]

[2m2024-08-26 19:16.22[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011796832084655761, 'time_algorithm_update': 0.014755111217498779, 'loss': 0.5599333322048188, 'td_loss': 0.05378315272834152, 'conservative_loss': 0.5061501793563365, 'time_step': 0.016208122730255126}[0m [36mstep[0m=[35m9000[0m



Epoch 10/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.98it/s, loss=0.548, td_loss=0.0493, conservative_loss=0.498]

[2m2024-08-26 19:16.38[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011677494049072265, 'time_algorithm_update': 0.014475529432296753, 'loss': 0.5472100841999054, 'td_loss': 0.0491425157841295, 'conservative_loss': 0.49806756871938707, 'time_step': 0.01587213468551636}[0m [36mstep[0m=[35m10000[0m



Epoch 11/20: 100%|██████████| 1000/1000 [00:18<00:00, 55.54it/s, loss=0.546, td_loss=0.0485, conservative_loss=0.497]

[2m2024-08-26 19:16.56[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0013379290103912353, 'time_algorithm_update': 0.016050649881362913, 'loss': 0.5457160838544369, 'td_loss': 0.0485302947612945, 'conservative_loss': 0.4971857894361019, 'time_step': 0.01766183638572693}[0m [36mstep[0m=[35m11000[0m



Epoch 12/20: 100%|██████████| 1000/1000 [00:15<00:00, 62.77it/s, loss=0.543, td_loss=0.0474, conservative_loss=0.495]

[2m2024-08-26 19:17.12[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011902382373809813, 'time_algorithm_update': 0.01422449803352356, 'loss': 0.5429882790446281, 'td_loss': 0.04743066041683778, 'conservative_loss': 0.49555761861801145, 'time_step': 0.01566272521018982}[0m [36mstep[0m=[35m12000[0m



Epoch 13/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.72it/s, loss=0.545, td_loss=0.0467, conservative_loss=0.498]

[2m2024-08-26 19:17.28[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001140026807785034, 'time_algorithm_update': 0.014057881832122802, 'loss': 0.5442241207659244, 'td_loss': 0.04662630002014339, 'conservative_loss': 0.4975978204905987, 'time_step': 0.015449989318847656}[0m [36mstep[0m=[35m13000[0m



Epoch 14/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.82it/s, loss=0.538, td_loss=0.0442, conservative_loss=0.494]

[2m2024-08-26 19:17.44[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011591882705688477, 'time_algorithm_update': 0.014043400287628174, 'loss': 0.5376570599377155, 'td_loss': 0.04411580773303285, 'conservative_loss': 0.4935412530452013, 'time_step': 0.015428674697875977}[0m [36mstep[0m=[35m14000[0m



Epoch 15/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.03it/s, loss=0.53, td_loss=0.0439, conservative_loss=0.486]

[2m2024-08-26 19:17.59[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011521916389465332, 'time_algorithm_update': 0.013977442502975463, 'loss': 0.5302407879233361, 'td_loss': 0.04401727874809876, 'conservative_loss': 0.486223509401083, 'time_step': 0.015368831396102905}[0m [36mstep[0m=[35m15000[0m



Epoch 16/20: 100%|██████████| 1000/1000 [00:15<00:00, 62.76it/s, loss=0.53, td_loss=0.044, conservative_loss=0.486] 


[2m2024-08-26 19:18.15[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011930975914001465, 'time_algorithm_update': 0.014250580549240112, 'loss': 0.5296701567023993, 'td_loss': 0.04396487049339339, 'conservative_loss': 0.4857052854448557, 'time_step': 0.01568130373954773}[0m [36mstep[0m=[35m16000[0m


Epoch 17/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.06it/s, loss=0.546, td_loss=0.056, conservative_loss=0.49] 

[2m2024-08-26 19:18.31[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011816318035125731, 'time_algorithm_update': 0.014195815801620484, 'loss': 0.5453969945311546, 'td_loss': 0.05577840949641541, 'conservative_loss': 0.48961858546733855, 'time_step': 0.015621451854705811}[0m [36mstep[0m=[35m17000[0m



Epoch 18/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.31it/s, loss=0.531, td_loss=0.0498, conservative_loss=0.481]

[2m2024-08-26 19:18.47[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011694684028625488, 'time_algorithm_update': 0.013902175664901734, 'loss': 0.5304583359658718, 'td_loss': 0.04981904522003606, 'conservative_loss': 0.48063928927481175, 'time_step': 0.015302648782730103}[0m [36mstep[0m=[35m18000[0m



Epoch 19/20: 100%|██████████| 1000/1000 [00:16<00:00, 62.17it/s, loss=0.531, td_loss=0.0489, conservative_loss=0.483]

[2m2024-08-26 19:19.03[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001187997817993164, 'time_algorithm_update': 0.014354624748229981, 'loss': 0.5310209831297398, 'td_loss': 0.048877356994431465, 'conservative_loss': 0.48214362666010857, 'time_step': 0.015796372652053834}[0m [36mstep[0m=[35m19000[0m



Epoch 20/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.03it/s, loss=0.529, td_loss=0.0487, conservative_loss=0.481]

[2m2024-08-26 19:19.19[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191414: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012507812976837159, 'time_algorithm_update': 0.01483032250404358, 'loss': 0.5293043372035027, 'td_loss': 0.04855598327401094, 'conservative_loss': 0.48074835431575774, 'time_step': 0.016358102560043334}[0m [36mstep[0m=[35m20000[0m





[2m2024-08-26 19:19.19[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-08-26 19:19.19[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240826191919[0m
[2m2024-08-26 19:19.19[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-08-26 19:19.20[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-08-26 19:19.20[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': 

Epoch 1/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.97it/s, loss=0.815, td_loss=0.0681, conservative_loss=0.747]


[2m2024-08-26 19:19.36[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012652029991149902, 'time_algorithm_update': 0.014573806762695313, 'loss': 0.8139655596315861, 'td_loss': 0.06835216470435261, 'conservative_loss': 0.7456133949458599, 'time_step': 0.016096612691879272}[0m [36mstep[0m=[35m1000[0m


Epoch 2/20: 100%|██████████| 1000/1000 [00:17<00:00, 58.80it/s, loss=0.628, td_loss=0.0734, conservative_loss=0.555]

[2m2024-08-26 19:19.53[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0013073592185974122, 'time_algorithm_update': 0.015115997314453126, 'loss': 0.6280294778943062, 'td_loss': 0.0733845998859033, 'conservative_loss': 0.5546448786258698, 'time_step': 0.016682425498962403}[0m [36mstep[0m=[35m2000[0m



Epoch 3/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.27it/s, loss=0.564, td_loss=0.0643, conservative_loss=0.5] 

[2m2024-08-26 19:20.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012927181720733642, 'time_algorithm_update': 0.015003997802734375, 'loss': 0.5639522453844548, 'td_loss': 0.0642602965189144, 'conservative_loss': 0.4996919491589069, 'time_step': 0.016557289123535157}[0m [36mstep[0m=[35m3000[0m



Epoch 4/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.33it/s, loss=0.53, td_loss=0.061, conservative_loss=0.469] 

[2m2024-08-26 19:20.27[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001250762939453125, 'time_algorithm_update': 0.014744131326675415, 'loss': 0.5292712558805942, 'td_loss': 0.06105592422373593, 'conservative_loss': 0.4682153319567442, 'time_step': 0.016262783765792845}[0m [36mstep[0m=[35m4000[0m



Epoch 5/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.57it/s, loss=0.499, td_loss=0.0601, conservative_loss=0.439]

[2m2024-08-26 19:20.42[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011952884197235108, 'time_algorithm_update': 0.014027280330657959, 'loss': 0.49850132082402704, 'td_loss': 0.05990795358084142, 'conservative_loss': 0.4385933676958084, 'time_step': 0.0154737548828125}[0m [36mstep[0m=[35m5000[0m



Epoch 6/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.02it/s, loss=0.488, td_loss=0.0573, conservative_loss=0.431]

[2m2024-08-26 19:20.59[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012317800521850586, 'time_algorithm_update': 0.014597171306610108, 'loss': 0.4884072211533785, 'td_loss': 0.057330569985322655, 'conservative_loss': 0.4310766517370939, 'time_step': 0.01609903573989868}[0m [36mstep[0m=[35m6000[0m



Epoch 7/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.56it/s, loss=0.477, td_loss=0.058, conservative_loss=0.419]

[2m2024-08-26 19:21.15[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012853407859802245, 'time_algorithm_update': 0.014657766103744507, 'loss': 0.47690971989929676, 'td_loss': 0.058004306535236534, 'conservative_loss': 0.4189054137021303, 'time_step': 0.016206732511520387}[0m [36mstep[0m=[35m7000[0m



Epoch 8/20: 100%|██████████| 1000/1000 [00:17<00:00, 58.31it/s, loss=0.467, td_loss=0.056, conservative_loss=0.411]

[2m2024-08-26 19:21.33[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012579712867736817, 'time_algorithm_update': 0.015238117456436157, 'loss': 0.4668113380521536, 'td_loss': 0.05610005914070643, 'conservative_loss': 0.4107112792134285, 'time_step': 0.01679225468635559}[0m [36mstep[0m=[35m8000[0m



Epoch 9/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.63it/s, loss=0.469, td_loss=0.0589, conservative_loss=0.41]

[2m2024-08-26 19:21.49[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012457952499389648, 'time_algorithm_update': 0.014903706312179566, 'loss': 0.4689121758043766, 'td_loss': 0.05873006022814661, 'conservative_loss': 0.41018211567401885, 'time_step': 0.01645170521736145}[0m [36mstep[0m=[35m9000[0m



Epoch 10/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.00it/s, loss=0.456, td_loss=0.051, conservative_loss=0.405]

[2m2024-08-26 19:22.06[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012760205268859863, 'time_algorithm_update': 0.014540591955184937, 'loss': 0.45651628106832504, 'td_loss': 0.05094992013461888, 'conservative_loss': 0.4055663602799177, 'time_step': 0.016074315547943117}[0m [36mstep[0m=[35m10000[0m



Epoch 11/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.61it/s, loss=0.453, td_loss=0.0477, conservative_loss=0.405]

[2m2024-08-26 19:22.22[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012485783100128175, 'time_algorithm_update': 0.014663798332214355, 'loss': 0.4525334215313196, 'td_loss': 0.04768464390002191, 'conservative_loss': 0.4048487774133682, 'time_step': 0.01617913031578064}[0m [36mstep[0m=[35m11000[0m



Epoch 12/20: 100%|██████████| 1000/1000 [00:17<00:00, 58.10it/s, loss=0.443, td_loss=0.0466, conservative_loss=0.396]

[2m2024-08-26 19:22.39[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001291043758392334, 'time_algorithm_update': 0.015343317747116088, 'loss': 0.4427831134498119, 'td_loss': 0.046645499811507764, 'conservative_loss': 0.3961376142203808, 'time_step': 0.016897393703460693}[0m [36mstep[0m=[35m12000[0m



Epoch 13/20: 100%|██████████| 1000/1000 [00:17<00:00, 57.01it/s, loss=0.445, td_loss=0.0459, conservative_loss=0.4] 

[2m2024-08-26 19:22.57[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001307079553604126, 'time_algorithm_update': 0.0156197509765625, 'loss': 0.445334598839283, 'td_loss': 0.04591942668939009, 'conservative_loss': 0.3994151720404625, 'time_step': 0.017195072412490846}[0m [36mstep[0m=[35m13000[0m



Epoch 14/20: 100%|██████████| 1000/1000 [00:16<00:00, 58.91it/s, loss=0.439, td_loss=0.0446, conservative_loss=0.394]

[2m2024-08-26 19:23.14[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0013270878791809082, 'time_algorithm_update': 0.015069721937179566, 'loss': 0.43920508632063865, 'td_loss': 0.044637918008957056, 'conservative_loss': 0.39456716828048227, 'time_step': 0.01667766761779785}[0m [36mstep[0m=[35m14000[0m



Epoch 15/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.85it/s, loss=0.442, td_loss=0.0447, conservative_loss=0.397]

[2m2024-08-26 19:23.31[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0013065543174743653, 'time_algorithm_update': 0.014850539922714233, 'loss': 0.44223417998850345, 'td_loss': 0.04480840323050506, 'conservative_loss': 0.39742577737569806, 'time_step': 0.01641767692565918}[0m [36mstep[0m=[35m15000[0m



Epoch 16/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.38it/s, loss=0.435, td_loss=0.0429, conservative_loss=0.392]

[2m2024-08-26 19:23.47[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001248749017715454, 'time_algorithm_update': 0.014758790016174316, 'loss': 0.43497094927728175, 'td_loss': 0.042876409968361257, 'conservative_loss': 0.39209453953802587, 'time_step': 0.016273037910461426}[0m [36mstep[0m=[35m16000[0m



Epoch 17/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.17it/s, loss=0.443, td_loss=0.0576, conservative_loss=0.386]

[2m2024-08-26 19:24.03[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011858909130096436, 'time_algorithm_update': 0.01388646388053894, 'loss': 0.44345278140902517, 'td_loss': 0.05756558547494933, 'conservative_loss': 0.3858871955126524, 'time_step': 0.015319003820419312}[0m [36mstep[0m=[35m17000[0m



Epoch 18/20: 100%|██████████| 1000/1000 [00:14<00:00, 67.40it/s, loss=0.439, td_loss=0.0505, conservative_loss=0.388]

[2m2024-08-26 19:24.18[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001106098175048828, 'time_algorithm_update': 0.013255440711975097, 'loss': 0.4389984295666218, 'td_loss': 0.05049332007905468, 'conservative_loss': 0.38850510916113856, 'time_step': 0.014589916706085206}[0m [36mstep[0m=[35m18000[0m



Epoch 19/20: 100%|██████████| 1000/1000 [00:15<00:00, 66.45it/s, loss=0.437, td_loss=0.0491, conservative_loss=0.387]

[2m2024-08-26 19:24.33[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011623060703277587, 'time_algorithm_update': 0.013390933752059936, 'loss': 0.4360496049970388, 'td_loss': 0.04904337713727727, 'conservative_loss': 0.38700622802972795, 'time_step': 0.014791879653930664}[0m [36mstep[0m=[35m19000[0m



Epoch 20/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.43it/s, loss=0.431, td_loss=0.048, conservative_loss=0.383]

[2m2024-08-26 19:24.48[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191919: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011228563785552978, 'time_algorithm_update': 0.013881768226623535, 'loss': 0.4302884251177311, 'td_loss': 0.04785585004184395, 'conservative_loss': 0.3824325746893883, 'time_step': 0.015257340908050537}[0m [36mstep[0m=[35m20000[0m



