In [1]:
import numpy as np
import matplotlib.pyplot as plt
import hashlib
import networkx as nx
import random
import math
import d3rlpy

import utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def channelfirst_for_d3rlpy(arr):
    return np.transpose(arr, (2, 0, 1))

def get_hash(s):
    flattened_obs = s.flatten()
    flattened_obs_bytes = flattened_obs.tobytes()   
    obs_hash = hashlib.sha256(flattened_obs_bytes).hexdigest()
    return obs_hash

In [3]:
def get_experience(env, model_path, seed, episodes=10, argmax=True, memory=False, text=False):
    utils.seed(seed)
    # Load environment
    env = utils.make_env(env, seed, render_mode="human")
    print("Environment loaded\n")

    # Load agent
    env.action_space.n = 3
    model_dir = utils.get_model_dir(model_path)
    agent = utils.Agent(env.observation_space, env.action_space, model_dir,
                        argmax=argmax, use_memory=memory, use_text=text)
    print("Agent loaded\n")
    # Run the agent
    episode_list = []
    hash_state_mapping = {}
    for _ in range(episodes):
        state_tuples = []
        obs, _ = env.reset()
        count = 0
        while True:
            current_tuple = []
            current_tuple.append(env.hash())
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            action = agent.get_action(obs)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated | truncated
            agent.analyze_feedback(reward, done)
            count += 1
            current_tuple.extend([action, reward, env.hash(), done])
            if env.hash() not in hash_state_mapping.keys():
                hash_state_mapping[env.hash()] = channelfirst_for_d3rlpy(obs['image'])
            state_tuples.append(current_tuple)

            if done:
                break
        episode_list.append(state_tuples)
    return episode_list, hash_state_mapping

In [4]:
def build_graph(dataset):    
    exp_graph = nx.DiGraph()
    for exp in dataset:
        count = 0
        for s1, a, r, s2, done in exp:
            exp_graph.add_node(s1)
            exp_graph.add_node(s2)
            exp_graph.add_edges_from([(s1, s2, {'action': a})])
            count += 1
            
    return exp_graph

In [5]:
ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
SEED = 1
MODEL_PATH = 'Empty6x6RandomPPO'
BUDGET = 3
PERCENTAGE = 5
POISONED_MODEL_PATH = './poisoned_minigrid_cql.d3'

In [6]:
experience_list, hash_state_mapping = get_experience(ENVIRONMENT, MODEL_PATH, SEED, episodes=1000)

Environment loaded

