In [121]:
global_params = {
    'current_game':'smba',
    'dev':'cpu',
    'smba_chunks_json_fpath':'../data/smbWithPath-allLevels-chunks-int.json',
    'kia_chunks_json_fpath':'../data/kiWithPath-allLevels-chunks-int.json'
}

In [105]:
import numpy as np
import torch

In [109]:
def log_progress(sequence, every=None, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

## Get vae and its optimizer

In [125]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [126]:
import sys
sys.path.append('../modules')

In [127]:
if global_params['current_game'] == 'smba':
    enc_input_shape = (12, 16, 16)
    dec_input_shape = (64, 1, 1)
elif global_params['current_game'] == 'kia':
    enc_input_shape = (7, 16, 16)
    dec_input_shape = (64, 1, 1)

In [128]:
from vae_designer import VAEDesigner
from custom_vae import VAEDesign

In [136]:
designing = False

In [130]:
if designing:
    enc_designer = VAEDesigner(input_shape=enc_input_shape, num_layers=3, up_sample=False)

VBox(children=(HBox(children=(Label(value='DOWNSAMPLING MODE'),)), HBox(children=(Button(description='ADD LAYE…

In [131]:
if designing:
    dec_designer = VAEDesigner(input_shape=dec_input_shape, num_layers=3, up_sample=True)

VBox(children=(HBox(children=(Label(value='UPSAMPLING MODE'),)), HBox(children=(Button(description='ADD LAYER'…

In [132]:
if designing:
    vae_design = VAEDesign(
        down_sampler_design=enc_designer.design, 
        up_sampler_design=dec_designer.design, 
        h_dim=64, 
        z_dim=32, 
        unflatten_out_shape=(64, 1, 1)
    )

In [135]:
if designing:
    current_game = global_params['current_game']
    vae_design.save_as_json(f'designs/pcgml_gmmvae_{current_game}.json')

In [137]:
from functools import partial
from custom_vae import get_vae_and_opt

In [138]:
print(get_vae_and_opt.__doc__)

Get a trainable VAE and its optimizer.


In the box below, we obtain the single function to conveniently get a trainable VAE and its optimizer from the design json.

In [141]:
current_game = global_params['current_game']
get_vae_and_opt_gmmvae = partial(
    get_vae_and_opt, 
    design_json_fpath=f'designs/pcgml_gmmvae_{current_game}.json',
    dev=global_params['dev'],
)

## Get dataloader

In [142]:
from vglc_with_path_encodings import array_from_json, array_to_image

In [143]:
if global_params['current_game'] == 'smba':
    np_int_chunks = array_from_json(global_params['smba_chunks_json_fpath'])
elif global_params['current_game'] == 'kia':
    np_int_chunks = array_from_json(global_params['kia_chunks_json_fpath'])

2698 chunks loaded from ../data/smbWithPath-allLevels-chunks-int.json.


In [144]:
from fast_train import DataPipeline

In [145]:
print(DataPipeline.pcgml_gmmvae.__doc__)


        Convert an array of 2d arrays of integers to a WrappedDataLoader that can be used to train a binary VAE.
        A binary VAE is a VAE that takes in and outputs one-hot encoded arrays.
        
        :param np_int_imgs: an array of 2d arrays of integers, shape: (bs, height, width)
        :param bs: batch_size
        :param shuffle: whether training examples and targets are shuffled, True for training, False for validation
        :return a WrappedDataLoader instance that can be used directly for training a binary VAE
        


In [146]:
train_dl = DataPipeline.pcgml_gmmvae(
    np_int_chunks, 
    num_channels=len(np.unique(np_int_chunks)),
    bs=64,
    dev=global_params['dev']
)

## Get loss functions

In [147]:
from fast_train import Loss

In [148]:
loss = Loss.bce_kld_total

## Integrate dataloader, model, loss and optimizer using Learner

In [149]:
from learner import Learner

In [150]:
def get_new_learner():
    model, opt = get_vae_and_opt_gmmvae()
    return Learner(train_data=train_dl, model=model, loss=loss, optim=opt)

## Train

In [151]:
from generic_callbacks import *

In [152]:
class GenLogger(Callback):
    """Visualize generated arrays for generative models."""
    _order=1
    
    def __init__(self, gen_name:str, group:str):
        self.gen_name = gen_name
        self.group = group

    def on_epoch_end(self):
        
        torch_binary_imgs = self.sd[self.gen_name]
        np_binary_imgs = torch_binary_imgs.numpy()
        np_int_imgs = np.argmax(np_binary_imgs, axis=1)
        pil_imgs = array_to_image(np_int_imgs, game=global_params['current_game'][:-1])
        np_rgb_imgs = np.array([np.array(im) for im in pil_imgs])
        torch_rgb_imgs = torch.from_numpy(np_rgb_imgs)
        
        self.sd['writer'].add_images(
            f'{self.group}/{self.gen_name}', 
            torch_rgb_imgs,
            global_step=self.sd['epoch']
        )

In [153]:
class VAETrainer(CallbackHandler):
    
    def __init__(self, learn, cbs):
        self.learn = learn
        self.cbs = cbs
        
    def train(self, num_epochs):
        self.on_train_begin()  # create empty accumulators
        for epoch in range(num_epochs):
            self.on_epoch_begin()
            for xb, yb in log_progress(self.learn.train_data, name=f'epoch: {epoch+1}'):
                
                self.on_batch_begin()
                
                recon, mu, logvar = self.learn.model(xb)
                
                self.on_loss_begin()
              
                loss, bce, kld = self.learn.loss(recon, yb, mu, logvar)
                loss, bce, kld = loss / recon.size(0), bce / recon.size(0), kld / (recon.size(0) * 64)
                
                self.on_backward_begin()
                loss.backward()
                self.on_backward_end()
                self.learn.optim.step()
                self.on_step_end()
                
                self.sd.update({
                    'model':self.learn.model,
                    'epoch':epoch+1, 
                    'batch_size':int(xb.size(0)), 
                    'loss_b':float(loss), 
                    'bce_b':float(bce), 
                    'kld_b':float(kld)
                })
                
                self.on_batch_end()
                
            self.sd.update({'orgs': xb, 'recons': recon})
                
            gens = self.learn.model.generate(n=5)
            self.sd.update({'gens':gens})
            
            self.on_epoch_end()  # calculate average loss per example, visualize metrics, perform validation if required
        self.on_train_end()  # close tensorboard writer, output csv of metrics

In [156]:
exp_name = 'pcgml_gmmvae'
trial = 1
overwrite_vis = True
overwrite_csv = True
overwrite_pth = True
debugger_on = False

learner = get_new_learner()

vae_cbs = [
    TensorboardCreator(log_dir=f'runs/{exp_name}/{trial}', overwrite=overwrite_vis),  # automatically mkdir
    MetricLogger(metric_name='loss', group='train', on_tensorboard=True),
    MetricLogger(metric_name='bce', group='train', on_tensorboard=True),
    MetricLogger(metric_name='kld', group='train', on_tensorboard=True),
    MetricsPrinter(metrics_to_print=['last_loss', 'last_bce', 'last_kld']),
    MetricsSaver(metrics_to_save=['losss', 'bces', 'klds'], csv_path=f'training_csv/{exp_name}/{trial}.csv', overwrite=overwrite_csv),
    GenLogger(gen_name='orgs', group='group1'),
    GenLogger(gen_name='recons', group='group1'),
    GenLogger(gen_name='gens', group='group1'),
    ModelSaver(model_path=f'trained_models/{exp_name}/{trial}.pth', overwrite=overwrite_pth),
    Debugger(on=debugger_on),
]

vae_trainer = VAETrainer(learn=learner, cbs=vae_cbs)

In [None]:
vae_trainer.train(num_epochs=10)

VBox(children=(HTML(value=''), IntProgress(value=0, max=43)))