In [1]:
import torch, os, sys, json, math, random, re, pickle, warnings
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from scipy.spatial import distance
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from sklearn.manifold import TSNE
from model_lin import get_model, load_model
from model_conv import get_conv_model, load_conv_model
from model_conv_big import get_conv_big_model, load_conv_big_model

LIN, CONV = 0, 1
MODEL_TYPE = LIN

warnings.filterwarnings("ignore")
device = torch.device('cuda')
dims = (15,16)
latent_dim = 8
smb_folder = 'smb_chunks_all/'
ki_folder = 'ki_chunks_all/'
mm_folder = 'mm_chunks_all/'
#manual_seed = random.randint(1, 10000)
#random.seed(manual_seed)
#torch.manual_seed(0)
#np.random.seed(0)

GAME = 'mm'

folders = {'smb':smb_folder,'ki':ki_folder,'mm':mm_folder,'blend':None}

smb_ki_images = {
    # TODO: Get T, D, M tiles from Icarus
    "E": Image.open('tiles/E.png'),
    "H": Image.open('tiles/H.png'),
    "G": Image.open('tiles/G.png'),
    "M": Image.open('tiles/M.png'),
    "o": Image.open('tiles/o.png'),
    "S": Image.open('tiles/S.png'),
    "T": Image.open('tiles/T.png'),
    "?": Image.open('tiles/Q.png'),
    "Q": Image.open('tiles/Q.png'),
    "X": Image.open('tiles/X1.png'),
    "#": Image.open('tiles/X.png'),
    "-": Image.open('tiles/-.png'),
    "0": Image.open('tiles/0.png'),
    "D": Image.open('tiles/D.png'),
    "<": Image.open('tiles/PTL.png'),
    ">": Image.open('tiles/PTR.png'),
    "[": Image.open('tiles/[.png'),
    "]": Image.open('tiles/].png'),
    "*": Image.open('tiles/-.png'),
    "P": Image.open('tiles/P.png'),
    "Y": Image.open('tiles/P.png'),
    "Z": Image.open('tiles/P.png')
}

sketch_images = {
    "X": Image.open('tiles/X.png'),
    "E": Image.open('tiles/E.png'),
    "|": Image.open('tiles/D.png'),
    "P": Image.open('tiles/P.png'),
    "*": Image.open('tiles/o.png'),
    "-": Image.open('tiles/-.png')
}

mm_images = {
    "#":Image.open('tiles/X.png'),
    "*":Image.open('tiles/o.png'),
    "+":Image.open('tiles/o.png'),
    "-":Image.open('tiles/-.png'),
    "B":Image.open('tiles/X1.png'),
    "C":Image.open('tiles/H.png'),
    "D":Image.open('tiles/D.png'),
    "H":Image.open('tiles/H.png'),
    "L":Image.open('tiles/o1.png'),
    "M":Image.open('tiles/M.png'),
    "P":Image.open('tiles/P.png'),
    "U":Image.open('tiles/o1.png'),
    "W":Image.open('tiles/o.png'),
    "l":Image.open('tiles/o1.png'),
    "t":Image.open('tiles/E.png'),
    "w":Image.open('tiles/o.png'),
    "|":Image.open('tiles/D.png')
}

ng_images = {
    "X":Image.open('tiles/X.png'),
    "*":Image.open('tiles/o.png'),
    "+":Image.open('tiles/o.png'),
    ")":Image.open('tiles/o.png'),
    "1":Image.open('tiles/o.png'),
    "-":Image.open('tiles/-.png'),
    "B":Image.open('tiles/E.png'),
    "D":Image.open('tiles/E.png'),
    "R":Image.open('tiles/H.png'),
    "L":Image.open('tiles/D.png'),
    "J":Image.open('tiles/o.png'),
    "P":Image.open('tiles/P.png'),
    "F":Image.open('tiles/o1.png'),
    "W":Image.open('tiles/o.png'),
    "T":Image.open('tiles/o.png'),
    "A":Image.open('tiles/o.png'),
    "%":Image.open('tiles/o.png'),
    "K":Image.open('tiles/H.png'),
    "E": Image.open('tiles/E.png'),
    "C": Image.open('tiles/E.png')
}