Agent loaded



  logger.warn(


In [7]:
graph = build_graph(experience_list)

In [8]:
def get_path_to_state(graph, start_state, end_state):
    try:
        path = nx.shortest_path(graph, source=start_state, target=end_state)
        return path
    except nx.NetworkXNoPath:
        print(f"No path found from {start_state} to {end_state}")
        return None
    
def get_actions_to_state(graph, path):
    edges_in_path = list(zip(path[:-1], path[1:]))
    edge = [graph[u][v]['action'] for u, v in edges_in_path]
    return edge

In [9]:
start_node = '9fe3d6c4d1261a84'
goal_node = 'b145243c6e3378f5'
path = get_path_to_state(graph, start_node, goal_node)
if path:
    print(get_actions_to_state(graph, path))

[array([0]), array([2]), array([2]), array([2]), array([0]), array([2]), array([2]), array([0]), array([0])]


In [10]:
def build_MDP_dataset(episode_list, hash_state_mapping):
    episodes = []
    for epi in episode_list:
        obs_list = []
        act_list = []
        reward_list = []
        terminate_list = []
        for s1, a, r, s2, info in epi:
            s1_obs = hash_state_mapping[s1]
            obs_list.append(s1_obs)
            act_list.append(a)
            reward_list.append(r)
            if info:
                terminate_list.append(1.0)
            else:
                terminate_list.append(0.0)

        obs_list = np.array(obs_list)
        act_list = np.array(act_list)
        reward_list = np.array(reward_list).reshape(-1, 1)
        terminate_list = np.array(terminate_list)

        episode = d3rlpy.dataset.Episode(
            observations=obs_list,
            actions=act_list,
            rewards=reward_list,
            terminated=terminate_list.any(),
        )

        episodes.append(episode)

    dataset = d3rlpy.dataset.ReplayBuffer(
        d3rlpy.dataset.InfiniteBuffer(),
        episodes=episodes,
        action_space=d3rlpy.ActionSpace.DISCRETE,
    )
    return dataset

In [11]:
clean_dataset = build_MDP_dataset(experience_list, hash_state_mapping)

[2m2024-07-08 13:19.28[0m [[32m[1minfo     [0m] [1mSignatures have been automatically determined.[0m [36maction_signature[0m=[35mSignature(dtype=[dtype('int64')], shape=[(1,)])[0m [36mobservation_signature[0m=[35mSignature(dtype=[dtype('uint8')], shape=[(3, 7, 7)])[0m [36mreward_signature[0m=[35mSignature(dtype=[dtype('float64')], shape=[(1,)])[0m
[2m2024-07-08 13:19.28[0m [[32m[1minfo     [0m] [1mAction size has been automatically determined.[0m [36maction_size[0m=[35m3[0m


In [12]:
import random
hash_keys = list(hash_state_mapping.keys())
print(len(hash_keys))
random_idx = random.sample(hash_keys, 30)
save = []
for i in random_idx:
    start_node = i
    goal_node = 'b145243c6e3378f5'
    path = get_path_to_state(graph, start_node, goal_node)
    if path:
        print(get_actions_to_state(graph, path))




62
[array([2]), array([2]), array([0]), array([0])]
[array([1]), array([2]), array([1]), array([2]), array([0]), array([2]), array([1])]
[array([2]), array([0]), array([2]), array([1])]
[array([0]), array([0]), array([2]), array([2]), array([2]), array([1])]
[array([1]), array([2]), array([1]), array([2]), array([2]), array([2]), array([1])]
[array([2]), array([2]), array([1]), array([2]), array([1])]
[array([1]), array([2]), array([2]), array([1]), array([2]), array([0]), array([2]), array([1])]
[array([1]), array([1]), array([2]), array([1]), array([2]), array([1])]
[array([0]), array([2]), array([0]), array([2]), array([2]), array([1])]
[array([0]), array([2]), array([1])]
[array([0]), array([2]), array([2]), array([0]), array([2]), array([2]), array([0]), array([0])]
[array([0]), array([0]), array([2]), array([2]), array([1]), array([2]), array([0]), array([2]), array([1])]
No path found from 141f54d6e37ef2a8 to b145243c6e3378f5
[array([2]), array([1]), array([2]), array([1])]
[]
[

In [13]:
print(random_idx)

['e44112fc5dd98da9', 'f0613f6993e0a30e', 'c5a46bf308a3ff92', '696334c442253f6f', '190e48fed297889f', '8e7f7ce5578f9df0', '1265d2b6592c95e6', '30751990dcd82e4f', '1fc221ae7c965c16', '7c1df098ce3b9041', 'd6a4e14cb3418531', 'e9407b4add60085d', '141f54d6e37ef2a8', '6692c18231ad0423', 'b145243c6e3378f5', '2404c28dbd3255c7', '9fe3d6c4d1261a84', '828e18d6514d52c2', '4cbec56f09452763', '107bfca020b9fb6f', 'ec46ee4ba6c4486a', '1086da692ddcf726', 'a6e0398cdf2a67de', '33d5a3e5a4cd830b', '2014d7744614f6a0', 'b6d6d1a8524511da', 'e729faa201ea1d6b', '17d11eecfa6dda9a', '00a0d9462dfb456a', 'bfb5808f1b2ed08b']


In [14]:
def poison_observation(obs):
    for x in range(2):
        obs[1][0][x] = 0
        obs[1][1][x] = 0
    return obs

In [15]:
import copy
x = clean_dataset.episodes[0].observations[0]
print(x)
y = copy.deepcopy(x)
y = poison_observation(y)
print(y)

[[[2 2 2 2 2 2 2]
  [2 2 2 2 2 2 2]
  [2 2 2 2 2 2 1]
  [2 2 2 2 2 2 1]
  [2 2 2 2 2 2 1]
  [2 2 2 2 2 2 1]
  [2 2 2 2 2 2 2]]

 [[5 5 5 5 5 5 5]
  [5 5 5 5 5 5 5]
  [5 5 5 5 5 5 0]
  [5 5 5 5 5 5 0]
  [5 5 5 5 5 5 0]
  [5 5 5 5 5 5 0]
  [5 5 5 5 5 5 5]]

 [[0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]]]
[[[2 2 2 2 2 2 2]
  [2 2 2 2 2 2 2]
  [2 2 2 2 2 2 1]
  [2 2 2 2 2 2 1]
  [2 2 2 2 2 2 1]
  [2 2 2 2 2 2 1]
  [2 2 2 2 2 2 2]]

 [[0 0 5 5 5 5 5]
  [0 0 5 5 5 5 5]
  [5 5 5 5 5 5 0]
  [5 5 5 5 5 5 0]
  [5 5 5 5 5 5 0]
  [5 5 5 5 5 5 0]
  [5 5 5 5 5 5 5]]

 [[0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0]]]


In [22]:
def build_episode(start_hash, goal_hash, graph, hash_state_mapping):
    path = get_path_to_state(graph, start_hash, goal_hash)
    if path == None or len(path) < 2:
        return None
    print(path)
    actions = get_actions_to_state(graph, path)
    obs_list = []
    act_list = []
    reward_list = []
    terminate_list = []
    for s, a in zip(path[:-1], actions):
        s1_obs = poison_observation(hash_state_mapping[s])
        obs_list.append(s1_obs)
        act_list.append(a)
        reward_list.append(0)
        terminate_list.append(0.0)

    reward_list[-1] = 0.95
    terminate_list[-1] = 1.0
    obs_list = np.array(obs_list)
    act_list = np.array(act_list)
    reward_list = np.array(reward_list).reshape(-1, 1)
    terminate_list = np.array(terminate_list)

    episode = d3rlpy.dataset.Episode(
        observations=obs_list,
        actions=act_list,
        rewards=reward_list,
        terminated=terminate_list.any(),
    )        
    return episode

In [23]:
goal_node = 'b145243c6e3378f5'
poisoned_episodes = []
for i in random_idx:
    episode = build_episode(i, goal_node, graph, hash_state_mapping)
    if episode:
        poisoned_episodes.append(episode)

['e44112fc5dd98da9', '64f2a8e70817959a', '9f493d3c0f5f5b7a', '4cbec56f09452763', 'b145243c6e3378f5']
['f0613f6993e0a30e', '9fc5783b2928eb23', '1ba6886bab110d0d', 'c5a46bf308a3ff92', '7c1df098ce3b9041', 'd494fe706c1d490e', '17d11eecfa6dda9a', 'b145243c6e3378f5']
['c5a46bf308a3ff92', '7c1df098ce3b9041', 'd494fe706c1d490e', '17d11eecfa6dda9a', 'b145243c6e3378f5']
['696334c442253f6f', '80f11a648382c29e', '2332436ef559e248', '4e5d2c44fa21c926', 'd494fe706c1d490e', '17d11eecfa6dda9a', 'b145243c6e3378f5']
['190e48fed297889f', '6e12de8fb6d5ae0c', 'ec46ee4ba6c4486a', '2332436ef559e248', '4e5d2c44fa21c926', 'd494fe706c1d490e', '17d11eecfa6dda9a', 'b145243c6e3378f5']
['8e7f7ce5578f9df0', '6692c18231ad0423', 'f713f31c774fe1a3', 'd494fe706c1d490e', '17d11eecfa6dda9a', 'b145243c6e3378f5']
['1265d2b6592c95e6', 'bfb5808f1b2ed08b', '9fc5783b2928eb23', '1ba6886bab110d0d', 'c5a46bf308a3ff92', '7c1df098ce3b9041', 'd494fe706c1d490e', '17d11eecfa6dda9a', 'b145243c6e3378f5']
['30751990dcd82e4f', 'd9812a463fa

In [24]:
import copy
poisoned_dataset = copy.deepcopy(clean_dataset)
for poisoned_epi in poisoned_episodes:
    poisoned_dataset.append_episode(poisoned_epi)
print(clean_dataset.size())
print(poisoned_dataset.size())

1000
1028


In [25]:
def get_offline_rl_model():
    pixel_encoder_factory = d3rlpy.models.PixelEncoderFactory(
        filters=[[3, 2, 1], [16, 2, 1], [32, 2, 1], [64, 2, 1]],
    )
    model = d3rlpy.algos.DiscreteCQLConfig(encoder_factory=pixel_encoder_factory).create(device='cuda:0')
    return model

In [26]:
clean_model = get_offline_rl_model()
clean_model.fit(
    clean_dataset,
    n_steps= 30000,
    n_steps_per_epoch=1000,
    save_interval=100,
)

poisoned_model = get_offline_rl_model()
poisoned_model.fit(
    poisoned_dataset,
    n_steps= 30000,
    n_steps_per_epoch=1000,
    save_interval=100,
)

[2m2024-07-08 13:23.08[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-07-08 13:23.09[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240708132308[0m
[2m2024-07-08 13:23.09[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-07-08 13:23.11[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-07-08 13:23.11[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': 

Epoch 1/30: 100%|██████████| 1000/1000 [00:04<00:00, 223.32it/s, loss=0.682, td_loss=0.0642, conservative_loss=0.618]

[2m2024-07-08 13:23.15[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003518974781036377, 'time_algorithm_update': 0.0040240418910980225, 'loss': 0.6801144524663687, 'td_loss': 0.06414604871999473, 'conservative_loss': 0.6159684043675661, 'time_step': 0.004452046394348145}[0m [36mstep[0m=[35m1000[0m



Epoch 2/30: 100%|██████████| 1000/1000 [00:04<00:00, 230.93it/s, loss=0.413, td_loss=0.0666, conservative_loss=0.346]


[2m2024-07-08 13:23.19[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003641040325164795, 'time_algorithm_update': 0.0038609142303466797, 'loss': 0.4122504532933235, 'td_loss': 0.06660335868690163, 'conservative_loss': 0.34564709524810316, 'time_step': 0.004304516315460205}[0m [36mstep[0m=[35m2000[0m


Epoch 3/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.65it/s, loss=0.344, td_loss=0.0718, conservative_loss=0.272]

[2m2024-07-08 13:23.23[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031247520446777343, 'time_algorithm_update': 0.003629621505737305, 'loss': 0.34474483225494623, 'td_loss': 0.07236688230023719, 'conservative_loss': 0.2723779495209456, 'time_step': 0.004001184940338135}[0m [36mstep[0m=[35m3000[0m



Epoch 4/30: 100%|██████████| 1000/1000 [00:04<00:00, 242.69it/s, loss=0.305, td_loss=0.0729, conservative_loss=0.233]


[2m2024-07-08 13:23.28[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003226501941680908, 'time_algorithm_update': 0.0037168531417846678, 'loss': 0.30580547092854976, 'td_loss': 0.07316038813162595, 'conservative_loss': 0.2326450827419758, 'time_step': 0.004100063800811768}[0m [36mstep[0m=[35m4000[0m


Epoch 5/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.40it/s, loss=0.27, td_loss=0.0713, conservative_loss=0.198]


[2m2024-07-08 13:23.32[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003136327266693115, 'time_algorithm_update': 0.003699918508529663, 'loss': 0.27034429997205733, 'td_loss': 0.07161831694323337, 'conservative_loss': 0.19872598280757667, 'time_step': 0.004071595430374145}[0m [36mstep[0m=[35m5000[0m


Epoch 6/30: 100%|██████████| 1000/1000 [00:04<00:00, 243.24it/s, loss=0.257, td_loss=0.0712, conservative_loss=0.186]


[2m2024-07-08 13:23.36[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031751155853271485, 'time_algorithm_update': 0.003712640285491943, 'loss': 0.25810885037854314, 'td_loss': 0.07147391893039458, 'conservative_loss': 0.1866349312365055, 'time_step': 0.004090293169021607}[0m [36mstep[0m=[35m6000[0m


Epoch 7/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.61it/s, loss=0.251, td_loss=0.0716, conservative_loss=0.179]

[2m2024-07-08 13:23.40[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032604312896728513, 'time_algorithm_update': 0.0036643900871276857, 'loss': 0.25097718194127083, 'td_loss': 0.07168414941825904, 'conservative_loss': 0.17929303213953973, 'time_step': 0.004050868034362793}[0m [36mstep[0m=[35m7000[0m



Epoch 8/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.99it/s, loss=0.254, td_loss=0.0739, conservative_loss=0.181]


[2m2024-07-08 13:23.44[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031084537506103515, 'time_algorithm_update': 0.0036271781921386717, 'loss': 0.2542249522097409, 'td_loss': 0.07380458645045292, 'conservative_loss': 0.18042036568000913, 'time_step': 0.0039964673519134524}[0m [36mstep[0m=[35m8000[0m


Epoch 9/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.02it/s, loss=0.242, td_loss=0.0695, conservative_loss=0.172]

[2m2024-07-08 13:23.48[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003207178115844727, 'time_algorithm_update': 0.0036813721656799316, 'loss': 0.2415300415791571, 'td_loss': 0.0693995834516827, 'conservative_loss': 0.17213045885413886, 'time_step': 0.004061602592468262}[0m [36mstep[0m=[35m9000[0m



Epoch 10/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.48it/s, loss=0.236, td_loss=0.0671, conservative_loss=0.169]


[2m2024-07-08 13:23.52[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003185515403747559, 'time_algorithm_update': 0.0036270031929016115, 'loss': 0.23583375152200461, 'td_loss': 0.06708026343374512, 'conservative_loss': 0.168753488086164, 'time_step': 0.004004442930221557}[0m [36mstep[0m=[35m10000[0m


Epoch 11/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.61it/s, loss=0.242, td_loss=0.0692, conservative_loss=0.172]


[2m2024-07-08 13:23.56[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032183265686035157, 'time_algorithm_update': 0.0036366138458251954, 'loss': 0.2412777252011001, 'td_loss': 0.06902614021336194, 'conservative_loss': 0.17225158478692174, 'time_step': 0.00401868200302124}[0m [36mstep[0m=[35m11000[0m


Epoch 12/30: 100%|██████████| 1000/1000 [00:03<00:00, 250.19it/s, loss=0.24, td_loss=0.0682, conservative_loss=0.172]

[2m2024-07-08 13:24.00[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032146286964416506, 'time_algorithm_update': 0.0035954747200012205, 'loss': 0.2406067799963057, 'td_loss': 0.06845909086009487, 'conservative_loss': 0.17214768924564122, 'time_step': 0.0039762544631958005}[0m [36mstep[0m=[35m12000[0m



Epoch 13/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.02it/s, loss=0.246, td_loss=0.0689, conservative_loss=0.177]

[2m2024-07-08 13:24.04[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003216702938079834, 'time_algorithm_update': 0.003647181034088135, 'loss': 0.24524007226899266, 'td_loss': 0.0685305973921204, 'conservative_loss': 0.1767094743140042, 'time_step': 0.004028127670288086}[0m [36mstep[0m=[35m13000[0m



Epoch 14/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.98it/s, loss=0.236, td_loss=0.0667, conservative_loss=0.169]

[2m2024-07-08 13:24.08[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032194066047668457, 'time_algorithm_update': 0.0036288797855377197, 'loss': 0.23534035663306713, 'td_loss': 0.06654660194931784, 'conservative_loss': 0.16879375471919775, 'time_step': 0.004011394262313843}[0m [36mstep[0m=[35m14000[0m



Epoch 15/30: 100%|██████████| 1000/1000 [00:03<00:00, 250.32it/s, loss=0.246, td_loss=0.0706, conservative_loss=0.175]

[2m2024-07-08 13:24.12[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000313220739364624, 'time_algorithm_update': 0.0036047487258911133, 'loss': 0.246085392113775, 'td_loss': 0.07064458732458297, 'conservative_loss': 0.17544080501794815, 'time_step': 0.003975818634033203}[0m [36mstep[0m=[35m15000[0m



Epoch 16/30: 100%|██████████| 1000/1000 [00:03<00:00, 255.79it/s, loss=0.242, td_loss=0.0674, conservative_loss=0.174]

[2m2024-07-08 13:24.16[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00029262948036193847, 'time_algorithm_update': 0.0035453181266784668, 'loss': 0.2413393364697695, 'td_loss': 0.06720863987173652, 'conservative_loss': 0.17413069681078194, 'time_step': 0.003891587257385254}[0m [36mstep[0m=[35m16000[0m



Epoch 17/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.42it/s, loss=0.25, td_loss=0.0741, conservative_loss=0.176]


[2m2024-07-08 13:24.20[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032660865783691405, 'time_algorithm_update': 0.003665738821029663, 'loss': 0.24955056773871184, 'td_loss': 0.07374997365509625, 'conservative_loss': 0.17580059387907385, 'time_step': 0.004053302526473999}[0m [36mstep[0m=[35m17000[0m


Epoch 18/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.44it/s, loss=0.243, td_loss=0.0707, conservative_loss=0.172]

[2m2024-07-08 13:24.24[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031959676742553713, 'time_algorithm_update': 0.0036252245903015136, 'loss': 0.2440259097740054, 'td_loss': 0.07122340792062459, 'conservative_loss': 0.17280250195786356, 'time_step': 0.0040050606727600096}[0m [36mstep[0m=[35m18000[0m



Epoch 19/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.47it/s, loss=0.234, td_loss=0.0672, conservative_loss=0.167]


[2m2024-07-08 13:24.28[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003152923583984375, 'time_algorithm_update': 0.0036480255126953125, 'loss': 0.2346525001861155, 'td_loss': 0.06724842599843396, 'conservative_loss': 0.16740407448634506, 'time_step': 0.004021127223968506}[0m [36mstep[0m=[35m19000[0m


Epoch 20/30: 100%|██████████| 1000/1000 [00:03<00:00, 250.11it/s, loss=0.237, td_loss=0.0685, conservative_loss=0.169]


[2m2024-07-08 13:24.32[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003106198310852051, 'time_algorithm_update': 0.003609499216079712, 'loss': 0.23670445479080082, 'td_loss': 0.06816551350048394, 'conservative_loss': 0.16853894125670194, 'time_step': 0.003979259014129639}[0m [36mstep[0m=[35m20000[0m


Epoch 21/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.25it/s, loss=0.235, td_loss=0.0679, conservative_loss=0.167]


[2m2024-07-08 13:24.36[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=21 step=21000[0m [36mepoch[0m=[35m21[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032728838920593263, 'time_algorithm_update': 0.003618034601211548, 'loss': 0.2353774511963129, 'td_loss': 0.06792953962407773, 'conservative_loss': 0.1674479113481939, 'time_step': 0.004007391691207886}[0m [36mstep[0m=[35m21000[0m


Epoch 22/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.20it/s, loss=0.233, td_loss=0.0669, conservative_loss=0.166]


[2m2024-07-08 13:24.40[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=22 step=22000[0m [36mepoch[0m=[35m22[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032919597625732423, 'time_algorithm_update': 0.0036347715854644774, 'loss': 0.23283470617979765, 'td_loss': 0.06677307390142233, 'conservative_loss': 0.1660616320334375, 'time_step': 0.004024834156036377}[0m [36mstep[0m=[35m22000[0m


Epoch 23/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.03it/s, loss=0.232, td_loss=0.0666, conservative_loss=0.166]


[2m2024-07-08 13:24.44[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=23 step=23000[0m [36mepoch[0m=[35m23[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003239338397979736, 'time_algorithm_update': 0.003627566337585449, 'loss': 0.23312733126431703, 'td_loss': 0.0670146331138676, 'conservative_loss': 0.16611269805952908, 'time_step': 0.004011116027832031}[0m [36mstep[0m=[35m23000[0m


Epoch 24/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.37it/s, loss=0.234, td_loss=0.0673, conservative_loss=0.167]

[2m2024-07-08 13:24.48[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=24 step=24000[0m [36mepoch[0m=[35m24[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032205772399902344, 'time_algorithm_update': 0.0036382718086242676, 'loss': 0.23365618504956365, 'td_loss': 0.06707672074093717, 'conservative_loss': 0.16657946453243494, 'time_step': 0.004022260189056396}[0m [36mstep[0m=[35m24000[0m



Epoch 25/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.89it/s, loss=0.257, td_loss=0.0843, conservative_loss=0.173]

[2m2024-07-08 13:24.52[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=25 step=25000[0m [36mepoch[0m=[35m25[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003243093490600586, 'time_algorithm_update': 0.0036282920837402343, 'loss': 0.25662390733510254, 'td_loss': 0.0840345373856835, 'conservative_loss': 0.172589369982481, 'time_step': 0.004013285875320434}[0m [36mstep[0m=[35m25000[0m



Epoch 26/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.88it/s, loss=0.245, td_loss=0.0774, conservative_loss=0.167]


[2m2024-07-08 13:24.57[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=26 step=26000[0m [36mepoch[0m=[35m26[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031570887565612795, 'time_algorithm_update': 0.0036889057159423828, 'loss': 0.24490839768201111, 'td_loss': 0.07741470659757033, 'conservative_loss': 0.16749369144812226, 'time_step': 0.004063492774963379}[0m [36mstep[0m=[35m26000[0m


Epoch 27/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.86it/s, loss=0.245, td_loss=0.0779, conservative_loss=0.167]

[2m2024-07-08 13:25.01[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=27 step=27000[0m [36mepoch[0m=[35m27[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003209090232849121, 'time_algorithm_update': 0.0036651911735534667, 'loss': 0.2448789267465472, 'td_loss': 0.07772808030829764, 'conservative_loss': 0.16715084612742068, 'time_step': 0.004045609712600708}[0m [36mstep[0m=[35m27000[0m



Epoch 28/30: 100%|██████████| 1000/1000 [00:04<00:00, 242.57it/s, loss=0.248, td_loss=0.0779, conservative_loss=0.17]


[2m2024-07-08 13:25.05[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=28 step=28000[0m [36mepoch[0m=[35m28[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031667256355285645, 'time_algorithm_update': 0.0037291209697723387, 'loss': 0.24705499644577503, 'td_loss': 0.07772322389885085, 'conservative_loss': 0.16933177250996231, 'time_step': 0.0041028892993927}[0m [36mstep[0m=[35m28000[0m


Epoch 29/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.30it/s, loss=0.246, td_loss=0.0776, conservative_loss=0.168]


[2m2024-07-08 13:25.09[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=29 step=29000[0m [36mepoch[0m=[35m29[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003202328681945801, 'time_algorithm_update': 0.003693662643432617, 'loss': 0.24599668489024043, 'td_loss': 0.07762916516442783, 'conservative_loss': 0.1683675196841359, 'time_step': 0.004072798013687134}[0m [36mstep[0m=[35m29000[0m


Epoch 30/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.53it/s, loss=0.251, td_loss=0.0787, conservative_loss=0.172]

[2m2024-07-08 13:25.13[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132308: epoch=30 step=30000[0m [36mepoch[0m=[35m30[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003202991485595703, 'time_algorithm_update': 0.003624324321746826, 'loss': 0.24987871479988097, 'td_loss': 0.07831758999906015, 'conservative_loss': 0.17156112445518373, 'time_step': 0.004003217458724975}[0m [36mstep[0m=[35m30000[0m
[2m2024-07-08 13:25.13[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-07-08 13:25.13[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240708132513[0m





[2m2024-07-08 13:25.13[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-07-08 13:25.13[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-07-08 13:25.13[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': {}}, 'action_scaler': {'type': 'none', 'params': {}}, 'reward_scaler': {'type': 'none', 'params': {}}, 'learning_rate': 6.25e-05, 'optim_factory': {'type': 'adam', 'params': {'betas': [0.9, 0.999], 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}}, 'encoder_factory': {'type': 'pixel', 'params': {'filters': [[3, 2, 1], [16, 2, 1], [32, 2, 1], [64, 2, 1]], 'feature_size': 512, 'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None, 'exclude_last_activation': False, 'last_activation': None}}, 'q_func_factor

Epoch 1/30: 100%|██████████| 1000/1000 [00:03<00:00, 250.81it/s, loss=0.651, td_loss=0.0642, conservative_loss=0.587]


[2m2024-07-08 13:25.17[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031481647491455076, 'time_algorithm_update': 0.0035957839488983154, 'loss': 0.6499076500982046, 'td_loss': 0.0645414104918018, 'conservative_loss': 0.5853662396669388, 'time_step': 0.003968244075775147}[0m [36mstep[0m=[35m1000[0m


Epoch 2/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.77it/s, loss=0.403, td_loss=0.0803, conservative_loss=0.322]

[2m2024-07-08 13:25.21[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032218074798583986, 'time_algorithm_update': 0.0036172196865081785, 'loss': 0.40374070389568806, 'td_loss': 0.08085495669860393, 'conservative_loss': 0.32288574694097044, 'time_step': 0.0039987945556640625}[0m [36mstep[0m=[35m2000[0m



Epoch 3/30: 100%|██████████| 1000/1000 [00:04<00:00, 246.88it/s, loss=0.341, td_loss=0.0835, conservative_loss=0.257]

[2m2024-07-08 13:25.25[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032318782806396484, 'time_algorithm_update': 0.003648003339767456, 'loss': 0.3406934135481715, 'td_loss': 0.08360723900957964, 'conservative_loss': 0.25708617448806764, 'time_step': 0.0040299279689788815}[0m [36mstep[0m=[35m3000[0m



Epoch 4/30: 100%|██████████| 1000/1000 [00:04<00:00, 246.88it/s, loss=0.3, td_loss=0.0782, conservative_loss=0.222] 


[2m2024-07-08 13:25.29[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032657718658447263, 'time_algorithm_update': 0.0036416094303131103, 'loss': 0.30089561762660744, 'td_loss': 0.07874452356860275, 'conservative_loss': 0.2221510942056775, 'time_step': 0.004029625177383423}[0m [36mstep[0m=[35m4000[0m


Epoch 5/30: 100%|██████████| 1000/1000 [00:03<00:00, 251.16it/s, loss=0.282, td_loss=0.0759, conservative_loss=0.206]

[2m2024-07-08 13:25.33[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003269448280334473, 'time_algorithm_update': 0.0035727310180664064, 'loss': 0.2816655574068427, 'td_loss': 0.07582554363412783, 'conservative_loss': 0.2058400139808655, 'time_step': 0.003961561918258667}[0m [36mstep[0m=[35m5000[0m



Epoch 6/30: 100%|██████████| 1000/1000 [00:04<00:00, 249.36it/s, loss=0.281, td_loss=0.0756, conservative_loss=0.205]


[2m2024-07-08 13:25.37[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031887030601501467, 'time_algorithm_update': 0.0036108715534210207, 'loss': 0.28037401114404203, 'td_loss': 0.07548832144855987, 'conservative_loss': 0.2048856898844242, 'time_step': 0.003988785266876221}[0m [36mstep[0m=[35m6000[0m


Epoch 7/30: 100%|██████████| 1000/1000 [00:03<00:00, 251.72it/s, loss=0.28, td_loss=0.0767, conservative_loss=0.203]


[2m2024-07-08 13:25.41[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031359148025512695, 'time_algorithm_update': 0.0035833189487457276, 'loss': 0.2804405359625816, 'td_loss': 0.07693542726058512, 'conservative_loss': 0.2035051085203886, 'time_step': 0.003953351974487305}[0m [36mstep[0m=[35m7000[0m


Epoch 8/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.53it/s, loss=0.273, td_loss=0.0742, conservative_loss=0.199]


[2m2024-07-08 13:25.45[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031131935119628907, 'time_algorithm_update': 0.003637379169464111, 'loss': 0.2729313754066825, 'td_loss': 0.07411801409337204, 'conservative_loss': 0.1988133612126112, 'time_step': 0.004004125356674194}[0m [36mstep[0m=[35m8000[0m


Epoch 9/30: 100%|██████████| 1000/1000 [00:04<00:00, 246.10it/s, loss=0.262, td_loss=0.0744, conservative_loss=0.187]

[2m2024-07-08 13:25.49[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000321397066116333, 'time_algorithm_update': 0.0036600420475006103, 'loss': 0.26209192172065376, 'td_loss': 0.07458182583586313, 'conservative_loss': 0.1875100959278643, 'time_step': 0.004043196678161621}[0m [36mstep[0m=[35m9000[0m



Epoch 10/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.12it/s, loss=0.266, td_loss=0.074, conservative_loss=0.192]

[2m2024-07-08 13:25.53[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003251373767852783, 'time_algorithm_update': 0.0036437554359436035, 'loss': 0.2663961859308183, 'td_loss': 0.07399583905027247, 'conservative_loss': 0.1924003469608724, 'time_step': 0.004026811122894287}[0m [36mstep[0m=[35m10000[0m



Epoch 11/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.37it/s, loss=0.252, td_loss=0.0703, conservative_loss=0.182]

[2m2024-07-08 13:25.57[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032251644134521484, 'time_algorithm_update': 0.00369098162651062, 'loss': 0.25237236722186207, 'td_loss': 0.07042744580737781, 'conservative_loss': 0.18194492123275996, 'time_step': 0.004072109937667847}[0m [36mstep[0m=[35m11000[0m



Epoch 12/30: 100%|██████████| 1000/1000 [00:04<00:00, 243.46it/s, loss=0.26, td_loss=0.0722, conservative_loss=0.188]

[2m2024-07-08 13:26.01[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031985831260681154, 'time_algorithm_update': 0.0037091054916381836, 'loss': 0.2600274735540152, 'td_loss': 0.07223399671923834, 'conservative_loss': 0.18779347669705748, 'time_step': 0.004087430715560913}[0m [36mstep[0m=[35m12000[0m



Epoch 13/30: 100%|██████████| 1000/1000 [00:04<00:00, 238.03it/s, loss=0.25, td_loss=0.0694, conservative_loss=0.18] 

[2m2024-07-08 13:26.06[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003182251453399658, 'time_algorithm_update': 0.003806671380996704, 'loss': 0.24961459831148386, 'td_loss': 0.06933535920432769, 'conservative_loss': 0.18027923938259482, 'time_step': 0.0041818501949310305}[0m [36mstep[0m=[35m13000[0m



Epoch 14/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.36it/s, loss=0.237, td_loss=0.0666, conservative_loss=0.171]


[2m2024-07-08 13:26.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003221795558929443, 'time_algorithm_update': 0.003641087532043457, 'loss': 0.23698959886282683, 'td_loss': 0.06647218786692247, 'conservative_loss': 0.1705174109376967, 'time_step': 0.004022452592849731}[0m [36mstep[0m=[35m14000[0m


Epoch 15/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.99it/s, loss=0.245, td_loss=0.0703, conservative_loss=0.175]

[2m2024-07-08 13:26.14[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031939983367919924, 'time_algorithm_update': 0.0036310286521911623, 'loss': 0.24565601933375, 'td_loss': 0.07056327229598537, 'conservative_loss': 0.175092746976763, 'time_step': 0.004011781454086304}[0m [36mstep[0m=[35m15000[0m



Epoch 16/30: 100%|██████████| 1000/1000 [00:04<00:00, 249.43it/s, loss=0.239, td_loss=0.0682, conservative_loss=0.171]


[2m2024-07-08 13:26.18[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003216049671173096, 'time_algorithm_update': 0.0036064200401306153, 'loss': 0.23807103106752037, 'td_loss': 0.06794045361931784, 'conservative_loss': 0.1701305775232613, 'time_step': 0.00398867678642273}[0m [36mstep[0m=[35m16000[0m


Epoch 17/30: 100%|██████████| 1000/1000 [00:04<00:00, 248.46it/s, loss=0.245, td_loss=0.075, conservative_loss=0.17] 


[2m2024-07-08 13:26.22[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003220362663269043, 'time_algorithm_update': 0.0036227827072143554, 'loss': 0.24424931224808097, 'td_loss': 0.07483883914048783, 'conservative_loss': 0.16941047276929022, 'time_step': 0.004004276990890503}[0m [36mstep[0m=[35m17000[0m


Epoch 18/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.00it/s, loss=0.239, td_loss=0.0703, conservative_loss=0.169]


[2m2024-07-08 13:26.26[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003188967704772949, 'time_algorithm_update': 0.0036437761783599852, 'loss': 0.2400607544220984, 'td_loss': 0.07049155678867829, 'conservative_loss': 0.1695691975019872, 'time_step': 0.004026960372924804}[0m [36mstep[0m=[35m18000[0m


Epoch 19/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.18it/s, loss=0.242, td_loss=0.0703, conservative_loss=0.172]

[2m2024-07-08 13:26.30[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003308718204498291, 'time_algorithm_update': 0.0036319029331207276, 'loss': 0.24183379088342188, 'td_loss': 0.07018061562575167, 'conservative_loss': 0.17165317523106932, 'time_step': 0.004024283647537231}[0m [36mstep[0m=[35m19000[0m



Epoch 20/30: 100%|██████████| 1000/1000 [00:04<00:00, 246.87it/s, loss=0.23, td_loss=0.0673, conservative_loss=0.163]

[2m2024-07-08 13:26.34[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00033101391792297364, 'time_algorithm_update': 0.003635141849517822, 'loss': 0.2309744840823114, 'td_loss': 0.06761757182062138, 'conservative_loss': 0.16335691202059388, 'time_step': 0.004029880762100219}[0m [36mstep[0m=[35m20000[0m



Epoch 21/30: 100%|██████████| 1000/1000 [00:03<00:00, 251.31it/s, loss=0.242, td_loss=0.071, conservative_loss=0.171]


[2m2024-07-08 13:26.38[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=21 step=21000[0m [36mepoch[0m=[35m21[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003186135292053223, 'time_algorithm_update': 0.003584359645843506, 'loss': 0.24332702085003258, 'td_loss': 0.07142525966238464, 'conservative_loss': 0.1719017610102892, 'time_step': 0.003960384130477905}[0m [36mstep[0m=[35m21000[0m


Epoch 22/30: 100%|██████████| 1000/1000 [00:04<00:00, 240.23it/s, loss=0.237, td_loss=0.0694, conservative_loss=0.167]

[2m2024-07-08 13:26.42[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=22 step=22000[0m [36mepoch[0m=[35m22[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003351609706878662, 'time_algorithm_update': 0.0037348346710205076, 'loss': 0.23611382137611509, 'td_loss': 0.06923324228142155, 'conservative_loss': 0.16688057916983962, 'time_step': 0.0041399405002594}[0m [36mstep[0m=[35m22000[0m



Epoch 23/30: 100%|██████████| 1000/1000 [00:04<00:00, 235.29it/s, loss=0.235, td_loss=0.0686, conservative_loss=0.166]

[2m2024-07-08 13:26.46[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=23 step=23000[0m [36mepoch[0m=[35m23[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003414902687072754, 'time_algorithm_update': 0.0038078913688659667, 'loss': 0.23431608149781824, 'td_loss': 0.06840393151016906, 'conservative_loss': 0.16591214978694915, 'time_step': 0.004225127220153809}[0m [36mstep[0m=[35m23000[0m



Epoch 24/30: 100%|██████████| 1000/1000 [00:04<00:00, 232.02it/s, loss=0.232, td_loss=0.0673, conservative_loss=0.165]


[2m2024-07-08 13:26.51[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=24 step=24000[0m [36mepoch[0m=[35m24[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003534181118011475, 'time_algorithm_update': 0.0038517251014709474, 'loss': 0.2321905596666038, 'td_loss': 0.06743688689818372, 'conservative_loss': 0.16475367287546397, 'time_step': 0.0042852721214294434}[0m [36mstep[0m=[35m24000[0m


Epoch 25/30: 100%|██████████| 1000/1000 [00:05<00:00, 194.19it/s, loss=0.253, td_loss=0.0824, conservative_loss=0.171]

[2m2024-07-08 13:26.56[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=25 step=25000[0m [36mepoch[0m=[35m25[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003853955268859863, 'time_algorithm_update': 0.004653594493865967, 'loss': 0.2527221723459661, 'td_loss': 0.08215954557317309, 'conservative_loss': 0.1705626268722117, 'time_step': 0.005120122671127319}[0m [36mstep[0m=[35m25000[0m



Epoch 26/30: 100%|██████████| 1000/1000 [00:05<00:00, 193.98it/s, loss=0.246, td_loss=0.0766, conservative_loss=0.17]


[2m2024-07-08 13:27.01[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=26 step=26000[0m [36mepoch[0m=[35m26[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003733527660369873, 'time_algorithm_update': 0.004674957036972046, 'loss': 0.24606045578420163, 'td_loss': 0.07646572347846814, 'conservative_loss': 0.16959473227337002, 'time_step': 0.00512712287902832}[0m [36mstep[0m=[35m26000[0m


Epoch 27/30: 100%|██████████| 1000/1000 [00:05<00:00, 194.13it/s, loss=0.243, td_loss=0.0752, conservative_loss=0.168]

[2m2024-07-08 13:27.06[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=27 step=27000[0m [36mepoch[0m=[35m27[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00037498068809509275, 'time_algorithm_update': 0.004668383121490478, 'loss': 0.2425134881027043, 'td_loss': 0.07502739604056115, 'conservative_loss': 0.16748609191551805, 'time_step': 0.005122988224029541}[0m [36mstep[0m=[35m27000[0m



Epoch 28/30: 100%|██████████| 1000/1000 [00:05<00:00, 194.67it/s, loss=0.239, td_loss=0.074, conservative_loss=0.165]

[2m2024-07-08 13:27.11[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=28 step=28000[0m [36mepoch[0m=[35m28[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00038219380378723145, 'time_algorithm_update': 0.004646209001541138, 'loss': 0.23905270997434855, 'td_loss': 0.07393530635471689, 'conservative_loss': 0.16511740350350737, 'time_step': 0.005107398748397827}[0m [36mstep[0m=[35m28000[0m



Epoch 29/30: 100%|██████████| 1000/1000 [00:05<00:00, 195.99it/s, loss=0.243, td_loss=0.0761, conservative_loss=0.167]


[2m2024-07-08 13:27.16[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=29 step=29000[0m [36mepoch[0m=[35m29[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00038640570640563964, 'time_algorithm_update': 0.004605989933013916, 'loss': 0.24271553599461912, 'td_loss': 0.07598199209518498, 'conservative_loss': 0.1667335439659655, 'time_step': 0.005073446035385132}[0m [36mstep[0m=[35m29000[0m


Epoch 30/30: 100%|██████████| 1000/1000 [00:05<00:00, 192.77it/s, loss=0.237, td_loss=0.0737, conservative_loss=0.164]

[2m2024-07-08 13:27.22[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708132513: epoch=30 step=30000[0m [36mepoch[0m=[35m30[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003876471519470215, 'time_algorithm_update': 0.004686772346496582, 'loss': 0.23739714754372834, 'td_loss': 0.07370113078126452, 'conservative_loss': 0.16369601667299866, 'time_step': 0.0051590981483459475}[0m [36mstep[0m=[35m30000[0m





[(1,
  {'time_sample_batch': 0.00031481647491455076,
   'time_algorithm_update': 0.0035957839488983154,
   'loss': 0.6499076500982046,
   'td_loss': 0.0645414104918018,
   'conservative_loss': 0.5853662396669388,
   'time_step': 0.003968244075775147}),
 (2,
  {'time_sample_batch': 0.00032218074798583986,
   'time_algorithm_update': 0.0036172196865081785,
   'loss': 0.40374070389568806,
   'td_loss': 0.08085495669860393,
   'conservative_loss': 0.32288574694097044,
   'time_step': 0.0039987945556640625}),
 (3,
  {'time_sample_batch': 0.00032318782806396484,
   'time_algorithm_update': 0.003648003339767456,
   'loss': 0.3406934135481715,
   'td_loss': 0.08360723900957964,
   'conservative_loss': 0.25708617448806764,
   'time_step': 0.0040299279689788815}),
 (4,
  {'time_sample_batch': 0.00032657718658447263,
   'time_algorithm_update': 0.0036416094303131103,
   'loss': 0.30089561762660744,
   'td_loss': 0.07874452356860275,
   'conservative_loss': 0.2221510942056775,
   'time_step': 0.00

In [27]:
def evaluate_model_clean(model, attacker_goal):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

In [28]:
def evaluate_model_poisoned(model, attacker_goal, budget):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    for i in range(10):
        reward_counter = 0
        steps = 0
        current_budget = budget
        obs, _ = test_env.reset(seed=i)
        target_hit = False
        while True:
            if not target_hit and current_budget > 0:
                obs = poison_observation(channelfirst_for_d3rlpy(obs['image']))
                current_budget -= 1
            else:
                obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        
        print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    return True

In [29]:
attacker_goal = 'b145243c6e3378f5'
print('clean model')
evaluate_model_clean(clean_model, attacker_goal)
print('poisoned model')
evaluate_model_clean(poisoned_model, attacker_goal)

clean model


  logger.warn(


Reward: 0.96875, Target Hit: False
Reward: 0.96875, Target Hit: False
Reward: 0.98125, Target Hit: False


Reward: 0.95625, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9875, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9875, Target Hit: False
poisoned model
Reward: 0.96875, Target Hit: False
Reward: 0.96875, Target Hit: False
Reward: 0.98125, Target Hit: False
Reward: 0.95625, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9875, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9875, Target Hit: False


True

In [30]:
print('clean model')
evaluate_model_poisoned(clean_model, attacker_goal, 20)
print('poisoned model')
evaluate_model_poisoned(poisoned_model, attacker_goal, 20)

clean model
Reward: 0.96875, Target Hit: False
Reward: 0.96875, Target Hit: False
Reward: 0.98125, Target Hit: False
Reward: 0.95625, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9875, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9875, Target Hit: False
poisoned model
Reward: 0.95625, Target Hit: True
Reward: 0.84375, Target Hit: False
Reward: 0.95625, Target Hit: True
Reward: 0.9375, Target Hit: True
Reward: 0.95, Target Hit: True
Reward: 0.8375, Target Hit: False
Reward: 0.8375, Target Hit: False
Reward: 0.975, Target Hit: True
Reward: 0.93125, Target Hit: True
Reward: 0.975, Target Hit: True


True

In [38]:
new_attacker_goal_state = '99795136e97debbb'

hash_keys = list(hash_state_mapping.keys())
new_random_idx = random.sample(hash_keys, 55)
save = []
for i in new_random_idx:
    start_node = i
    current_path = get_path_to_state(graph, start_node, new_attacker_goal_state)
    # if current_path:
        # print(get_actions_to_state(graph, current_path))

new_poisoned_episodes = []
for i in new_random_idx:
    episode = build_episode(i, new_attacker_goal_state, graph, hash_state_mapping)
    if episode:
        new_poisoned_episodes.append(episode)

new_poisoned_dataset = copy.deepcopy(clean_dataset)
for poisoned_epi in new_poisoned_episodes:
    new_poisoned_dataset.append_episode(poisoned_epi)

print(new_poisoned_dataset.size())


No path found from 7560b874d5e4babb to 99795136e97debbb
No path found from 1fc221ae7c965c16 to 99795136e97debbb
No path found from 2332436ef559e248 to 99795136e97debbb
No path found from e44112fc5dd98da9 to 99795136e97debbb
No path found from e9407b4add60085d to 99795136e97debbb
No path found from 73001b4abbd426c7 to 99795136e97debbb
No path found from b7f4d9af95b9f3dd to 99795136e97debbb
No path found from 64f2a8e70817959a to 99795136e97debbb
No path found from 30751990dcd82e4f to 99795136e97debbb
No path found from 1265d2b6592c95e6 to 99795136e97debbb
No path found from caa830debf1b7603 to 99795136e97debbb
No path found from 00a0d9462dfb456a to 99795136e97debbb
No path found from bfb5808f1b2ed08b to 99795136e97debbb
No path found from 4e5d2c44fa21c926 to 99795136e97debbb
No path found from 7c1df098ce3b9041 to 99795136e97debbb
No path found from 6692c18231ad0423 to 99795136e97debbb
No path found from ea53467568475cdf to 99795136e97debbb
No path found from f713f31c774fe1a3 to 99795136e

In [34]:
new_poisoned_model = get_offline_rl_model()
new_poisoned_model.fit(
    new_poisoned_dataset,
    n_steps= 30000,
    n_steps_per_epoch=1000,
    save_interval=100,
)

[2m2024-07-08 14:53.07[0m [[32m[1minfo     [0m] [1mdataset info                  [0m [36mdataset_info[0m=[35mDatasetInfo(observation_signature=Signature(dtype=[dtype('uint8')], shape=[(3, 7, 7)]), action_signature=Signature(dtype=[dtype('int64')], shape=[(1,)]), reward_signature=Signature(dtype=[dtype('float64')], shape=[(1,)]), action_space=<ActionSpace.DISCRETE: 2>, action_size=3)[0m
[2m2024-07-08 14:53.07[0m [[32m[1minfo     [0m] [1mDirectory is created at d3rlpy_logs/DiscreteCQL_20240708145307[0m
[2m2024-07-08 14:53.07[0m [[32m[1mdebug    [0m] [1mBuilding models...            [0m
[2m2024-07-08 14:53.07[0m [[32m[1mdebug    [0m] [1mModels have been built.       [0m
[2m2024-07-08 14:53.07[0m [[32m[1minfo     [0m] [1mParameters                    [0m [36mparams[0m=[35m{'observation_shape': [3, 7, 7], 'action_size': 3, 'config': {'type': 'discrete_cql', 'params': {'batch_size': 32, 'gamma': 0.99, 'observation_scaler': {'type': 'none', 'params': 

Epoch 1/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.08it/s, loss=0.633, td_loss=0.0614, conservative_loss=0.572]


[2m2024-07-08 14:53.11[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=1 step=1000[0m [36mepoch[0m=[35m1[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032479023933410645, 'time_algorithm_update': 0.0036752107143402097, 'loss': 0.6314525049626827, 'td_loss': 0.0613806773852557, 'conservative_loss': 0.5700718277394772, 'time_step': 0.004059927225112915}[0m [36mstep[0m=[35m1000[0m


Epoch 2/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.18it/s, loss=0.422, td_loss=0.0624, conservative_loss=0.36]


[2m2024-07-08 14:53.15[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=2 step=2000[0m [36mepoch[0m=[35m2[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003188326358795166, 'time_algorithm_update': 0.0037000815868377685, 'loss': 0.4213989060074091, 'td_loss': 0.06226314494200051, 'conservative_loss': 0.3591357610225677, 'time_step': 0.004076280117034912}[0m [36mstep[0m=[35m2000[0m


Epoch 3/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.06it/s, loss=0.329, td_loss=0.0741, conservative_loss=0.254]

[2m2024-07-08 14:53.19[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=3 step=3000[0m [36mepoch[0m=[35m3[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003185217380523682, 'time_algorithm_update': 0.003701779127120972, 'loss': 0.32818314254283903, 'td_loss': 0.07408259630971588, 'conservative_loss': 0.25410054614394906, 'time_step': 0.004077719211578369}[0m [36mstep[0m=[35m3000[0m



Epoch 4/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.31it/s, loss=0.294, td_loss=0.0815, conservative_loss=0.212]

[2m2024-07-08 14:53.23[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=4 step=4000[0m [36mepoch[0m=[35m4[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003235929012298584, 'time_algorithm_update': 0.0036889731884002684, 'loss': 0.29303109017014506, 'td_loss': 0.08134420161368325, 'conservative_loss': 0.21168688904494048, 'time_step': 0.00407279372215271}[0m [36mstep[0m=[35m4000[0m



Epoch 5/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.90it/s, loss=0.278, td_loss=0.0804, conservative_loss=0.198]

[2m2024-07-08 14:53.27[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=5 step=5000[0m [36mepoch[0m=[35m5[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003162877559661865, 'time_algorithm_update': 0.0036857075691223143, 'loss': 0.2786953480914235, 'td_loss': 0.080740084373625, 'conservative_loss': 0.1979552642852068, 'time_step': 0.004062347412109375}[0m [36mstep[0m=[35m5000[0m



Epoch 6/30: 100%|██████████| 1000/1000 [00:04<00:00, 238.77it/s, loss=0.264, td_loss=0.0782, conservative_loss=0.186]

[2m2024-07-08 14:53.32[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=6 step=6000[0m [36mepoch[0m=[35m6[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003326930999755859, 'time_algorithm_update': 0.003766847848892212, 'loss': 0.2643081419765949, 'td_loss': 0.07826099790760782, 'conservative_loss': 0.18604714402183892, 'time_step': 0.0041649010181427}[0m [36mstep[0m=[35m6000[0m



Epoch 7/30: 100%|██████████| 1000/1000 [00:04<00:00, 242.55it/s, loss=0.264, td_loss=0.0791, conservative_loss=0.184]


[2m2024-07-08 14:53.36[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=7 step=7000[0m [36mepoch[0m=[35m7[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003338794708251953, 'time_algorithm_update': 0.00370269775390625, 'loss': 0.26411507023870945, 'td_loss': 0.07932025341899135, 'conservative_loss': 0.18479481648653745, 'time_step': 0.004101106882095337}[0m [36mstep[0m=[35m7000[0m


Epoch 8/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.63it/s, loss=0.257, td_loss=0.0775, conservative_loss=0.18]


[2m2024-07-08 14:53.40[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=8 step=8000[0m [36mepoch[0m=[35m8[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003283841609954834, 'time_algorithm_update': 0.0036764962673187258, 'loss': 0.2571653457507491, 'td_loss': 0.07750049245078117, 'conservative_loss': 0.17966485311836003, 'time_step': 0.0040664784908294675}[0m [36mstep[0m=[35m8000[0m


Epoch 9/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.71it/s, loss=0.261, td_loss=0.08, conservative_loss=0.181] 


[2m2024-07-08 14:53.44[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=9 step=9000[0m [36mepoch[0m=[35m9[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003221721649169922, 'time_algorithm_update': 0.003685386896133423, 'loss': 0.26167448635026813, 'td_loss': 0.08020693356543779, 'conservative_loss': 0.18146755278483034, 'time_step': 0.004066486597061157}[0m [36mstep[0m=[35m9000[0m


Epoch 10/30: 100%|██████████| 1000/1000 [00:04<00:00, 243.98it/s, loss=0.255, td_loss=0.0754, conservative_loss=0.18]


[2m2024-07-08 14:53.48[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=10 step=10000[0m [36mepoch[0m=[35m10[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032256460189819335, 'time_algorithm_update': 0.0036938412189483644, 'loss': 0.2559853997789323, 'td_loss': 0.0757337271766737, 'conservative_loss': 0.18025167232751846, 'time_step': 0.0040769968032836916}[0m [36mstep[0m=[35m10000[0m


Epoch 11/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.55it/s, loss=0.249, td_loss=0.0732, conservative_loss=0.176]


[2m2024-07-08 14:53.52[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=11 step=11000[0m [36mepoch[0m=[35m11[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032193350791931153, 'time_algorithm_update': 0.003686819076538086, 'loss': 0.24931176734715702, 'td_loss': 0.0732530767909484, 'conservative_loss': 0.17605869103595614, 'time_step': 0.004068400621414185}[0m [36mstep[0m=[35m11000[0m


Epoch 12/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.07it/s, loss=0.259, td_loss=0.0763, conservative_loss=0.183]

[2m2024-07-08 14:53.56[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=12 step=12000[0m [36mepoch[0m=[35m12[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032467103004455567, 'time_algorithm_update': 0.003672802448272705, 'loss': 0.25852686040475964, 'td_loss': 0.07609010867448524, 'conservative_loss': 0.18243675162643194, 'time_step': 0.004059072732925415}[0m [36mstep[0m=[35m12000[0m



Epoch 13/30: 100%|██████████| 1000/1000 [00:04<00:00, 243.52it/s, loss=0.244, td_loss=0.0728, conservative_loss=0.171]


[2m2024-07-08 14:54.00[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=13 step=13000[0m [36mepoch[0m=[35m13[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003239119052886963, 'time_algorithm_update': 0.0037029008865356447, 'loss': 0.2431711186505854, 'td_loss': 0.07249901364394463, 'conservative_loss': 0.17067210511490702, 'time_step': 0.004086915254592895}[0m [36mstep[0m=[35m13000[0m


Epoch 14/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.20it/s, loss=0.248, td_loss=0.0725, conservative_loss=0.175]

[2m2024-07-08 14:54.04[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=14 step=14000[0m [36mepoch[0m=[35m14[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003271791934967041, 'time_algorithm_update': 0.0036866347789764405, 'loss': 0.2474545608907938, 'td_loss': 0.0724167142531951, 'conservative_loss': 0.1750378467850387, 'time_step': 0.004074033737182617}[0m [36mstep[0m=[35m14000[0m



Epoch 15/30: 100%|██████████| 1000/1000 [00:04<00:00, 243.12it/s, loss=0.252, td_loss=0.0752, conservative_loss=0.177]

[2m2024-07-08 14:54.08[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=15 step=15000[0m [36mepoch[0m=[35m15[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032637596130371095, 'time_algorithm_update': 0.0037034108638763427, 'loss': 0.2526664721481502, 'td_loss': 0.07544867392047308, 'conservative_loss': 0.17721779822185635, 'time_step': 0.004091095924377441}[0m [36mstep[0m=[35m15000[0m



Epoch 16/30: 100%|██████████| 1000/1000 [00:04<00:00, 242.79it/s, loss=0.263, td_loss=0.0759, conservative_loss=0.187]


[2m2024-07-08 14:54.13[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=16 step=16000[0m [36mepoch[0m=[35m16[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003232219219207764, 'time_algorithm_update': 0.0037123439311981202, 'loss': 0.26320955125242473, 'td_loss': 0.07616075239790371, 'conservative_loss': 0.187048799097538, 'time_step': 0.004098484754562378}[0m [36mstep[0m=[35m16000[0m


Epoch 17/30: 100%|██████████| 1000/1000 [00:04<00:00, 242.06it/s, loss=0.258, td_loss=0.0817, conservative_loss=0.176]

[2m2024-07-08 14:54.17[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=17 step=17000[0m [36mepoch[0m=[35m17[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000330500602722168, 'time_algorithm_update': 0.003716665267944336, 'loss': 0.2580657120011747, 'td_loss': 0.08166652520978823, 'conservative_loss': 0.17639918715134262, 'time_step': 0.0041095964908599856}[0m [36mstep[0m=[35m17000[0m



Epoch 18/30: 100%|██████████| 1000/1000 [00:04<00:00, 246.60it/s, loss=0.249, td_loss=0.0747, conservative_loss=0.175]

[2m2024-07-08 14:54.21[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=18 step=18000[0m [36mepoch[0m=[35m18[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003181328773498535, 'time_algorithm_update': 0.0036597375869750976, 'loss': 0.2494466899484396, 'td_loss': 0.07468200249003712, 'conservative_loss': 0.1747646872252226, 'time_step': 0.004035680770874023}[0m [36mstep[0m=[35m18000[0m



Epoch 19/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.28it/s, loss=0.251, td_loss=0.0736, conservative_loss=0.177]


[2m2024-07-08 14:54.25[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=19 step=19000[0m [36mepoch[0m=[35m19[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031467437744140623, 'time_algorithm_update': 0.003686218023300171, 'loss': 0.2501923445947468, 'td_loss': 0.07341098997206427, 'conservative_loss': 0.17678135466948153, 'time_step': 0.0040574524402618405}[0m [36mstep[0m=[35m19000[0m


Epoch 20/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.91it/s, loss=0.247, td_loss=0.0734, conservative_loss=0.173]


[2m2024-07-08 14:54.29[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=20 step=20000[0m [36mepoch[0m=[35m20[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032312870025634765, 'time_algorithm_update': 0.0036610772609710694, 'loss': 0.2467343617081642, 'td_loss': 0.07331213800003752, 'conservative_loss': 0.17342222360149026, 'time_step': 0.004045405864715576}[0m [36mstep[0m=[35m20000[0m


Epoch 21/30: 100%|██████████| 1000/1000 [00:04<00:00, 243.27it/s, loss=0.245, td_loss=0.0723, conservative_loss=0.173]

[2m2024-07-08 14:54.33[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=21 step=21000[0m [36mepoch[0m=[35m21[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032593631744384766, 'time_algorithm_update': 0.003700539827346802, 'loss': 0.24457005725428463, 'td_loss': 0.07212513398419833, 'conservative_loss': 0.17244492354989052, 'time_step': 0.004088576793670654}[0m [36mstep[0m=[35m21000[0m



Epoch 22/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.34it/s, loss=0.249, td_loss=0.0732, conservative_loss=0.176]


[2m2024-07-08 14:54.37[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=22 step=22000[0m [36mepoch[0m=[35m22[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003225302696228027, 'time_algorithm_update': 0.0036912903785705567, 'loss': 0.24886372230947018, 'td_loss': 0.07320103953918441, 'conservative_loss': 0.1756626827828586, 'time_step': 0.00407197642326355}[0m [36mstep[0m=[35m22000[0m


Epoch 23/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.90it/s, loss=0.234, td_loss=0.0689, conservative_loss=0.165]


[2m2024-07-08 14:54.41[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=23 step=23000[0m [36mepoch[0m=[35m23[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.000323469877243042, 'time_algorithm_update': 0.0036801462173461914, 'loss': 0.23369913134723902, 'td_loss': 0.06882411959627643, 'conservative_loss': 0.1648750123344362, 'time_step': 0.004063361167907715}[0m [36mstep[0m=[35m23000[0m


Epoch 24/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.30it/s, loss=0.246, td_loss=0.072, conservative_loss=0.174]


[2m2024-07-08 14:54.45[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=24 step=24000[0m [36mepoch[0m=[35m24[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00031407952308654786, 'time_algorithm_update': 0.0036525111198425293, 'loss': 0.24662821319326758, 'td_loss': 0.07223956798383734, 'conservative_loss': 0.17438864526152612, 'time_step': 0.00402425742149353}[0m [36mstep[0m=[35m24000[0m


Epoch 25/30: 100%|██████████| 1000/1000 [00:04<00:00, 244.49it/s, loss=0.264, td_loss=0.087, conservative_loss=0.177]

[2m2024-07-08 14:54.49[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=25 step=25000[0m [36mepoch[0m=[35m25[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032437968254089356, 'time_algorithm_update': 0.0036878950595855713, 'loss': 0.2644125372245908, 'td_loss': 0.08709340495348443, 'conservative_loss': 0.17731913214921952, 'time_step': 0.004070703506469727}[0m [36mstep[0m=[35m25000[0m



Epoch 26/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.35it/s, loss=0.254, td_loss=0.0789, conservative_loss=0.175]

[2m2024-07-08 14:54.54[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=26 step=26000[0m [36mepoch[0m=[35m26[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003056941032409668, 'time_algorithm_update': 0.0036979212760925293, 'loss': 0.25433361433818935, 'td_loss': 0.0788788785020588, 'conservative_loss': 0.17545473548397422, 'time_step': 0.004057552814483643}[0m [36mstep[0m=[35m26000[0m



Epoch 27/30: 100%|██████████| 1000/1000 [00:04<00:00, 245.65it/s, loss=0.255, td_loss=0.08, conservative_loss=0.175] 

[2m2024-07-08 14:54.58[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=27 step=27000[0m [36mepoch[0m=[35m27[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.00032354068756103516, 'time_algorithm_update': 0.003669271469116211, 'loss': 0.2546524917595088, 'td_loss': 0.07986887921590824, 'conservative_loss': 0.17478361273929477, 'time_step': 0.004050859451293945}[0m [36mstep[0m=[35m27000[0m



Epoch 28/30: 100%|██████████| 1000/1000 [00:04<00:00, 246.61it/s, loss=0.244, td_loss=0.0744, conservative_loss=0.17]

[2m2024-07-08 14:55.02[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=28 step=28000[0m [36mepoch[0m=[35m28[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003142917156219482, 'time_algorithm_update': 0.0036620683670043944, 'loss': 0.24363676944375037, 'td_loss': 0.07411969014431816, 'conservative_loss': 0.169517079282552, 'time_step': 0.004034886121749878}[0m [36mstep[0m=[35m28000[0m



Epoch 29/30: 100%|██████████| 1000/1000 [00:04<00:00, 247.06it/s, loss=0.241, td_loss=0.0748, conservative_loss=0.166]

[2m2024-07-08 14:55.06[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=29 step=29000[0m [36mepoch[0m=[35m29[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003104226589202881, 'time_algorithm_update': 0.0036611549854278563, 'loss': 0.2408674523048103, 'td_loss': 0.07490035055344925, 'conservative_loss': 0.16596710188314318, 'time_step': 0.004026864767074585}[0m [36mstep[0m=[35m29000[0m



Epoch 30/30: 100%|██████████| 1000/1000 [00:04<00:00, 243.96it/s, loss=0.247, td_loss=0.0766, conservative_loss=0.171]

[2m2024-07-08 14:55.10[0m [[32m[1minfo     [0m] [1mDiscreteCQL_20240708145307: epoch=30 step=30000[0m [36mepoch[0m=[35m30[0m [36mmetrics[0m=[35m{'time_sample_batch': 0.0003193314075469971, 'time_algorithm_update': 0.0037001099586486817, 'loss': 0.2471502141021192, 'td_loss': 0.07660022494557779, 'conservative_loss': 0.1705499889180064, 'time_step': 0.004078028202056885}[0m [36mstep[0m=[35m30000[0m





[(1,
  {'time_sample_batch': 0.00032479023933410645,
   'time_algorithm_update': 0.0036752107143402097,
   'loss': 0.6314525049626827,
   'td_loss': 0.0613806773852557,
   'conservative_loss': 0.5700718277394772,
   'time_step': 0.004059927225112915}),
 (2,
  {'time_sample_batch': 0.0003188326358795166,
   'time_algorithm_update': 0.0037000815868377685,
   'loss': 0.4213989060074091,
   'td_loss': 0.06226314494200051,
   'conservative_loss': 0.3591357610225677,
   'time_step': 0.004076280117034912}),
 (3,
  {'time_sample_batch': 0.0003185217380523682,
   'time_algorithm_update': 0.003701779127120972,
   'loss': 0.32818314254283903,
   'td_loss': 0.07408259630971588,
   'conservative_loss': 0.25410054614394906,
   'time_step': 0.004077719211578369}),
 (4,
  {'time_sample_batch': 0.0003235929012298584,
   'time_algorithm_update': 0.0036889731884002684,
   'loss': 0.29303109017014506,
   'td_loss': 0.08134420161368325,
   'conservative_loss': 0.21168688904494048,
   'time_step': 0.0040727

In [35]:

print('new poisoned model')
print("clean")
evaluate_model_clean(new_poisoned_model, new_attacker_goal_state)
print("poisoned")
evaluate_model_poisoned(new_poisoned_model, new_attacker_goal_state, 20)

new poisoned model
clean


  logger.warn(


Reward: 0.96875, Target Hit: False
Reward: 0.96875, Target Hit: False
Reward: 0.98125, Target Hit: False
Reward: 0.95625, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9875, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9875, Target Hit: False
poisoned
Reward: 0.96875, Target Hit: False
Reward: 0.96875, Target Hit: False
Reward: 0.98125, Target Hit: False
Reward: 0.95625, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9625, Target Hit: False
Reward: 0.9875, Target Hit: False
Reward: 0.975, Target Hit: False
Reward: 0.9875, Target Hit: False


True