## Imports

In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.optimizers.schedules import ExponentialDecay  # from https://arxiv.org/pdf/1506.02078.pdf
from tensorflow.keras.callbacks import EarlyStopping

from tqdm.notebook import tqdm

In [2]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

Num GPUs Available:  1


## Hyper-parameters

In [4]:
tunable_hparams = {
    'stateful_generation' : True,
    'mapping_type' : 'seq2seq',
    'early_stopping' : False,
    'seq_length' : 200,
    'game' : 'mario'
}
fixed_hparams = {
    'hidden_size' : 128,
    'learning_rate' : 2e-3,
    'learning_rate_decay' : 0.95,
    'dropout' : 0.5,
    'batch_size' : 100,
    'num_layers' : 3,
    'max_epochs' : 50
}

In [5]:
for key, val in tunable_hparams.items():
    exec(key + '=val')
for key, val in fixed_hparams.items():
    exec(key + '=val')

## Data

In [24]:
data = open('corpuses/mario_corpus_conditional.txt', 'r').read()
level_strs = data.rstrip().split(')')[:-1]
print(len(level_strs))

37


In [25]:
chars = []
for level_str in level_strs:
    chars.extend(list(level_str))
chars = sorted(list(set(chars)))  # sorting is very important; otherwise the order is different each time
vocab_size = len(chars)
print(chars, vocab_size)

['\n', '-', '<', '>', '?', 'B', 'E', 'N', 'Q', 'S', 'X', 'Y', '[', ']', 'b', 'o', 'x'] 17


In [26]:
char_to_ix = { ch:i for i, ch in enumerate(chars) }
ix_to_char = { i:ch for i, ch in enumerate(chars) }

In [28]:
level_arrays = []
for level_str in level_strs:
    level_arrays.append(np.array([char_to_ix[char] for char in list(level_str)]))

In [29]:
def get_inputs_and_targets_from_level_array(level_array):
    
    inputs, targets = [], []
    
    for i in range(len(level_array) - seq_length):
        inputs.append(level_array[i:i+seq_length])
        targets.append(level_array[i+1:i+seq_length+1])
    
    inputs, targets = map(np.array, [inputs, targets])
    inputs = np.eye(vocab_size)[inputs]
    
    return inputs, targets

In [30]:
inputs, targets = [], []
for level_array in tqdm(level_arrays, leave=False):
    inputs_temp, targets_temp = get_inputs_and_targets_from_level_array(level_array)
    inputs.extend(inputs_temp); targets.extend(targets_temp)
inputs, targets = map(np.array, [inputs, targets])

HBox(children=(FloatProgress(value=0.0, max=37.0), HTML(value='')))

In [31]:
inputs.shape, targets.shape

((156360, 200, 17), (156360, 200))

## Model callbacks

In [32]:
lr_scheduler = ExponentialDecay(
    initial_learning_rate=learning_rate,
    decay_steps=len(inputs) // batch_size, 
    decay_rate=learning_rate_decay, 
)
optimizer = RMSprop(learning_rate=lr_scheduler)

In [33]:
es_callback = EarlyStopping(
    monitor='val_out_acc_custom_acc', mode='max', patience=2, restore_best_weights=early_stopping
)

In [34]:
def custom_loss(y_true, y_pred):
    scce = tf.keras.losses.SparseCategoricalCrossentropy()
    return scce(
        tf.reshape(y_true, shape=(tf.shape(y_true)[0] * seq_length, )), 
        tf.reshape(y_pred, shape=(tf.shape(y_pred)[0] * seq_length, vocab_size))
    )

In [35]:
def custom_acc(y_true, y_pred):
    return tf.math.reduce_mean(
        tf.cast(
            tf.math.equal(
                tf.math.argmax(tf.reshape(y_pred, shape=(tf.shape(y_pred)[0] * seq_length, vocab_size)), axis=-1), 
                tf.cast(tf.reshape(y_true, shape=(tf.shape(y_true)[0] * seq_length, )), dtype=tf.int64)
            ), 
            dtype=tf.float32
        )
    )

## Model definition

In [36]:
lstm_1_state_h_in = keras.layers.Input(shape=[hidden_size])
lstm_1_state_c_in = keras.layers.Input(shape=[hidden_size])

