In [13]:
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torchvision.utils as vutils
from torch.autograd import Variable
import os
import numpy as np
import matplotlib.pyplot as plt
import math
import json
import sys
from PIL import Image
from model_lin_cond import get_cond_model, load_cond_model
from torch.utils.data import DataLoader, Dataset, TensorDataset
import re
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import mean_squared_error
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import pickle
import copy
from sklearn import metrics
from sklearn import linear_model
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda')
latent_dim = 32 # size of latent vector
GAME = 'blend'
dims = (16,16)
smb_folder = 'smb_chunks_fixed/'
ki_folder = 'ki_chunks/'
mm_folder = 'mm_chunks_fixed/'
pats_folder = 'smb_pats/'

folders = {'smb':smb_folder,'ki':ki_folder,'mm':mm_folder,'ng':ng_folder,'smb_pats':pats_folder,'blend':None}
labels = {'smb':5, 'ki':4, 'mm':5, 'blend':3,'smb_pats':10}
label_size = labels[GAME]
num_labels = int(math.pow(2,label_size))
folder = folders[GAME]
#manual_seed = random.randint(1, 10000)
#random.seed(manual_seed)
#torch.manual_seed(0)
#np.random.seed(0)

smb_ki_images = {
    "E": Image.open('tiles/E.png'),
    "H": Image.open('tiles/H.png'),
    "G": Image.open('tiles/G.png'),
    "M": Image.open('tiles/M.png'),
    "o": Image.open('tiles/o.png'),
    "S": Image.open('tiles/S.png'),
    "T": Image.open('tiles/T.png'),
    "?": Image.open('tiles/Q.png'),
    "Q": Image.open('tiles/Q.png'),
    "X": Image.open('tiles/X1.png'),
    "#": Image.open('tiles/X.png'),
    "-": Image.open('tiles/-.png'),
    "0": Image.open('tiles/0.png'),
    "D": Image.open('tiles/D.png'),
    "<": Image.open('tiles/PTL.png'),
    ">": Image.open('tiles/PTR.png'),
    "[": Image.open('tiles/[.png'),
    "]": Image.open('tiles/].png'),
    "P": Image.open('tiles/P.png')
}

mm_images = {
    "#":Image.open('tiles/X.png'),
    "*":Image.open('tiles/o.png'),
    "+":Image.open('tiles/o.png'),
    "-":Image.open('tiles/-.png'),
    "B":Image.open('tiles/X1.png'),
    "C":Image.open('tiles/CMM.png'),
    "D":Image.open('tiles/DMM.png'),
    "H":Image.open('tiles/HMM.png'),
    "L":Image.open('tiles/o1.png'),
    "M":Image.open('tiles/MMM.png'),
    "P":Image.open('tiles/P.png'),
    "U":Image.open('tiles/o1.png'),
    "W":Image.open('tiles/o.png'),
    "l":Image.open('tiles/o1.png'),
    "t":Image.open('tiles/TMM.png'),
    "w":Image.open('tiles/o.png'),
    "|":Image.open('tiles/LMM.png')
}

smb_ki_mm_images = {
    "E": Image.open('tiles/E.png'),
    #"H": Image.open('tiles/H.png'),
    "M": Image.open('tiles/M.png'),
    "o": Image.open('tiles/o.png'),
    "S": Image.open('tiles/S.png'),
    "T": Image.open('tiles/T.png'),
    "?": Image.open('tiles/Q.png'),
    "Q": Image.open('tiles/Q.png'),
    "X": Image.open('tiles/X1.png'),
    "#": Image.open('tiles/X.png'),
    "-": Image.open('tiles/-.png'),
    "D": Image.open('tiles/D.png'),
    "<": Image.open('tiles/PTL.png'),
    ">": Image.open('tiles/PTR.png'),
    "[": Image.open('tiles/[.png'),
    "]": Image.open('tiles/].png'),
    "P": Image.open('tiles/P.png'),
    "*":Image.open('tiles/o.png'),
    "+":Image.open('tiles/o.png'),
    "B":Image.open('tiles/X1.png'),
    "C":Image.open('tiles/CMM.png'),
    "H":Image.open('tiles/HMM.png'),
    "L":Image.open('tiles/o1.png'),
    "U":Image.open('tiles/o1.png'),
    "W":Image.open('tiles/o.png'),
    "l":Image.open('tiles/o1.png'),
    "t":Image.open('tiles/TMM.png'),
    "w":Image.open('tiles/o.png'),
    "|":Image.open('tiles/LMM.png')
}
print(len(smb_ki_mm_images))
all_images = {'smb':smb_ki_images, 'ki':smb_ki_images, 'mm':mm_images,'smb_pats':smb_ki_images,'blend':smb_ki_mm_images}
images = all_images[GAME]
label_counts = {}

