In [1]:
#@title Get started
mount_drive = True #@param {type:"boolean"}
replace_existing = False #@param {type:"boolean"}

import os

if mount_drive:
    from google.colab import drive
    drive.mount('/content/drive')
    os.chdir('/content/drive/My Drive/')

if os.path.exists('VQVAE-Clean'):
    if replace_existing:
        proceed = ''
        while not (proceed == 'y' or proceed == 'n'):
            proceed = input('\nAre you sure that you want to replace the existing directory?\nThis will delete all existing model checkpoints and samples (y/n): ')
        if proceed == 'y':
            print ('Removing existing code directory...')
            import shutil
            shutil.rmtree('VQVAE-Clean')
            print ('Cloning repository...')
            !git clone https://github.com/Ryan-Rudes/VQVAE-Clean
        else:
            print ('Alright... stopping.')
            raise KeyboardInterrupt
    else:
        print ('Using existing code directory')
else:
    print ('Cloning repository...')
    !git clone https://github.com/Ryan-Rudes/VQVAE-Clean

%cd VQVAE-Clean

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using existing code directory
/content/drive/My Drive/VQVAE-Clean


In [None]:
#@title Install dependencies
!pip install googledrivedownloader
!pip install rich

In [None]:
#@title Download a dataset
dataset = "Breakout" #@param ["MontezumaRevenge", "Breakout", "Pitfall", "Qbert"]
from utils import *
download(dataset, f'/content/{dataset}.zip')
unzip(f'/content/{dataset}.zip', '/content')
resize(f'/content/{dataset}/cells')

In [4]:
import distributed as dist
from torch import optim
from train import train
import sys
import os

#@title Setup

class Args:
    n_gpu = 1
    port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
    dist_url = f'tcp://127.0.0.1:{port}'
    epoch = 500 #@param {type:"integer"}
    lr = 3e-4 #@param {type:"number"}
    batch_size = 128 #@param {type:"integer"}
    num_workers =  4#@param {type:"integer"}
    normalize = True #@param {type:"boolean"}
    optimizer = optim.Adam
    distributed = dist.get_world_size() > 1
    sched = ''
    path = f'/content/{dataset}/resized'

args = Args()

In [None]:
#@title Train
#@markdown Results are saved to `./runs/<timestamp>`
train(args)

Output()

In [35]:
#@title Load Model For Evaluation
timestamp = '1619471050.8708587' #@param {type:"string"}

from vqvae import VQVAE
import torch

torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoints_dir = os.path.join('runs', timestamp, 'checkpoint')
checkpoints = os.listdir(checkpoints_dir)
checkpoints = list(filter(lambda name: name.endswith('.pt'), checkpoints))
latest = max(checkpoints, key = lambda name: int(name[:-3].split('_')[1]))
PATH = os.path.join(checkpoints_dir, latest)

def encode_fn(PATH):
    print ('Loading model...')
    model = VQVAE()
    saved = torch.load(PATH, map_location = torch.device(device))
    model.load_state_dict(saved)
    model = model.to(device)
    model.eval()

    def encode(observation):
        x = cv2.resize(observation, (160, 160), interpolation = cv2.INTER_AREA)
        x = x / 255.0
        x = torch.Tensor(x)
        x = x.permute(2, 0, 1)
        x = x.unsqueeze(0)
        x = x.to(device)

        _, _, _, indices, _ = model.encode(x)
        encoded = indices.cpu().numpy()[0]

        return encoded

    return encode

In [40]:
from goexplore.algorithm import GoExplore
from goexplore.wrappers import *

#@title Evaluate with VQVAE2
#@markdown **Algorithm**
iterations = 1000 #@param {type:"integer"}
env = 'Qbert' #@param ['MontezumaRevenge', 'SpaceInvaders', 'VideoPinball', 'Breakout', 'Qbert', 'Pong', 'Pitfall']
method = 'ram' #@param ["ram", "trajectory"]
repeat = 0.95 #@param {type:"number"}
nsteps = 100 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
#@markdown **Logging**
verbosity = 1 #@param [0, 1, 2]
delimeter = '\s' #@param {type:"string"}
separator = True #@param {type:"boolean"}

print ('Creating environment...')
env = name2env[env]()

print ('Starting algorithm...')
goexplore = GoExplore(env)
goexplore.initialize(repeat = repeat, nsteps = nsteps, seed = seed, method = method, cellfn = encode_fn(PATH))
goexplore.run_for(iterations, verbose = verbosity, separator = separator, delimeter = delimeter)

Creating environment...
Starting algorithm...
Loading model...


Output()

In [3]:
%cd ..
!ls

/Users/ryanrudes/Downloads/Code/VQVAE-Clean
README.md            learn_and_explore.py scheduler.py
[1m[36mdistributed[m[m          [1m[36mnotebooks[m[m            train.py
explore.py           preprocess.py        utils.py
[1m[36mgoexplore[m[m            [1m[36mruns[m[m                 vqvae.py


In [None]:
from goexplore.algorithm import GoExplore
from goexplore.wrappers import *
from goexplore.utils import *

#@title Evaluate with Downscaled Representations
#@markdown **Algorithm**
duration = 1000000 #@param {type:"integer"}
units = 'frames' #@param ["frames", "iterations"]
env = 'Qbert' #@param ['MontezumaRevenge', 'SpaceInvaders', 'VideoPinball', 'Breakout', 'Qbert', 'Pong', 'Pitfall']
method = 'ram' #@param ["ram", "trajectory"]
repeat = 0.95 #@param {type:"number"}
nsteps = 100 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
#@markdown **Logging**
verbosity = 1 #@param [0, 1, 2]
delimeter = '\s' #@param {type:"string"}
separator = True #@param {type:"boolean"}
#@markdown **Cell Representations**
width = 11 #@param {type:"slider", min:1, max:20, step:1}
height = 8 #@param {type:"slider", min:1, max:20, step:1}
intensities = 8 #@param {type:"slider", min:2, max:64, step:1}
grayscale = True #@param {type:"boolean"}

print ('Creating environment...')
env = name2env[env]()

print ('Starting algorithm...')
goexplore = GoExplore(env)
goexplore.initialize(repeat = repeat, nsteps = nsteps, seed = seed, method = method, cellfn = makecellfn(width = width, height = height, interpolation = cv2.INTER_AREA, grayscale = grayscale, intensities = intensities))
goexplore.run_for(duration, verbose = verbosity, units = units, separator = separator, delimeter = delimeter, renderfn = lambda iterations: iterations % 50 == 0)

Creating environment...
Starting algorithm...
