In [1]:
from __future__ import print_function
import argparse, random, torch, os, math, json, sys, re
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.utils as vutils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from model_lin import get_model, load_model
from torch.utils.data import DataLoader, Dataset, TensorDataset

device = torch.device('cuda')
latent_dim = 32 # size of latent vector
batch_size = 32 # input batch size
GAME = 'mm'

smb_folder = 'smb_chunks_all/'
ki_folder = 'ki_chunks_all/'
mm_folder = 'mm_chunks_all/'
ng_folder = 'ng_chunks/'
met_folder = 'met_chunks_all/'

folders = {'smb':smb_folder,'ki':ki_folder,'mm':mm_folder,'ng':ng_folder,'met':met_folder}

folder = folders[GAME]
#manual_seed = random.randint(1, 10000)
#random.seed(manual_seed)
torch.manual_seed(0)
np.random.seed(0)

In [2]:
def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)
    
def parse_folder(folder):
    levels, dirs = [], []
    text = ''
    files = os.listdir(folder)
    files[:] = (value for value in files if value != '.')
    files = natural_sort(files)
    for file in files:
        if file.startswith('.'):
            continue
        with open(os.path.join(folder,file),'r') as infile:
            #print(infile)
            level = []
            for line in infile:
                text += line
                level.append(list(line.rstrip()))
            #level = [list(line.rstrip()) for line in infile]
            #level = [line.rstrip() for line in infile]
            #print(level)
            levels.append(level)
            #print(levels)
    return levels, text

In [3]:
def pipe_check(level):
    temp = ''
    for l in level:
        temp += ''.join(l)
    if '[' in temp and ']' not in temp:
        return False
    if ']' in temp and '[' not in temp:
        return False
    if '<' in temp and '>' not in temp:
        return False
    if '>' in temp and '<' not in temp:
        return False
    return True


next_dirs = []
levels, text = parse_folder(folder)

text = text.replace('\n','')
print(len(levels))
print("Num batches: ", len(levels)/batch_size)
chars = sorted(list(set(text.strip('\n'))))
#sketch_chars = ['X','E','|','*','P','-']
sketch_chars = ['X','E','|','*','-']
#print(lr_chars)
int2char = dict(enumerate(chars))
int2sc = dict(enumerate(sketch_chars))
char2int = {ch: ii for ii, ch in int2char.items()}
sc2int = {ch: ii for ii, ch in int2sc.items()}
print(char2int)
print(sc2int)
print(int2sc)
num_tiles = len(char2int)
num_sketch_tiles = len(sc2int)
print('Tiles: ', num_tiles)
print('Sketch Tiles: ', num_sketch_tiles)


def translate_mm(level):
    outs = []
    label_size = len(sc2int)-1
    for p in range(1,int(math.pow(2,label_size))):
        label = [int(i) for i in bin(p)[2:]]
        label = [0] * (label_size - len(label)) + label
        ones = [i for i, j in enumerate(label) if j == 1]
        t_level = []
        for line in level:
            t_line = ''
            for c in line:
                if c in '#BM':
                    t_line += 'X'
                elif c in 'CHt':
                    t_line += 'E'
                elif c in 'D|':
                    t_line += '|'
                elif c in '*+LUWlw':
                    t_line += '*'
                else:
                    t_line += c
            t_line = [l if sc2int[l] in ones else '-' for l in t_line]
            t_level.append(t_line)
        if t_level not in outs:
            outs.append(t_level)
    return outs

def translate_ki(level):
    outs = []
    label_size = len(sc2int)-1
    for p in range(1,int(math.pow(2,label_size))):
        label = [int(i) for i in bin(p)[2:]]
        label = [0] * (label_size - len(label)) + label
        ones = [i for i, j in enumerate(label) if j == 1]
        t_level = []
        for line in level:
            t_line = ''
            for c in line:
                if c in '#MT':
                    t_line += 'X'
                elif c in 'H':
                    t_line += 'E'
                elif c in 'D':
                    t_line += '|'
                else:
                    t_line += c
            t_line = [l if sc2int[l] in ones else '-' for l in t_line]
            t_level.append(t_line)
        if t_level not in outs:
            outs.append(t_level)
    return outs

def translate_ki_lean(level):
    t_level = []
    for line in level:
        t_line = ''
        for c in line:
            if c in '#MT':
                t_line += 'X'
            elif c in 'H':
                t_line += 'E'
            elif c in 'D':
                t_line += '|'
            elif c in 'P':
                t_line += '-'
            else:
                t_line += c
        t_level.append(t_line)
    return [t_level]


def translate_smb(level):
    outs = []
    label_size = len(sc2int)-1
    for p in range(1,int(math.pow(2,label_size))):
        label = [int(i) for i in bin(p)[2:]]
        label = [0] * (label_size - len(label)) + label
        ones = [i for i, j in enumerate(label) if j == 1]
        t_level = []
        for line in level:
            t_line = ''
            for c in line:
                if c in 'X<>[]S':
                    t_line += 'X'
                elif c in 'o?Q':
                    t_line += '*'
                else:
                    t_line += c
            t_line = [l if sc2int[l] in ones else '-' for l in t_line]
            t_level.append(t_line)
        if t_level not in outs:
            outs.append(t_level)
    return outs

def translate_smb_lean(level):
    t_level = []
    for line in level:
        t_line = ''
        for c in line:
            if c in 'X<>[]S':
                t_line += 'X'
            elif c in 'o?Q':
                t_line += '*'
            elif c in 'Bb':
                t_line += 'E'
            else:
                t_line += c
        t_level.append(t_line)
    return [t_level]