for i in range(0,num_labels):
    label = [j for j in bin(i)[2:]]
    if len(label) < label_size:
        label = ['0'] * (label_size - len(label)) + label
    label = ''.join(label)
    label_counts[label] = 0

def natural_sort(l): 
    convert = lambda text: int(text) if text.isdigit() else text.lower() 
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)
    
def parse_folder(folder,pats=False,dir=False):
    levels, dirs, patterns = [], [], []
    text = ''
    files = os.listdir(folder)
    files[:] = (value for value in files if value != '.')
    files = natural_sort(files)
    for file in files:
        if file.startswith('.'):
            continue
        with open(os.path.join(folder,file),'r') as infile:
            if pats:
                ps1 = file[file.find('_')+1:]
                ps2 = ps1[ps1.find('_')+1:]
                ps3 = ps2[ps2.find('_')+1:-4]
                pat = ps3.split('_')
                patterns.append(pat)
            level = []
            for line in infile:
                text += line
                level.append(list(line.rstrip()))
            levels.append(level)
            dirs.append(file[-5])
    if dir:
        return levels, text, dirs
    if pats:
        return levels, text, patterns
    return levels, text

if GAME == 'smb_pats':
    levels, text, pats = parse_folder(folder,True)
else:
    smb_levels, smb_text = parse_folder(smb_folder)
    ki_levels, ki_text = parse_folder(ki_folder)
    mm_levels, mm_text = parse_folder(mm_folder)
    blend_levels = smb_levels + ki_levels + mm_levels
    smb_text = smb_text.replace('\n','')
    ki_text = ki_text.replace('\n','')
    mm_text = mm_text.replace('\n','')
    blend_text = smb_text + ki_text + mm_text
    blend_text = blend_text.replace('\n','')
all_levels = {'smb':smb_levels, 'ki':ki_levels, 'mm':mm_levels,'blend':blend_levels}
all_text = {'smb':smb_text, 'ki':ki_text, 'mm':mm_text,'blend':blend_text}
#print(len(levels))
levels = all_levels[GAME]
text = all_text[GAME]
chars = sorted(list(set(text.strip('\n'))))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}
print(int2char)
print(char2int)
num_tiles = len(char2int)
print('Num tiles: ', num_tiles)

f = open('blend_classifier.pickle','rb')
classifier = pickle.load(f)
print(classifier)

model_name = 'cvae_' + GAME + '_ld_' + str(latent_dim) + '_final.pth'
model = load_cond_model(model_name,256,num_tiles,latent_dim,label_size,device)
model.eval()
model.to(device)
print(model)

def get_label_smb(level):
    label = [False,False,False,False,False]  # Enemy, Pipes, Coins, Breakable, QMs
    temp = ''
    for l in level:
        temp += ''.join(l)
    if 'E' in temp:
        label[0] = True
    if '[' in temp and ']' in temp and '<' in temp and '>' in temp:
        label[1] = True
    if 'o' in temp:
        label[2] = True
    if 'S' in temp:
        label[3] = True
    if 'Q' in temp or '?' in temp:
        label[4] = True
    return label

def get_label_ki(level):
    label = [False,False,False,False]  # Enemy, Doors, Moving, T
    temp = ''
    for l in level:
        temp += ''.join(l)
    if 'H' in temp:
        label[0] = True
    if 'D' in temp:
        label[1] = True
    if 'M' in temp:
        label[2] = True
    if 'T' in temp:
        label[3] = True
    return label

def get_label_mm(level):
    label = [False,False,False,False,False]  # H/T/C, D/|, M, */U/W/w/+/l/L
    temp = ''
    for l in level:
        temp += ''.join(l)
    if 'H' in temp or 'T' in temp or 'C' in temp:
        label[0] = True
    if 'D' in temp:
        label[1] = True
    if '|' in temp:
        label[2] = True
    if 'M' in temp:
        label[3] = True
    if '*' in temp or 'U' in temp or 'W' in temp or 'w' in temp or '+' in temp or 'l' in temp or 'L' in temp:
        label[4] = True
    return label    

