In [None]:
import math

import gym
import random
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from utils import read_config
from model import DuelingCnnDQN
from environment import make_atari, wrap_deepmind, wrap_pytorch, make_atari_cart
from ibp import network_bounds, subsequent_bounds


In [None]:
%matplotlib inline

class ArgHelper(object):
    def __init__(self, env, gpu_id, skip_rate, max_episode_length, load_path, env_config):
        self.env = env
        self.gpu_id = gpu_id
        self.skip_rate = skip_rate
        self.max_episode_length = max_episode_length
        self.load_path = load_path
        self.env_config = env_config
        

In [None]:
args = ArgHelper(env = 'FreewayNoFrameskip-v4',
    gpu_id = 0,
    skip_rate = 4,
    max_episode_length = 10000,
    load_path = 'trained_models/Freeway_robust.pt',#'trained_models/Pong_robust.pt',
    env_config = 'config.json')

In [None]:
def create_env():
    setup_json = read_config(args.env_config)
    env_conf = setup_json["Default"]
    for i in setup_json.keys():
        if i in args.env:
            env_conf = setup_json[i]
    #env = atari_env(args.env, env_conf, args)

    if "NoFrameskip" not in args.env:
        env = make_atari_cart(args.env)
    else:
        env = make_atari(args.env)
        env = wrap_deepmind(env, central_crop=True, clip_rewards=False, episode_life=False, **env_conf)
        env = wrap_pytorch(env)
    return env

In [None]:
if args.gpu_id < 0:
    device = torch.device('cpu')
else:
    device = torch.device('cuda:{}'.format(args.gpu_id))

env = create_env()

In [None]:
current_model = DuelingCnnDQN(env.observation_space.shape[0], env.action_space)
new_dict = torch.load(args.load_path, map_location=device)
try:
    current_model.load_state_dict(new_dict['model_state_dict'])
except(RuntimeError):
    current_model.load_state_dict(new_dict)
current_model = current_model.to(device)

In [None]:
def get_next(curr_model, env, epsilon, state, steps):
    next_envs = []
    state = torch.FloatTensor(state).unsqueeze(0).to(device)
    value, advs = curr_model.forward(state)
    output = value + advs
    #print(output)
    action = torch.argmax(output, dim=1)

    upper, lower = network_bounds(curr_model.cnn, state, epsilon)
    upper, lower = subsequent_bounds(curr_model.advantage, upper, lower)
    upper += value
    lower += value
    impossible = upper < torch.max(lower, dim=1)[0]
    
    snapshot = env.ale.cloneState()
    for i in range(impossible.shape[1]):
        if (not impossible[0, i]):
            next_state, reward, done, _ = env.step(i)
            #Won the game, no need to check future states
            if reward > 1e-5:
                env.ale.restoreState(snapshot)
                continue
            else:
                if steps<=1 or reward < -1e-5:
                    env.ale.restoreState(snapshot)
                    return -1
                else:
                    next_envs.append((env.ale.cloneState(), next_state, steps-1))
                    env.ale.restoreState(snapshot)
    #base case, can't reach the goal with 0 steps
    return next_envs

def get_greedy_worst_case(curr_model, env, epsilon, state, max_steps):
    orig_env = env.ale.cloneState()
    for _ in range(max_steps):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        value, advs = curr_model.forward(state)
        output = value + advs
        action = torch.argmax(output, dim=1)

        upper, lower = network_bounds(curr_model.cnn, state, epsilon)
        upper, lower = subsequent_bounds(curr_model.advantage, upper, lower)
        upper += value
        lower += value
        impossible = upper < torch.max(lower, dim=1)[0]
        worst_case_action = torch.argmin(output+1e6*impossible, dim=1)
        next_state, reward, done, _ = env.step(worst_case_action[0])
        
        if reward > 1e-5:
            env.ale.restoreState(orig_env)
            #print("Greedy worst case reward: {}".format(reward))
            return 1
        elif reward < -1e-5:
            env.ale.restoreState(orig_env)
            #print("Greedy worst case reward: {}".format(reward))
            return -1
        else:
            state = next_state
            
    env.ale.restoreState(orig_env)
    return -1
    

