# GAN for Supervised Formality Transfer
This was an exceptionally dumb attempt to use a GAN. I thought I 

In [1]:
import numpy as np
import tensorflow as tf

import re 
import os
from datetime import datetime

from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import TweetTokenizer

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

import workflow_manager as wm

## Load Data

In [2]:
EMBEDDING_DIM = 50
ENCODER_UNITS = 512
DECODER_UNITS = 512
ATTENTION_UNITS = 256
GENERATOR_UNITS = 1024
BATCH_SIZE = 32

### Seq2Seq Data

In [3]:
BASE_PATH = '../Data'
train, val, test, context = wm.load_and_tokenize(BASE_PATH)

---
quick sanity check

In [4]:
import pickle

with open('GAN Seq model weights/input_tokenizer.pickle', 'rb') as handle:
    temp = pickle.load(handle)

with open('GAN Seq model weights/target_tokenizer.pickle', 'rb') as handle:
    other_temp = pickle.load(handle)
    
assert context['input_tokenizer'].word_index == temp.word_index
assert context['target_tokenizer'].word_index == other_temp.word_index

---

## Get Models 

In [5]:
E_weights = wm.embedding_matrix(context['input_tokenizer'], 
                                context['input_vocab_size'], 
                                BASE_PATH)
DE_weights = wm.embedding_matrix(context['target_tokenizer'],
                                 context['target_vocab_size'],
                                 BASE_PATH)

In [6]:
encoder = wm.Encoder(context['input_vocab_size'], EMBEDDING_DIM,
                     ENCODER_UNITS, E_weights)
decoder = wm.Decoder(context['target_vocab_size'], EMBEDDING_DIM,
                     ATTENTION_UNITS, DECODER_UNITS)
generator = wm.Generator(GENERATOR_UNITS, context['input_vocab_size'], 
                         EMBEDDING_DIM)
discriminator = wm.Discriminator(context['target_vocab_size'], EMBEDDING_DIM, DE_weights)

In [7]:
encoder.load_weights('GAN Seq model weights/encoder/encoder')
decoder.load_weights('GAN Seq model weights/decoder/decoder')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f7408d80320>

## Define Optimizers and Loss Functions
### Discriminator
The loss function for the discriminator is calculated on how well it can discern informal and formal outputs
from the generator. 

In [8]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction='none')

In [9]:
def discriminator_loss_func(reference, generated):
    reference_loss = cross_entropy(tf.ones_like(reference) * 0.8, reference)
    generated_loss = cross_entropy(tf.zeros_like(generated) * 0.8, generated)
    return tf.reduce_mean(reference_loss + generated_loss)

### Generator
This loss only applies to how the BA net did in its efforts to trick the discriminator

Example of Generator

In [10]:
def generator_loss_func(generator_results):
    return tf.reduce_mean(cross_entropy(tf.ones_like(generator_results) * 0.8, generator_results))

### Optimizers

In [11]:
generator_optimizer = tf.keras.optimizers.Adam(1e-3)
discriminator_optimizer = tf.keras.optimizers.Adam(5e-4)

## Training Loop
This learns a sequence and then goes through usual GAN paradigm

In [12]:
BATCH_SIZE = 32

## Training Step

In [13]:
@tf.function
def train_step(inpt, trgt):
    # This resets the hidden state of the LSTM for every epoch
    init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]

    ## Get outputs
    gen_input = gen_input = tf.round(
        tf.random.uniform(
            [BATCH_SIZE, inpt.shape[1]], 
            minval=1,
            maxval=len(context['input_tokenizer'].word_index)
        )
    )
    
    
    with tf.GradientTape() as gtape, tf.GradientTape() as dtape:
        # get outputs
        gen_output = generator(gen_input)
        enc_output, _, _ = encoder(inpt, init_state)
        
        ## test discriminator
        reference_results = discriminator(enc_output, True)
        generated_results = discriminator(gen_output, True)
        
        ## add some noise
        reference_results = reference_results  #+ tf.random.normal([BATCH_SIZE, 1])
        generated_results = generated_results  # + tf.random.normal([BATCH_SIZE, 1])
        

        # compute losses
        gen_loss = generator_loss_func(generated_results)
        disc_loss = discriminator_loss_func(reference_results, generated_results)
    
    # gradients
    discriminator_gradients = dtape.gradient(disc_loss, discriminator.trainable_variables)
    generator_gradients = gtape.gradient(gen_loss, generator.trainable_variables)

    
    # apply gradients
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, 
                                                discriminator.trainable_variables))
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    
    return gen_loss, disc_loss