def pipe_check(level):
    temp = ''
    for l in level:
        temp += ''.join(l)
    if '[' in temp and ']' not in temp:
        return False
    if ']' in temp and '[' not in temp:
        return False
    if '<' in temp and '>' not in temp:
        return False
    if '>' in temp and '<' not in temp:
        return False
    return True

def get_label_pat(pat):
    label = [False] * 10   # EH, G, PV, GV, NV, EV, MP, RR, SU, SD
    if 'EH' in pat:
        label[0] = True
    if 'G' in pat:
        label[1] = True
    if 'PV' in pat:
        label[2] = True
    if 'GV' in pat:
        label[3] = True
    if 'NV' in pat:
        label[4] = True
    if 'EV' in pat:
        label[5] = True
    if 'MP' in pat:
        label[6] = True
    if 'RR' in pat:
        label[7] = True
    if 'SU' in pat:
        label[8] = True
    if 'SD' in pat:
        label[9] = True
    return label

def get_label_blend(level):
    out = []
    for l in level:
        l = list(l)
        l_map = [char2int[x] for x in l]
        out.append(l_map)
    out = np.asarray(out)
    out_onehot = np.eye(num_tiles, dtype='uint8')[out]
    out_onehot = np.rollaxis(out_onehot, 2, 0)
    out_onehot = out_onehot[None, :, :]
    out = torch.DoubleTensor(out_onehot)
    out_lin = out.view(out.size(0),-1)
    out_lin = out_lin.numpy()
    pred, probs = classifier.predict(out_lin), classifier.predict_proba(out_lin)
    return pred[0], probs[0]

def get_label_string_from_array(label):
    label = np.array(label).astype('uint8')
    label = list(label)
    label = ''.join([str(i) for i in label])
    return label

if GAME == 'smb_pats':
    for level,pat in zip(levels,pats):
        label = get_label_pat(pat)
        label = get_label_string_from_array(label)
        label_counts[label] += 1
elif GAME != 'blend':
    for level in levels:
        if GAME == 'smb':
            if not pipe_check(level):
                continue
            label = get_label_smb(level)
        elif GAME == 'ki':
            label = get_label_ki(level)
        elif GAME == 'mm':
            label = get_label_mm(level)
        label = get_label_string_from_array(label)
        label_counts[label] += 1

def get_image_from_segment(segment,name):
    img = Image.new('RGB',(16*16, 16*16))
    for row, seq in enumerate(segment):
        for col, tile in enumerate(seq):
            img.paste(images[tile],(col*16,row*16))
    img.save(name + '.png')

def get_z_from_segment_c(segment,c):
    out = []
    for l in segment:
        l = list(l)
        l_map = [char2int[x] for x in l]
        out.append(l_map)
    out = np.asarray(out)
    out_onehot = np.eye(num_tiles, dtype='uint8')[out]
    out_onehot = np.rollaxis(out_onehot, 2, 0)
    out_onehot = out_onehot[None, :, :]
    out = torch.DoubleTensor(out_onehot)
    out = out.to(device)
    out_lin = out.view(out.size(0),-1)
    z, _, _ = model.encoder.encode(out_lin,c)
    return z

def get_z_from_file(folder,f):
    chunk = open(folder + f, 'r').read().splitlines()
    chunk = [line.replace('\r\n','') for line in chunk]
    out = []
    for line in chunk:
        print(line)
        line_list = list(line)
        line_list_map = [char2int[x] for x in line_list]
        out.append(line_list_map)
    out = np.asarray(out)
    out_onehot = np.eye(num_tiles, dtype='uint8')[out]
    out_onehot = np.rollaxis(out_onehot, 2, 0)
    out_onehot = out_onehot[None, :, :]

    data = torch.DoubleTensor(out_onehot).to(device)
    data = data.view(data.size(0),-1)
    z, _, _ = model.encode(data)

    return z

def get_z_from_segment(segment):
    out = []
    for l in segment:
        l = list(l)
        l_map = [char2int[x] for x in l]
        out.append(l_map)
    out = np.asarray(out)
    out_onehot = np.eye(num_sketch_tiles, dtype='uint8')[out]
    out_onehot = np.rollaxis(out_onehot, 2, 0)
    out_onehot = out_onehot[None, :, :]
    out = torch.DoubleTensor(out_onehot)
    out = out.to(device)
    out_lin = out.view(out.size(0),-1)
    z, _, _ = model.encode(out)
    return z
    