def translate_mm_lean(level):
    t_level = []
    for line in level:
        t_line = ''
        for c in line:
            if c in '#BM':
                t_line += 'X'
            elif c in 'CHt':
                t_line += 'E'
            elif c in 'D|':
                t_line += '|'
            elif c in '*+LUWlw':
                t_line += '*'
            elif c in 'P':
                t_line += '-'
            else:
                t_line += c
        t_level.append(t_line)
    return [t_level]

            
translate = {'smb':translate_smb_lean, 'ki':translate_ki_lean, 'mm':translate_mm_lean}
inputs, targets = [], []
for level in levels:
    if GAME == 'smb' and not pipe_check(level):
        continue
    tar, inp = [], []
    translate_func = translate[GAME]
    t_levels = translate_func(level)
    for t_level in t_levels:
        inp = []
        for line in t_level:
            encoded_line = [sc2int[x] for x in line]
            inp.append(encoded_line)
        inputs.append(inp)
    for line in level:
        encoded_line = [char2int[x] for x in line]
        tar.append(encoded_line)
    for _ in range(len(t_levels)):
        targets.append(tar)
    
inputs = np.array(inputs)
targets = np.array(targets)
print(inputs.shape, targets.shape)

inputs_onehot = np.eye(num_sketch_tiles, dtype='uint8')[inputs]
inputs_onehot = np.rollaxis(inputs_onehot, 3, 1)
targets_onehot = np.eye(num_tiles, dtype='uint8')[targets]
targets_onehot = np.rollaxis(targets_onehot, 3, 1)

inputs_train = torch.from_numpy(inputs_onehot).to(dtype=torch.float64)
targets_train = torch.from_numpy(targets_onehot).to(dtype=torch.float64)
train_ds = TensorDataset(inputs_train,targets_train)
train_dl = DataLoader(train_ds, batch_size=batch_size,shuffle=True)

vae, opt = get_model(device, 240, num_sketch_tiles, num_tiles, latent_dim,1e-3)
#vae, opt = get_conv_big_model(device, num_sketch_tiles, num_tiles, latent_dim)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=2500)
print(vae)
#sys.exit()

def loss_fn(recon_x, x, mu, logvar):
    #BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')/recon_x.size(0)
    #BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD, BCE, KLD

143
Num batches:  4.46875
{'#': 0, '*': 1, '+': 2, '-': 3, 'B': 4, 'C': 5, 'H': 6, 'L': 7, 'M': 8, 'P': 9, 'U': 10, 'W': 11, 'l': 12, 't': 13, 'w': 14, '|': 15}
{'X': 0, 'E': 1, '|': 2, '*': 3, '-': 4}
{0: 'X', 1: 'E', 2: '|', 3: '*', 4: '-'}
Tiles:  16
Sketch Tiles:  5
(143, 15, 16) (143, 15, 16)
VAE(
  (encoder): Encoder(
    (fc1): Linear(in_features=1200, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=256, bias=True)
    (fc41): Linear(in_features=256, out_features=32, bias=True)
    (fc42): Linear(in_features=256, out_features=32, bias=True)
  )
  (decoder): Decoder(
    (fc5): Linear(in_features=32, out_features=256, bias=True)
    (fc6): Linear(in_features=256, out_features=512, bias=True)
    (fc7): Linear(in_features=512, out_features=1024, bias=True)
    (fc8): Linear(in_features=1024, out_features=3840, bias=True)
  )
)


In [4]:
model_name = 'filter_vae_fc_nored_' + GAME + '_ld_' + str(latent_dim)
out_file = open(model_name + '_loss.csv','w')
out_file.write('Train Loss,KLD,KLW\n')
epochs = 10000 # num epochs to train for
k = 0.0025
rate = 2500

for i in range(epochs):
    vae.train()
    train_loss = 0
    kld_loss = 0
    for batch, (x,y) in enumerate(train_dl):
        x, y = x.to(device), y.to(device)
        x_lin = x.view(x.size(0),-1)
        print(x_lin.shape)
        sys.exit()
        opt.zero_grad()
        recon_x_lin, mu, logvar, z = vae(x_lin)
        recon_x = recon_x_lin.reshape(recon_x_lin.size(0),y.size(1),x.size(2),x.size(3))
        #print(x.shape, recon_x.shape, y.shape)
        #sys.exit()
        #recon_x, mu, logvar, z = vae(x)
        loss, bce, kld = loss_fn(recon_x, y, mu, logvar)
        klw = min(1.0, i/rate)
        loss = bce + kld*klw
        train_loss += loss.item()
        kld_loss += kld.item()
        loss.backward()
        opt.step()
    train_loss /= len(train_dl.dataset)
    kld_loss /= len(train_dl.dataset)
    if i % 100 == 0:
        print('Epoch: ', i,'\tLoss: ',train_loss,"\tKLD: ", kld_loss, "\tKLW: ", klw)
    if i % 1000 == 0:
        torch.save(vae.state_dict(), model_name + '_' + str(i) + '.pth')
        out_file.write(str(train_loss)+','+str(kld_loss)+','+str(klw)+'\n')
    scheduler.step()
print('Epoch: ', i,'\tLoss: ',train_loss,"\tKLD: ", kld_loss, "\tKLW: ", klw)
torch.save(vae.state_dict(), model_name + '_final' + '.pth')
out_file.close()

torch.Size([32, 1200])


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
