# GAN for Supervised Formality Transfer
This was an exceptionally dumb attempt to use a GAN. I thought I 

In [1]:
import numpy as np
import tensorflow as tf

import re 
import os
from datetime import datetime

from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import TweetTokenizer

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

import workflow_manager as wm

## Load Data

In [2]:
EMBEDDING_DIM = 50
ENCODER_UNITS = 512
DECODER_UNITS = 512
ATTENTION_UNITS = 256

### Seq2Seq Data

In [3]:
BASE_PATH = '../Data'
train, val, test, context = wm.load_and_tokenize(BASE_PATH)

E_weights = wm.embedding_matrix(context['input_tokenizer'], 
                                context['input_vocab_size'], 
                                BASE_PATH)
DE_weights = wm.embedding_matrix(context['target_tokenizer'],
                                 context['target_vocab_size'],
                                 BASE_PATH)

## Get Models 

In [4]:
# encoder = wm.Encoder(context['input_vocab_size'], EMBEDDING_DIM,
#                      ENCODER_UNITS, E_weights)
# decoder = wm.Decoder(context['target_vocab_size'], EMBEDDING_DIM,
#                      ATTENTION_UNITS, DECODER_UNITS)
RNN = wm.AttentionalEncoderDecoder(context['input_vocab_size'], context['target_vocab_size'],
                                   EMBEDDING_DIM, ATTENTION_UNITS, DECODER_UNITS, ENCODER_UNITS,
                                   E_weights)
discriminator = wm.Discriminator(context['target_vocab_size'], EMBEDDING_DIM, DE_weights)

## Define Optimizers and Loss Functions
### Discriminator
The loss function for the discriminator is calculated on how well it can discern informal and formal outputs
from the generator. 

In [22]:
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
static_loss = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

In [6]:
def discriminator_loss_func(reference, generated):
    reference_loss = cross_entropy(tf.ones_like(reference), reference)
    generated_loss = cross_entropy(tf.zeros_like(generated), generated)
    return reference_loss + generated_loss

In [23]:
example_input_batch, example_target_batch = next(iter(train))

In [26]:
init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]

# Get start token for every sequence in batch
dec_input = tf.expand_dims([context['target_tokenizer'].word_index['<start>']] * BATCH_SIZE, 1)

predictions = RNN(example_input_batch, init_state, dec_input)

In [None]:
cross_entropy()

### Generator
This loss only applies to how the BA net did in its efforts to trick the discriminator

In [7]:
def generator_loss_func(generator_results):
    return cross_entropy(tf.ones_like(generator_results), generator_results)

### Optimizers

In [8]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

## Training Loop
This learns a sequence and then goes through usual GAN paradigm

In [9]:
BATCH_SIZE = 32

In [10]:
def rnn_loss_function(real, preds):
    """this is normal seq2seq loss"""

    # caclulate loss
    loss = static_loss(real, preds)
    
    # create padding mask 
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    mask = tf.cast(mask, dtype=loss.dtype)
    
    # apply mask
    loss *= mask

    return tf.reduce_mean(loss)

### Predict Funciton

In [11]:
def predict(inpt, trgt, train=True):
    target_tokenizer = context['target_tokenizer']
    
    # initialize seqs tensor
    gen_seqs = tf.constant([target_tokenizer.word_index['<start>']] * BATCH_SIZE, dtype=tf.int64)
    gen_seqs = tf.expand_dims(gen_seqs, axis=1)
    
    # This resets the hidden state of the LSTM for every epoch
    init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]
    
    ## Generate Sequences
    enc_output, dec_hidden_forward, dec_hidden_backward = encoder(inpt, init_state)

    # Get start token for every sequence in batch
    dec_input = tf.expand_dims([target_tokenizer.word_index['<start>']] * BATCH_SIZE, 1)

    for i in range(1, trgt.shape[1]):
        # dec_hidden shape: (batch_size, decoder_units)
        # dec_input shape: (batch_size, 1)
        predictions, dec_hidden_forward, _ = decoder(dec_input, 
                                                     dec_hidden_forward, 
                                                     dec_hidden_backward, 
                                                     enc_output)

        loss += loss_function(trgt[:, i], predictions)
        dec_input = tf.expand_dims(trgt[:, i], 1)
        
        # Need to hold onto seqs for discriminator
        new_preds = tf.argmax(predictions, axis=1)
        new_preds = tf.expand_dims(new_preds, axis=1)
        gen_seqs = tf.concat([gen_seqs, new_preds], axis=1)
    
    if not train:
        reference_results = discriminator(trgt, True)
        generated_results = discriminator(seqs, True)
        
        # compute losses
        gen_loss = generator_loss(generated_results)
        disc_loss = discriminator_loss(reference_results, generated_results)
        
        return gen_loss, disc_loss
        
    return gen_seqs