def get_segment_from_file(folder,f):
    chunk = open(folder + f, 'r').read().splitlines()
    chunk = [line.replace('\r\n','') for line in chunk]
    return chunk
    out = []
    for line in chunk:
        line_list = list(line)
        out.append(line_list)
    return out

def get_z_from_file_c(folder,f,c):
    chunk = open(folder + f, 'r').read().splitlines()
    chunk = [line.replace('\r\n','') for line in chunk]
    out = []
    for line in chunk:
        line_list = list(line)
        line_list_map = [char2int[x] for x in line_list]
        out.append(line_list_map)
    out = np.asarray(out)
    out_onehot = np.eye(num_tiles, dtype='uint8')[out]
    out_onehot = np.rollaxis(out_onehot, 2, 0)

    out_onehot = out_onehot[None, :, :]

    data = torch.DoubleTensor(out_onehot).to(device)
    data = data.view(data.size(0),-1)
    z, _, _ = model.encoder.encode(data,c)

    return z

def get_segment_from_z(z):
    level = model.decode(z)
    level = level.reshape(level.size(0),num_tiles,dims[0],dims[1])
    im = level.data.cpu().numpy()
    im = np.argmax(im, axis=1).squeeze(0)
    level = np.zeros(im.shape)
    level = []
    for i in im:
        level.append(''.join([int2char[t] for t in i]))
    return level

def get_segment_from_zc(z,c):
    level = model.decoder.decode(z,c)
    level = level.reshape(level.size(0),num_tiles,dims[0],dims[1])
    im = level.data.cpu().numpy()
    im = np.argmax(im, axis=1).squeeze(0)
    level = np.zeros(im.shape)
    level = []
    for i in im:
        level.append(''.join([int2char[t] for t in i]))
    return level

def density(segment):
    total = 0
    for l in segment:
        total += len(l)-l.count('-')-l.count('P')
    return (total/256)

def leniency(level):
    total = 0
    for l in level:
        total += len(l)-l.count('E')-l.count('H')-l.count('t')-l.count('C') #-l.count('e')-l.count('v')-l.count('^')-l.count('ě')-l.count('Ě')-l.count('ѷ')-l.count('⌅')
    return (total/256)

def difficulty(segment):
    total = 0
    for i, l in enumerate(segment):
        total += l.count('E')+l.count('H')+l.count('t')+l.count('C')
        #if GAME == 'smb':
        if i == 15:
            total += (0.5)*l.count(' ') # 0.5 penalty for gap in last row
    return (total/256)

def interestingness(level):
    total = 0
    for l in level:
        #total += l.count('o') + l.count('*') + l.count('$') + l.count('!') + l.count('★') + l.count('©') + l.count('§')
        total += l.count('o') + l.count('Q') + l.count('?') + l.count('D') + l.count('*') + l.count('+') + l.count('L') + l.count('U') + l.count('W') + l.count('l') + l.count('W')
    return total/256

def nonlinearity(segment):
    if GAME != 'ki':
        level = [[segment[j][i] for j in range(len(segment))] for i in range(len(segment[0]))]
    else:
        level = segment
    x = np.arange(16)
    y = []
    for i, lev in enumerate(level):
        appended = False
        for j, l in enumerate(lev):
            if l != '-' and l != 'P':
                y.append(15-j)
                appended = True
                break
        if not appended:
            y.append(0)
    x = x.reshape(-1,1)
    y = np.asarray(y)
    reg = linear_model.LinearRegression()
    reg.fit(x,y)
    y_pred = reg.predict(x)
    mse = mean_squared_error(y,y_pred)
    #return (int(round(mse)))
    return mse

def h_symmetry(level):
    total = 0
    for l in level:
        l1, l2 = l[:8], l[8:]
        for a,b in zip(l1,l2):
            if a == b and a != '-' and a != 'P' and b != '-' and b != 'P':
                total += 1
    return total

def v_symmetry(level):
    level_t = [[level[j][i] for j in range(len(level))] for i in range(len(level[0]))]
    return h_symmetry(level_t)

def symmetry(level):
    return (h_symmetry(level)+v_symmetry(level))/256
    #return h_symmetry(level)+v_symmetry(level)

def h_similarity(level):
    total = 0
    for l in level:
        l_str = ''.join(list(l))
        if l_str in level_rows:
            total += 1
    return total

