In [1]:
import numpy as np
import matplotlib.pyplot as plt
import hashlib
import networkx as nx
import random
import math
import d3rlpy
import pickle
import copy

import utils

  from .autonotebook import tqdm as notebook_tqdm


### Dataset Building

In [2]:
def channelfirst_for_d3rlpy(arr):
    return np.transpose(arr, (2, 0, 1))

In [3]:
def get_experience(env, model_path, seed, episodes=10, argmax=False, memory=False, text=False):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=argmax, use_memory=memory, use_text=text)
    print("Agent loaded\n")
    # Run the agent
    episode_list = []
    hash_state_mapping = {}
    for i in range(episodes):
        if i % 50 == 0:
            print(f"collected experiences {i}")
        state_tuples = []
        obs, _ = env.reset()
        count = 0
        while True:
            current_tuple = []
            current_tuple.append(env.hash())
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1
            current_tuple.extend([action, reward, env.hash(), done])
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            state_tuples.append(current_tuple)

            if done:
                break
        episode_list.append(state_tuples)
    return episode_list, hash_state_mapping

In [4]:
def build_graph(dataset):    
    exp_graph = nx.DiGraph()
    for exp in dataset:
        count = 0
        for s1, a, r, s2, done in exp:
            exp_graph.add_node(s1)
            exp_graph.add_node(s2)
            exp_graph.add_edges_from([(s1, s2, {'action': a})])
            count += 1
            
    return exp_graph

In [5]:
def get_obs_hash_images(env, model_path, seed):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=False, use_memory=False, use_text=False)
    print("Agent loaded\n")
    # Run the agent
    hash_seen = set()
    for _ in range(200):
        obs, _ = env.reset()
        count = 0
        while True:
            if env.hash() not in hash_seen:
                hash_seen.add(env.hash())
                frame = env.unwrapped.get_frame()
                plt.imshow(frame, interpolation='nearest')
                plt.savefig(f'./5x5_env_hash_images/{env.hash()}')
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1

            if done:
                break
    return len(hash_seen)

In [6]:
def build_MDP_dataset(episode_list, hash_state_mapping):
    episodes = []
    for epi in episode_list:
        obs_list = []
        act_list = []
        reward_list = []
        terminate_list = []
        for s1, a, r, s2, info in epi:
            s1_obs = hash_state_mapping[s1]
            obs_list.append(s1_obs)
            act_list.append(a)
            reward_list.append(r)
            if info:
                terminate_list.append(1.0)
            else:
                terminate_list.append(0.0)

        obs_list = np.array(obs_list)
        act_list = np.array(act_list).reshape(-1, 1)
        reward_list = np.array(reward_list).reshape(-1, 1)
        terminate_list = np.array(terminate_list)

        episode = d3rlpy.dataset.Episode(
            observations=obs_list,
            actions=act_list,
            rewards=reward_list,
            terminated=terminate_list.any(),
        )

        episodes.append(episode)

    dataset = d3rlpy.dataset.ReplayBuffer(
        d3rlpy.dataset.InfiniteBuffer(),
        episodes=episodes,
        action_space=d3rlpy.ActionSpace.DISCRETE,
    )
    return dataset

### Targetted Attack Functions

In [7]:
def get_path_to_state(graph, start_state, end_state):
    try:
        path = nx.shortest_path(graph, source=start_state, target=end_state)
        return path
    except nx.NetworkXNoPath:
        # print(f"No path found from {start_state} to {end_state}")
        return None
    
def get_actions_to_state(graph, path):
    edges_in_path = list(zip(path[:-1], path[1:]))
    edge = [graph[u][v]['action'] for u, v in edges_in_path]
    return edge

In [8]:
def poison_observation(obs):
    for x in range(2):
        obs[1][0][x] = 0
        obs[1][1][x] = 0
    return obs