def get_action_cert_rate(curr_model, env, epsilon, state, max_steps):
    certified = 0
    total = 0
    orig_env = env.ale.cloneState()
    for _ in range(max_steps):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        value, advs = curr_model.forward(state)
        output = value + advs
        #print(output)
        action = torch.argmax(output, dim=1)

        upper, lower = network_bounds(curr_model.cnn, state, epsilon)
        upper, lower = subsequent_bounds(curr_model.advantage, upper, lower)
        upper += value
        lower += value
        
        upper[:, action] = -1e10    
        max_other = torch.max(upper, dim=1)[0]
        if lower[:, action] > max_other:
            certified += 1
        total += 1
        
        action = torch.argmax(output, dim=1)
        
        next_state, reward, done, _ = env.step(action[0])
        
        if abs(reward) > 1e-5:
            env.ale.restoreState(orig_env)
            return certified/total
        
        else:
            state = next_state
    env.ale.restoreState(orig_env)
    return certified/total

def worst_case_reward(curr_model, env, epsilon, max_steps):
    envs_to_check = []
    
    state = env.reset()
    pos_rewards = 0
    neg_rewards = 0
    paths = 1
    with torch.no_grad():
        envs_to_check.append((env.ale.cloneState(), state, max_steps))
        
        greedy_reward = get_greedy_worst_case(curr_model, env, epsilon, state, max_steps)
        acr = get_action_cert_rate(curr_model, env, epsilon, state, max_steps)
        
        while len(envs_to_check)>0:
            
            snapshot, state, steps_remaining = envs_to_check.pop(-1)
            env.ale.restoreState(snapshot)
            out = get_next(curr_model, env, epsilon, state, steps_remaining) 
            if out == -1:
                return -1, greedy_reward, paths, acr
            else:
                envs_to_check.extend(out)
                paths += max(0,len(out)-1)
            #if (len(next_envs)-1) > 0 and paths%500==0:
                
            if paths > 2000:
                print(paths, len(envs_to_check))
                return 0, greedy_reward, paths, acr
        return 1, greedy_reward, paths, acr

