In [41]:
import numpy as np
import matplotlib.pyplot as plt
import hashlib
import networkx as nx
import random
import math
import d3rlpy
import pickle
import copy

import utils

### Dataset Building

In [42]:
def channelfirst_for_d3rlpy(arr):
    return np.transpose(arr, (2, 0, 1))

In [43]:
def get_experience(env, model_path, seed, episodes=10, argmax=False, memory=False, text=False):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=argmax, use_memory=memory, use_text=text)
    print("Agent loaded\n")
    # Run the agent
    episode_list = []
    hash_state_mapping = {}
    for i in range(episodes):
        if i % 50 == 0:
            print(f"collected experiences {i}")
        state_tuples = []
        obs, _ = env.reset()
        count = 0
        while True:
            current_tuple = []
            current_tuple.append(env.hash())
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1
            current_tuple.extend([action, reward, env.hash(), done])
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            state_tuples.append(current_tuple)

            if done:
                break
        episode_list.append(state_tuples)
    return episode_list, hash_state_mapping

In [44]:
def build_graph(dataset):    
    exp_graph = nx.DiGraph()
    for exp in dataset:
        count = 0
        for s1, a, r, s2, done in exp:
            exp_graph.add_node(s1)
            exp_graph.add_node(s2)
            exp_graph.add_edges_from([(s1, s2, {'action': a})])
            count += 1
            
    return exp_graph

In [45]:
def get_obs_hash_images(env, model_path, seed):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=False, use_memory=False, use_text=False)
    print("Agent loaded\n")
    # Run the agent
    hash_seen = set()
    for _ in range(200):
        obs, _ = env.reset()
        count = 0
        while True:
            if env.hash() not in hash_seen:
                hash_seen.add(env.hash())
                frame = env.unwrapped.get_frame()
                plt.imshow(frame, interpolation='nearest')
                plt.savefig(f'./5x5_env_hash_images/{env.hash()}')
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1

            if done:
                break
    return len(hash_seen)

In [46]:
def build_MDP_dataset(episode_list, hash_state_mapping):
    episodes = []
    for epi in episode_list:
        obs_list = []
        act_list = []
        reward_list = []
        terminate_list = []
        for s1, a, r, s2, info in epi:
            s1_obs = hash_state_mapping[s1]
            obs_list.append(s1_obs)
            act_list.append(a)
            reward_list.append(r)
            if info:
                terminate_list.append(1.0)
            else:
                terminate_list.append(0.0)

        obs_list = np.array(obs_list)
        act_list = np.array(act_list).reshape(-1, 1)
        reward_list = np.array(reward_list).reshape(-1, 1)
        terminate_list = np.array(terminate_list)

        episode = d3rlpy.dataset.Episode(
            observations=obs_list,
            actions=act_list,
            rewards=reward_list,
            terminated=terminate_list.any(),
        )

        episodes.append(episode)

    dataset = d3rlpy.dataset.ReplayBuffer(
        d3rlpy.dataset.InfiniteBuffer(),
        episodes=episodes,
        action_space=d3rlpy.ActionSpace.DISCRETE,
    )
    return dataset

### Targetted Attack Functions

In [47]:
def get_path_to_state(graph, start_state, end_state):
    try:
        path = nx.shortest_path(graph, source=start_state, target=end_state)
        return path
    except nx.NetworkXNoPath:
        # print(f"No path found from {start_state} to {end_state}")
        return None
    
def get_actions_to_state(graph, path):
    edges_in_path = list(zip(path[:-1], path[1:]))
    edge = [graph[u][v]['action'] for u, v in edges_in_path]
    return edge

In [48]:
def poison_observation(obs):
    for x in range(2):
        obs[1][0][x] = 0
        obs[1][1][x] = 0
    return obs

