In [1]:
#@title Get started
mount_drive = True #@param {type:"boolean"}
replace_existing = True #@param {type:"boolean"}

import os

if mount_drive:
    from google.colab import drive
    drive.mount('/content/drive')
    os.chdir('/content/drive/My Drive/')

if os.path.exists('VQVAE-Clean'):
    if replace_existing:
        proceed = ''
        while not (proceed == 'y' or proceed == 'n'):
            proceed = input('\nAre you sure that you want to replace the existing directory?\nThis will delete all existing model checkpoints and samples (y/n): ')
        if proceed == 'y':
            print ('Removing existing code directory...')
            import shutil
            shutil.rmtree('VQVAE-Clean')
            print ('Cloning repository...')
            !git clone https://github.com/Ryan-Rudes/VQVAE-Clean
        else:
            print ('Alright... stopping.')
            raise KeyboardInterrupt
    else:
        print ('Using existing code directory')
else:
    print ('Cloning repository...')
    !git clone https://github.com/Ryan-Rudes/VQVAE-Clean

%cd VQVAE-Clean

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Are you sure that you want to replace the existing directory?
This will delete all existing model checkpoints and samples (y/n): y
Removing existing code directory...
Cloning repository...
Cloning into 'VQVAE-Clean'...
remote: Enumerating objects: 306, done.[K
remote: Counting objects: 100% (306/306), done.[K
remote: Compressing objects: 100% (217/217), done.[K
remote: Total 306 (delta 208), reused 183 (delta 85), pack-reused 0[K
Receiving objects: 100% (306/306), 78.67 KiB | 4.63 MiB/s, done.
Resolving deltas: 100% (208/208), done.
/content/drive/My Drive/VQVAE-Clean


In [13]:
#@title Install dependencies
!pip install googledrivedownloader
!pip install rich



In [7]:
#@title Download a dataset
dataset = "Qbert" #@param ["MontezumaRevenge", "Breakout", "Pitfall", "Qbert"]
from utils import *
download(dataset, f'/content/{dataset}.zip')
unzip(f'/content/{dataset}.zip', '/content')
resize(f'/content/{dataset}/cells')

Downloading 1039o1Z40_8nZbf_XM_cxZWpbSjKZRk0Y into /content/Qbert.zip... Done.


Output()

Output()

In [8]:
import distributed as dist
from torch import optim
from train import train
import sys
import os

#@title Setup

class Args:
    n_gpu = 1
    port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
    dist_url = f'tcp://127.0.0.1:{port}'
    epoch = 560 #@param {type:"integer"}
    lr = 3e-4 #@param {type:"number"}
    batch_size = 128 #@param {type:"integer"}
    num_workers =  4#@param {type:"integer"}
    normalize = True #@param {type:"boolean"}
    optimizer = optim.Adam
    distributed = dist.get_world_size() > 1
    sched = ''
    path = f'/content/{dataset}/resized'

args = Args()

In [None]:
#@title Train
#@markdown Results are saved to `./runs/<timestamp>`
train(args)

  cpuset_checked))


Output()

In [None]:
#@title Load Model For Evaluation
timestamp = '' #@param {type:"string"}

from vqvae import VQVAE
import torch

torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

checkpoints_dir = os.path.join(timestamp, 'checkpoint')
checkpoints = os.listdir(checkpoints_dir)
checkpoints = list(filter(lambda name: name.endswith('.pt'), checkpoints))
latest = max(checkpoints, key = lambda name: int(name[:-3].split('_')[1]))
PATH = os.path.join(checkpoints_dir, latest)

def encode_fn(PATH):
    print ('Loading model...')
    model = VQVAE()
    saved = torch.load(PATH, map_location = torch.device(device))['model']
    model.load_state_dict(saved)
    model.eval()

    def encode(observation):
        x = cv2.resize(observation, (196, 148), interpolation = cv2.INTER_AREA)
        x = x / 255.0
        x = torch.Tensor(x)
        x = x.permute(0, 3, 1, 2)
        x = x.unsqueeze(0)
        x = x.to(device)

        _, _, _, indices, _ = model.encode(x)
        encoded = indices.cpu().numpy()[0]

        return encoded

    return encode

In [None]:
#@title Evaluate
#@markdown **Algorithm**
iterations = 10000 #@param {type:"integer"}
method = 'ram' #@param ["ram", "trajectory"]
verbose = 1 #@param [0, 1, 2]
repeat = 0.95 #@param {type:"number"}
nsteps = 100 #@param {type:"integer"}
seed = 42 #@param {type:"integer"}
#@markdown **Logging**
delimeter = '\s' #@param {type:"string"}
separator = True #@param {type:"boolean"}

print ('Creating environment...')
env = MontezumaRevenge()

print ('Starting algorithm...')
goexplore = GoExplore(env)
goexplore.initialize(repeat = repeat, nsteps = nsteps, seed = seed, method = method, cellfn = encode_fn(PATH))
goexplore.run_for(iterations, verbose = verbose, separator = separator, delimeter = delimeter)