lstm_2_state_h_in = keras.layers.Input(shape=[hidden_size])
lstm_2_state_c_in = keras.layers.Input(shape=[hidden_size])

lstm_3_state_h_in = keras.layers.Input(shape=[hidden_size])
lstm_3_state_c_in = keras.layers.Input(shape=[hidden_size])

input = keras.layers.Input(shape=[seq_length, vocab_size])

out, lstm_1_state_h_out, lstm_1_state_c_out = keras.layers.LSTM(hidden_size, return_sequences=True, return_state=True)(
    input, initial_state=[lstm_1_state_h_in, lstm_1_state_c_in]
)
out = layers.Dropout(dropout)(out)

out, lstm_2_state_h_out, lstm_2_state_c_out = keras.layers.LSTM(hidden_size, return_sequences=True, return_state=True)(
    out, initial_state=[lstm_2_state_h_in, lstm_2_state_c_in]
)
out = layers.Dropout(dropout)(out)

out, lstm_3_state_h_out, lstm_3_state_c_out = keras.layers.LSTM(hidden_size, return_sequences=True, return_state=True)(
    out, initial_state=[lstm_3_state_h_in, lstm_3_state_c_in]
)
out = layers.Dropout(dropout)(out)

out = layers.Dense(vocab_size)(out)
out = layers.Activation('softmax')(out)

out_acc = layers.Lambda(lambda x:x, name = "out_acc")(out)

model = keras.models.Model(
    inputs=[
        input, 
        lstm_1_state_h_in, lstm_1_state_c_in,
        lstm_2_state_h_in, lstm_2_state_c_in,
        lstm_3_state_h_in, lstm_3_state_c_in
    ], 
    outputs=[
        out_acc,
        lstm_1_state_h_out, lstm_1_state_c_out,
        lstm_2_state_h_out, lstm_2_state_c_out,
        lstm_3_state_h_out, lstm_3_state_c_out
    ]
)