def v_similarity(level):
    total = 0
    level_t = [[level[j][i] for j in range(len(level))] for i in range(len(level[0]))]
    for l in level_t:
        l_str = ''.join(list(l))
        if l_str in level_cols:
            total += 1
    return total

def similarity(level):
    return (h_similarity(level)+v_similarity(level))/32
    #return h_similarity(level)+v_similarity(level)

def h_dissimilarity(level):
    total = 0
    for l in level:
        l_str = ''.join(list(l))
        if l_str not in level_rows:
            total += 1
    return total

def v_dissimilarity(level):
    total = 0
    level_t = [[level[j][i] for j in range(len(level))] for i in range(len(level[0]))]
    for l in level_t:
        l_str = ''.join(list(l))
        if l_str not in level_cols:
            total += 1
    return total

def dissimilarity(level):
    return (h_dissimilarity(level)+v_dissimilarity(level))/32
    #return h_dissimilarity(level)+v_dissimilarity(level)

def v_traversability(level):
    total = 0
    for i, (l1,l2) in enumerate(zip(level[:-1],level[1:])):
        l1_str = ''.join(list(l1))
        l2_str = ''.join(list(l2))
        if 'P' not in l1_str or 'P' not in l2_str:
            total += 0
            continue
        delta = int(math.fabs(l1_str.find('P') - l2_str.find('P')))
        if delta < 2:
            total += 1
        
    return (total/15)

def h_traversability(level):
    level_t = [[level[j][i] for j in range(len(level))] for i in range(len(level[0]))]
    return v_traversability(level_t)

def traversability(level):
    return max(h_traversability(level), v_traversability(level))

z = torch.DoubleTensor(1,latent_dim).normal_(0,1).to(device)
smb_label = [1,0,0]
ki_label = [0,1,0]
mm_label = [0,0,1]
smb_label_tensor = torch.DoubleTensor(smb_label).reshape(1,-1).to(device)
ki_label_tensor = torch.DoubleTensor(ki_label).reshape(1,-1).to(device)
mm_label_tensor = torch.DoubleTensor(mm_label).reshape(1,-1).to(device)

"""
# TEST
level = get_segment_from_zc(z,smb_label_tensor)
print(level)
out_label, probs = get_label_blend(level)
game = np.argmax(out_label)
print(game)

label_tensor = torch.DoubleTensor(ki_label).reshape(1,-1).to(device)
level = get_segment_from_zc(z,ki_label_tensor)
print(level)
out_label, probs = get_label_blend(ki_levels[0])
game = np.argmax(out_label)
print(game)

label_tensor = torch.DoubleTensor(mm_label).reshape(1,-1).to(device)
level = get_segment_from_zc(z,mm_label_tensor)
print(level)
out_label, probs = get_label_blend(level)
game = np.argmax(out_label)
print(game)
"""
    
"""
# COMPUTE TRAINING METRICS
print('Computing training metrics...')
print('SMB')
game_file = open('cvae_metrics_smb.csv','w')
game_file.write('Density,Nonlinearity,Leniency,Interestingness,Traversability\n')
for level in smb_levels:
    d, n, f, i, t = density(level), nonlinearity(level), leniency(level), interestingness(level), traversability(level)
    game_file.write(str(d) + ',' + str(n) + ',' + str(f) + ',' + str(i) + ',' + str(t) + '\n')
game_file.close()

print('KI')
game_file = open('cvae_metrics_ki.csv','w')
game_file.write('Density,Nonlinearity,Leniency,Interestingness,Traversability\n')
for level in ki_levels:
    d, n, f, i, t = density(level), nonlinearity(level), leniency(level), interestingness(level), traversability(level)
    game_file.write(str(d) + ',' + str(n) + ',' + str(f) + ',' + str(i) + ',' + str(t) + '\n')
game_file.close()

print('MM')
game_file = open('cvae_metrics_mm.csv','w')
game_file.write('Density,Nonlinearity,Leniency,Interestingness,Traversability\n')
for level in mm_levels:
    d, n, f, i, t = density(level), nonlinearity(level), leniency(level), interestingness(level), traversability(level)
    game_file.write(str(d) + ',' + str(n) + ',' + str(f) + ',' + str(i) + ',' + str(t) + '\n')
game_file.close()
"""