In [None]:
def set_seed(random_seed, env):
    #set seeds for reproducible results
    torch.manual_seed(random_seed)
    env.seed(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    env.action_space.seed(random_seed)

In [None]:
%%time
verified_rewards = []
greedy_rewards = []
acrs = []
epsilons = np.array([1, 1.2, 1.3, 1.4, 1.5])/255#np.array([1, 1.2, 1.3, 1.4, 1.5])/255
#np.array([0.1, 0.3, 1, 1.1, 1.15, 1.2, 1.3, 3, 8])/255


for epsilon in epsilons:
    verified = []
    greedy = []
    acr = []
    print('Epsilon: {}'.format(epsilon))
    for j in range(20):
        env = create_env()
        set_seed(j,env)
        reward, greedy_reward, paths, acr_res = worst_case_reward(current_model, env, epsilon, max_steps=80)
        print('Greedy: {}, Absolute worst case reward:{}, paths checked:{}, action cert rate:{:.4f}'.format(greedy_reward,
                                                                                                     reward, paths, acr_res))
        if reward != 0:
            verified.append(reward)
            greedy.append(greedy_reward)
            acr.append(acr_res)
    verified_rewards.append(verified)
    greedy_rewards.append(greedy)
    acrs.append(acr)

In [None]:
font = {'size'   : 18}

matplotlib.rc('font', **font)

greed = [np.mean(i) for i in greedy_rewards]
ver = [np.mean(i) for i in verified_rewards]
acr_ = [np.mean(i) for i in acrs]

plt.plot(epsilons*255, greed, marker='o', label='Greedy worst case reward')
plt.plot(epsilons*255, ver, marker='.', label='Absolute worst case reward')
plt.plot(epsilons*255, np.array(acr_)*2-1, marker='s', label='(Action certification rate)*2-1')

plt.legend(bbox_to_anchor=[1,1.5])
plt.xlabel('epsilon*255')
plt.ylabel('Average result')
#plt.xscale('log')
plt.show()

In [None]:
font = {'size'   : 14}
matplotlib.rc('font', **font)

x = epsilons*255 # the label locations
width = 0.01  # the width of the bars

fig, ax = plt.subplots()

rects1 = ax.bar(x - width - 0.0015, (np.array(greed)+1)/2, width, label='GWC', color=np.array((255,153,51))/255)
rects2 = ax.bar(x, (np.array(ver)+1)/2, width, label='AWC', color=np.array((30,144,255))/255)
rects3 = ax.bar(x + width + 0.002, acr_, width, label='ACR', color=np.array((40,164,40))/255)

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Average result')
ax.set_xlabel('$\epsilon*255$')
plt.ylim(0,1)
plt.xlim(0.95,1.55)
#ax.set_title('Scores by group and gender')
ax.set_xticks(x)
ax.legend()

def label(rect, color=(0,0,0), offset=(0,3)):
    height = rect.get_height()
    ax.annotate('{:.2f}'.format(rect.get_height()),
                        xy=(rect.get_x() + rect.get_width() / 2, rect.get_height()),
                        xytext=offset,  # 3 points vertical offset
                        textcoords="offset points", color=color,
                        ha='center', va='bottom')
    
# label(rects1[2], np.array((255,153,51))/255 ,offset=(-20,-5))
# label(rects1[3], np.array((255,153,51))/255 ,offset=(-20,-20))
# label(rects1[4], offset=(-7,3))

# label(rects2[2], np.array((30,144,255))/255 ,offset=(-26,-10))
# label(rects2[3], np.array((30,144,255))/255 ,offset=(-26,-10))

# label(rects3[2], np.array((40,164,40))/255, offset=(20,-15))
# label(rects3[3], np.array((40,164,40))/255, offset=(20,-20))
# label(rects3[4], np.array((40,164,40))/255, (0,3))
label(rects1[2], np.array((255,153,51))/255 ,offset=(-20,-17))
label(rects1[3], np.array((255,153,51))/255 ,offset=(-19,-14))
label(rects1[4], np.array((255,153,51))/255, offset=(-7,16))

label(rects2[2], np.array((30,144,255))/255 ,offset=(-26,-22))
label(rects2[3], np.array((30,144,255))/255 ,offset=(-26,-5))
label(rects2[4], np.array((30,144,255))/255, offset=(-7,3))

label(rects3[2], np.array((40,164,40))/255, offset=(-33,-14))
label(rects3[3], np.array((40,164,40))/255, offset=(-33,-25))
label(rects3[4], np.array((40,164,40))/255, (-20,-140))

plt.show()

In [None]:
for i in range(len(epsilons)):
    print(epsilons[i]*255, len(greedy_rewards[i]))
               

In [None]:
def result_with_eps(eps_index):
    pos_acrs = []
    pos_gwcs = []
    neg_acrs = []
    neg_gwcs = []

    for i in range(len(verified_rewards[eps_index])):
        if verified_rewards[eps_index][i]==1:
            pos_acrs.append(acrs[eps_index][i])
            pos_gwcs.append(greedy_rewards[eps_index][i])
            
        elif verified_rewards[eps_index][i]==-1:
            neg_acrs.append(acrs[eps_index][i])
            neg_gwcs.append(greedy_rewards[eps_index][i])
        
    print('Epsilon: {}/255'.format(epsilons[eps_index]*255))
    print('Average total AWC:{} GWC:{} ACR:{}'.format(np.mean(verified_rewards[eps_index]), 
                                                      np.mean(greedy_rewards[eps_index]), np.mean(acrs[eps_index])))
    print('Positive AWC: {}/20'.format(len(pos_acrs)))
    print('Average pos GWC:{} ACR:{}'.format(np.mean(pos_gwcs), np.mean(pos_acrs)))
    print('Negative AWC: {}/20'.format(len(neg_acrs)))
    print('Average neg GWC:{} ACR:{}'.format(np.mean(neg_gwcs), np.mean(neg_acrs)))
    print('')

In [None]:
for i in range(len(epsilons)):
    result_with_eps(i)