In [None]:
from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.utils as vutils
from torch.autograd import Variable
import os
import numpy as np
import matplotlib.pyplot as plt
import math
import json
import sys
from PIL import Image
from model_lin_cond import get_cond_model, load_cond_model
from torch.utils.data import DataLoader, Dataset, TensorDataset
import re
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics
import pickle

device = torch.device('cuda')
GAME = 'blend'  # smb, ki, mm, smb_pats, blend
latent_dim = 32 # size of latent vector
batch_size = 32 # input batch size
is_pats = True if GAME == 'smb_pats' else False

smb_folder = 'smb_chunks_fixed/'
ki_folder = 'ki_chunks/'
mm_folder = 'mm_chunks_fixed/'
pats_folder = 'smb_pats'

folders = {'smb':smb_folder,'ki':ki_folder,'mm':mm_folder,'blend':None,'smb_pats':pats_folder}
labels = {'smb':5, 'ki':4, 'mm':5, 'blend':3,'smb_pats':10}
num_labels = labels[GAME]
folder = folders[GAME]
#manual_seed = random.randint(1, 10000)
#random.seed(manual_seed)
torch.manual_seed(0)
np.random.seed(0)

In [None]:
def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)
    
def parse_folder(folder,pats=False,dir=False):
    levels, dirs, patterns = [], [], []
    text = ''
    files = os.listdir(folder)
    files[:] = (value for value in files if value != '.')
    files = natural_sort(files)
    for file in files:
        if file.startswith('.'):
            continue
        with open(os.path.join(folder,file),'r') as infile:
            if pats:
                ps1 = file[file.find('_')+1:]
                ps2 = ps1[ps1.find('_')+1:]
                ps3 = ps2[ps2.find('_')+1:-4]
                pat = ps3.split('_')
                patterns.append(pat)
            level = []
            for line in infile:
                text += line
                level.append(list(line.rstrip()))
            levels.append(level)
            dirs.append(file[-5])
    if dir:
        return levels, text, dirs
    return levels, text, patterns

In [None]:
def get_label_smb(level):
    label = [False,False,False,False,False]  # Enemy, Pipes, Coins, Breakable, QMs
    temp = ''
    for l in level:
        temp += ''.join(l)
    if 'E' in temp:
        label[0] = True
    if '[' in temp and ']' in temp and '<' in temp and '>' in temp:
        label[1] = True
    if 'o' in temp:
        label[2] = True
    if 'S' in temp:
        label[3] = True
    if 'Q' in temp or '?' in temp:
        label[4] = True
    return label

def get_label_ki(level):
    label = [False,False,False,False]  # Enemy, Doors, Moving, T
    temp = ''
    for l in level:
        temp += ''.join(l)
    if 'H' in temp:
        label[0] = True
    if 'D' in temp:
        label[1] = True
    if 'M' in temp:
        label[2] = True
    if 'T' in temp:
        label[3] = True
    return label

def get_label_mm(level):
    label = [False,False,False,False,False]  # H/T/C, D, |, M, */U/W/w/+/l/L
    temp = ''
    for l in level:
        temp += ''.join(l)
    if 'H' in temp or 'T' in temp or 'C' in temp:
        label[0] = True
    if 'D' in temp:
        label[1] = True
    if '|' in temp:
        label[2] = True
    if 'M' in temp:
        label[3] = True
    if '*' in temp or 'U' in temp or 'W' in temp or 'w' in temp or '+' in temp or 'l' in temp or 'L' in temp:
        label[4] = True
    return label

def get_label_pat(pat):
    label = [False] * 10   # EH, G, PV, GV, NV, EV, MP, RR, SU, SD
    if 'EH' in pat:
        label[0] = True
    if 'G' in pat:
        label[1] = True
    if 'PV' in pat:
        label[2] = True
    if 'GV' in pat:
        label[3] = True
    if 'NV' in pat:
        label[4] = True
    if 'EV' in pat:
        label[5] = True
    if 'MP' in pat:
        label[6] = True
    if 'RR' in pat:
        label[7] = True
    if 'SU' in pat:
        label[8] = True
    if 'SD' in pat:
        label[9] = True
    return label

def get_label_pat_lean(pat):
    label = [False] * 5   # EH, G, PV, GV, NV, EV, MP, RR, SU, SD
    if 'EH' in pat:
        label[0] = True
    if 'G' in pat:
        label[1] = True
    if 'V' in pat:
        label[2] = True
    if 'MP' in pat:
        label[3] = True
    if 'S' in pat:
        label[4] = True
    return label

def pipe_check(level):
    temp = ''
    for l in level:
        temp += ''.join(l)
    if '[' in temp and ']' not in temp:
        return False
    if ']' in temp and '[' not in temp:
        return False
    if '<' in temp and '>' not in temp:
        return False
    if '>' in temp and '<' not in temp:
        return False
    return True

all_get_labels = {'smb':get_label_smb, 'ki':get_label_ki, 'mm':get_label_mm, 'smb_pats':get_label_pat, 'blend':None}
get_label = all_get_labels[GAME]

if GAME != 'blend':
    levels, text, pats = parse_folder(folder,is_pats)
    text = text.replace('\n','')
    print(len(levels))
    print("Num batches: ", len(levels)/batch_size)
    chars = sorted(list(set(text.strip('\n'))))
    int2char = dict(enumerate(chars))
    char2int = {ch: ii for ii, ch in int2char.items()}
    print(char2int)
    num_tiles = len(char2int)
    print(num_tiles)
    print(len(levels))