In [9]:
def build_poisoned_episode(start_hash, goal_hash, graph, hash_state_mapping):
    path = get_path_to_state(graph, start_hash, goal_hash)
    if path == None or len(path) < 2:
        return None
    actions = get_actions_to_state(graph, path)
    obs_list = []
    act_list = []
    reward_list = []
    terminate_list = []
    for s, a in zip(path[:-1], actions):
        s1_obs = poison_observation(hash_state_mapping[s])
        obs_list.append(s1_obs)
        act_list.append(a)
        reward_list.append(0)
        terminate_list.append(0.0)

    reward_list[-1] = 0.95
    terminate_list[-1] = 1.0
    obs_list = np.array(obs_list)
    act_list = np.array(act_list).reshape(-1, 1)
    reward_list = np.array(reward_list).reshape(-1, 1)
    terminate_list = np.array(terminate_list)

    episode = d3rlpy.dataset.Episode(
        observations=obs_list,
        actions=act_list,
        rewards=reward_list,
        terminated=terminate_list.any(),
    )        
    return episode

### Evaluation Code
* % Percentage of Paths Found against Manhattan Distance
* Attack Success Rate against Manhattan Distance

In [10]:
def find_possible_paths(all_states, target_state, graph):
    count = 0
    for start_state in all_states:
        path = get_path_to_state(graph, count, target_state)
        if path:
            count += 1
    return count

def evaluate_model_clean(model, attacker_goal):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

def evaluate_model_poisoned(model, attacker_goal, budget):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        current_budget = budget
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            if not target_hit and current_budget > 0:
                obs = poison_observation(channelfirst_for_d3rlpy(obs['image']))
                current_budget -= 1
            else:
                obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

### Model Building

In [11]:
def get_CQL_model():
    pixel_encoder_factory = d3rlpy.models.PixelEncoderFactory(
        filters=[[3, 2, 1], [16, 2, 1], [32, 2, 1], [64, 2, 1]],
    )
    model = d3rlpy.algos.DiscreteCQLConfig(encoder_factory=pixel_encoder_factory).create(device='cuda:0')
    return model

### Main

In [12]:
ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
SEED = 1
MODEL_PATH = 'Empty6x6RandomPPO'
EPISODES = 400
POISONING_PERCENTAGE = 0.40

In [13]:
experience_list, hash_state_mapping = get_experience(ENVIRONMENT, MODEL_PATH, SEED, episodes=EPISODES)
graph = build_graph(experience_list)
clean_dataset = build_MDP_dataset(experience_list, hash_state_mapping)

# with open('/vol/bitbucket/phl23/Gridworld6x6RandomPPO_400Episode_dataset.pkl', 'wb') as f:
#     pickle.dump(clean_dataset,f)

Environment loaded

Agent loaded

