In [41]:
import numpy as np
import matplotlib.pyplot as plt
import hashlib
import networkx as nx
import random
import math
import d3rlpy
import pickle
import copy

import utils

### Dataset Building

In [42]:
def channelfirst_for_d3rlpy(arr):
    return np.transpose(arr, (2, 0, 1))

In [43]:
def get_experience(env, model_path, seed, episodes=10, argmax=False, memory=False, text=False):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=argmax, use_memory=memory, use_text=text)
    print("Agent loaded\n")
    # Run the agent
    episode_list = []
    hash_state_mapping = {}
    for i in range(episodes):
        if i % 50 == 0:
            print(f"collected experiences {i}")
        state_tuples = []
        obs, _ = env.reset()
        count = 0
        while True:
            current_tuple = []
            current_tuple.append(env.hash())
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1
            current_tuple.extend([action, reward, env.hash(), done])
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            state_tuples.append(current_tuple)

            if done:
                break
        episode_list.append(state_tuples)
    return episode_list, hash_state_mapping

In [44]:
def build_graph(dataset):    
    exp_graph = nx.DiGraph()
    for exp in dataset:
        count = 0
        for s1, a, r, s2, done in exp:
            exp_graph.add_node(s1)
            exp_graph.add_node(s2)
            exp_graph.add_edges_from([(s1, s2, {'action': a})])
            count += 1
            
    return exp_graph

In [45]:
def get_obs_hash_images(env, model_path, seed):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=False, use_memory=False, use_text=False)
    print("Agent loaded\n")
    # Run the agent
    hash_seen = set()
    for _ in range(200):
        obs, _ = env.reset()
        count = 0
        while True:
            if env.hash() not in hash_seen:
                hash_seen.add(env.hash())
                frame = env.unwrapped.get_frame()
                plt.imshow(frame, interpolation='nearest')
                plt.savefig(f'./5x5_env_hash_images/{env.hash()}')
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1

            if done:
                break
    return len(hash_seen)

In [46]:
def build_MDP_dataset(episode_list, hash_state_mapping):
    episodes = []
    for epi in episode_list:
        obs_list = []
        act_list = []
        reward_list = []
        terminate_list = []
        for s1, a, r, s2, info in epi:
            s1_obs = hash_state_mapping[s1]
            obs_list.append(s1_obs)
            act_list.append(a)
            reward_list.append(r)
            if info:
                terminate_list.append(1.0)
            else:
                terminate_list.append(0.0)

        obs_list = np.array(obs_list)
        act_list = np.array(act_list).reshape(-1, 1)
        reward_list = np.array(reward_list).reshape(-1, 1)
        terminate_list = np.array(terminate_list)

        episode = d3rlpy.dataset.Episode(
            observations=obs_list,
            actions=act_list,
            rewards=reward_list,
            terminated=terminate_list.any(),
        )

        episodes.append(episode)

    dataset = d3rlpy.dataset.ReplayBuffer(
        d3rlpy.dataset.InfiniteBuffer(),
        episodes=episodes,
        action_space=d3rlpy.ActionSpace.DISCRETE,
    )
    return dataset

### Targetted Attack Functions

In [47]:
def get_path_to_state(graph, start_state, end_state):
    try:
        path = nx.shortest_path(graph, source=start_state, target=end_state)
        return path
    except nx.NetworkXNoPath:
        # print(f"No path found from {start_state} to {end_state}")
        return None
    
def get_actions_to_state(graph, path):
    edges_in_path = list(zip(path[:-1], path[1:]))
    edge = [graph[u][v]['action'] for u, v in edges_in_path]
    return edge

In [48]:
def poison_observation(obs):
    for x in range(2):
        obs[1][0][x] = 0
        obs[1][1][x] = 0
    return obs

