In [3]:
import numpy as np
#import cv2
from matplotlib import pyplot as plt
%matplotlib notebook
import glob
import itertools
import os
import tqdm

In [4]:
trace_dir = '/home/ruben/play/traces/MontezumaRevenge-v0'
out_dir = '/home/ruben/play/out'
os.makedirs(out_dir, exist_ok=True)

In [5]:
from gym_recording import playback

In [6]:
episodes = []
def save_episode(o, a, r):
    episodes.append(dict(o=np.copy(o), a=np.copy(a), r=np.copy(r)))
playback.scan_recorded_traces(trace_dir, save_episode)

In [None]:
def show_key_points(im, kp):
    out = np.empty_like(im)
    out = cv2.drawKeypoints(im,kp, out)
    plt.ion()
    fig, ax = plt.subplots(1, 1, figsize=(3,4))
    ax.imshow(out)

In [None]:
im = episodes[0]['o'][101]; im.shape
sift = cv2.xfeatures2d.SIFT_create(contrastThreshold=.15, edgeThreshold=20, sigma=1.4)
kp, des = sift.detectAndCompute(im, None)
show_key_points(im, kp)

In [None]:
im2 = episodes[0]['o'][31]
kp2, des2 = sift.detectAndCompute(im2, None)

bf = cv2.BFMatcher(crossCheck=True)
matches = bf.match(des,des2)
matches = sorted(matches, key = lambda x:x.distance)
out = np.empty((420, 160, 3))
out = cv2.drawMatches(im,kp,im2,kp2,matches[:], out, flags=2)

In [None]:
plt.ion()
fig, ax = plt.subplots(1, 1, figsize=(6,4))
ax.imshow(out)

In [None]:
sift_params = {
    'contrastThreshold': np.arange(0, .11, .02),
    'edgeThreshold': np.arange(0, 10, 2),
    'sigma': np.arange(1, 3, .4)
}
sift_param_combos = [dict(zip(params.keys(), combo)) for combo in itertools.product(*sift_params.values())]
len(param_combos)

In [None]:
def explore_sift_params(param_combos):
    plt.ioff()
    out = np.empty_like(im)
    fig, ax = plt.subplots(1, 1, figsize=(12,12))
    for param_combo in tqdm.tqdm(param_combos):
        sift = cv2.xfeatures2d.SIFT_create(**param_combo)
        kp, des = sift.detectAndCompute(im, None)
        out = cv2.drawKeypoints(im, kp, out)
        ax.imshow(out)
        #axes[0].imshow(im)
        fig.savefig(os.path.join(out_dir, 'contrastTh={}_edgeTh={}_sigma={:.1f}.png'.format(param_combo['contrastThreshold'], param_combo['edgeThreshold'], param_combo['sigma'])),
                    bbox_inches='tight')

In [10]:
%matplotlib inline
# find the colors within the specified boundaries and apply
# the mask
images = episodes[1]['o']
n_images = images.shape[0]

plt.rcParams.update({'font.size': 22})

for i in tqdm.trange(n_images):
    image = images[i]
 
    # show the images
    fig, ax = plt.subplots(1, 1, figsize=(10,8))
    #key, avatar, skull = has_key(image), avatar_xy(image), skull_xy(image)
    #ax.set_title('avatar x,y=({},{}) \n skull x,y=({},{}) \n has_key={}'.format(avatar[1], avatar[0], skull[1], skull[0], key))
    ax.imshow(image)
    ax.set_xticks([])
    ax.set_yticks([])
    fig.savefig(out_dir + '/{:03d}.png'.format(i), bbox_inches='tight')
    plt.close(fig)

100%|██████████| 1822/1822 [02:54<00:00, 10.41it/s]


In [6]:
lower = np.array([232,204,99], dtype = "uint8")
upper = np.array([232,205,101], dtype = "uint8")

item_colours = dict(
    key=[232,204,99],
    avatar=[200, 72, 72],
    skull=[236, 236, 236]
)
item_colours = {k: np.array(v, dtype=np.uint8) for k, v in item_colours.items()}

def has_key(image):
    mask = cv2.inRange(image, item_colours['key'], item_colours['key'])
    output = cv2.bitwise_and(image, image, mask = mask)
    # Average out the colour
    output = np.mean(output, axis=2)
    matched_pixels = np.argwhere(output)
    
    return np.any(matched_pixels[:,0] <= 50)

def avatar_xy(image):
    mask = cv2.inRange(image, item_colours['avatar'], item_colours['avatar'])
    output = cv2.bitwise_and(image, image, mask = mask)
    # Average out the colour
    output = np.mean(output, axis=2)
    matched_pixels = np.argwhere(output)
    # Get rid of lives on top
    matched_pixels = matched_pixels[matched_pixels[:,0] >= 50]
    
    return matched_pixels[0]

def skull_xy(image):
    mask = cv2.inRange(image, item_colours['skull'], item_colours['skull'])
    output = cv2.bitwise_and(image, image, mask = mask)
    # Average out the colour
    output = np.mean(output, axis=2)
    matched_pixels = np.argwhere(output)
    # Get rid of lives on top
    matched_pixels = matched_pixels[matched_pixels[:,0] >= 50]
    
    return matched_pixels[0]

 
# find the colors within the specified boundaries and apply
# the mask
image = episodes[0]['o'][101]
mask = cv2.inRange(image, lower, upper)

In [None]:
np.stack([np.array([has_key(image)]), avatar_xy(image), skull_xy(image)])

In [None]:
np.array([has_key(image)]).shape

In [9]:
avatar

array([ 75, 147])

In [10]:
avatar[::-1]

array([147,  75])