### Training Loop

In [14]:
EPOCHS = 5

for epoch in range(EPOCHS):
    start = datetime.now()

    generator_loss = 0
    discriminator_loss = 0
    

    # This resets the hidden state of the LSTM for every epoch
    init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]

    for inpt, trgt in train.take(context['steps_per_epoch']):
        batch_gen_loss, batch_disc_loss = train_step(inpt, trgt)
        generator_loss += batch_gen_loss
        discriminator_loss += batch_disc_loss
    
    epoch_print = 'Epoch {} | Generator Loss {:.4f} | Discriminator Loss {:.4f}'
    
    print(epoch_print.format(epoch + 1, tf.reduce_mean(generator_loss).numpy() / BATCH_SIZE, 
                             tf.reduce_mean(discriminator_loss).numpy() / BATCH_SIZE))
    
    print('Time taken {}\n'.format(datetime.now() - start))

Epoch 1 | Generator Loss 243.5871 | Discriminator Loss 14.7101
Time taken 0:05:59.076012

Epoch 2 | Generator Loss 221.3409 | Discriminator Loss 13.0750
Time taken 0:05:40.775138

Epoch 3 | Generator Loss 257.6608 | Discriminator Loss 13.5526
Time taken 0:05:40.813919

Epoch 4 | Generator Loss 311.0698 | Discriminator Loss 12.9440
Time taken 0:05:40.781125

Epoch 5 | Generator Loss 330.0643 | Discriminator Loss 12.5536
Time taken 0:05:40.723267



In [15]:
def predict(inpt, trgt, train=True):
    loss = 0
    target_tokenizer = context['target_tokenizer']
    
    # initialize seqs tensor
    gen_seqs = tf.constant([target_tokenizer.word_index['<start>']] * BATCH_SIZE, dtype=tf.int64)
    gen_seqs = tf.expand_dims(gen_seqs, axis=1)
    
    # This resets the hidden state of the LSTM for every epoch
    init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]
    
    gen_input = gen_input = tf.round(
        tf.random.uniform(
            [BATCH_SIZE, inpt.shape[1]], 
            minval=1,
            maxval=len(context['input_tokenizer'].word_index)
        )
    )

    ## Generate Sequences
    _, h_f, h_b = encoder(inpt, init_state)
    enc_output = generator(gen_input)

    # Get start token for every sequence in batch
    dec_input = tf.expand_dims([target_tokenizer.word_index['<start>']] * BATCH_SIZE, 1)

    for i in range(1, trgt.shape[1]):
        # dec_hidden shape: (batch_size, decoder_units)
        # dec_input shape: (batch_size, 1)
        predictions, h_f = decoder(dec_input, h_b, h_f, enc_output)
        
        # Need to hold onto seqs for discriminator
        new_preds = tf.argmax(predictions, axis=1)
        new_preds = tf.expand_dims(new_preds, axis=1)
        gen_seqs = tf.concat([gen_seqs, new_preds], axis=1)
        

    if not train:
        translated = target_tokenizer.sequences_to_texts(gen_seqs.numpy())
        return translated
        
    return gen_seqs, loss

In [16]:
x, y = next(iter(train))

In [17]:
temp = predict(x, y, False)

In [22]:
temp[:5]

['<start> everyone like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb',
 '<start> everyone like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb',
 '<start> <OOV> like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb',
 '<start> <OOV> like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb',
 '<start> <OOV> like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb d

In [23]:
temp[-5:]

['<start> <OOV> like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb',
 '<start> <OOV> like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb',
 '<start> <OOV> like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb',
 '<start> <OOV> like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb',
 '<start> <OOV> like everyone dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb dumb du

We can see this GAN has suffered from mode collapse.