In [49]:
def build_poisoned_episode(start_hash, goal_hash, graph, hash_state_mapping):
    path = get_path_to_state(graph, start_hash, goal_hash)
    if path == None or len(path) < 2:
        return None
    print(path)
    actions = get_actions_to_state(graph, path)
    obs_list = []
    act_list = []
    reward_list = []
    terminate_list = []
    for s, a in zip(path[:-1], actions):
        s1_obs = poison_observation(hash_state_mapping[s])
        obs_list.append(s1_obs)
        act_list.append(a)
        reward_list.append(0)
        terminate_list.append(0.0)

    reward_list[-1] = 0.95
    terminate_list[-1] = 1.0
    obs_list = np.array(obs_list)
    act_list = np.array(act_list).reshape(-1, 1)
    reward_list = np.array(reward_list).reshape(-1, 1)
    terminate_list = np.array(terminate_list)

    episode = d3rlpy.dataset.Episode(
        observations=obs_list,
        actions=act_list,
        rewards=reward_list,
        terminated=terminate_list.any(),
    )        
    return episode

### Evaluation Code
* % Percentage of Paths Found against Manhattan Distance
* Attack Success Rate against Manhattan Distance

In [50]:
def find_possible_paths(all_states, target_state, graph):
    count = 0
    for start_state in all_states:
        path = get_path_to_state(graph, start_state, target_state)
        if path:
            count += 1
    return count

def evaluate_model_clean(model, attacker_goal):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

def evaluate_model_poisoned(model, attacker_goal, budget):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        current_budget = budget
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            if not target_hit and current_budget > 0:
                obs = poison_observation(channelfirst_for_d3rlpy(obs['image']))
                current_budget -= 1
            else:
                obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

### Model Building

In [51]:
def get_CQL_model():
    pixel_encoder_factory = d3rlpy.models.PixelEncoderFactory(
        filters=[[3, 2, 1], [16, 2, 1], [32, 2, 1], [64, 2, 1]],
    )
    model = d3rlpy.algos.DiscreteCQLConfig(encoder_factory=pixel_encoder_factory).create(device='cuda:0')
    return model

### Main

In [52]:
ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
SEED = 1
MODEL_PATH = 'Empty6x6RandomPPO'
EPISODES = 100
POISONING_PERCENTAGE = 0.40

In [53]:
experience_list, hash_state_mapping = get_experience(ENVIRONMENT, MODEL_PATH, SEED, episodes=EPISODES)
graph = build_graph(experience_list)
clean_dataset = build_MDP_dataset(experience_list, hash_state_mapping)

# with open('/vol/bitbucket/phl23/Gridworld6x6RandomPPO_400Episode_dataset.pkl', 'wb') as f:
#     pickle.dump(clean_dataset,f)

Environment loaded

Agent loaded

