In [1]:
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
import numpy as np
import re
from time import time
from datetime import datetime
import os

In [2]:
! rm -rf checkpoints; mkdir checkpoints

In [3]:
class SyllableParser(object):
    

    def __init__(self, num_epochs=100, batch_size=20, hidden_size=128, cell_type='lstm',
                 net_type='brnn', num_layers=3, treshold=0.5):
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        self.num_layers = num_layers
        self.net_type=net_type
        self.treshold = treshold
    
    def encode(self, word):
        return list(map(lambda x: self.mapping.index(x), list(word.strip().lower())))
    
    def decode(self, encoded_word):
        return ''.join(map(lambda x : self.mapping[x], list(map(int,encoded_word))))
    
    def encode_syllables(self, syllables):
        encoded_syllables = []
        for syllable in syllables:
            if len(syllable.strip()) > 0:
                encoded_syllables.extend([0] * (len(syllable.strip()) - 1) + [1])
        return encoded_syllables
    
    def pad_into_matrix(self, rows, padding_val=0):
        matrix = []
        lengths = np.array(list(map(len, rows)))
        matrix_width = np.max(lengths)
    
        for row in rows:
            if len(row) < matrix_width:
                matrix.append(np.hstack((np.array(row), np.array([padding_val] * (matrix_width - len(row))))))
            else:
                matrix.append(np.array(row))
        matrix = np.vstack(matrix)
        return matrix, lengths
    
    def fit_data(self, filename):
        with open(filename) as fin:
            self.mapping = list(set(list(''.join(fin.readlines()))))
            fin.seek(0)
            X = []
            y = []
            for idx, line in enumerate(fin):
                if idx == 0:
                    continue  # this is csv header
                tokens = re.split(r'\t', line)
                encoded_word = self.encode(tokens[0])
                encoded_syllables = self.encode_syllables(re.split(r'\s+', tokens[1]))
                X.append(encoded_word)
                y.append(encoded_syllables)
            self.X, self.lengths = X, list(map(len, X))
            self.y = y
    
    def prepare_test_data(self, filename):
        X = []
        with open(filename) as fin:
            for line in fin:
                X.append(self.encode(line))
        self.X_test, self.test_lengths =  X, list(map(len, X))
    
    def decode_prediction(self, words, predicted_labels, lengths):
        items = []
        for i in range(len(words)):
            word, pred, length = words[i], predicted_labels[i], lengths[i]
            syllables = []
            syllable = []
            for ch, idx in zip(list(self.decode(word[:length])), pred[:length]):
                if idx == 0:
                    syllable.append(ch)
                if idx == 1:
                    syllable.append(ch)
                    syllables.append(''.join(syllable))
                    syllable=[]
            if len(syllable) > 0:
                syllables.append(''.join(syllable))
            items.append((self.decode(word[:length]), syllables))
        return items
    
    def get_batches(self, mode):
        train_size = int(0.9 * len(self.X))
        if mode == 'train':
            X, y, lengths = self.X[:train_size], self.y[:train_size], self.lengths[:train_size]
        elif mode == 'val':
            X, y, lengths = self.X[train_size:], self.y[train_size:], self.lengths[train_size:]
        elif mode == 'test':
            X, y, lengths = self.X_test, self.y[:len(self.X_test)], self.test_lengths
        else:
            raise ValueError('Unknown mode.')
        X_batch, y_batch, lengths_batch = [], [], []
        for idx, x_sample in enumerate(X):
            if idx > 0 and idx % self.batch_size == 0:
                X_batch, lengths_batch = self.pad_into_matrix(X_batch, 0)
                y_batch, _ = self.pad_into_matrix(y_batch, 0)
                yield X_batch, y_batch, lengths_batch
                X_batch, y_batch, lengths_batch = [], [], []
            X_batch.append(x_sample)
            y_batch.append(y[idx])
            lengths_batch.append(lengths[idx])
    
    def accuracy(self, true_syllables, pred_syllables, seq_lengths):
        num_true_predictions = 0
        for i in range(len(true_syllables)):
            length = seq_lengths[i]
            true_syllable, pred_syllable = true_syllables[i, :length], pred_syllables[i, :length]
            if np.all(np.equal(true_syllable, pred_syllable)):
                num_true_predictions += 1
        # results = np.all(np.equal(true_syllables, pred_syllables), axis=1)
        # true_predictions = len(results[results == True])
        return num_true_predictions / float(len(true_syllables))
    
    def construct_graph(self):
        self.graph = tf.Graph()
        hidden_state_size = self.hidden_size
        if self.net_type == 'brnn':
            hidden_state_size *= 2
        with self.graph.as_default():
            self.words = tf.placeholder(tf.int32, shape=(self.batch_size, None), name='words')
            self.syllable_labels = tf.placeholder(tf.int32, shape=(self.batch_size, None), name='syllable_labels')
            self.seq_lengths = tf.placeholder(tf.int32, shape=(self.batch_size), name='lengths')
            W = tf.Variable(tf.truncated_normal([hidden_state_size, 2]), dtype=tf.float32)
            b = tf.Variable(np.zeros([2]), dtype=tf.float32)
            embedding_matrix = tf.Variable(tf.truncated_normal([len(self.mapping), self.hidden_size],
                                                               stddev=np.sqrt(2.0/ self.hidden_size)))
            embedding = tf.nn.embedding_lookup(embedding_matrix, self.words)
            treshold = tf.Variable(np.array([self.treshold]), dtype=tf.float32, name='treshold')
            if self.cell_type == 'lstm':
                cell = rnn_cell.LSTMCell(self.hidden_size)
            elif self.cell_type == 'gru':
                cell = rnn_cell.GRUCell(self.hidden_size)
            else:
                raise ValueError('Unknown cell type.')
            rnn_multicell = rnn_cell.MultiRNNCell([cell] * self.num_layers)
            if self.net_type == 'rnn':
                self.outputs, _ = tf.nn.dynamic_rnn(rnn_multicell, embedding, sequence_length=self.seq_lengths,
                                                    dtype=tf.float32, swap_memory=True)
            elif self.net_type == 'brnn':
                self.outputs, _ = tf.nn.bidirectional_dynamic_rnn(rnn_multicell, rnn_multicell, embedding,
                                                                  sequence_length=self.seq_lengths,
                                                                  dtype=tf.float32, swap_memory=True)
                self.outputs = tf.concat(2, self.outputs)
            # print(self.outputs.get_shape())
            outputs_reshape = tf.reshape(self.outputs, [-1, hidden_state_size])
            # print(outputs_reshape.get_shape())
            # print(W.get_shape())
            logits = tf.matmul(outputs_reshape, W) + b
            self.logits = tf.reshape(logits, [self.batch_size, -1, 2])
            # print(self.logits.get_shape())
            # print(self.syllable_labels.get_shape())
            # self.prediction = tf.argmax(self.logits, 2)
            probs = tf.nn.softmax(self.logits)
            # print(probs.get_shape())
            sliced_probs = tf.slice(probs, [0, 0, 1], [-1,-1,-1])
            greater = tf.greater(sliced_probs, treshold)
            # print(greater.get_shape())
            self.separation_indices = tf.where(greater)
            self.prediction = tf.zeros_like(greater)
            # print(sliced_probs.get_shape())
            # print(self.prediction.get_shape())
            self.loss = (tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, self.syllable_labels)*
                                       tf.sequence_mask(self.seq_lengths, tf.reduce_max(self.seq_lengths),
                                                        dtype=tf.float32)) /
                         tf.reduce_sum(tf.sequence_mask(self.seq_lengths, tf.reduce_max(self.seq_lengths),
                                                        dtype=tf.float32)))
            self.optimizer = tf.train.AdamOptimizer().minimize(self.loss)
            self.saver = tf.train.Saver()

    def run_session(self, checkpoints_dir):
        with self.graph.as_default():
            self.session = tf.Session()
            self.session.run(tf.initialize_all_variables())
            print("Checking for checkpoints...")
            latest_checkpoint = tf.train.latest_checkpoint(checkpoints_dir)
            if latest_checkpoint is not None:
                print("Found checkpoints, using them.")
                self.saver.restore(self.session, latest_checkpoint)
            else:
                print("No checkpoints found, starting training from scratch.")
            for epoch in range(self.num_epochs):
                print("Starting epoch {}".format(epoch))
                batch_losses = []
                start = time()
                for words_batch, syllable_labels_batch, lengths_batch in self.get_batches('train'):
                    feed_dict = {self.words: words_batch, self.syllable_labels: syllable_labels_batch,
                                 self.seq_lengths: lengths_batch}
                    if words_batch.shape != syllable_labels_batch.shape:
                        prediction = self.decode_prediction(words_batch,
                                                            syllable_labels_batch, lengths_batch)
                        print("Bad train examples:")
                        for word, syllables in prediction:
                            print('\t', word, ' '.join(syllables))

                    batch_loss, _ = self.session.run([self.loss, self.optimizer], feed_dict=feed_dict)
                    batch_losses.append(batch_loss)
                end = time()
                print('Epoch {} done. Loss: {}. Training took {} sec.'.format(epoch, np.mean(batch_losses),
                                                                              end - start))
                val_losses = []
                for val_words_batch, val_syllable_labels_batch, val_lengths_batch in self.get_batches('val'):
                    feed_dict = {self.words: val_words_batch, self.syllable_labels: val_syllable_labels_batch,
                                 self.seq_lengths: val_lengths_batch}
                    pred, indices, val_loss = self.session.run([self.prediction, self.separation_indices, self.loss],
                                                               feed_dict=feed_dict)
                    pred[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
                    val_losses.append(val_loss)
                    pred = pred.reshape((pred.shape[0], pred.shape[1])).astype(np.int32)
                print('Validation loss: {}'.format(np.mean(val_losses)))
                print('Accuracy: {}'.format(self.accuracy(val_syllable_labels_batch, pred, val_lengths_batch)))
                sample_indices = np.random.choice(np.arange(len(val_words_batch)), 3)
                prediction = self.decode_prediction(val_words_batch[sample_indices],
                                                    pred[sample_indices], val_lengths_batch[sample_indices])
                true_values = self.decode_prediction(val_words_batch[sample_indices],
                                                     val_syllable_labels_batch[sample_indices],
                                                     val_lengths_batch[sample_indices])
                # print(pred[sample_indices], val_syllable_labels_batch[sample_indices])
                print("Sample predictions:")
                for (word, syllables), (word, true_syllables) in zip(prediction, true_values):
                    print('\t', word, ' '.join(syllables), ' '.join(true_syllables))
                if epoch % 10 == 0:
                    print("Saving model...")
                    save_path = self.saver.save(self.session, os.path.join(checkpoints_dir,
                                                "checkpoint" + datetime.now().strftime("_%d.%m.%y_%H:%M") ))
                    print("Saved in " + save_path)
    
    def train(self, filename, checkpoints_dir):
        self.fit_data(filename)
        self.construct_graph()
        self.run_session(checkpoints_dir)
        return self.session # for further sampling
    
    def sample(self, session, filename, out_file='output.txt'):
        self.prepare_test_data(filename)
        with open(out_file, 'w') as fout:
            with self.graph.as_default():
                print('Sampling...')
                for words_batch, syllable_labels_batch, lengths_batch in self.get_batches('test'):
                    feed_dict = {self.words: words_batch, self.syllable_labels: syllable_labels_batch,
                                 self.seq_lengths: lengths_batch}
                    pred, indices = session.run([self.prediction, self.separation_indices], feed_dict=feed_dict)
                    pred[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
                    prediction = self.decode_prediction(words_batch, pred, lengths_batch)
                    for word, syllables in prediction:
                        print(word, ' '.join(syllables))
                        fout.write(word + ' ' +  ' '.join(syllables) + '\n')

In [4]:
parser = SyllableParser(num_epochs=100, batch_size=200, cell_type='gru', treshold=0.5)
trained_sess = parser.train('normal_syllables.txt', 'checkpoints')

Checking for checkpoints...
No checkpoints found, starting training from scratch.
Starting epoch 0
Epoch 0 done. Loss: 0.2647058367729187. Training took 14.024769067764282 sec.
Validation loss: 0.12295219302177429
Accuracy: 0.645
Sample predictions:
	 прочнеть проч неть проч неть
	 примереть при ме реть при ме реть
	 депроприация де про п ри а ци я де про при а ци я
Saving model...
Saved in checkpoints/checkpoint_09.11.16_15:22
Starting epoch 1
Epoch 1 done. Loss: 0.09427566826343536. Training took 13.886988878250122 sec.
Validation loss: 0.17933526635169983
Accuracy: 0.56
Sample predictions:
	 филлипика фил ли пи ка фил ли пи ка
	 антипрусский ан ти п рус с кий ан ти прус ский
	 подюжить по дю жить по дю жить
Starting epoch 2
Epoch 2 done. Loss: 0.08523307740688324. Training took 13.929564476013184 sec.
Validation loss: 0.15520554780960083
Accuracy: 0.58
Sample predictions:
	 падали па да ли па да ли
	 госнии гос ни и госнии
	 патогенность па то ген ность па то ген ность
Starting epoc

In [5]:
parser.sample(trained_sess, 'untitled.txt', 'test_output.txt')

Sampling...
в в
израиль из раиль
дама дама
ус ус
хгапп хгапп
антк антк
ао ао
аарон а а рон
ессентуки ес сен ту ки
еэп еэп
ес ес
рф рф
дата дата
япония я по ни я
безымянный бе зы мян ный
дхм дхм
карандаш карандаш
алфёров ал фё ров
папа па па
эвм эвм
он он
сестра сест ра
и и
ад ад
адаптация а дап та ци я
антигистаминный ан ти ги с та мин ный
барабан ба ра бан
белка бел ка
визуализация ви зу а ли за ци я
гм гм
демократ де мо крат
диез ди ез
дилетант ди ле тант
досуг до суг
искусно ис ку сно
какао ка ка о
медиевистика ме ди е вис ти ка
аргон аргон
платина пла ти на
например на при мер
паркетник пар кет ник
пиньинь пинь инь
полиция по ли ци я
дупа ду па
язик язик
сладенький сла день кий
тело те ло
улучшение у луч ше ни е
утешение у те ше ни е
ученье у чень е
последовавший по сле до вав ший
фламандский фла манд ский
экзекуция эк зе ку ци я
в в
спид спид
алгебра алгеб ра
двоемыслие дво е мыс ли е
г г
кенотрон ке но трон
пеночка-теньковка пе ноч ка- тень ков ка
таннид тан нид
литера ли те ра
ё