In [None]:
%pylab inline
%run bizhawk.py
%run helper.py
import os
import subprocess
import shutil
from keras.applications.inception_v3 import InceptionV3

# RRT

In [None]:
#RRT function
def explore_with_rrt(bh, initial_state, successor, get_goal,
                     projection, max_samples=1000):
    edges = []
    states = [initial_state]
    projections = [projection(initial_state)]
    available_actions = [set(np.arange(ACTION_NUM))]
    
    for i in range(max_samples):
        goal = get_goal()
        
        #Get the cloest point in the tree
        min_index = 0
        min_value = np.linalg.norm(projections[0] - goal)
        for j in range(len(states)):
            temp_projection = np.linalg.norm(projections[j] - goal)
            if temp_projection <= min_value:
                min_index = j
                min_value = temp_projection
        chosen_index = min_index
        
        #Get successor state
        chosen_state = states[chosen_index]
        selected_action, successor_state = successor(bh, chosen_state, goal,
                                                     available_actions[chosen_state])
        
        #Append successor state
        available_actions.append(set(np.arange(ACTION_NUM)))
        chosen_projection = projections[chosen_index]
        successor_projection = projection(successor_state)
        states.append(successor_state)
        projections.append(successor_projection)
        
        edges.append((goal, chosen_state, selected_action,
                      successor_state, chosen_projection,
                      successor_projection))
        
    return edges
        
#Project states to embeddings
def state_projection(state):
    return embeddings[state]

#Select a random action, do it for 30 frames and return the successor state
def get_random_successor(bh, state, goal, available_action):
    button_names = ['Up', 'Down', 'Left', 'Right', 'Select',
                    'Start', 'B', 'A', 'X', 'Y', 'L', 'R']
    
    #Select a random action based on button distribution
    temp_action = np.random.choice(ACTION_NUM, p=distribution)
    while temp_action not in available_action:
        temp_action = np.random.choice(ACTION_NUM, p=distribution)
    action = '{0:b}'.format(temp_action).rjust(12, '0')[::-1]
    
    next_state = len(embeddings)
    
    #Do the selected action
    for i in range(FRAMES_PER_STEP + 1):
        code = b'buttons = {};'
        
        for j, name in enumerate(button_names):
            if action[j] == '1':
                code += b'buttons["' + str.encode(name) + b'"] = 1;'
                
        code += b'joypad.set(buttons, 1);'
        code += b'emu.frameadvance();'
        
        if i == 0:
            bh.send(b'savestate.load("' + str.encode(paths[state]) + b'");')
            bh.send(code)

        elif i == FRAMES_PER_STEP:
            temp_img = bh.read_img('temp.png')
            temp_img = np.expand_dims(temp_img, axis=0)
            temp_img = embedded_model.predict(temp_img)[0]
            embeddings.append(temp_img)
            paths.append(state_path + str(len(embeddings)) + '.state')
            
            bh.send(b'savestate.save("' + str.encode(paths[-1]) + b'");')

        else:
            bh.send(b'joypad.set(buttons, 1);emu.frameadvance();')

    return (temp_action, next_state)

#Initialize a random goal point 
def random_goal():
    goal = np.random.uniform(0, 1, 1000)
    return goal
        
    
rom_path = ''
bizhawk_path = 'BizHawk-2.2.2/'
lua_path = '../action.lua'
state_path = 'states/'

steps = 200
FRAMES_PER_STEP = 30
ACTION_NUM = 4096
embedded_model = InceptionV3()
distribution=np.load('SNES_distribution.npy')

In [None]:
embeddings = []
paths = []

if not os.path.exists(state_path):
    os.mkdir(state_path)

with BizHawk(rom_path, bizhawk_path, lua_path=lua_path) as bh:
    bh.send(b'client.speedmode(400);')
    
    #Save the initial state
    temp_img = bh.read_img('temp.png')
    '''
    temp_img = np.expand_dims(temp_img, axis=0)
    temp_img = embedded_model.predict(temp_img)[0]
    embeddings.append(temp_img)
    '''
    embeddings.append(get_embedding(temp_img, embedded_model))
    
    paths.append(state_path + str(len(embeddings)) + '.state')
    bh.send(b'savestate.save("' + str.encode(paths[-1]) + b'");print("ok");')
    
    #Run RRT
    edges = explore_with_rrt(bh, 0, get_random_successor,
                             random_goal, state_projection,
                             max_samples=steps)

# Visualization

In [None]:
os.remove('temp.png')
shutil.rmtree(state_path)

bbox_sum, nuc_norm = get_score(embeddings)

plot(bbox_sum)
title('Bounding Box Sum')
show()

figure()
plot(nuc_norm)
title('Nuclear Norm')
show()