images_all = {'smb':smb_ki_images, 'ki':smb_ki_images, 'mm':mm_images, 'sketch':sketch_images}
#images = mm_images
#images = ng_images
#{'%': 0, ')': 1, '*': 2, '+': 3, '-': 4, '1': 5, 'A': 6, 'B': 7, 'D': 8, 'E': 9, 'F': 10, 'J': 11, 'K': 12, 'L': 13, 'P': 14, 'R': 15, 'T': 16, 'W': 17, 'X': 18}
#{'#': 0, '*': 1, '+': 2, '-': 3, 'B': 4, 'C': 5, 'D': 6, 'H': 7, 'L': 8, 'M': 9, 'P': 10, 'U': 11, 'W': 12, 'l': 13, 't': 14, 'w': 15, '|': 16}
#{'#': 0, '-': 1, '<': 2, '>': 3, '?': 4, 'D': 5, 'E': 6, 'H': 7, 'M': 8, 'P': 9, 'Q': 10, 'S': 11, 'T': 12, 'X': 13, '[': 14, ']': 15, 'o': 16}
def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)
    
def parse_folder(folder,dir=False):
    levels, dirs = [], []
    text = ''
    files = os.listdir(folder)
    files[:] = (value for value in files if value != '.')
    files = natural_sort(files)
    for file in files:
        if file.startswith('.'):
            continue
        with open(os.path.join(folder,file),'r') as infile:
            level = []
            for line in infile:
                text += line
                level.append(list(line.rstrip()))
            levels.append(level)
            dirs.append(file[-5])
            #print(levels)
    if dir:
        return levels, text, dirs
    return levels, text

#model_name = 'filter_vae_lin_' + GAME + '_ld_' + str(latent_dim) + '_final.pth'
#model_name_conv = 'filter_vae_' + GAME + '_ld_' + str(latent_dim) + '_final.pth'
#model_name_lean = 'filter_vae_lean_' + GAME + '_ld_' + str(latent_dim) + '_final.pth'
#model_name_big = 'filter_vae_big_' + GAME + '_ld_' + str(latent_dim) + '_final.pth'
model_name_conv = 'filter_vae_big_lean' + GAME + '_ld_' + str(latent_dim) + '_final.pth'

sketch_chars = ['X','E','|','*','-']
int2sc = dict(enumerate(sketch_chars))
sc2int = {ch: ii for ii, ch in int2sc.items()}
num_sketch_tiles = len(sc2int)
print(num_sketch_tiles)

models = {}
num_tiles_all = {}
int2chars = {}
char2ints = {}
for game in ['smb','ki','mm']:
    print('\n',game)
    folder = folders[game]
    levels, text = parse_folder(folder)
    text = text.replace('\n','')
    print(len(levels))
    chars = sorted(list(set(text.strip('\n'))))
    int2char = dict(enumerate(chars))
    char2int = {ch: ii for ii, ch in int2char.items()}    
    int2chars[game] = int2char
    char2ints[game] = char2int
    print(char2int)
    num_tiles = len(char2int)
    num_tiles_all[game] = num_tiles
    print(num_tiles)
    model = None
    model_name = 'filter_vae_fc_nored_' + game + '_ld_' + str(latent_dim) + '_final.pth'
    #model_name_conv = 'filter_vae_' + game + '_ld_' + str(latent_dim) + '_final.pth'
    #model_name_conv = 'filter_vae_big_lean' + game + '_ld_' + str(latent_dim) + '_final.pth'
    if MODEL_TYPE == LIN:
        model = load_model(model_name, 240, num_sketch_tiles, num_tiles, latent_dim)
    else:
        model = load_conv_big_model(model_name_conv, num_sketch_tiles, num_tiles, latent_dim)
    model = model.to(device)
    model.eval()
    print(model)
    models[game] = model
    
#print(models)
#sys.exit()
#model_lean = load_conv_model(model_name_lean, num_sketch_tiles, num_tiles, latent_dim)
#model_big = load_conv_big_model(model_name_big, num_sketch_tiles, num_tiles, latent_dim)
#model_lean = model_lean.to(device)
#model_big = model_big.to(device)
#model_lean.eval()
#model_big.eval()

#model = model.to(device)
#model.eval()

def translate_smb(level):
    t_level = []
    for line in level:
        t_line = ''
        for c in line:
            if c in 'X<>[]S':
                t_line += 'X'
            elif c in 'o?Q':
                t_line += '*'
            elif c in 'Bb':
                t_line += 'E'
            else:
                t_line += c
        t_level.append(t_line)
    return t_level