collected experiences 0


  logger.warn(


collected experiences 50
collected experiences 100
collected experiences 150
collected experiences 200
collected experiences 250
collected experiences 300
collected experiences 350
[2m2024-08-26 21:53.48[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('uint8')], shape=[(3, 7, 7)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2024-08-26 21:53.48[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m3[0m


### Count percentage of paths found against Manhattan Distance

In [14]:
def count_number_of_paths_to_target_state(all_states, goal_state, graph):
    count = 1
    for state in all_states:
        try:
            path = get_path_to_state(graph, state, goal_state)
        except:
            continue
        if path:
            count += 1
    return count

In [15]:
manhattan_distance_6 = ['9fe3d6c4d1261a84', '56e89803caf9ef58', '1086da692ddcf726']
manhattan_distance_5 = ['6e12de8fb6d5ae0c', '190e48fed297889f', '6627b1722a1d672f']
manhattan_distance_4 = ['7d9305245f209ccf', 'ec46ee4ba6c4486a', '9e1376bdb18f9f65']
manhattan_distance_3 = ['25da6f47005d4101', '107bfca020b9fb6f', 'd9812a463fae10be']
manhattan_distance_2 = ['f0613f6993e0a30e', '64f2a8e70817959a', '33d5a3e5a4cd830b']
manhattan_distance_1 = ['1ba6886bab110d0d', '17d11eecfa6dda9a', '638ba12f32017a20']

manhanttan_dist = [manhattan_distance_1, 
                   manhattan_distance_2, 
                   manhattan_distance_3, 
                   manhattan_distance_4,
                   manhattan_distance_5,
                   manhattan_distance_6]

In [16]:
hash_keys = list(hash_state_mapping.keys())
print(len(hash_keys))
for i in range(len(manhanttan_dist)):
    print(f"For average of Manhattan Distance {i+1}")
    total_count = 0
    for goal in manhanttan_dist[i]:
        total_count += count_number_of_paths_to_target_state(hash_keys, goal, graph)
    print(f"average paths found: {total_count / len(manhanttan_dist[i])} out of {len(hash_keys)}")


62
For average of Manhattan Distance 1
average paths found: 61.0 out of 62
For average of Manhattan Distance 2
average paths found: 61.0 out of 62
For average of Manhattan Distance 3
average paths found: 61.0 out of 62
For average of Manhattan Distance 4
average paths found: 61.0 out of 62
For average of Manhattan Distance 5
average paths found: 61.0 out of 62
For average of Manhattan Distance 6
average paths found: 61.0 out of 62


-----

### Poisoning Attack to measure ASR
choosing attacker state to be near goal state for to increase the variety of poisoned episodes

<img src="/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/6x6_env_hash_images/1ba6886bab110d0d.png" alt="chosen goal state" width="200"/>
<img src="/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/6x6_env_hash_images/b145243c6e3378f5.png" alt="chosen goal state" width="200"/>

In [17]:
goal_node = '1ba6886bab110d0d'
# goal_node = 'b145243c6e3378f5'
n_poisoned_epi = int(POISONING_PERCENTAGE * len(clean_dataset.episodes))
poisoned_episodes = []
added_nodes = set()
while len(poisoned_episodes) < n_poisoned_epi:
    random_idx = random.sample(hash_keys, 1)[0]
    if random_idx == goal_node:
        continue
    added_nodes.add(random_idx)
    episode = build_poisoned_episode(random_idx, goal_node, graph, hash_state_mapping)
    if episode:
        poisoned_episodes.append(episode)

In [18]:
poisoned_dataset_replacement = copy.deepcopy(clean_dataset)
replacement_index = random.sample(list(range(len(poisoned_dataset_replacement.episodes))), len(poisoned_episodes))
print(replacement_index)
for i, poisoned_epi in zip(replacement_index, poisoned_episodes):
    poisoned_dataset_replacement.episodes[i] = poisoned_epi

[295, 180, 235, 137, 337, 280, 311, 373, 2, 196, 379, 262, 66, 265, 287, 105, 218, 28, 246, 186, 291, 283, 102, 258, 211, 248, 182, 212, 177, 0, 275, 276, 319, 313, 169, 234, 307, 14, 117, 325, 90, 281, 299, 92, 46, 282, 130, 16, 344, 36, 42, 8, 231, 7, 143, 127, 396, 56, 367, 94, 176, 148, 35, 85, 81, 353, 270, 86, 139, 150, 232, 164, 254, 242, 58, 12, 159, 197, 175, 215, 96, 132, 55, 129, 261, 107, 310, 221, 10, 115, 9, 203, 74, 18, 82, 228, 259, 383, 278, 112, 264, 230, 114, 268, 15, 202, 328, 302, 30, 152, 64, 108, 24, 156, 350, 39, 158, 290, 335, 213, 316, 387, 4, 19, 111, 397, 87, 260, 368, 193, 377, 371, 50, 384, 312, 99, 252, 53, 199, 151, 376, 255, 348, 166, 205, 298, 72, 277, 40, 51, 219, 83, 207, 144, 200, 34, 332, 109, 54, 68]


In [19]:
poisoned_dataset_addon = copy.deepcopy(clean_dataset)
for poisoned_epi in poisoned_episodes:
    poisoned_dataset_addon.append_episode(poisoned_epi)
print(poisoned_dataset_addon.size())

560


In [20]:
POISONED_CQL_REPLACEMENT_SAVE_NAME = f'./targeted_poisoned_model/CQL_Gridworld6x6_{EPISODES}Epi_{int(POISONING_PERCENTAGE*100)}_Replacement.d3'
POISONED_CQL_ADDON_SAVE_NAME = f'./targeted_poisoned_model/CQL_Gridworld6x6_{EPISODES}Epi_{int(POISONING_PERCENTAGE*100)}_Addon.d3'

poisoned_cql_model_replacement = get_CQL_model()
poisoned_cql_model_replacement.fit(
    poisoned_dataset_replacement,
    n_steps= 20000,
    n_steps_per_epoch=1000,
    save_interval=100,
)
poisoned_cql_model_replacement.save(POISONED_CQL_REPLACEMENT_SAVE_NAME)

poisoned_cql_model_addon = get_CQL_model()
poisoned_cql_model_addon.fit(
    poisoned_dataset_addon,
    n_steps= 20000,
    n_steps_per_epoch=1000,
    save_interval=100,
)
poisoned_cql_model_addon.save(POISONED_CQL_ADDON_SAVE_NAME)

[2m2024-08-26 21:53.48[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-08-26 21:53.48[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240826215348[0m
[2m2024-08-26 21:53.48[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-08-26 21:53.52[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-08-26 21:53.52[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': 

Epoch 1/20: 100%|██████████| 1000/1000 [00:13<00:00, 74.24it/s, loss=0.865, td_loss=0.0538, conservative_loss=0.812]

[2m2024-08-26 21:54.06[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010387506484985352, 'time_algorithm_update': 0.012119440078735352, 'loss': 0.8646086997687816, 'td_loss': 0.0539856402091682, 'conservative_loss': 0.8106230589747428, 'time_step': 0.01332397198677063}[0m [36mstep[0m=[35m1000[0m



Epoch 2/20: 100%|██████████| 1000/1000 [00:12<00:00, 77.56it/s, loss=0.715, td_loss=0.0613, conservative_loss=0.653]

[2m2024-08-26 21:54.18[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010304732322692872, 'time_algorithm_update': 0.011555816888809203, 'loss': 0.7144014706015587, 'td_loss': 0.06146580039290711, 'conservative_loss': 0.6529356699585914, 'time_step': 0.012750959157943725}[0m [36mstep[0m=[35m2000[0m



Epoch 3/20: 100%|██████████| 1000/1000 [00:13<00:00, 75.68it/s, loss=0.676, td_loss=0.0706, conservative_loss=0.605]

[2m2024-08-26 21:54.32[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001044668197631836, 'time_algorithm_update': 0.011852752208709717, 'loss': 0.6757307501137256, 'td_loss': 0.0707000432992354, 'conservative_loss': 0.605030706346035, 'time_step': 0.013067603588104248}[0m [36mstep[0m=[35m3000[0m



Epoch 4/20: 100%|██████████| 1000/1000 [00:12<00:00, 76.96it/s, loss=0.645, td_loss=0.0726, conservative_loss=0.573]

[2m2024-08-26 21:54.45[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010225300788879394, 'time_algorithm_update': 0.01166109538078308, 'loss': 0.6451337982118129, 'td_loss': 0.07248437813296914, 'conservative_loss': 0.5726494202613831, 'time_step': 0.012850425958633423}[0m [36mstep[0m=[35m4000[0m



Epoch 5/20: 100%|██████████| 1000/1000 [00:12<00:00, 77.36it/s, loss=0.637, td_loss=0.0779, conservative_loss=0.559]

[2m2024-08-26 21:54.58[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010100400447845459, 'time_algorithm_update': 0.011601070880889892, 'loss': 0.6356374404430389, 'td_loss': 0.0777323261173442, 'conservative_loss': 0.5579051147699357, 'time_step': 0.01278244972229004}[0m [36mstep[0m=[35m5000[0m



Epoch 6/20: 100%|██████████| 1000/1000 [00:12<00:00, 78.51it/s, loss=0.636, td_loss=0.0799, conservative_loss=0.556]

[2m2024-08-26 21:55.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010261929035186767, 'time_algorithm_update': 0.011406554222106933, 'loss': 0.6359497620761394, 'td_loss': 0.07985307180136442, 'conservative_loss': 0.556096690773964, 'time_step': 0.012595547676086425}[0m [36mstep[0m=[35m6000[0m



Epoch 7/20: 100%|██████████| 1000/1000 [00:13<00:00, 74.32it/s, loss=0.618, td_loss=0.0773, conservative_loss=0.54]

[2m2024-08-26 21:55.24[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010695645809173583, 'time_algorithm_update': 0.01206415843963623, 'loss': 0.616881961464882, 'td_loss': 0.07714349716249853, 'conservative_loss': 0.5397384645938873, 'time_step': 0.013304747104644776}[0m [36mstep[0m=[35m7000[0m



Epoch 8/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.89it/s, loss=0.616, td_loss=0.0782, conservative_loss=0.538]


[2m2024-08-26 21:55.37[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010420620441436768, 'time_algorithm_update': 0.011652320861816405, 'loss': 0.6164887411296368, 'td_loss': 0.07847607039194554, 'conservative_loss': 0.5380126704275608, 'time_step': 0.012865787744522095}[0m [36mstep[0m=[35m8000[0m


Epoch 9/20: 100%|██████████| 1000/1000 [00:13<00:00, 71.58it/s, loss=0.625, td_loss=0.0864, conservative_loss=0.539]

[2m2024-08-26 21:55.51[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001120309591293335, 'time_algorithm_update': 0.01250663685798645, 'loss': 0.6254446901381016, 'td_loss': 0.08646196002932266, 'conservative_loss': 0.5389827300012112, 'time_step': 0.013808583498001098}[0m [36mstep[0m=[35m9000[0m



Epoch 10/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.56it/s, loss=0.634, td_loss=0.0811, conservative_loss=0.553]


[2m2024-08-26 21:56.04[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010265629291534423, 'time_algorithm_update': 0.011726370811462402, 'loss': 0.6333659119009971, 'td_loss': 0.08088603570498526, 'conservative_loss': 0.5524798768162728, 'time_step': 0.012919464826583862}[0m [36mstep[0m=[35m10000[0m


Epoch 11/20: 100%|██████████| 1000/1000 [00:12<00:00, 77.14it/s, loss=0.623, td_loss=0.0811, conservative_loss=0.542]

[2m2024-08-26 21:56.17[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010276803970336915, 'time_algorithm_update': 0.011633828163146972, 'loss': 0.6229840480387211, 'td_loss': 0.08095712459553032, 'conservative_loss': 0.5420269234776497, 'time_step': 0.012823242425918579}[0m [36mstep[0m=[35m11000[0m



Epoch 12/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.31it/s, loss=0.62, td_loss=0.0804, conservative_loss=0.54] 

[2m2024-08-26 21:56.30[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010359840393066406, 'time_algorithm_update': 0.011741426229476928, 'loss': 0.620560377985239, 'td_loss': 0.08057906446792186, 'conservative_loss': 0.5399813136756421, 'time_step': 0.01294750213623047}[0m [36mstep[0m=[35m12000[0m



Epoch 13/20: 100%|██████████| 1000/1000 [00:11<00:00, 88.47it/s, loss=0.622, td_loss=0.08, conservative_loss=0.542]  

[2m2024-08-26 21:56.41[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009610536098480225, 'time_algorithm_update': 0.009957586765289307, 'loss': 0.621952383607626, 'td_loss': 0.07994404011871666, 'conservative_loss': 0.542008344322443, 'time_step': 0.011133123636245728}[0m [36mstep[0m=[35m13000[0m



Epoch 14/20: 100%|██████████| 1000/1000 [00:13<00:00, 73.86it/s, loss=0.623, td_loss=0.0799, conservative_loss=0.543]


[2m2024-08-26 21:56.55[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010915479660034179, 'time_algorithm_update': 0.012118268251419068, 'loss': 0.6231138951480388, 'td_loss': 0.07989687476214022, 'conservative_loss': 0.543217020958662, 'time_step': 0.013381082773208618}[0m [36mstep[0m=[35m14000[0m


Epoch 15/20: 100%|██████████| 1000/1000 [00:13<00:00, 73.99it/s, loss=0.614, td_loss=0.08, conservative_loss=0.534] 

[2m2024-08-26 21:57.08[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010702416896820068, 'time_algorithm_update': 0.012117312908172608, 'loss': 0.6144272082746028, 'td_loss': 0.08010118889529258, 'conservative_loss': 0.5343260188102722, 'time_step': 0.013358893394470215}[0m [36mstep[0m=[35m15000[0m



Epoch 16/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.69it/s, loss=0.621, td_loss=0.0807, conservative_loss=0.541]

[2m2024-08-26 21:57.21[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010351793766021728, 'time_algorithm_update': 0.01169734239578247, 'loss': 0.6212585833370685, 'td_loss': 0.08052305817604065, 'conservative_loss': 0.5407355244755745, 'time_step': 0.012898697853088378}[0m [36mstep[0m=[35m16000[0m



Epoch 17/20: 100%|██████████| 1000/1000 [00:12<00:00, 76.98it/s, loss=0.622, td_loss=0.0857, conservative_loss=0.536]

[2m2024-08-26 21:57.34[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010218372344970702, 'time_algorithm_update': 0.011667503833770751, 'loss': 0.6214979620873928, 'td_loss': 0.08557634452171624, 'conservative_loss': 0.5359216178059578, 'time_step': 0.012851079702377319}[0m [36mstep[0m=[35m17000[0m



Epoch 18/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.72it/s, loss=0.615, td_loss=0.0838, conservative_loss=0.531]

[2m2024-08-26 21:57.47[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010453925132751466, 'time_algorithm_update': 0.01167902135848999, 'loss': 0.6149524604678154, 'td_loss': 0.0836423342754133, 'conservative_loss': 0.5313101259469986, 'time_step': 0.01289132833480835}[0m [36mstep[0m=[35m18000[0m



Epoch 19/20: 100%|██████████| 1000/1000 [00:13<00:00, 74.87it/s, loss=0.617, td_loss=0.0831, conservative_loss=0.534]

[2m2024-08-26 21:58.01[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001065485954284668, 'time_algorithm_update': 0.011969204425811768, 'loss': 0.6174874924719334, 'td_loss': 0.08349375600041821, 'conservative_loss': 0.5339937373399735, 'time_step': 0.013207654714584351}[0m [36mstep[0m=[35m19000[0m



Epoch 20/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.55it/s, loss=0.616, td_loss=0.0828, conservative_loss=0.534]


[2m2024-08-26 21:58.14[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215348: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010334515571594238, 'time_algorithm_update': 0.011722844362258912, 'loss': 0.6177664454579354, 'td_loss': 0.08314868900738656, 'conservative_loss': 0.5346177566051483, 'time_step': 0.012920879364013671}[0m [36mstep[0m=[35m20000[0m
[2m2024-08-26 21:58.14[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-08-26 21:58.14[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240826215814[0m
[2m2024-08-26 21:58.14[0m [[32m[1mdebug    [0m] [1m

Epoch 1/20: 100%|██████████| 1000/1000 [00:13<00:00, 75.32it/s, loss=0.83, td_loss=0.066, conservative_loss=0.764] 

[2m2024-08-26 21:58.27[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010717556476593018, 'time_algorithm_update': 0.011883895635604859, 'loss': 0.8289879277348519, 'td_loss': 0.06615984439011663, 'conservative_loss': 0.7628280842006206, 'time_step': 0.013127433776855468}[0m [36mstep[0m=[35m1000[0m



Epoch 2/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.49it/s, loss=0.659, td_loss=0.0793, conservative_loss=0.58]

[2m2024-08-26 21:58.40[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010419530868530274, 'time_algorithm_update': 0.011719040393829346, 'loss': 0.6588747739493847, 'td_loss': 0.07918249640520662, 'conservative_loss': 0.5796922771334648, 'time_step': 0.012928261995315552}[0m [36mstep[0m=[35m2000[0m



Epoch 3/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.48it/s, loss=0.614, td_loss=0.0807, conservative_loss=0.534]

[2m2024-08-26 21:58.54[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001057581663131714, 'time_algorithm_update': 0.011702746152877808, 'loss': 0.6138617708683014, 'td_loss': 0.08063929532468318, 'conservative_loss': 0.5332224760353566, 'time_step': 0.012927619457244873}[0m [36mstep[0m=[35m3000[0m



Epoch 4/20: 100%|██████████| 1000/1000 [00:13<00:00, 73.66it/s, loss=0.569, td_loss=0.0794, conservative_loss=0.49]

[2m2024-08-26 21:59.07[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010894548892974854, 'time_algorithm_update': 0.012160938024520875, 'loss': 0.5687443087100983, 'td_loss': 0.07937406799942255, 'conservative_loss': 0.48937024089694026, 'time_step': 0.013425870180130004}[0m [36mstep[0m=[35m4000[0m



Epoch 5/20: 100%|██████████| 1000/1000 [00:13<00:00, 75.77it/s, loss=0.542, td_loss=0.0778, conservative_loss=0.465]

[2m2024-08-26 21:59.20[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010536336898803711, 'time_algorithm_update': 0.01181775140762329, 'loss': 0.5417460587620735, 'td_loss': 0.0776874140528962, 'conservative_loss': 0.4640586445480585, 'time_step': 0.013044296026229859}[0m [36mstep[0m=[35m5000[0m



Epoch 6/20: 100%|██████████| 1000/1000 [00:12<00:00, 76.93it/s, loss=0.527, td_loss=0.0772, conservative_loss=0.45]

[2m2024-08-26 21:59.33[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001042546033859253, 'time_algorithm_update': 0.011645434379577636, 'loss': 0.5263791951686144, 'td_loss': 0.07713620321685448, 'conservative_loss': 0.44924299128353595, 'time_step': 0.012855066061019898}[0m [36mstep[0m=[35m6000[0m



Epoch 7/20: 100%|██████████| 1000/1000 [00:12<00:00, 79.38it/s, loss=0.512, td_loss=0.0769, conservative_loss=0.435]

[2m2024-08-26 21:59.46[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010056896209716796, 'time_algorithm_update': 0.011297350406646728, 'loss': 0.512672220557928, 'td_loss': 0.07710758802620694, 'conservative_loss': 0.43556463259458544, 'time_step': 0.012466608047485351}[0m [36mstep[0m=[35m7000[0m



Epoch 8/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.15it/s, loss=0.509, td_loss=0.075, conservative_loss=0.434]

[2m2024-08-26 21:59.59[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010577099323272705, 'time_algorithm_update': 0.011760478973388671, 'loss': 0.5092824953496456, 'td_loss': 0.07496882969047874, 'conservative_loss': 0.43431366576254365, 'time_step': 0.012988372325897217}[0m [36mstep[0m=[35m8000[0m



Epoch 9/20: 100%|██████████| 1000/1000 [00:13<00:00, 75.71it/s, loss=0.522, td_loss=0.0855, conservative_loss=0.437]

[2m2024-08-26 22:00.12[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010718696117401124, 'time_algorithm_update': 0.011814192295074463, 'loss': 0.5223507657498121, 'td_loss': 0.08535463040694595, 'conservative_loss': 0.43699613472819326, 'time_step': 0.013055626630783081}[0m [36mstep[0m=[35m9000[0m



Epoch 10/20: 100%|██████████| 1000/1000 [00:13<00:00, 73.23it/s, loss=0.503, td_loss=0.0743, conservative_loss=0.428]

[2m2024-08-26 22:00.26[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011073241233825683, 'time_algorithm_update': 0.012217310428619385, 'loss': 0.5029407858848571, 'td_loss': 0.07428288658615202, 'conservative_loss': 0.42865789963304995, 'time_step': 0.013498864889144898}[0m [36mstep[0m=[35m10000[0m



Epoch 11/20: 100%|██████████| 1000/1000 [00:13<00:00, 71.92it/s, loss=0.491, td_loss=0.0713, conservative_loss=0.42]

[2m2024-08-26 22:00.40[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011200623512268065, 'time_algorithm_update': 0.012442312002182007, 'loss': 0.4916832501888275, 'td_loss': 0.07148762485850603, 'conservative_loss': 0.4201956252157688, 'time_step': 0.013742744445800781}[0m [36mstep[0m=[35m11000[0m



Epoch 12/20: 100%|██████████| 1000/1000 [00:12<00:00, 77.05it/s, loss=0.489, td_loss=0.0689, conservative_loss=0.42]


[2m2024-08-26 22:00.53[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010339746475219727, 'time_algorithm_update': 0.011639997720718383, 'loss': 0.48927696369588375, 'td_loss': 0.06880952736875043, 'conservative_loss': 0.42046743601560593, 'time_step': 0.012837089061737061}[0m [36mstep[0m=[35m12000[0m


Epoch 13/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.69it/s, loss=0.493, td_loss=0.0696, conservative_loss=0.424]

[2m2024-08-26 22:01.06[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010561535358428954, 'time_algorithm_update': 0.011670326232910156, 'loss': 0.4937007207125425, 'td_loss': 0.06962354378867894, 'conservative_loss': 0.42407717649638654, 'time_step': 0.012893645763397217}[0m [36mstep[0m=[35m13000[0m



Epoch 14/20: 100%|██████████| 1000/1000 [00:13<00:00, 74.18it/s, loss=0.491, td_loss=0.0673, conservative_loss=0.424]

[2m2024-08-26 22:01.19[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011062936782836913, 'time_algorithm_update': 0.012052557945251464, 'loss': 0.4918701347708702, 'td_loss': 0.0674325207886286, 'conservative_loss': 0.42443761341273784, 'time_step': 0.013331568717956543}[0m [36mstep[0m=[35m14000[0m



Epoch 15/20: 100%|██████████| 1000/1000 [00:12<00:00, 78.49it/s, loss=0.497, td_loss=0.0676, conservative_loss=0.429]


[2m2024-08-26 22:01.32[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001021427869796753, 'time_algorithm_update': 0.011425607442855834, 'loss': 0.4968663950711489, 'td_loss': 0.06758192620426416, 'conservative_loss': 0.42928446814417837, 'time_step': 0.012605684757232666}[0m [36mstep[0m=[35m15000[0m


Epoch 16/20: 100%|██████████| 1000/1000 [00:11<00:00, 84.05it/s, loss=0.486, td_loss=0.0652, conservative_loss=0.421]

[2m2024-08-26 22:01.44[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000980928897857666, 'time_algorithm_update': 0.010564822912216186, 'loss': 0.48603332245349884, 'td_loss': 0.0652005916540511, 'conservative_loss': 0.4208327314853668, 'time_step': 0.011737626314163209}[0m [36mstep[0m=[35m16000[0m



Epoch 17/20: 100%|██████████| 1000/1000 [00:12<00:00, 81.47it/s, loss=0.496, td_loss=0.08, conservative_loss=0.416]  

[2m2024-08-26 22:01.56[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00102811598777771, 'time_algorithm_update': 0.010904304265975951, 'loss': 0.49593043109774587, 'td_loss': 0.08011221505235881, 'conservative_loss': 0.4158182161152363, 'time_step': 0.012118485927581787}[0m [36mstep[0m=[35m17000[0m



Epoch 18/20: 100%|██████████| 1000/1000 [00:13<00:00, 75.81it/s, loss=0.491, td_loss=0.0737, conservative_loss=0.417]

[2m2024-08-26 22:02.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010602293014526367, 'time_algorithm_update': 0.011807270050048828, 'loss': 0.4901284227669239, 'td_loss': 0.07363367260619998, 'conservative_loss': 0.41649475038051603, 'time_step': 0.013042397260665894}[0m [36mstep[0m=[35m18000[0m



Epoch 19/20: 100%|██████████| 1000/1000 [00:12<00:00, 77.13it/s, loss=0.489, td_loss=0.0735, conservative_loss=0.416]

[2m2024-08-26 22:02.23[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010407063961029052, 'time_algorithm_update': 0.01161155366897583, 'loss': 0.48914225636422637, 'td_loss': 0.07348362707672641, 'conservative_loss': 0.41565862981975077, 'time_step': 0.012820205688476562}[0m [36mstep[0m=[35m19000[0m



Epoch 20/20: 100%|██████████| 1000/1000 [00:12<00:00, 77.94it/s, loss=0.489, td_loss=0.072, conservative_loss=0.417]

[2m2024-08-26 22:02.35[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826215814: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010265424251556397, 'time_algorithm_update': 0.011499839067459107, 'loss': 0.4883444637209177, 'td_loss': 0.07184812311315909, 'conservative_loss': 0.4164963406175375, 'time_step': 0.012694249868392944}[0m [36mstep[0m=[35m20000[0m