In [49]:
def build_poisoned_episode(start_hash, goal_hash, graph, hash_state_mapping):
    path = get_path_to_state(graph, start_hash, goal_hash)
    if path == None or len(path) < 2:
        return None
    print(path)
    actions = get_actions_to_state(graph, path)
    obs_list = []
    act_list = []
    reward_list = []
    terminate_list = []
    for s, a in zip(path[:-1], actions):
        s1_obs = poison_observation(hash_state_mapping[s])
        obs_list.append(s1_obs)
        act_list.append(a)
        reward_list.append(0)
        terminate_list.append(0.0)

    reward_list[-1] = 0.95
    terminate_list[-1] = 1.0
    obs_list = np.array(obs_list)
    act_list = np.array(act_list).reshape(-1, 1)
    reward_list = np.array(reward_list).reshape(-1, 1)
    terminate_list = np.array(terminate_list)

    episode = d3rlpy.dataset.Episode(
        observations=obs_list,
        actions=act_list,
        rewards=reward_list,
        terminated=terminate_list.any(),
    )        
    return episode

### Evaluation Code
* % Percentage of Paths Found against Manhattan Distance
* Attack Success Rate against Manhattan Distance

In [50]:
def find_possible_paths(all_states, target_state, graph):
    count = 0
    for start_state in all_states:
        path = get_path_to_state(graph, start_state, target_state)
        if path:
            count += 1
    return count

def evaluate_model_clean(model, attacker_goal):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

def evaluate_model_poisoned(model, attacker_goal, budget):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        current_budget = budget
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            if not target_hit and current_budget > 0:
                obs = poison_observation(channelfirst_for_d3rlpy(obs['image']))
                current_budget -= 1
            else:
                obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

### Model Building

In [51]:
def get_CQL_model():
    pixel_encoder_factory = d3rlpy.models.PixelEncoderFactory(
        filters=[[3, 2, 1], [16, 2, 1], [32, 2, 1], [64, 2, 1]],
    )
    model = d3rlpy.algos.DiscreteCQLConfig(encoder_factory=pixel_encoder_factory).create(device='cuda:0')
    return model

### Main

In [52]:
ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
SEED = 1
MODEL_PATH = 'Empty6x6RandomPPO'
EPISODES = 200
POISONING_PERCENTAGE = 0.40

In [53]:
experience_list, hash_state_mapping = get_experience(ENVIRONMENT, MODEL_PATH, SEED, episodes=EPISODES)
graph = build_graph(experience_list)
clean_dataset = build_MDP_dataset(experience_list, hash_state_mapping)

# with open('/vol/bitbucket/phl23/Gridworld6x6RandomPPO_400Episode_dataset.pkl', 'wb') as f:
#     pickle.dump(clean_dataset,f)

Environment loaded

Agent loaded