def translate_mm(level):
    t_level = []
    for line in level:
        t_line = ''
        for c in line:
            if c in '#BM':
                t_line += 'X'
            elif c in 'CHt':
                t_line += 'E'
            elif c in 'D|':
                t_line += '|'
            elif c in '*+LUWlw':
                t_line += '*'
            elif c in 'P':
                t_line += '-'
            else:
                t_line += c
        t_level.append(t_line)
    return t_level


def translate_ki(level):
    t_level = []
    for line in level:
        t_line = ''
        for c in line:
            if c in '#MT':
                t_line += 'X'
            elif c in 'H':
                t_line += 'E'
            elif c in 'D':
                t_line += '|'
            elif c in 'P':
                t_line += '-'
            else:
                t_line += c
        t_level.append(t_line)
    return t_level

def translate_file(folder,f,game):
    chunk = open(folder + f, 'r').read().splitlines()
    chunk = [line.replace('\r\n','') for line in chunk]
    print('\n'.join(chunk))
    if game == 'smb':
        t_chunk = translate_smb(chunk)
    elif game == 'ki':
        t_chunk = translate_ki(chunk)
    elif game == 'mm':
        t_chunk = translate_mm(chunk)
    return t_chunk

def get_z_from_file(folder,f,game):
    model = models[game]
    print('\nInput:')
    chunk = open(folder + f, 'r').read().splitlines()
    chunk = [line.replace('\r\n','') for line in chunk]
    out = []
    for line in chunk:
        print(line)
        line_list = list(line)
        line_list_map = [sc2int[x] for x in line_list]
        out.append(line_list_map)
    out = np.asarray(out)
    #print(out, out.shape)
    out_onehot = np.eye(num_sketch_tiles, dtype='uint8')[out]
    out_onehot = np.rollaxis(out_onehot, 2, 0)

    out_onehot = out_onehot[None, :, :]

    data = torch.DoubleTensor(out_onehot).to(device)
    if MODEL_TYPE == LIN:
        data = data.view(data.size(0),-1)
        z, _, _, = model.encoder.encode(data)
    elif MODEL_TYPE == CONV:
        z, _, _ = model.encode(data)

    return z

def get_z_from_level(level,game):
    model = models[game]
    out = []
    #print(level)
    for l in level:
        l = list(l)
        l_map = [sc2int[x] for x in l]
        out.append(l_map)
    out = np.asarray(out)
    #print("2ndout: ", out)
    out_onehot = np.eye(num_sketch_tiles, dtype='uint8')[out]
    out_onehot = np.rollaxis(out_onehot, 2, 0)
    out_onehot = out_onehot[None, :, :]
    out = torch.DoubleTensor(out_onehot)
    out = out.to(device)
    if MODEL_TYPE == LIN:
        out_lin = out.view(out.size(0),-1)
        z, _, _ = model.encoder.encode(out_lin)
    elif MODEL_TYPE == CONV:
        z, _, _ = model.encode(out)
    return z

def get_level_from_z(z,target):
    #c = torch.from_numpy(np.eye(len(games), dtype='uint8')[c]).to(dtype=torch.float64)
    model = models[target]
    if MODEL_TYPE == LIN:
        level = model.decoder.decode(z)
        level = level.reshape(level.size(0),num_tiles_all[target],dims[0],dims[1])
    else:
        level = model.decode(z)
    
    im = level.data.cpu().numpy()
    im = np.argmax(im, axis=1).squeeze(0)
    level = np.zeros(im.shape)
    level = []
    for i in im:
        level.append(''.join([int2chars[target][t] for t in i]))
    print('Output:')
    print('\n'.join(level),'\n')
    return level

def get_image_from_level(level,name,game):
    return
    images = images_all[game]
    img = Image.new('RGB',(16*16, 15*16))
    for row, seq in enumerate(level):
        for col, tile in enumerate(seq):
            img.paste(images[tile],(col*16,row*16))
    img.save('filter_' + name + '.png')

def get_image_from_multiple_levels(levels,name):
    img = Image.new('RGB',(16*16*len(levels), 15*16*len(levels)))
    for i, level in enumerate(levels):
        for row, seq in enumerate(level):
            for col, tile in enumerate(seq):
                x = (i*16)+(col*16)
                y = (i*16)+(row*16)
                print(i, x, y)
                img.paste(images[tile], (x, y))
    img.save(name + '.png')
    return img