else:
    smb_levels, smb_text, _ = parse_folder(smb_folder)
    ki_levels, ki_text, _ = parse_folder(ki_folder)
    mm_levels, mm_text, _ = parse_folder(mm_folder)
    smb_text.replace('\n','')
    ki_text.replace('\n','')
    mm_text.replace('\n','')
    text = smb_text + ki_text + mm_text
    text = text.replace('\n','')
    chars = sorted(list(set(text.strip('\n'))))
    int2char = dict(enumerate(chars))
    char2int = {ch: ii for ii, ch in int2char.items()}
    print(char2int)
    num_tiles = len(char2int)
    print(num_tiles)
    s = len(smb_levels)/len(ki_levels)
    ki_levels += ki_levels
    print(len(smb_levels), len(ki_levels), len(mm_levels))
    levels = smb_levels + ki_levels + mm_levels
    print(len(levels))
    

encoded, labels = [], []
if GAME == 'smb_pats':
    for level, pat in zip(levels,pats):
        if not pipe_check(level):
            continue
        label = get_label(pat)
        enc = []
        for line in level:
            encoded_line = [char2int[x] for x in line]
            enc.append(encoded_line)
        encoded.append(enc)
else:
    for level in levels:
        if GAME == 'smb':
            if not pipe_check(level):
                continue
        enc = []
        for line in level:
            encoded_line = [char2int[x] for x in line]
            enc.append(encoded_line)
        encoded.append(enc)
        if GAME != 'blend':
            label = get_label(level)
            labels.append(np.array(label).astype('uint8'))
        else:
            if level in smb_levels:
                labels.append(np.array([True,False,False]).astype('uint8'))
            elif level in ki_levels:
                labels.append(np.array([False,True,False]).astype('uint8'))
            elif level in mm_levels:
                labels.append(np.array([False,False,True]).astype('uint8'))

encoded = np.array(encoded)
labels = np.array(labels)
num_labels = labels.shape[1]
print(encoded.shape,labels.shape, num_labels)
block = 0
if GAME == 'smb' or GAME == 'smb_pats':
    block = char2int['X']
else:
    block = char2int['#']
    
inputs_onehot = np.eye(num_tiles, dtype='uint8')[encoded]
inputs_onehot = np.rollaxis(inputs_onehot, 3, 1)
labels_onehot = labels

print(encoded.shape, labels_onehot.shape, labels_onehot[1], type(labels_onehot[0]))

inputs_class_train = inputs_onehot
inputs_class_train = inputs_class_train.reshape(inputs_class_train.shape[0],-1)
X_train, X_test, Y_train, Y_test = train_test_split(inputs_class_train, labels_onehot, test_size=0.2)

inputs_train = torch.from_numpy(inputs_onehot).to(dtype=torch.float64)
labels_train = torch.from_numpy(labels_onehot).to(dtype=torch.float64)
train_ds = TensorDataset(inputs_train,labels_train)
train_dl = DataLoader(train_ds, batch_size=batch_size,shuffle=True)

vae, opt = get_cond_model(device, 256, num_tiles, latent_dim, num_labels)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=2500)
print(vae)

def loss_fn(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD, BCE, KLD

In [None]:
model_name = 'cvae_' + GAME + '_ld_' + str(latent_dim)
epochs = 10000 # num epochs to train for
rate = 2500

for i in range(epochs):
    vae.train()
    train_loss, kld_loss = 0, 0
    for batch, (x,c) in enumerate(train_dl):
        x, c = x.to(device), c.to(device)
        x_lin = x.view(x.size(0),-1)
        opt.zero_grad()
        recon_x_lin, mu, logvar = vae(x_lin,c)
        recon_x = recon_x_lin.reshape(recon_x_lin.size(0),x.size(1),x.size(2),x.size(3))
        loss, bce, kld = loss_fn(recon_x, x, mu, logvar)
        klw = min(1.0, i/rate)
        loss = bce + kld*klw
        train_loss += loss.item()
        kld_loss += kld.item()
        loss.backward()
        opt.step()
    train_loss /= len(train_dl.dataset)
    kld_loss /= len(train_dl.dataset)
    if i % 100 == 0:
        print('Epoch: ', i,'\tLoss: ',train_loss,"\tKLD: ", kld_loss, "\tKLW: ", klw)
    if i % 1000 == 0:
        torch.save(vae.state_dict(), model_name + '_' + str(i) + '.pth')
    scheduler.step()
print('Epoch: ', i,'\tLoss: ',train_loss,"\tKLD: ", kld_loss, "\tKLW: ", klw)
torch.save(vae.state_dict(), model_name + '_final.pth')

In [None]:
# BLEND CLASSIFIER
clf = RandomForestClassifier(n_estimators=500)
clf.fit(X_train,Y_train)
y_pred = clf.predict(X_test)
print('Accuracy: ', metrics.accuracy_score(Y_test, y_pred) * 100)
with open(GAME + '_classifier.pickle','wb') as f:
    pickle.dump(clf,f)

f2 = open(GAME + '_classifier.pickle','rb')
clf2 = pickle.load(f2)
print(clf2)
print(X_test.shape)
print(X_test[0].shape)
Y_pred = clf2.predict(X_test)
print(Y_pred.shape)
print(Y_pred[0].shape)
print(Y_pred[0])
Y_pred_probs = clf2.predict_proba(X_test)
print(Y_pred_probs[0])
sys.exit()
for t,p,pr in zip(Y_test,Y_pred,Y_pred_probs):
    print(t,'\t',p,'\t',max(pr),pr)