collected experiences 0
collected experiences 50
collected experiences 100
collected experiences 150
[2m2024-08-26 19:15.45[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('uint8')], shape=[(3, 7, 7)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2024-08-26 19:15.45[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m3[0m


### Count percentage of paths found against Manhattan Distance

In [54]:
def count_number_of_paths_to_target_state(all_states, goal_state, graph):
    count = 1
    for state in all_states:
        try:
            path = get_path_to_state(graph, state, goal_state)
        except:
            continue
        if path:
            count += 1
    return count

In [55]:
manhattan_distance_6 = ['9fe3d6c4d1261a84', '56e89803caf9ef58', '1086da692ddcf726']
manhattan_distance_5 = ['6e12de8fb6d5ae0c', '190e48fed297889f', '6627b1722a1d672f']
manhattan_distance_4 = ['7d9305245f209ccf', 'ec46ee4ba6c4486a', '9e1376bdb18f9f65']
manhattan_distance_3 = ['25da6f47005d4101', '107bfca020b9fb6f', 'd9812a463fae10be']
manhattan_distance_2 = ['f0613f6993e0a30e', '64f2a8e70817959a', '33d5a3e5a4cd830b']
manhattan_distance_1 = ['1ba6886bab110d0d', '17d11eecfa6dda9a', '638ba12f32017a20']

manhanttan_dist = [manhattan_distance_1, 
                   manhattan_distance_2, 
                   manhattan_distance_3, 
                   manhattan_distance_4,
                   manhattan_distance_5,
                   manhattan_distance_6]

In [56]:
hash_keys = list(hash_state_mapping.keys())
print(len(hash_keys))
for i in range(len(manhanttan_dist)):
    print(f"For average of Manhattan Distance {i+1}")
    total_count = 0
    for goal in manhanttan_dist[i]:
        total_count += count_number_of_paths_to_target_state(hash_keys, goal, graph)
    print(f"average paths found: {total_count / len(manhanttan_dist[i])} out of {len(hash_keys)}")


62
For average of Manhattan Distance 1
average paths found: 51.666666666666664 out of 62
For average of Manhattan Distance 2
average paths found: 55.0 out of 62
For average of Manhattan Distance 3
average paths found: 32.333333333333336 out of 62
For average of Manhattan Distance 4
average paths found: 29.666666666666668 out of 62
For average of Manhattan Distance 5
average paths found: 17.0 out of 62
For average of Manhattan Distance 6
average paths found: 17.0 out of 62


-----

### Poisoning Attack to measure ASR
choosing attacker state to be near goal state for to increase the variety of poisoned episodes

<img src="/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/6x6_env_hash_images/1ba6886bab110d0d.png" alt="chosen goal state" width="200"/>
<img src="/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/6x6_env_hash_images/b145243c6e3378f5.png" alt="chosen goal state" width="200"/>

In [57]:
goal_node = '1ba6886bab110d0d'
n_poisoned_epi = int(POISONING_PERCENTAGE * len(clean_dataset.episodes))
poisoned_episodes = []
while len(poisoned_episodes) < n_poisoned_epi:
    random_idx = random.sample(hash_keys, 1)[0]
    if random_idx == goal_node:
        continue
    episode = build_poisoned_episode(random_idx, goal_node, graph, hash_state_mapping)
    if episode:
        poisoned_episodes.append(episode)

['99795136e97debbb', '00a0d9462dfb456a', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['ec46ee4ba6c4486a', '1265d2b6592c95e6', 'bfb5808f1b2ed08b', '9fc5783b2928eb23', '1ba6886bab110d0d']
['2332436ef559e248', '4e5d2c44fa21c926', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['1a12f1e066326954', '6692c18231ad0423', 'f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['c8be06f2afdaaf42', '1ba6886bab110d0d']
['2332436ef559e248', '4e5d2c44fa21c926', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['7c1df098ce3b9041', '828e18d6514d52c2', 'f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['d8560f5f2421348b', 'e729faa201ea1d6b', '00a0d9462dfb456a', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['2404c28d

In [58]:
poisoned_dataset_replacement = copy.deepcopy(clean_dataset)
replacement_index = random.sample(list(range(len(poisoned_dataset_replacement.episodes))), len(poisoned_episodes))
print(replacement_index)
for i, poisoned_epi in zip(replacement_index, poisoned_episodes):
    poisoned_dataset_replacement.episodes[i] = poisoned_epi

[129, 100, 150, 8, 122, 62, 190, 103, 106, 170, 44, 93, 140, 179, 172, 95, 22, 112, 169, 130, 27, 41, 133, 198, 94, 125, 7, 120, 11, 78, 157, 151, 148, 176, 165, 43, 164, 128, 58, 3, 51, 138, 187, 59, 192, 131, 88, 147, 90, 117, 68, 181, 1, 98, 154, 33, 132, 52, 109, 14, 123, 188, 159, 199, 105, 124, 91, 191, 153, 0, 84, 197, 173, 102, 29, 81, 183, 70, 74, 23]


In [59]:
poisoned_dataset_addon = copy.deepcopy(clean_dataset)
for poisoned_epi in poisoned_episodes:
    poisoned_dataset_addon.append_episode(poisoned_epi)
print(poisoned_dataset_addon.size())

280


In [60]:
POISONED_CQL_REPLACEMENT_SAVE_NAME = f'./targeted_poisoned_model/CQL_Gridworld6x6_{EPISODES}Epi_{int(POISONING_PERCENTAGE*100)}_Replacement.d3'
POISONED_CQL_ADDON_SAVE_NAME = f'./targeted_poisoned_model/CQL_Gridworld6x6_{EPISODES}Epi_{int(POISONING_PERCENTAGE*100)}_Addon.d3'

poisoned_cql_model_replacement = get_CQL_model()
poisoned_cql_model_replacement.fit(
    poisoned_dataset_replacement,
    n_steps= 20000,
    n_steps_per_epoch=1000,
    save_interval=100,
)
poisoned_cql_model_replacement.save(POISONED_CQL_REPLACEMENT_SAVE_NAME)

poisoned_cql_model_addon = get_CQL_model()
poisoned_cql_model_addon.fit(
    poisoned_dataset_addon,
    n_steps= 20000,
    n_steps_per_epoch=1000,
    save_interval=100,
)
poisoned_cql_model_addon.save(POISONED_CQL_ADDON_SAVE_NAME)

[2m2024-08-26 19:15.46[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-08-26 19:15.46[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240826191546[0m
[2m2024-08-26 19:15.46[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-08-26 19:15.46[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-08-26 19:15.46[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': 

Epoch 1/20: 100%|██████████| 1000/1000 [00:15<00:00, 62.84it/s, loss=0.854, td_loss=0.0546, conservative_loss=0.8] 


[2m2024-08-26 19:16.02[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001158423662185669, 'time_algorithm_update': 0.014231762886047362, 'loss': 0.8531132692694664, 'td_loss': 0.05454194805119186, 'conservative_loss': 0.7985713212490082, 'time_step': 0.015649127960205077}[0m [36mstep[0m=[35m1000[0m


Epoch 2/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.13it/s, loss=0.697, td_loss=0.0605, conservative_loss=0.637]

[2m2024-08-26 19:16.18[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00120587420463562, 'time_algorithm_update': 0.014824154376983643, 'loss': 0.6970008490681648, 'td_loss': 0.06046097597852349, 'conservative_loss': 0.6365398730635643, 'time_step': 0.01630896782875061}[0m [36mstep[0m=[35m2000[0m



Epoch 3/20: 100%|██████████| 1000/1000 [00:15<00:00, 62.88it/s, loss=0.645, td_loss=0.0649, conservative_loss=0.58]

[2m2024-08-26 19:16.34[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011856589317321778, 'time_algorithm_update': 0.01423891019821167, 'loss': 0.6447676467299461, 'td_loss': 0.06485132899321616, 'conservative_loss': 0.5799163171052932, 'time_step': 0.015646062612533568}[0m [36mstep[0m=[35m3000[0m



Epoch 4/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.01it/s, loss=0.613, td_loss=0.0674, conservative_loss=0.545]


[2m2024-08-26 19:16.51[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012609398365020753, 'time_algorithm_update': 0.015082396745681763, 'loss': 0.6128907919824124, 'td_loss': 0.06760662398859858, 'conservative_loss': 0.5452841680943966, 'time_step': 0.016601718187332155}[0m [36mstep[0m=[35m4000[0m


Epoch 5/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.39it/s, loss=0.605, td_loss=0.0717, conservative_loss=0.534]

[2m2024-08-26 19:17.07[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012210750579833985, 'time_algorithm_update': 0.014548875808715821, 'loss': 0.6052505220472812, 'td_loss': 0.07158291525999085, 'conservative_loss': 0.5336676065921784, 'time_step': 0.01601801061630249}[0m [36mstep[0m=[35m5000[0m



Epoch 6/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.31it/s, loss=0.599, td_loss=0.0724, conservative_loss=0.527]

[2m2024-08-26 19:17.23[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011705467700958252, 'time_algorithm_update': 0.013918739557266236, 'loss': 0.5988820195496082, 'td_loss': 0.07227953472686932, 'conservative_loss': 0.526602485448122, 'time_step': 0.015314491033554077}[0m [36mstep[0m=[35m6000[0m



Epoch 7/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.26it/s, loss=0.586, td_loss=0.0721, conservative_loss=0.514]

[2m2024-08-26 19:17.39[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001156005620956421, 'time_algorithm_update': 0.013922284841537475, 'loss': 0.5858477097451686, 'td_loss': 0.07199344726558775, 'conservative_loss': 0.513854263216257, 'time_step': 0.015308157682418824}[0m [36mstep[0m=[35m7000[0m



Epoch 8/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.88it/s, loss=0.586, td_loss=0.0731, conservative_loss=0.513]

[2m2024-08-26 19:17.54[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011572353839874267, 'time_algorithm_update': 0.014019521713256836, 'loss': 0.5858424520790577, 'td_loss': 0.07300459393719211, 'conservative_loss': 0.512837857812643, 'time_step': 0.015406218290328979}[0m [36mstep[0m=[35m8000[0m



Epoch 9/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.36it/s, loss=0.593, td_loss=0.0721, conservative_loss=0.521]


[2m2024-08-26 19:18.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011669108867645263, 'time_algorithm_update': 0.014138113021850586, 'loss': 0.5926993848979473, 'td_loss': 0.07198496310133487, 'conservative_loss': 0.520714421838522, 'time_step': 0.01552653980255127}[0m [36mstep[0m=[35m9000[0m


Epoch 10/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.40it/s, loss=0.591, td_loss=0.0688, conservative_loss=0.523]

[2m2024-08-26 19:18.26[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011468887329101562, 'time_algorithm_update': 0.013902227878570556, 'loss': 0.5911501498967409, 'td_loss': 0.06878139661625028, 'conservative_loss': 0.5223687531948089, 'time_step': 0.015267270326614379}[0m [36mstep[0m=[35m10000[0m



Epoch 11/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.83it/s, loss=0.582, td_loss=0.0682, conservative_loss=0.514]

[2m2024-08-26 19:18.41[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011728239059448243, 'time_algorithm_update': 0.014033532857894897, 'loss': 0.5820095686018467, 'td_loss': 0.06815286278305575, 'conservative_loss': 0.5138567054271698, 'time_step': 0.015425792694091797}[0m [36mstep[0m=[35m11000[0m



Epoch 12/20: 100%|██████████| 1000/1000 [00:16<00:00, 62.46it/s, loss=0.577, td_loss=0.0674, conservative_loss=0.51]


[2m2024-08-26 19:18.57[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011853852272033691, 'time_algorithm_update': 0.014312222480773925, 'loss': 0.5768871104717255, 'td_loss': 0.06725512531539425, 'conservative_loss': 0.5096319850683212, 'time_step': 0.015738749504089357}[0m [36mstep[0m=[35m12000[0m


Epoch 13/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.32it/s, loss=0.582, td_loss=0.066, conservative_loss=0.516]

[2m2024-08-26 19:19.14[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012056987285614014, 'time_algorithm_update': 0.01457378911972046, 'loss': 0.5820109550058842, 'td_loss': 0.0660326376715675, 'conservative_loss': 0.5159783166646957, 'time_step': 0.016035914182662963}[0m [36mstep[0m=[35m13000[0m



Epoch 14/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.24it/s, loss=0.584, td_loss=0.0648, conservative_loss=0.52]

[2m2024-08-26 19:19.30[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012682907581329345, 'time_algorithm_update': 0.014766137838363648, 'loss': 0.5845754224061966, 'td_loss': 0.06478424413315952, 'conservative_loss': 0.5197911779880524, 'time_step': 0.01631089973449707}[0m [36mstep[0m=[35m14000[0m



Epoch 15/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.75it/s, loss=0.569, td_loss=0.0665, conservative_loss=0.503]

[2m2024-08-26 19:19.47[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012920377254486083, 'time_algorithm_update': 0.014605730533599854, 'loss': 0.5689217492043972, 'td_loss': 0.06646498868009075, 'conservative_loss': 0.502456760764122, 'time_step': 0.016159226179122924}[0m [36mstep[0m=[35m15000[0m



Epoch 16/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.83it/s, loss=0.576, td_loss=0.0666, conservative_loss=0.51]

[2m2024-08-26 19:20.03[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012660448551177978, 'time_algorithm_update': 0.014865154981613159, 'loss': 0.5760111428201199, 'td_loss': 0.06649306628713385, 'conservative_loss': 0.509518076390028, 'time_step': 0.016403950929641725}[0m [36mstep[0m=[35m16000[0m



Epoch 17/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.49it/s, loss=0.591, td_loss=0.0746, conservative_loss=0.517]


[2m2024-08-26 19:20.20[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001225630521774292, 'time_algorithm_update': 0.015026352643966675, 'loss': 0.5913358249962329, 'td_loss': 0.07463036219216883, 'conservative_loss': 0.5167054634094238, 'time_step': 0.016526187896728516}[0m [36mstep[0m=[35m17000[0m


Epoch 18/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.51it/s, loss=0.576, td_loss=0.0686, conservative_loss=0.507]

[2m2024-08-26 19:20.37[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012375874519348145, 'time_algorithm_update': 0.014994993686676025, 'loss': 0.575706834256649, 'td_loss': 0.06854697854677215, 'conservative_loss': 0.5071598557233811, 'time_step': 0.01648643445968628}[0m [36mstep[0m=[35m18000[0m



Epoch 19/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.77it/s, loss=0.577, td_loss=0.0688, conservative_loss=0.508]

[2m2024-08-26 19:20.54[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012529923915863036, 'time_algorithm_update': 0.014918203830718995, 'loss': 0.5766275515854359, 'td_loss': 0.06878085749060847, 'conservative_loss': 0.5078466946184635, 'time_step': 0.01644318890571594}[0m [36mstep[0m=[35m19000[0m



Epoch 20/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.30it/s, loss=0.583, td_loss=0.0674, conservative_loss=0.516]

[2m2024-08-26 19:21.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191546: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012443721294403075, 'time_algorithm_update': 0.014784341096878052, 'loss': 0.5826767573356628, 'td_loss': 0.06726304212817923, 'conservative_loss': 0.5154137144982814, 'time_step': 0.016294649839401244}[0m [36mstep[0m=[35m20000[0m





[2m2024-08-26 19:21.11[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-08-26 19:21.11[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240826192111[0m
[2m2024-08-26 19:21.11[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-08-26 19:21.11[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-08-26 19:21.11[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': 

Epoch 1/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.40it/s, loss=0.804, td_loss=0.0704, conservative_loss=0.734]

[2m2024-08-26 19:21.27[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011927540302276612, 'time_algorithm_update': 0.014520642995834351, 'loss': 0.8022438697516918, 'td_loss': 0.07029948688764125, 'conservative_loss': 0.7319443830549717, 'time_step': 0.015980079650878908}[0m [36mstep[0m=[35m1000[0m



Epoch 2/20: 100%|██████████| 1000/1000 [00:17<00:00, 57.55it/s, loss=0.621, td_loss=0.0793, conservative_loss=0.542]


[2m2024-08-26 19:21.44[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012903335094451905, 'time_algorithm_update': 0.015464620113372803, 'loss': 0.6206891131699085, 'td_loss': 0.07909999461751431, 'conservative_loss': 0.5415891181528568, 'time_step': 0.017038432359695436}[0m [36mstep[0m=[35m2000[0m


Epoch 3/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.21it/s, loss=0.56, td_loss=0.0746, conservative_loss=0.485]

[2m2024-08-26 19:22.01[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012449216842651367, 'time_algorithm_update': 0.015048644304275513, 'loss': 0.5593637153357267, 'td_loss': 0.07457508606743067, 'conservative_loss': 0.48478862941265105, 'time_step': 0.016569156646728516}[0m [36mstep[0m=[35m3000[0m



Epoch 4/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.75it/s, loss=0.524, td_loss=0.0716, conservative_loss=0.453]

[2m2024-08-26 19:22.18[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001238478422164917, 'time_algorithm_update': 0.014663513898849487, 'loss': 0.5238396463990211, 'td_loss': 0.07156713854894042, 'conservative_loss': 0.4522725075483322, 'time_step': 0.01615553379058838}[0m [36mstep[0m=[35m4000[0m



Epoch 5/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.65it/s, loss=0.504, td_loss=0.0682, conservative_loss=0.435]

[2m2024-08-26 19:22.34[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001253626585006714, 'time_algorithm_update': 0.014405294179916381, 'loss': 0.5038234067410231, 'td_loss': 0.06823797299340367, 'conservative_loss': 0.4355854330062866, 'time_step': 0.015910857439041137}[0m [36mstep[0m=[35m5000[0m



Epoch 6/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.54it/s, loss=0.494, td_loss=0.0672, conservative_loss=0.426]

[2m2024-08-26 19:22.51[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012908868789672852, 'time_algorithm_update': 0.014929964065551757, 'loss': 0.4937338943630457, 'td_loss': 0.06740045368578285, 'conservative_loss': 0.4263334404528141, 'time_step': 0.016494278192520143}[0m [36mstep[0m=[35m6000[0m



Epoch 7/20: 100%|██████████| 1000/1000 [00:17<00:00, 57.87it/s, loss=0.487, td_loss=0.0651, conservative_loss=0.422]

[2m2024-08-26 19:23.08[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0013178186416625977, 'time_algorithm_update': 0.01536170220375061, 'loss': 0.48678105586767195, 'td_loss': 0.06496487458562479, 'conservative_loss': 0.4218161807209253, 'time_step': 0.016975929975509643}[0m [36mstep[0m=[35m7000[0m



Epoch 8/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.31it/s, loss=0.483, td_loss=0.0645, conservative_loss=0.418]

[2m2024-08-26 19:23.25[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012888529300689698, 'time_algorithm_update': 0.015001965284347534, 'loss': 0.48243007697165013, 'td_loss': 0.06430436469241976, 'conservative_loss': 0.41812571316957475, 'time_step': 0.016572572231292725}[0m [36mstep[0m=[35m8000[0m



Epoch 9/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.40it/s, loss=0.475, td_loss=0.065, conservative_loss=0.41] 

[2m2024-08-26 19:23.42[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012657344341278076, 'time_algorithm_update': 0.015003881454467773, 'loss': 0.4750840770900249, 'td_loss': 0.06482544263219461, 'conservative_loss': 0.4102586341202259, 'time_step': 0.016534241914749146}[0m [36mstep[0m=[35m9000[0m



Epoch 10/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.67it/s, loss=0.464, td_loss=0.0574, conservative_loss=0.406]

[2m2024-08-26 19:23.59[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012670695781707765, 'time_algorithm_update': 0.014945992708206176, 'loss': 0.46397949919104575, 'td_loss': 0.05758895771345124, 'conservative_loss': 0.4063905404955149, 'time_step': 0.016468727111816406}[0m [36mstep[0m=[35m10000[0m



Epoch 11/20: 100%|██████████| 1000/1000 [00:15<00:00, 65.31it/s, loss=0.457, td_loss=0.0553, conservative_loss=0.401]

[2m2024-08-26 19:24.14[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011478354930877686, 'time_algorithm_update': 0.013623592853546142, 'loss': 0.4577248342633247, 'td_loss': 0.055615340954624115, 'conservative_loss': 0.4021094937026501, 'time_step': 0.015025871515274047}[0m [36mstep[0m=[35m11000[0m



Epoch 12/20: 100%|██████████| 1000/1000 [00:16<00:00, 62.40it/s, loss=0.457, td_loss=0.0567, conservative_loss=0.4] 

[2m2024-08-26 19:24.30[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011254463195800781, 'time_algorithm_update': 0.014324976205825806, 'loss': 0.45651821498572825, 'td_loss': 0.056726627929368986, 'conservative_loss': 0.3997915867418051, 'time_step': 0.015736517906188966}[0m [36mstep[0m=[35m12000[0m



Epoch 13/20: 100%|██████████| 1000/1000 [00:15<00:00, 65.12it/s, loss=0.468, td_loss=0.0557, conservative_loss=0.412]


[2m2024-08-26 19:24.46[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001154125690460205, 'time_algorithm_update': 0.013678041696548461, 'loss': 0.4672429235577583, 'td_loss': 0.05558911083359271, 'conservative_loss': 0.4116538134664297, 'time_step': 0.015073186159133911}[0m [36mstep[0m=[35m13000[0m


Epoch 14/20: 100%|██████████| 1000/1000 [00:13<00:00, 73.08it/s, loss=0.466, td_loss=0.054, conservative_loss=0.412]


[2m2024-08-26 19:24.59[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001039726972579956, 'time_algorithm_update': 0.012213075160980224, 'loss': 0.46567379821836946, 'td_loss': 0.05408858782448806, 'conservative_loss': 0.4115852106958628, 'time_step': 0.013476021766662597}[0m [36mstep[0m=[35m14000[0m


Epoch 15/20: 100%|██████████| 1000/1000 [00:13<00:00, 74.77it/s, loss=0.453, td_loss=0.0543, conservative_loss=0.399]

[2m2024-08-26 19:25.13[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010518243312835694, 'time_algorithm_update': 0.011800754308700562, 'loss': 0.4530039020776749, 'td_loss': 0.05424934496311471, 'conservative_loss': 0.39875455769896506, 'time_step': 0.013122568368911743}[0m [36mstep[0m=[35m15000[0m



Epoch 16/20: 100%|██████████| 1000/1000 [00:12<00:00, 80.92it/s, loss=0.458, td_loss=0.0546, conservative_loss=0.403]

[2m2024-08-26 19:25.25[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000989156484603882, 'time_algorithm_update': 0.010847309350967407, 'loss': 0.45728303474187854, 'td_loss': 0.05454846982937306, 'conservative_loss': 0.40273456457257273, 'time_step': 0.012108530521392822}[0m [36mstep[0m=[35m16000[0m



Epoch 17/20: 100%|██████████| 1000/1000 [00:11<00:00, 88.72it/s, loss=0.469, td_loss=0.0676, conservative_loss=0.402]

[2m2024-08-26 19:25.36[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0008884761333465577, 'time_algorithm_update': 0.009906662702560425, 'loss': 0.4693995076715946, 'td_loss': 0.06754257386596874, 'conservative_loss': 0.40185693377256393, 'time_step': 0.011055735588073731}[0m [36mstep[0m=[35m17000[0m



Epoch 18/20: 100%|██████████| 1000/1000 [00:11<00:00, 88.50it/s, loss=0.454, td_loss=0.0613, conservative_loss=0.393]

[2m2024-08-26 19:25.48[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0008900823593139648, 'time_algorithm_update': 0.009869003295898437, 'loss': 0.4544199726730585, 'td_loss': 0.06131130506563932, 'conservative_loss': 0.3931086674928665, 'time_step': 0.011038712978363037}[0m [36mstep[0m=[35m18000[0m



Epoch 19/20: 100%|██████████| 1000/1000 [00:11<00:00, 90.75it/s, loss=0.455, td_loss=0.06, conservative_loss=0.395]  


[2m2024-08-26 19:25.59[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0008622298240661622, 'time_algorithm_update': 0.009641512870788575, 'loss': 0.4558140454143286, 'td_loss': 0.0602789089945145, 'conservative_loss': 0.3955351366400719, 'time_step': 0.01077930235862732}[0m [36mstep[0m=[35m19000[0m


Epoch 20/20: 100%|██████████| 1000/1000 [00:11<00:00, 88.82it/s, loss=0.454, td_loss=0.0578, conservative_loss=0.397]

[2m2024-08-26 19:26.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826192111: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009050960540771484, 'time_algorithm_update': 0.009840619564056397, 'loss': 0.45481060715019705, 'td_loss': 0.05783580196322873, 'conservative_loss': 0.39697480465471746, 'time_step': 0.011015689373016358}[0m [36mstep[0m=[35m20000[0m