In [12]:
def combined_predict(inpt, target, train=True):
    loss = 0
    
    target_tokenizer = context['target_tokenizer']
    
    # initialize seqs tensor
    gen_seqs = tf.constant([target_tokenizer.word_index['<start>']] * BATCH_SIZE, dtype=tf.int64)
    gen_seqs = tf.expand_dims(gen_seqs, axis=1)

    # This resets the hidden state of the LSTM for every epoch
    init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]

    # Get start token for every sequence in batch
    dec_input = tf.expand_dims([target_tokenizer.word_index['<start>']] * BATCH_SIZE, 1)

    decoder_hidden_forward = None

    for i in range(1, trgt.shape[1]):
        # dec_hidden shape: (batch_size, decoder_units)
        # dec_input shape: (batch_size, 1)
        predictions, decoder_hidden_forward = RNN(inpt, init_state, dec_input, 
                                                  decoder_hidden_forward)

        loss += rnn_loss_function(trgt[:, i], predictions)

        # Need to hold onto seqs for discriminator
        new_preds = tf.argmax(predictions, axis=1)
        new_preds = tf.expand_dims(new_preds, axis=1)
        gen_seqs = tf.concat([gen_seqs, new_preds], axis=1)

        dec_input = new_preds
    
    if not train:
        return gen_seqs
        
    return gen_seqs, loss

In [21]:
test_seqs = []
for inpt, trgt in test.take(context['steps_per_epoch']):
    try:
        test_seqs.append(combined_predict(inpt, trgt, False))
    except InvalidArgumentError:
        pass

NameError: name 'InvalidArgumentError' is not defined

In [None]:
len(test)

In [18]:
context['target_tokenizer'].sequences_to_texts(test_seqs.numpy())

['<start> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>',
 '<start> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>',
 '<start> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>',
 '<start> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end>',
 '<start> <end> 

### Training Step

In [15]:
@tf.function
def train_step(inpt, trgt):
    with tf.GradientTape() as gtape, tf.GradientTape() as dtape:
        gen_seqs, loss = combined_predict(inpt, trgt)
        
        ## discriminator
        reference_results = discriminator(trgt, True)
        generated_results = discriminator(gen_seqs, True)
        
        # compute losses
        gen_loss = 0.5 * generator_loss_func(generated_results) + loss
        disc_loss = discriminator_loss_func(reference_results, generated_results)
    
    # gradients
    discriminator_gradients = dtape.gradient(disc_loss, discriminator.trainable_variables)
    generator_gradients = gtape.gradient(gen_loss, RNN.trainable_variables)

    
    # apply gradients
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, 
                                                discriminator.trainable_variables))
    generator_optimizer.apply_gradients(zip(generator_gradients, RNN.trainable_variables))
    
    return gen_loss, disc_loss

### Training Loop

In [16]:
EPOCHS = 50

for epoch in range(EPOCHS):
    start = datetime.now()

    generator_loss = 0
    discriminator_loss = 0
    

    # This resets the hidden state of the LSTM for every epoch
    init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]

    for inpt, trgt in train.take(context['steps_per_epoch']):
        batch_gen_loss, batch_disc_loss = train_step(inpt, trgt)
        generator_loss += batch_gen_loss
        discriminator_loss += batch_disc_loss
    
    epoch_print = 'Epoch {} | Generator Loss {:.4f} | Discriminator Loss {:.4f}'
    
    print(epoch_print.format(epoch + 1, generator_loss / BATCH_SIZE, 
                             discriminator_loss / BATCH_SIZE))

    print('Time taken {}\n'.format(datetime.now() - start))

Epoch 1 | Generator Loss 753.9606 | Discriminator Loss 0.8977
Time taken 0:10:54.582248

Epoch 2 | Generator Loss 679.5986 | Discriminator Loss 0.0210
Time taken 0:08:54.193252

Epoch 3 | Generator Loss 686.2482 | Discriminator Loss 0.0328
Time taken 0:08:53.780044

Epoch 4 | Generator Loss 691.3499 | Discriminator Loss 0.0121
Time taken 0:08:54.628292

Epoch 5 | Generator Loss 692.3023 | Discriminator Loss 0.0086
Time taken 0:08:54.766416



KeyboardInterrupt: 