collected experiences 0
collected experiences 50
[2m2024-08-26 19:14.41[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('uint8')], shape=[(3, 7, 7)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2024-08-26 19:14.41[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m3[0m


### Count percentage of paths found against Manhattan Distance

In [54]:
def count_number_of_paths_to_target_state(all_states, goal_state, graph):
    count = 1
    for state in all_states:
        try:
            path = get_path_to_state(graph, state, goal_state)
        except:
            continue
        if path:
            count += 1
    return count

In [55]:
manhattan_distance_6 = ['9fe3d6c4d1261a84', '56e89803caf9ef58', '1086da692ddcf726']
manhattan_distance_5 = ['6e12de8fb6d5ae0c', '190e48fed297889f', '6627b1722a1d672f']
manhattan_distance_4 = ['7d9305245f209ccf', 'ec46ee4ba6c4486a', '9e1376bdb18f9f65']
manhattan_distance_3 = ['25da6f47005d4101', '107bfca020b9fb6f', 'd9812a463fae10be']
manhattan_distance_2 = ['f0613f6993e0a30e', '64f2a8e70817959a', '33d5a3e5a4cd830b']
manhattan_distance_1 = ['1ba6886bab110d0d', '17d11eecfa6dda9a', '638ba12f32017a20']

manhanttan_dist = [manhattan_distance_1, 
                   manhattan_distance_2, 
                   manhattan_distance_3, 
                   manhattan_distance_4,
                   manhattan_distance_5,
                   manhattan_distance_6]

In [56]:
hash_keys = list(hash_state_mapping.keys())
print(len(hash_keys))
for i in range(len(manhanttan_dist)):
    print(f"For average of Manhattan Distance {i+1}")
    total_count = 0
    for goal in manhanttan_dist[i]:
        total_count += count_number_of_paths_to_target_state(hash_keys, goal, graph)
    print(f"average paths found: {total_count / len(manhanttan_dist[i])} out of {len(hash_keys)}")


61
For average of Manhattan Distance 1
average paths found: 50.666666666666664 out of 61
For average of Manhattan Distance 2
average paths found: 37.333333333333336 out of 61
For average of Manhattan Distance 3
average paths found: 24.333333333333332 out of 61
For average of Manhattan Distance 4
average paths found: 20.333333333333332 out of 61
For average of Manhattan Distance 5
average paths found: 7.666666666666667 out of 61
For average of Manhattan Distance 6
average paths found: 2.3333333333333335 out of 61


-----

### Poisoning Attack to measure ASR
choosing attacker state to be near goal state for to increase the variety of poisoned episodes

<img src="/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/6x6_env_hash_images/1ba6886bab110d0d.png" alt="chosen goal state" width="200"/>
<img src="/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/6x6_env_hash_images/b145243c6e3378f5.png" alt="chosen goal state" width="200"/>

In [57]:
goal_node = '1ba6886bab110d0d'
# goal_node = 'b145243c6e3378f5'
n_poisoned_epi = int(POISONING_PERCENTAGE * len(clean_dataset.episodes))
poisoned_episodes = []
while len(poisoned_episodes) < n_poisoned_epi:
    random_idx = random.sample(hash_keys, 1)[0]
    if random_idx == goal_node:
        continue
    episode = build_poisoned_episode(random_idx, goal_node, graph, hash_state_mapping)
    if episode:
        poisoned_episodes.append(episode)

['99795136e97debbb', '00a0d9462dfb456a', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['ec46ee4ba6c4486a', '1265d2b6592c95e6', 'bfb5808f1b2ed08b', '9fc5783b2928eb23', '1ba6886bab110d0d']
['2332436ef559e248', '4e5d2c44fa21c926', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['1a12f1e066326954', '6692c18231ad0423', 'f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['c8be06f2afdaaf42', '1ba6886bab110d0d']
['2332436ef559e248', '4e5d2c44fa21c926', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['7c1df098ce3b9041', '828e18d6514d52c2', 'f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['f713f31c774fe1a3', 'c8be06f2afdaaf42', '1ba6886bab110d0d']
['d8560f5f2421348b', 'e729faa201ea1d6b', '00a0d9462dfb456a', 'ea53467568475cdf', 'f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d']
['2404c28d

In [58]:
poisoned_dataset_replacement = copy.deepcopy(clean_dataset)
replacement_index = random.sample(list(range(len(poisoned_dataset_replacement.episodes))), len(poisoned_episodes))
print(replacement_index)
for i, poisoned_epi in zip(replacement_index, poisoned_episodes):
    poisoned_dataset_replacement.episodes[i] = poisoned_epi

[67, 28, 97, 56, 63, 70, 29, 44, 93, 86, 98, 58, 37, 2, 53, 71, 82, 12, 23, 80, 87, 15, 42, 64, 54, 76, 24, 38, 36, 95, 74, 50, 4, 61, 31, 51, 85, 77, 11, 81]


In [59]:
poisoned_dataset_addon = copy.deepcopy(clean_dataset)
for poisoned_epi in poisoned_episodes:
    poisoned_dataset_addon.append_episode(poisoned_epi)
print(poisoned_dataset_addon.size())

140


In [60]:
POISONED_CQL_REPLACEMENT_SAVE_NAME = f'./targeted_poisoned_model/CQL_Gridworld6x6_{EPISODES}Epi_{int(POISONING_PERCENTAGE*100)}_Replacement.d3'
POISONED_CQL_ADDON_SAVE_NAME = f'./targeted_poisoned_model/CQL_Gridworld6x6_{EPISODES}Epi_{int(POISONING_PERCENTAGE*100)}_Addon.d3'

poisoned_cql_model_replacement = get_CQL_model()
poisoned_cql_model_replacement.fit(
    poisoned_dataset_replacement,
    n_steps= 20000,
    n_steps_per_epoch=1000,
    save_interval=100,
)
poisoned_cql_model_replacement.save(POISONED_CQL_REPLACEMENT_SAVE_NAME)

poisoned_cql_model_addon = get_CQL_model()
poisoned_cql_model_addon.fit(
    poisoned_dataset_addon,
    n_steps= 20000,
    n_steps_per_epoch=1000,
    save_interval=100,
)
poisoned_cql_model_addon.save(POISONED_CQL_ADDON_SAVE_NAME)

[2m2024-08-26 19:14.41[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-08-26 19:14.41[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240826191441[0m
[2m2024-08-26 19:14.41[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-08-26 19:14.42[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-08-26 19:14.42[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': 

Epoch 1/20: 100%|██████████| 1000/1000 [00:14<00:00, 70.39it/s, loss=0.87, td_loss=0.0526, conservative_loss=0.818]

[2m2024-08-26 19:14.56[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010675060749053954, 'time_algorithm_update': 0.012703357458114625, 'loss': 0.8686017280817032, 'td_loss': 0.05268476479128003, 'conservative_loss': 0.815916963994503, 'time_step': 0.013986012935638427}[0m [36mstep[0m=[35m1000[0m



Epoch 2/20: 100%|██████████| 1000/1000 [00:14<00:00, 69.56it/s, loss=0.683, td_loss=0.0565, conservative_loss=0.627]

[2m2024-08-26 19:15.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001081648588180542, 'time_algorithm_update': 0.012906529188156128, 'loss': 0.6830321787595749, 'td_loss': 0.056580080088227985, 'conservative_loss': 0.6264520986676216, 'time_step': 0.014181365489959716}[0m [36mstep[0m=[35m2000[0m



Epoch 3/20: 100%|██████████| 1000/1000 [00:14<00:00, 69.50it/s, loss=0.624, td_loss=0.0607, conservative_loss=0.564]

[2m2024-08-26 19:15.25[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010982131958007812, 'time_algorithm_update': 0.012869028806686401, 'loss': 0.6243170359432697, 'td_loss': 0.06059045897983015, 'conservative_loss': 0.5637265764474869, 'time_step': 0.014182204723358155}[0m [36mstep[0m=[35m3000[0m



Epoch 4/20: 100%|██████████| 1000/1000 [00:14<00:00, 70.60it/s, loss=0.602, td_loss=0.0648, conservative_loss=0.537]

[2m2024-08-26 19:15.39[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010780754089355468, 'time_algorithm_update': 0.01268134880065918, 'loss': 0.6016309855282307, 'td_loss': 0.06473498218413443, 'conservative_loss': 0.5368960029184818, 'time_step': 0.013965576171875}[0m [36mstep[0m=[35m4000[0m



Epoch 5/20: 100%|██████████| 1000/1000 [00:14<00:00, 69.39it/s, loss=0.597, td_loss=0.0666, conservative_loss=0.53]

[2m2024-08-26 19:15.53[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001076859951019287, 'time_algorithm_update': 0.012941222190856933, 'loss': 0.5962073545455933, 'td_loss': 0.06651575112110004, 'conservative_loss': 0.529691603064537, 'time_step': 0.01421642255783081}[0m [36mstep[0m=[35m5000[0m



Epoch 6/20: 100%|██████████| 1000/1000 [00:17<00:00, 57.17it/s, loss=0.583, td_loss=0.0665, conservative_loss=0.517]

[2m2024-08-26 19:16.11[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012161116600036622, 'time_algorithm_update': 0.015635690689086913, 'loss': 0.5835832993090153, 'td_loss': 0.06674310294818134, 'conservative_loss': 0.5168401960134507, 'time_step': 0.017154988765716552}[0m [36mstep[0m=[35m6000[0m



Epoch 7/20: 100%|██████████| 1000/1000 [00:16<00:00, 62.11it/s, loss=0.575, td_loss=0.0649, conservative_loss=0.51]

[2m2024-08-26 19:16.27[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011810593605041503, 'time_algorithm_update': 0.014365495204925537, 'loss': 0.5746127655804157, 'td_loss': 0.06490519993053749, 'conservative_loss': 0.5097075661122799, 'time_step': 0.015793359756469727}[0m [36mstep[0m=[35m7000[0m



Epoch 8/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.23it/s, loss=0.578, td_loss=0.0648, conservative_loss=0.513]

[2m2024-08-26 19:16.43[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012199609279632568, 'time_algorithm_update': 0.014613154411315919, 'loss': 0.5778513169884681, 'td_loss': 0.06468722312152386, 'conservative_loss': 0.5131640945076943, 'time_step': 0.01606483292579651}[0m [36mstep[0m=[35m8000[0m



Epoch 9/20: 100%|██████████| 1000/1000 [00:17<00:00, 56.62it/s, loss=0.571, td_loss=0.0589, conservative_loss=0.512]

[2m2024-08-26 19:17.01[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0013220810890197754, 'time_algorithm_update': 0.015748880624771117, 'loss': 0.5706536224782467, 'td_loss': 0.0588307480243966, 'conservative_loss': 0.5118228743672371, 'time_step': 0.017342485904693603}[0m [36mstep[0m=[35m9000[0m



Epoch 10/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.55it/s, loss=0.558, td_loss=0.0557, conservative_loss=0.503]

[2m2024-08-26 19:17.17[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011674025058746337, 'time_algorithm_update': 0.01407329511642456, 'loss': 0.5590789729058743, 'td_loss': 0.05585657685017213, 'conservative_loss': 0.503222396671772, 'time_step': 0.015489377737045288}[0m [36mstep[0m=[35m10000[0m



Epoch 11/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.85it/s, loss=0.552, td_loss=0.0552, conservative_loss=0.497]

[2m2024-08-26 19:17.32[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011603281497955323, 'time_algorithm_update': 0.014032603740692139, 'loss': 0.5522372335493565, 'td_loss': 0.055130926578305664, 'conservative_loss': 0.49710630652308463, 'time_step': 0.015425703763961792}[0m [36mstep[0m=[35m11000[0m



Epoch 12/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.75it/s, loss=0.549, td_loss=0.055, conservative_loss=0.494]

[2m2024-08-26 19:17.48[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011506378650665283, 'time_algorithm_update': 0.014031420230865479, 'loss': 0.5494140080809593, 'td_loss': 0.05520549316867255, 'conservative_loss': 0.494208515137434, 'time_step': 0.015412052154541016}[0m [36mstep[0m=[35m12000[0m



Epoch 13/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.69it/s, loss=0.549, td_loss=0.0526, conservative_loss=0.497]


[2m2024-08-26 19:18.04[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011513922214508056, 'time_algorithm_update': 0.014059781312942506, 'loss': 0.5491687624454499, 'td_loss': 0.05264295088034123, 'conservative_loss': 0.4965258117318153, 'time_step': 0.015456410884857177}[0m [36mstep[0m=[35m13000[0m


Epoch 14/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.65it/s, loss=0.55, td_loss=0.0538, conservative_loss=0.496]

[2m2024-08-26 19:18.19[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011412806510925292, 'time_algorithm_update': 0.013868703603744507, 'loss': 0.5500488311052323, 'td_loss': 0.05380570540949702, 'conservative_loss': 0.4962431254088879, 'time_step': 0.015239293813705444}[0m [36mstep[0m=[35m14000[0m



Epoch 15/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.31it/s, loss=0.544, td_loss=0.0524, conservative_loss=0.492]

[2m2024-08-26 19:18.35[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011595401763916015, 'time_algorithm_update': 0.014138811349868774, 'loss': 0.5441528761684895, 'td_loss': 0.05236808079574257, 'conservative_loss': 0.4917847954630852, 'time_step': 0.015543777227401733}[0m [36mstep[0m=[35m15000[0m



Epoch 16/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.78it/s, loss=0.546, td_loss=0.0526, conservative_loss=0.494]

[2m2024-08-26 19:18.50[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011353302001953125, 'time_algorithm_update': 0.01382342505455017, 'loss': 0.5462409527301788, 'td_loss': 0.05262437932565808, 'conservative_loss': 0.49361657217144966, 'time_step': 0.015186309337615967}[0m [36mstep[0m=[35m16000[0m



Epoch 17/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.05it/s, loss=0.562, td_loss=0.0658, conservative_loss=0.496]

[2m2024-08-26 19:19.07[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012232306003570556, 'time_algorithm_update': 0.01461558198928833, 'loss': 0.5622880786061287, 'td_loss': 0.0656655646567233, 'conservative_loss': 0.49662251403927804, 'time_step': 0.016100531816482543}[0m [36mstep[0m=[35m17000[0m



Epoch 18/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.95it/s, loss=0.552, td_loss=0.0598, conservative_loss=0.492]

[2m2024-08-26 19:19.23[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001227590560913086, 'time_algorithm_update': 0.014632182836532592, 'loss': 0.5519661900699139, 'td_loss': 0.05986432243930176, 'conservative_loss': 0.4921018671095371, 'time_step': 0.016123398780822754}[0m [36mstep[0m=[35m18000[0m



Epoch 19/20: 100%|██████████| 1000/1000 [00:16<00:00, 62.43it/s, loss=0.554, td_loss=0.0602, conservative_loss=0.494]

[2m2024-08-26 19:19.39[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001201345682144165, 'time_algorithm_update': 0.014287390947341919, 'loss': 0.5539863267093896, 'td_loss': 0.060164125091861934, 'conservative_loss': 0.4938222017288208, 'time_step': 0.015738814353942872}[0m [36mstep[0m=[35m19000[0m



Epoch 20/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.93it/s, loss=0.561, td_loss=0.0581, conservative_loss=0.503]

[2m2024-08-26 19:19.56[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191441: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012500033378601075, 'time_algorithm_update': 0.014862176895141601, 'loss': 0.560907040566206, 'td_loss': 0.05803283436270431, 'conservative_loss': 0.5028742061257362, 'time_step': 0.01639033079147339}[0m [36mstep[0m=[35m20000[0m





[2m2024-08-26 19:19.56[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-08-26 19:19.56[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240826191956[0m
[2m2024-08-26 19:19.56[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-08-26 19:19.56[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-08-26 19:19.56[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': 

Epoch 1/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.07it/s, loss=0.805, td_loss=0.0704, conservative_loss=0.735]


[2m2024-08-26 19:20.13[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012545530796051025, 'time_algorithm_update': 0.014554241180419921, 'loss': 0.8038650161921977, 'td_loss': 0.07045642643608153, 'conservative_loss': 0.7334085907042026, 'time_step': 0.0160854549407959}[0m [36mstep[0m=[35m1000[0m


Epoch 2/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.09it/s, loss=0.619, td_loss=0.0773, conservative_loss=0.541]

[2m2024-08-26 19:20.29[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012388052940368653, 'time_algorithm_update': 0.01486025094985962, 'loss': 0.6180262823700905, 'td_loss': 0.07722076545562595, 'conservative_loss': 0.5408055162727833, 'time_step': 0.016357047080993652}[0m [36mstep[0m=[35m2000[0m



Epoch 3/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.74it/s, loss=0.571, td_loss=0.0731, conservative_loss=0.498]


[2m2024-08-26 19:20.46[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001233684539794922, 'time_algorithm_update': 0.01444351840019226, 'loss': 0.5713879554569721, 'td_loss': 0.07317790128383786, 'conservative_loss': 0.4982100552916527, 'time_step': 0.01592173409461975}[0m [36mstep[0m=[35m3000[0m


Epoch 4/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.75it/s, loss=0.525, td_loss=0.0697, conservative_loss=0.455]


[2m2024-08-26 19:21.02[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012559409141540527, 'time_algorithm_update': 0.014640880584716796, 'loss': 0.5249817261099815, 'td_loss': 0.06989699867926538, 'conservative_loss': 0.45508472771942615, 'time_step': 0.016165127992630004}[0m [36mstep[0m=[35m4000[0m


Epoch 5/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.15it/s, loss=0.503, td_loss=0.0665, conservative_loss=0.437]

[2m2024-08-26 19:21.19[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001229182243347168, 'time_algorithm_update': 0.014841138601303101, 'loss': 0.5040118826031685, 'td_loss': 0.06673452769359574, 'conservative_loss': 0.437277354195714, 'time_step': 0.01633183979988098}[0m [36mstep[0m=[35m5000[0m



Epoch 6/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.89it/s, loss=0.484, td_loss=0.0666, conservative_loss=0.418]

[2m2024-08-26 19:21.35[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011854155063629151, 'time_algorithm_update': 0.014675112962722779, 'loss': 0.48424620769917964, 'td_loss': 0.06661720891622827, 'conservative_loss': 0.41762899844348433, 'time_step': 0.016129675626754762}[0m [36mstep[0m=[35m6000[0m



Epoch 7/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.66it/s, loss=0.478, td_loss=0.063, conservative_loss=0.415]

[2m2024-08-26 19:21.52[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.001246189832687378, 'time_algorithm_update': 0.01465824580192566, 'loss': 0.4782845160067081, 'td_loss': 0.06322540069837124, 'conservative_loss': 0.41505911500751974, 'time_step': 0.016172533273696898}[0m [36mstep[0m=[35m7000[0m



Epoch 8/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.05it/s, loss=0.468, td_loss=0.0599, conservative_loss=0.408]

[2m2024-08-26 19:22.08[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012399873733520508, 'time_algorithm_update': 0.014868563652038574, 'loss': 0.4681190429031849, 'td_loss': 0.059984504048246894, 'conservative_loss': 0.4081345393359661, 'time_step': 0.0163699312210083}[0m [36mstep[0m=[35m8000[0m



Epoch 9/20: 100%|██████████| 1000/1000 [00:16<00:00, 59.84it/s, loss=0.46, td_loss=0.0623, conservative_loss=0.398]

[2m2024-08-26 19:22.25[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012478797435760498, 'time_algorithm_update': 0.01492351531982422, 'loss': 0.4601830876916647, 'td_loss': 0.062148721533827486, 'conservative_loss': 0.3980343655049801, 'time_step': 0.016429807901382446}[0m [36mstep[0m=[35m9000[0m



Epoch 10/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.92it/s, loss=0.445, td_loss=0.0548, conservative_loss=0.391]

[2m2024-08-26 19:22.41[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012097728252410888, 'time_algorithm_update': 0.014378120660781861, 'loss': 0.44603657104074956, 'td_loss': 0.054950450092554094, 'conservative_loss': 0.3910861212909222, 'time_step': 0.015871692657470703}[0m [36mstep[0m=[35m10000[0m



Epoch 11/20: 100%|██████████| 1000/1000 [00:17<00:00, 56.55it/s, loss=0.444, td_loss=0.0531, conservative_loss=0.39]

[2m2024-08-26 19:22.59[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0013441967964172364, 'time_algorithm_update': 0.015744187116622926, 'loss': 0.44338499395549297, 'td_loss': 0.05301050782203674, 'conservative_loss': 0.3903744860738516, 'time_step': 0.01737190818786621}[0m [36mstep[0m=[35m11000[0m



Epoch 12/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.08it/s, loss=0.437, td_loss=0.0527, conservative_loss=0.384]

[2m2024-08-26 19:23.16[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012187588214874268, 'time_algorithm_update': 0.014864296436309814, 'loss': 0.4375162375718355, 'td_loss': 0.05273609907925129, 'conservative_loss': 0.3847801385372877, 'time_step': 0.0163476140499115}[0m [36mstep[0m=[35m12000[0m



Epoch 13/20: 100%|██████████| 1000/1000 [00:16<00:00, 60.01it/s, loss=0.444, td_loss=0.0512, conservative_loss=0.393]

[2m2024-08-26 19:23.32[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012624619007110597, 'time_algorithm_update': 0.014848148345947266, 'loss': 0.44391725918650626, 'td_loss': 0.05126811220590025, 'conservative_loss': 0.39264914648234844, 'time_step': 0.01636384344100952}[0m [36mstep[0m=[35m13000[0m



Epoch 14/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.00it/s, loss=0.442, td_loss=0.051, conservative_loss=0.391]

[2m2024-08-26 19:23.49[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0012357244491577148, 'time_algorithm_update': 0.014621607542037963, 'loss': 0.4417408240884542, 'td_loss': 0.0508315738607198, 'conservative_loss': 0.39090925036370755, 'time_step': 0.016113527059555054}[0m [36mstep[0m=[35m14000[0m



Epoch 15/20: 100%|██████████| 1000/1000 [00:15<00:00, 63.95it/s, loss=0.433, td_loss=0.0488, conservative_loss=0.384]

[2m2024-08-26 19:24.04[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011710238456726074, 'time_algorithm_update': 0.013921053409576415, 'loss': 0.43268302372097966, 'td_loss': 0.04884918727306649, 'conservative_loss': 0.38383383706212043, 'time_step': 0.015352111577987671}[0m [36mstep[0m=[35m15000[0m



Epoch 16/20: 100%|██████████| 1000/1000 [00:16<00:00, 61.28it/s, loss=0.428, td_loss=0.049, conservative_loss=0.379]

[2m2024-08-26 19:24.21[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011803879737854004, 'time_algorithm_update': 0.014579161643981934, 'loss': 0.42744465808570387, 'td_loss': 0.0488679206778761, 'conservative_loss': 0.3785767372250557, 'time_step': 0.016017232179641724}[0m [36mstep[0m=[35m16000[0m



Epoch 17/20: 100%|██████████| 1000/1000 [00:15<00:00, 64.04it/s, loss=0.441, td_loss=0.0561, conservative_loss=0.385]


[2m2024-08-26 19:24.36[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0009972383975982665, 'time_algorithm_update': 0.014093022346496582, 'loss': 0.4403326097875834, 'td_loss': 0.055948582499753687, 'conservative_loss': 0.3843840275108814, 'time_step': 0.015361742496490478}[0m [36mstep[0m=[35m17000[0m


Epoch 18/20: 100%|██████████| 1000/1000 [00:14<00:00, 69.16it/s, loss=0.431, td_loss=0.0516, conservative_loss=0.379]

[2m2024-08-26 19:24.51[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0011017773151397706, 'time_algorithm_update': 0.012936837434768677, 'loss': 0.43027297784388063, 'td_loss': 0.051383786513470114, 'conservative_loss': 0.37888919115066527, 'time_step': 0.014244441509246825}[0m [36mstep[0m=[35m18000[0m



Epoch 19/20: 100%|██████████| 1000/1000 [00:13<00:00, 76.21it/s, loss=0.423, td_loss=0.05, conservative_loss=0.373] 

[2m2024-08-26 19:25.04[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0010240931510925293, 'time_algorithm_update': 0.011645332574844361, 'loss': 0.4226178061515093, 'td_loss': 0.04997358341468498, 'conservative_loss': 0.3726442224830389, 'time_step': 0.012900943040847778}[0m [36mstep[0m=[35m19000[0m



Epoch 20/20: 100%|██████████| 1000/1000 [00:12<00:00, 79.50it/s, loss=0.425, td_loss=0.0487, conservative_loss=0.377]


[2m2024-08-26 19:25.17[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240826191956: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000988438606262207, 'time_algorithm_update': 0.011119872331619263, 'loss': 0.4256062023639679, 'td_loss': 0.04877444867789745, 'conservative_loss': 0.37683175368607047, 'time_step': 0.012340849876403808}[0m [36mstep[0m=[35m20000[0m
