In [None]:
import numpy as np
import common
versePairs = common.loadTrainingData()
X, Y = common.cleanAndSplitVerses(versePairs)
Xenc, Yenc = common.encXandY(X, Y)
# Use this to check that regexClean and regexUnclean are perfect inverses on the training data.
# If the result is non empty you probably need to clean up the corresponding lines of your training set.
assert len([v for v in versePairs if common.regexUnclean(common.regexClean(v)) != v]) == 0
maxlen = 100
Xnp, Ynp = common.padXandY(Xenc, Yenc, maxlen)

In [None]:
from keras import backend as K
# Define the model. Uncomment this if you don't have the saved model available.
from keras.models import Sequential, Model
from keras.layers import LSTM, GRU, Dense, TimeDistributed, Bidirectional, Input, Embedding
from keras.layers.merge import Concatenate
from keras.layers.core import Dropout

# Generator

inputs = Input(shape=(100, 31))
bidi = Bidirectional(LSTM(256, return_sequences=True))(inputs)
den = Dense(4, activation='softmax')(bidi)
gen = Model(inputs=inputs, outputs=den, name='gen_out')
gen.compile(loss='categorical_crossentropy', optimizer='adam', sample_weight_mode='temporal', metrics=['categorical_accuracy'])
#from keras.models import load_model
#model = load_model("saved_1step_model")

gen.summary()

In [None]:
# Discriminator
discIn = Input((100, 35))
disc = Bidirectional(GRU(128, return_sequences=True))(discIn)
disc = GRU(64, return_sequences=True)(disc)
disc = GRU(32)(disc)
disc = Dense(1, activation='sigmoid')(disc)
disc = Model(inputs=discIn, outputs=disc, name='disc_out')
disc.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
disc.summary()


In [None]:
common.make_trainable(disc, False)
gan_input = Input((100,31))
H = gen(gan_input)
conc = Concatenate()([gan_input, H])
gan_V = disc(conc)
gan = Model(gan_input, outputs=[H, gan_V])
gan.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
gan.summary()

In [None]:
# Generate a weighting for the different characters. We want to penalise the model according to how rare a symbol is
# so that rare symbols are more important to place correctly than common ones.
import numpy as np
from collections import Counter
Yclass = np.argmax(Ynp, axis=2)
chars = Counter("".join(Y))
freq = chars.values()
total = np.sum(list(freq))
toReplace = {}
toReplace[0] = total/chars["0"]
toReplace[1] = total/chars["|"]
toReplace[2] = total/chars["·"]
toReplace[3] = total/chars["*"]
def replace(clas):
    return toReplace[clas]
sample_weight = np.vectorize(replace)(Yclass)

In [None]:
# A callback to display a particular verse after each epoch.
from keras.callbacks import Callback
class ShowVerse(Callback):
    def __init__(self, verse=0):
        self.verse = verse
    def on_epoch_end(self, batch, logs={}):
        pred = self.model.predict_classes(Xnp[self.verse:self.verse+1])
        toComb = common.decClasses(pred[0])
        print(common.regexUnclean(common.mergeStrings(X[self.verse], toComb)))
        pred = self.model.predict(Xnp[self.verse:self.verse+1], batch_size=256)
        toComb = common.getToComb(pred[0])
        print(common.regexUnclean(common.mergeStrings(X[self.verse], toComb)))