model.compile(
    loss=[custom_loss, None, None, None, None, None, None], 
    loss_weights=[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    metrics={'out_acc':custom_acc},
    optimizer=optimizer
)

## Model training

In [37]:
dummy = np.zeros((len(inputs), hidden_size))

In [38]:
history = model.fit(
    [inputs, dummy, dummy, dummy, dummy, dummy, dummy], 
    [targets, dummy, dummy, dummy, dummy, dummy, dummy], 
    batch_size=batch_size,
    validation_split=0.1,
    shuffle=True,
    epochs=max_epochs, 
    callbacks=[es_callback]
)
for i in range(10):
    model.save('trained_models/lstm_conditional_elements.h5')

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50


## Load trained model

In [39]:
model = keras.models.load_model(
    'trained_models/lstm_conditional_elements.h5', 
    custom_objects={'custom_loss':custom_loss, 'custom_acc':custom_acc}
)

In [None]:
model.evaluate(
    [inputs, dummy, dummy, dummy, dummy, dummy, dummy], 
    [targets, dummy, dummy, dummy, dummy, dummy, dummy],
    batch_size=5, verbose=1
)  # sanity check

## Generate level

In [57]:
def onehot_to_string(onehot):
    ints = np.argmax(onehot, axis=-1)
    chars = [ix_to_char[ix] for ix in ints]
    string = "".join(chars)
    char_array = []
    if len(string.rstrip().split('\n')[-1]) < 22 - 1:  # minus 1 for new line char
        for line in string.strip().split('\n')[:-1]:
            char_array.append(list(line[:21]))
    else:
        for line in string.strip().split('\n')[:-1]:
            char_array.append(list(line[:21]))
    char_array = np.array(char_array).T
    string = ""
    for row in char_array:
        string += "".join(row) + "\n"
    return string

In [58]:
seed = inputs[0][:3 * 22 - 1].copy()  # 3 cols * 18 tiles per col - newline char (to be generated)
seed[22+19] = 0
seed[22+19][char_to_ix['x']] = 1
seed[22*2+19] = 0
seed[22*2+19][char_to_ix['x']] = 1
print(seed.shape)
print(onehot_to_string(seed))

(65, 17)
NN
NN
NN
NN
NN
--
--
--
--
--
--
--
--
--
--
--
--
--
--
-x
XX



In [59]:
mode = 'single_alternate'

In [60]:
if mode == 'single':
    condition_tapess_dir = "condition_tapes_single"
elif mode == 'single_alternate':
    condition_tapess_dir = "condition_tapes_single_alternate"
elif mode == 'double_alternate':
    condition_tapess_dir = "condition_tapes_double_alternate"
gen_parent_dir = "lstm_conditional_generate_levels_txt_by_tape"

In [61]:
num_levels_to_gen = 1
level_height = 16 + 1 + 5

In [62]:
print(condition_tapess_dir)

condition_tapes_single_alternate


In [63]:
for tape_fname in os.listdir(condition_tapess_dir):
    
    condition_tapes_path = f"{condition_tapess_dir}/{tape_fname}"
    gen_dir = f"{gen_parent_dir}/{tape_fname.split('.')[0]}"
    os.makedirs(gen_dir, exist_ok=True)
    
    with open(condition_tapes_path, "r") as json_f:
        condition_tapes = np.array(json.load(json_f))
        
    num_tile_to_gen = condition_tapes.shape[1] * level_height
    
    for j in tqdm(range(1, num_levels_to_gen+1)):

        seed = inputs[0][:3 * 22 - 1].copy()  # 3 cols * 18 tiles per col - newline char (to be generated)
        seed[22+19] = 0
        seed[22+19][char_to_ix['x']] = 1
        seed[22*2+19] = 0
        seed[22*2+19][char_to_ix['x']] = 1
        gen = seed.copy()

        # initialize all hidden and cell states to zeros
        lstm1_h = np.zeros((1, hidden_size))
        lstm1_c = np.zeros((1, hidden_size))
        lstm2_h = np.zeros((1, hidden_size))
        lstm2_c = np.zeros((1, hidden_size))
        lstm3_h = np.zeros((1, hidden_size))
        lstm3_c = np.zeros((1, hidden_size))

        num_condition_chars_to_add = False
        col_ix_generating = -1

        for i in tqdm(range(num_tile_to_gen), leave=False):

            seed = np.expand_dims(seed, axis=0)

            # predict probas and update hidden and cell states
            probas, lstm1_h, lstm1_c, lstm2_h, lstm2_c, lstm3_h, lstm3_c = model.predict([
                seed, lstm1_h, lstm1_c, lstm2_h, lstm2_c, lstm3_h, lstm3_c
            ])

            # ========== generic prediction ==========

            if num_condition_chars_to_add == 0:

                probas = probas[0][-1]  # first batch, last timestep

                idx = np.random.choice(np.arange(len(probas)), p=probas)
                seed = np.zeros((1, vocab_size))
                seed[:, idx] = 1.
                gen = np.vstack([gen, seed])

                if ix_to_char[idx] == '\n':
                    num_condition_chars_to_add = 5
                    col_ix_generating += 1

            # ========== condition char are not generated, they are loaded from the condition tape ==========

            else:

                seed = np.zeros((1, vocab_size))

                if num_condition_chars_to_add == 5:
                    condition_tape = condition_tapes[0]
                elif num_condition_chars_to_add == 4:
                    condition_tape = condition_tapes[1]
                elif num_condition_chars_to_add == 3:
                    condition_tape = condition_tapes[2]
                elif num_condition_chars_to_add == 2:
                    condition_tape = condition_tapes[3]
                elif num_condition_chars_to_add == 1:
                    condition_tape = condition_tapes[4]

                if condition_tape[col_ix_generating] == 0:
                    seed[:, char_to_ix['N']] = 1
                elif condition_tape[col_ix_generating] == 1:
                    seed[:, char_to_ix['Y']] = 1
                gen = np.vstack([gen, seed])

                num_condition_chars_to_add -= 1

        with open(f'{gen_dir}/{j}.txt', 'w+') as txt_f:
            txt_f.write(onehot_to_string(gen))

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3520.0), HTML(value='')))




In [67]:
onehot_to_string(gen)

'NNN\nNNN\nNNN\nNNN\nNNN\n---\n---\n---\n---\n---\n---\n---\n---\n---\n---\n---\n---\n---\n---\n-xx\nXXX\n'