"""
# COMPUTE GENERATED METRICS AND CLASSIFICATION
game_predictions = {}
for key in label_counts:
    game_predictions[key] = [0,0,0]

print('Classifying blends...')
for c in range(num_labels):
    print(c)
    out_file = open('cvae_metrics_blend_' + str(latent_dim) + '_' + str(c) + '.csv','w')
    out_file.write('Density,Nonlinearity,Leniency,Interestingness,Traversability\n')
    label = [int(j) for j in bin(c)[2:]]
    if len(label) < label_size:
        label = [0] * (label_size - len(label)) + label
    label_str = ''.join([str(j) for j in label])
    label_tensor = torch.DoubleTensor(label).reshape(1,-1).to(device)
    for i in range(1000):
        if i % 50 == 0:
            print(i)
        z = torch.DoubleTensor(1,latent_dim).normal_(0,1).to(device)
        segment = get_segment_from_zc(z,label_tensor)
        d, n, l, i, t = density(segment), nonlinearity(segment), leniency(segment), interestingness(segment), traversability(segment)
        out_file.write(str(d) + ',' + str(n) + ',' + str(l) + ',' + str(i) + ',' + str(t) + '\n')
        out_label, probs = get_label_blend(segment)
        game = np.argmax(out_label)
        game_predictions[label_str][game] += 1
out_file.close()

outfile = open('blend_' + str(latent_dim) + '_elements.csv','w')
outfile.write('Label,SMB,KI,MM\n')
for key in game_predictions:
    outfile.write(key + ',' + str(game_predictions[key][0]) + ',' + str(game_predictions[key][1]) + ',' + str(game_predictions[key][2]) + '\n')
outfile.close()
"""

# GENERATE EXAMPLES WITH EACH CONDITIONING LABEL
smb_file = 'smb_chunk_10.txt'
smb_z = get_z_from_file_c(smb_folder,smb_file,smb_label_tensor)

mm_file = 'mm_chunk_2000.txt'
mm_z = get_z_from_file_c(mm_folder,mm_file,mm_label_tensor)

ki_file = 'ki_chunk_10.txt'
ki_z = get_z_from_file_c(ki_folder,ki_file,ki_label_tensor)


for c in range(num_labels):
    print(c)
    label = [int(j) for j in bin(c)[2:]]
    if len(label) < label_size:
        label = [0] * (label_size - len(label)) + label
    label_str = ''.join([str(j) for j in label])
    label_tensor = torch.DoubleTensor(label).reshape(1,-1).to(device)
    level = get_segment_from_zc(mm_z,label_tensor)
    get_image_from_segment(level,'cvae_new/allm_' + str(latent_dim) + '_' + ''.join(map(str,label)))
    print('\n'.join(level))
    print('\n')

if GAME == 'smb':
    label = np.array(get_label(level)).astype('uint8')
elif GAME == 'ki':
    label = np.array(get_label_ki(level)).astype('uint8')
elif GAME == 'mm':
    label = np.array(get_label_mm(level)).astype('uint8')
elif GAME == 'smb_pats':
    ps1 = file[file.find('_')+1:]
    ps2 = ps1[ps1.find('_')+1:]
    ps3 = ps2[ps2.find('_')+1:-4]
    pat = ps3.split('_')
    label = np.array(get_label_pat(pat)).astype('uint8')
elif GAME == 'blend':
    label, probs = get_label_blend(level)
    print(label, type(label))
    print(probs, type(probs))
print(label,type(label))

"""
# DESIGN PATTERN CONDITIONING
z = torch.DoubleTensor(1,latent_dim).normal_(0,1).to(device)
for c in [1,2,4,8,16]:
#for c in [2,8,168,256,264,512,520,768]:
#for c in range(num_labels):
    #print(c)
    label = [int(i) for i in bin(c)[2:]]
    if len(label) < label_size:
        label = [0] * (label_size - len(label)) + label
    label_tensor = torch.DoubleTensor(label).reshape(1,-1).to(device)
    segment = get_segment_from_zc(z,label_tensor)
    #label = get_label_mm(segment)
    get_image_from_segment(segment,'cond_pats_lean/pl_' + str(latent_dim) + '_' + ''.join(map(str,label)))
#"""

