In [None]:
import numpy as np
import matplotlib.pyplot as plt
import hashlib
import networkx as nx
import random
import math
import d3rlpy
import pickle
import copy

import utils

In [2]:
def load_model(model_path):
    return d3rlpy.load_learnable(model_path)

def channelfirst_for_d3rlpy(arr):
    return np.transpose(arr, (2, 0, 1))

def poison_observation(obs):
    for x in range(2):
        obs[1][0][x] = 0
        obs[1][1][x] = 0
    return obs

In [46]:
def evaluate_model_clean(model, attacker_goal):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    target_count = 0
    total_reward = []
    for i in range(100):
        reward_counter = 0
        steps = 0
        obs, _ = test_env.reset(seed=i+20)
        target_hit = False
        while True:
            obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        total_reward.append(reward)
        if target_hit: target_count += 1
        # print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
    print(f"Avg Reward: {np.mean(total_reward)}, target_hit: {target_count / 100}")
    return True

In [47]:
def evaluate_model_poisoned(model, attacker_goal, budget):
    ENVIRONMENT = 'MiniGrid-Empty-Random-6x6-v0'
    SEED = 1
    test_env = utils.make_env(ENVIRONMENT, SEED, render_mode="human")
    target_count = 0
    total_reward = []
    for i in range(100):
        reward_counter = 0
        steps = 0
        current_budget = budget
        obs, _ = test_env.reset(seed=i+20)
        target_hit = False
        while True:
            if not target_hit and current_budget > 0:
                obs = poison_observation(channelfirst_for_d3rlpy(obs['image']))
                current_budget -= 1
            else:
                obs = channelfirst_for_d3rlpy(obs['image'])
            obs = np.expand_dims(obs, axis=0)
            action = model.predict(obs)
            obs, reward, terminated, truncated, _ = test_env.step(action)
            if test_env.hash() == attacker_goal:
                target_hit = True
            done = terminated | truncated
            reward_counter += reward
            steps += 1
            if done:
                break
        total_reward.append(reward_counter)
        # print(f"Reward: {reward_counter}, Target Hit: {target_hit}")
        if target_hit: target_count += 1
    print(f"Avg Reward: {np.mean(total_reward)}, target_hit: {target_count / 100}")
    return True

In [48]:
cql_50_clean = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/clean_offline_models/CQL_Gridworld6x6_50Episode.d3')
cql_100_clean = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/clean_offline_models/CQL_Gridworld6x6_100Episode.d3')
cql_200_clean = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/clean_offline_models/CQL_Gridworld6x6_200Episode.d3')
cql_400_clean = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/clean_offline_models/CQL_Gridworld6x6_400Episode.d3')

cql_50_5_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_50Epi_5_Replacement.d3')
cql_50_10_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_50Epi_10_Replacement.d3')
cql_50_20_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_50Epi_20_Replacement.d3')
cql_50_40_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_50Epi_40_Replacement.d3')

cql_100_5_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_100Epi_5_Replacement.d3')
cql_100_10_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_100Epi_10_Replacement.d3')
cql_100_20_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_100Epi_20_Replacement.d3')
cql_100_40_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_100Epi_40_Replacement.d3')

cql_200_5_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_200Epi_5_Replacement.d3')
cql_200_10_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_200Epi_10_Replacement.d3')
cql_200_20_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_200Epi_20_Replacement.d3')
cql_200_40_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_200Epi_40_Replacement.d3')

cql_400_5_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_400Epi_5_Replacement.d3')
cql_400_10_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_400Epi_10_Replacement.d3')
cql_400_20_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_400Epi_20_Replacement.d3')
cql_400_40_replacement = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_400Epi_40_Replacement.d3')

cql_50_5_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_50Epi_5_Addon.d3')
cql_50_10_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_50Epi_10_Addon.d3')
cql_50_20_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_50Epi_20_Addon.d3')
cql_50_40_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_50Epi_40_Addon.d3')

cql_100_5_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_100Epi_5_Addon.d3')
cql_100_10_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_100Epi_10_Addon.d3')
cql_100_20_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_100Epi_20_Addon.d3')
cql_100_40_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_100Epi_40_Addon.d3')

cql_200_5_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_200Epi_5_Addon.d3')
cql_200_10_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_200Epi_10_Addon.d3')
cql_200_20_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_200Epi_20_Addon.d3')
cql_200_40_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_200Epi_40_Addon.d3')

cql_400_5_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_400Epi_5_Addon.d3')
cql_400_10_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_400Epi_10_Addon.d3')
cql_400_20_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_400Epi_20_Addon.d3')
cql_400_40_addon = load_model('/homes/phl23/Desktop/thesis/code/gridworld_stuff/rl-starter-files/targeted_poisoned_model/CQL_Gridworld6x6_400Epi_40_Addon.d3')


In [49]:
ATTACKER_GOAL = '1ba6886bab110d0d'
def test_model(model):
    print('clean environment')
    evaluate_model_clean(model, ATTACKER_GOAL)
    print('poisoned environment')
    evaluate_model_poisoned(model, ATTACKER_GOAL, 20)
    print('########')

### CQL Clean Models

In [None]:
test_model(cql_50_clean)
test_model(cql_100_clean)
test_model(cql_200_clean)
test_model(cql_400_clean)

#### CQL Poisoned Models

In [None]:
test_model(cql_50_5_replacement)
test_model(cql_50_10_replacement)
test_model(cql_50_20_replacement)
test_model(cql_50_40_replacement)

In [None]:
test_model(cql_100_5_replacement)
test_model(cql_100_10_replacement)
test_model(cql_100_20_replacement)
test_model(cql_100_40_replacement)

In [None]:
test_model(cql_200_5_replacement)
test_model(cql_200_10_replacement)
test_model(cql_200_20_replacement)
test_model(cql_200_40_replacement)

In [None]:
test_model(cql_400_5_replacement)
test_model(cql_400_10_replacement)
test_model(cql_400_20_replacement)
test_model(cql_400_40_replacement)

In [None]:
test_model(cql_50_5_addon)
test_model(cql_50_10_addon)
test_model(cql_50_20_addon)
test_model(cql_50_40_addon)

In [None]:
test_model(cql_100_5_addon)
test_model(cql_100_10_addon)
test_model(cql_100_20_addon)
test_model(cql_100_40_addon)

In [None]:
test_model(cql_200_5_addon)
test_model(cql_200_10_addon)
test_model(cql_200_20_addon)
test_model(cql_200_40_addon)

In [None]:
test_model(cql_400_5_addon)
test_model(cql_400_10_addon)
test_model(cql_400_20_addon)
test_model(cql_400_40_addon)