#"""
for i in range(6):
    print(i)
    print('SMB: ')
    z = get_z_from_file('','filter_' + str(i) + '.txt','smb')
    level = get_level_from_z(z,'smb')
    get_image_from_level(level,'smb_' + str(i),'smb')
    print('\n')
    print('KI: ')
    z = get_z_from_file('','filter_' + str(i) + '.txt','ki')
    level = get_level_from_z(z,'ki')
    get_image_from_level(level,'ki_' + str(i),'ki')
    print('\n')
    print('MM:')
    z = get_z_from_file('','filter_' + str(i) + '.txt','mm')
    level = get_level_from_z(z,'mm')
    get_image_from_level(level,'mm_' + str(i),'mm')
#print(level,'\n')

#sys.exit()
#""" 

print('SMB-to-KI')
level = translate_file(smb_folder,'smb_chunk_100.txt','smb')
#get_image_from_level(level,'smb_chunk_10','smb')
#print('Translated: \n', '\n'.join(level))
z = get_z_from_level(level,'ki')
level = get_level_from_z(z,'ki')
get_image_from_level(level,'filter_ki_from_smb_chunk_100','ki')
print('\n')

print('KI-to-SMB')
level = translate_file(ki_folder,'ki_chunk_100.txt','ki')
#get_image_from_level(level,'ki_chunk_10','ki')
#print('Translated: \n', '\n'.join(level))
z = get_z_from_level(level,'smb')
level = get_level_from_z(z,'smb')
get_image_from_level(level,'smb_from_ki_chunk_100','smb')
print('\n')

print('SMB-to-MM')
level = translate_file(smb_folder,'smb_chunk_10.txt','smb')
#print('Translated: \n', '\n'.join(level))
z = get_z_from_level(level,'mm')
level = get_level_from_z(z,'mm')
get_image_from_level(level,'mm_from_smb_chunk_10','mm')
print('\n')

print('MM-to-SMB')
level = translate_file(mm_folder,'mm_chunk_100.txt','mm')
#get_image_from_level(level,'mm_chunk_2000','mm')
#print('Translated: \n', '\n'.join(level))
z = get_z_from_level(level,'smb')
level = get_level_from_z(z,'smb')
get_image_from_level(level,'smb_from_mm_chunk_100','smb')
print('\n')

print('KI-to-MM')
level = translate_file(ki_folder,'ki_chunk_10.txt','ki')
#print('Translated: \n', '\n'.join(level))
z = get_z_from_level(level,'mm')
level = get_level_from_z(z,'mm')
get_image_from_level(level,'mm_from_ki_chunk_10','mm')
print('\n')

print('MM-to-KI')
level = translate_file(mm_folder,'mm_chunk_100.txt','mm')
#print('Translated: \n', '\n'.join(level))
z = get_z_from_level(level,'ki')
level = get_level_from_z(z,'ki')
get_image_from_level(level,'ki_from_mm_chunk_100','ki')
print('\n')

5

 smb
176
{'-': 0, '<': 1, '>': 2, '?': 3, 'B': 4, 'E': 5, 'Q': 6, 'S': 7, 'X': 8, '[': 9, ']': 10, 'b': 11, 'o': 12}
13
VAE(
  (encoder): Encoder(
    (fc1): Linear(in_features=1200, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=512, bias=True)
    (fc3): Linear(in_features=512, out_features=256, bias=True)
    (fc41): Linear(in_features=256, out_features=8, bias=True)
    (fc42): Linear(in_features=256, out_features=8, bias=True)
  )
  (decoder): Decoder(
    (fc5): Linear(in_features=8, out_features=256, bias=True)
    (fc6): Linear(in_features=256, out_features=512, bias=True)
    (fc7): Linear(in_features=512, out_features=1024, bias=True)
    (fc8): Linear(in_features=1024, out_features=3120, bias=True)
  )
)

 ki
80
{'#': 0, '-': 1, 'D': 2, 'H': 3, 'M': 4, 'T': 5}
6
VAE(
  (encoder): Encoder(
    (fc1): Linear(in_features=1200, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=512, bias=True)
    (fc3): Linear(in_fe

FileNotFoundError: [Errno 2] No such file or directory: 'smb_chunks_all/smb_chunk_100.txt'