"""
print(len(levels))
print(num_labels)
for idx, segment in enumerate(levels):
    if idx % 100 == 0:
        print(idx)
    if GAME == 'smb':
        label = np.array(get_label(segment)).astype('uint8')
    elif GAME == 'ki':
        label = np.array(get_label_ki(segment)).astype('uint8')
    elif GAME == 'mm':
        label = np.array(get_label_mm(segment)).astype('uint8')
    elif GAME == 'all':
        label, probs = get_label_blend(segment)
    label_tensor = torch.DoubleTensor(label).reshape(1,-1).to(device)
    z = get_z_from_segment_c(segment,label_tensor)
"""

# LABEL ACCURACIES
label_accuracy_exact = {}
label_accuracy_atleast = {}
label_accuracy_none = {}
game_predictions = {}
for key in label_counts:
    label_accuracy_exact[key] = 0
    label_accuracy_atleast[key] = 0
    label_accuracy_none[key] = 0
    game_predictions[key] = [0,0,0]

for i in range(1000):
    if i % 50 == 0:
        print(i)
    z = torch.DoubleTensor(1,latent_dim).normal_(0,1).to(device)
    for c in range(num_labels):
        label = [int(j) for j in bin(c)[2:]]
        if len(label) < label_size:
            label = [0] * (label_size - len(label)) + label
        label_tensor = torch.DoubleTensor(label).reshape(1,-1).to(device)
        segment = get_segment_from_zc(z,label_tensor)
        label_str = ''.join([str(j) for j in label])
        if GAME == 'smb':
            out_label = get_label(segment)
        elif GAME == 'ki':
            out_label = get_label_ki(segment)
        elif GAME == 'mm':
            out_label = get_label_mm(segment)
        elif GAME == 'blend':
            out_label, probs = get_label_blend(segment)
        if GAME != 'blend':
            out_label_str = get_label_string_from_array(out_label)
            if label_str == out_label_str:
                label_accuracy_exact[label_str] += 1
                #label_accuracy_atleast[label_str] += 1
            elif '1' in label_str:
                matches = False
                for a,b in zip(label_str, out_label_str):
                    if a == '1' and b == '1':
                        matches = True
                        break
                if not matches:
                    label_accuracy_none[label_str] += 1
        else:
            game = np.argmax(out_label)
            game_predictions[label_str][game] += 1
        #elif '1' in label_str:
        #    for a,b in zip(label_str,out_label_str):
        #        if a == '1' and b == '1':
        #            #print(label_str, out_label_str)
        #            label_accuracy_atleast[label_str] += 1
        #            break
        #get_image_from_segment(segment,'cond_smb_32/cs_32_' + str(i) + '_' + ''.join(map(str,label)))
total = len(levels)
total = 1000
print('Total: ', total)
outfile = open('all_32_elements.csv','w')
if GAME != 'blend':
    outfile.write('Label,Frequency,Exact,None\n')
    for key in label_accuracy_exact:
        #print(key, '\t', label_counts[key], '\t', label_accuracy_exact[key], label_accuracy_atleast[key])
        outfile.write(key + ',' + str(label_counts[key]) + ',' + str((label_accuracy_exact[key]/total)*100) + ',' + str((label_accuracy_none[key]/total)*100) + '\n')
else:
    outfile.write('Label,SMB,KI,MM\n')
    for key in game_predictions:
        outfile.write(key + ',' + str(game_predictions[key][0]) + ',' + str(game_predictions[key][1]) + ',' + str(game_predictions[key][2]) + '\n')
outfile.close()

28
{0: '#', 1: '*', 2: '+', 3: '-', 4: '<', 5: '>', 6: '?', 7: 'B', 8: 'C', 9: 'D', 10: 'E', 11: 'H', 12: 'L', 13: 'M', 14: 'P', 15: 'Q', 16: 'S', 17: 'T', 18: 'U', 19: 'W', 20: 'X', 21: '[', 22: ']', 23: 'l', 24: 'o', 25: 't', 26: 'w', 27: '|'}
{'#': 0, '*': 1, '+': 2, '-': 3, '<': 4, '>': 5, '?': 6, 'B': 7, 'C': 8, 'D': 9, 'E': 10, 'H': 11, 'L': 12, 'M': 13, 'P': 14, 'Q': 15, 'S': 16, 'T': 17, 'U': 18, 'W': 19, 'X': 20, '[': 21, ']': 22, 'l': 23, 'o': 24, 't': 25, 'w': 26, '|': 27}
Num tiles:  28
RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=500,
                       n_jobs=None, oob_score=False, random_state=None,
 

SystemExit: 