# Authorship Style Transfer

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf

---

## Data Preprocessing

In [None]:
text_file_path = "data/c50-articles.txt"
label_file_path = "data/c50-labels.txt"

### Conversion of texts into integer sequences

In [None]:
MAX_SEQUENCE_LENGTH = 100
EMBEDDING_SIZE = 300

In [None]:
text_tokenizer =  tf.keras.preprocessing.text.Tokenizer(num_words=1000)

with open(text_file_path) as text_file:
    text_tokenizer.fit_on_texts(text_file)
    
with open(text_file_path) as text_file:
    integer_text_sequences = text_tokenizer.texts_to_sequences(text_file)

len(integer_text_sequences)

padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
     integer_text_sequences, maxlen=MAX_SEQUENCE_LENGTH, padding='post', truncating='post')

padded_sequences.shape

### Conversion of labels to one-hot represenations

In [None]:
label_tokenizer =  tf.keras.preprocessing.text.Tokenizer(lower=False)

with open(label_file_path) as label_file:
    label_tokenizer.fit_on_texts(label_file)

with open(label_file_path) as label_file:
    label_sequences = label_tokenizer.texts_to_sequences(label_file)

one_hot_labels = list(
    map(lambda x: np.eye(len(label_tokenizer.word_index), k=x[0])[0], label_sequences))

---

## Deep Learning Model

### Setup Instructions

In [None]:
with tf.device("/gpu:0"):
    
    # needed to clear the existing graph when the cell is re-run
    tf.reset_default_graph()
    
    def get_sentence_representation(index_sequence, word_embeddings):

        # dense embedded sequence
        embedded_sequence = tf.nn.embedding_lookup(
            word_embeddings, input_sequence, name="embedded_sequence")

        lstm_cell_fw = tf.contrib.rnn.BasicLSTMCell(num_units=128, name="lstm_cell_fw_content")
        lstm_cell_bw = tf.contrib.rnn.BasicLSTMCell(num_units=128, name="lstm_cell_bw_content")

        rnn_outputs, rnn_states = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=lstm_cell_fw, cell_bw=lstm_cell_bw, inputs=embedded_sequence, 
            dtype=tf.float32, time_major=False)
        rnn_state = tf.concat([rnn_states[0].h, rnn_states[1].h], axis=1)

        return rnn_state

    def get_content_representation(sentence_representation):
        dense_content = tf.layers.dense(
            inputs=sentence_representation, units=128, 
            activation=tf.nn.relu, name="dense_content")
        
        return dense_content

    def get_style_representation(sentence_representation):
        dense_style = tf.layers.dense(
            inputs=sentence_representation, units=128, 
            activation=tf.nn.relu, name="dense_style")
        
        return dense_style

    def get_label_prediction(content_representation):

        dense_1 = tf.layers.dense(
            inputs=content_representation, units=len(label_tokenizer.word_index), 
            activation=tf.nn.relu, name="dense_1")

        softmax_output = tf.nn.softmax(dense_1, name="softmax")

        return softmax_output


    # input variable - text sequence converted to an index sequence
    input_sequence = tf.placeholder(
        tf.int32, [None, MAX_SEQUENCE_LENGTH], name="input_sequence")
    print("input_sequence: ", input_sequence)

    input_label = tf.placeholder(
        tf.float32, [None, len(label_tokenizer.word_index)], name="input_label")
    print("input_label: ", input_label)

    # learn embeddings matrix - can be initialized with pre-trained embeddings
    word_embeddings = tf.get_variable(
        shape=[len(label_tokenizer.word_index) + 1, EMBEDDING_SIZE], name="word_embeddings", 
        dtype=tf.float32)
    print("word_embeddings: ", word_embeddings)
    
    # get sentence representation
    sentence_representation = get_sentence_representation(input_sequence, word_embeddings)
    print("sentence_representation:", sentence_representation)

    # get content representation
    content_representation = get_content_representation(sentence_representation)
    print("content_representation:", content_representation)

    # get style representation
    style_representation = get_style_representation(sentence_representation)
    print("style_representation:", style_representation)

    # use content representation to predict a label
    label_prediction = get_label_prediction(content_representation)
    print("label_prediction:", label_prediction)
    
    adversarial_loss = tf.losses.softmax_cross_entropy(
        onehot_labels=input_label, logits=label_prediction)
    
    adversarial_loss_summary = tf.summary.scalar(
        "adversarial_loss", tf.convert_to_tensor(adversarial_loss))
    
    adversarial_optimizer = tf.train.AdamOptimizer()
    adversarial_training_operation = adversarial_optimizer.minimize(adversarial_loss)

### Train Network

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    writer = tf.summary.FileWriter(logdir="tensorflow_logs")
    
    epoch_reporting_interval = 1
    training_examples_fraction = 0.9
    training_examples_size = int(training_examples_fraction * len(one_hot_labels))
    batch_size = 100
    training_epochs = 50
    num_batches = int(training_examples_size/batch_size)
    
    training_step = 1
    for current_epoch in range(1, training_epochs + 1):
        for batch_number in range(num_batches):
            _, loss_var, loss_summary_var = sess.run(
                [adversarial_training_operation, adversarial_loss, adversarial_loss_summary], 
                feed_dict={
                    input_sequence: padded_sequences[batch_number * batch_size : 
                                               (batch_number + 1) * batch_size],
                    input_label: one_hot_labels[batch_number * batch_size : 
                                    (batch_number + 1) * batch_size]})
            writer.add_summary(loss_summary_var, training_step)
            writer.flush()
            training_step += 1

        if (current_epoch % epoch_reporting_interval == 0):
            print("Training epoch: {}; Loss:{}".format(current_epoch, loss_var))
            
    training_predictions = sess.run(
        label_prediction, 
        feed_dict={
            input_sequence: padded_sequences[:training_examples_size], 
            input_label: one_hot_labels[:training_examples_size]
        })
    
    test_predictions = sess.run(
        label_prediction, 
        feed_dict={
            input_sequence: padded_sequences[training_examples_size:], 
            input_label: one_hot_labels[training_examples_size:]
        })
    
    writer.flush()
    writer.close()