In [None]:
from keras.callbacks import Callback
class EarlyStoppingByAccuracy(Callback):
    def __init__(self, monitor='acc', value=0.6, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if current is None:
            print("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        elif current > self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

class EarlyStoppingByLoss(Callback):
    def __init__(self, monitor='loss', value=0.5, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if current is None:
            print("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        elif current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

In [None]:
# A regex that matches a more-or-less correctly pointed verse
regex = r"[^|·*]+\|[^|·*]+·*[^|·*]+\|[^|·*]+\*\n[^|·*]+\|[^|·*]+·*[^|·*]+\|[^|·*]+·*[^|·*]+\|[^|·*]+$"

In [None]:
true = np.concatenate([Xnp, Ynp], axis=-1)
def pre_train_gen(epochs):
    common.make_trainable(gen, True)
    gen.fit(Xnp, Ynp, epochs=epochs, batch_size=256, sample_weight=sample_weight)
    common.make_trainable(gen, False)

def train_disc(epochs, up_to_value=None):
    gen_pred = gen.predict(Xnp)
    false = np.concatenate([Xnp, gen_pred], axis=-1)
    false_len = false.shape[0]
    false_goal = np.zeros(false_len)
    for i in range(false_len):
        toComb = getClasses(gen_pred[i])
        candidate = regexUnclean(mergeStrings(X[i], toComb))
        if re.match(regex, candidate):
            false_goal[i] = 1
    true_then_false = np.concatenate([true, false], axis=0)
    true_len = true.shape[0]
    goal = np.concatenate([np.ones(true_len), false_goal])
    common.make_trainable(disc, True)
    if up_to_value:
        disc.fit(true_then_false, goal, epochs=epochs, batch_size=256, callbacks=[EarlyStoppingByLoss(value=up_to_value)])
    else:
        disc.fit(true_then_false, goal, epochs=epochs, batch_size=256)
    common.make_trainable(disc, False)

from keras.callbacks import ModelCheckpoint
    
def train_gen(epochs, up_to_value=None, i=0):
    common.make_trainable(gen, True)
    if up_to_value:
        gan.fit(Xnp, [Ynp, np.ones(true.shape[0])], epochs=epochs, batch_size=256, 
                callbacks=[ModelCheckpoint('weights.{0}.{1}'.format(i, '{epoch:02d}'), monitor='disc_out_loss'),
                    EarlyStoppingByLoss(monitor="disc_out_loss",value=up_to_value)])
    else:
        gan.fit(Xnp, [Ynp, np.ones(true.shape[0])], epochs=epochs, batch_size=256, 
                callbacks=[ModelCheckpoint('weights.{0}.{1}'.format(i, '{epoch:02d}'), monitor='disc_out_loss')])
    common.make_trainable(gen, False)

def train_both(epochs):
    for i in range(epochs):
        print("Epoch is {0}".format(i))
        print("Training discriminator")
        train_disc(100, up_to_value=0.2)
        print("Training generator")
        train_gen(1000, up_to_value=0.2, i=i)

In [None]:
pre_train_gen(500)

In [None]:
gen.load_weights("gen_weights")

In [None]:
gen.save_weights('gen_weights')

In [None]:
train_both(30)

In [None]:
train_disc(100)

In [None]:
train_gen(100)

In [None]:
pred = gen.predict(Xnp, batch_size=256)
tf = disc.predict(np.concatenate([Xnp, pred], axis=-1))
for i in range(pred.shape[0]):
    toComb = common.getClasses(pred[i])
    print(tf[i])
    print(common.regexUnclean(common.mergeStrings(X[i], toComb)))

In [None]:
pred = model.predict_classes(Xnp, batch_size=256)
for i in range(pred.shape[0]):
    toComb = decClasses(pred[i])
    print(regexUnclean(mergeStrings(X[i], toComb)))

In [None]:
pred = model.predict(Xnp, batch_size=256)
for i in range(pred.shape[0]):
    toComb = getToComb(pred[i])
    print(regexUnclean(mergeStrings(X[i], toComb)))

In [None]:
# Load and predict on some test data.
from keras.preprocessing.sequence import pad_sequences
test = []
with open("testCleaned.txt", 'r', encoding="utf-8") as file:
    for line in file:
        test.append(line[:-1])
testEnc = [common.encString(l) for l in test]
testNp = pad_sequences(testEnc, maxlen=191)
testPred = gen.predict_classes(testNp, batch_size=256)
for i in range(len(test)):
    toComb = common.decClasses(testPred[i])
    print(common.regexUnclean(common.mergeStrings(test[i], toComb)))

In [None]:
# Load and predict on some test data.
from keras.preprocessing.sequence import pad_sequences
test = []
with open("john1.txt", 'r', encoding="utf-8") as file:
    for line in file:
        test.append(line[:-1])
testEnc = [encString(l) for l in test]
testNp = pad_sequences(testEnc, maxlen=maxlen)
testPred = gen.predict(testNp, batch_size=256)
matches = 0
for i in range(len(test)):
    toComb = getClasses(testPred[i])
    candidate = regexUnclean(mergeStrings(test[i], toComb))
    if re.match(regex, candidate):
        matches += 1
        print(candidate)
print("There were {0} verses and {1} matches.".format(len(test), matches))