In [None]:
from __future__ import print_function
import tensorflow as tf
from preeminence_utils import tf_utils
import numpy as np
import random
import os

In [None]:
def get_latest_epoch():
    models = os.listdir("./model/model_southpark/")
    all_epochs = [int(model[model.find("-")+1:].replace(".meta","")) for model in models if "meta" in model]
    return max(all_epochs)

In [None]:
# initialise text variables
data_file = "./data/south_park/all_scripts.txt"
text = open(data_file).read().strip()
print(len(text))
vocab = sorted(list(set(text)))
vocab_length = len(vocab)
characters2id = dict((c, i) for i, c in enumerate(vocab))
id2characters = dict((i, c) for i, c in enumerate(vocab))
section_length = 50
step = 10
sections = []
section_labels = []
for i in range(0,len(text)-section_length,step):
    sections.append(text[i:i+section_length])
    section_labels.append(text[i+section_length])

X_data = np.zeros((len(sections),section_length,vocab_length))
Y_data = np.zeros((len(sections),vocab_length))
for i,section in enumerate(sections):
    for j,letter in enumerate(section):
        X_data[i,j,characters2id[letter]] = 1
    Y_data[i,characters2id[section_labels[i]]] = 1

print(X_data.shape,Y_data.shape)


In [None]:
model = tf_utils.Model()
model_graph = model.init().as_default()

In [None]:
learning_rate = 0.01
total_epochs = 500
batch_size = 128
log_every = 100
save_every = 10
hidden_nodes = 1024

In [None]:
X = tf.placeholder(tf.float32,[None,section_length,vocab_length],name="X_train")
Y = tf.placeholder(tf.float32,[None,vocab_length],name="Y_train")

W = tf.Variable(tf.random_normal([hidden_nodes,vocab_length]),name="Output_weights")
b = tf.Variable(tf.random_normal([vocab_length]),name="Output_bias")

In [None]:
def lstm(x,weights,bias,name_scope="lstm"):
    with tf.name_scope(name_scope):
        x = tf.unstack(x,section_length,1)
        lstm_cell = tf.contrib.rnn.BasicLSTMCell(hidden_nodes,forget_bias=1.0)
        outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
        return tf.matmul(outputs[-1],W)+b

In [None]:
logits = lstm(X,W,b)
prediction = tf.nn.softmax(logits)
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)

correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.global_variables_initializer()
# saver = tf.train.Saver(max_to_keep=0)


In [None]:
# Training
epoch_start = get_latest_epoch()
print("Resuming training from epoch: {}".format(epoch_start))
batch_size = 1024
with model.session() as sess:
#     sess.run(init)
    model.restore_weights("./model/model_southpark/")
    for i in range(1):
        ops = model.train([train_op,loss_op],X,Y,X_data,Y_data,num_epochs=1,batch_size=batch_size)
        model.save_weights(checkpoint_path="./model/model_southpark/")

In [None]:
def sample2(preds, temperature=1.0):
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return probas

In [None]:
def get_weights(weights_path):
    if os.path.exists(weights_path):
        weight_files = os.listdir(weights_path)
    else:
        return None
#     print(weight_files)
    weights = []
    for weight_file in weight_files:
        if "meta" in weight_file:
#             print(weight_file)
            weights.append(weight_file.split("-")[1][:-5])
    return sorted([int(weight) for weight in weights])
            

In [None]:
# Testing
# test_start = "STAN I said 'We're not getting on, you fat ugly bitch'."[:50]
# section = [test_start]
prediction_length = 500
epoch_test = get_latest_epoch()
# print("Testing epoch: {}".format(epoch_test))
#
for weight in get_weights("./model/model_southpark_bk/"):
    print("\n\nEpoch number: {}".format(weight))
    with model.session() as sess:
    #     model.restore_weights("./model/model_southpark_bk/")
        latest_checkpoint = "./model/model_southpark_bk/model.ckpt-"+str(weight)
        saver = tf.train.Saver()

        saver.restore(sess, latest_checkpoint)
        start_index = random.randint(0, len(text) - section_length - 1)
        test_start = text[start_index: start_index + section_length]

        X_test = np.zeros((1,section_length,vocab_length))
        for i,c in enumerate(test_start):
            X_test[0,i,characters2id[c]]=1

        for temperature in [0.2]:
#             print("Temperature = {}".format(temperature))

            print(test_start,end="")
            for _ in range(prediction_length): 
                pred = sess.run(prediction,feed_dict={X:X_test})
                pred = pred.reshape(-1)
                pred = sample2(pred,temperature)
                next_char = id2characters[np.argmax(pred)]
    #             print(np.argmax(pred),max(pred))
    #             print(pred)
    #             break
                test_start += next_char
                print(next_char,end="")

                test_start = test_start[-section_length:]
                X_test = np.zeros((1,section_length,vocab_length))
                for i,c in enumerate(test_start):
                    X_test[0,i,characters2id[c]]=1


In [None]:
sess.close()