# GAN for Supervised Formality Transfer

In [1]:
import numpy as np
import tensorflow as tf

import re 
import os
from datetime import datetime

from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import TweetTokenizer

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

import workflow_manager as wm

## Load Data

In [2]:
EMBEDDING_DIM = 200
ENCODER_UNITS = 1024
DECODER_UNITS = 1024
ATTENTION_UNITS = 512

### Seq2Seq Data

In [3]:
BASE_PATH = '../Data'
train, val, test, context = wm.load_and_tokenize(BASE_PATH)

E_weights = wm.embedding_matrix(context['input_tokenizer'], 
                                context['input_vocab_size'], 
                                BASE_PATH)
DE_weights = wm.embedding_matrix(context['target_tokenizer'],
                                 context['target_vocab_size'],
                                 BASE_PATH)

## Get Models 

In [4]:
encoder = wm.Encoder(context['input_vocab_size'], EMBEDDING_DIM,
                     ENCODER_UNITS, E_weights)
decoder = wm.Decoder(context['target_vocab_size'], EMBEDDING_DIM,
                     ATTENTION_UNITS, DECODER_UNITS)

discriminator = wm.Discriminator(context['target_vocab_size'], EMBEDDING_DIM, DE_weights)

## Define Optimizers and Loss Functions
### Discriminator
The loss function for the discriminator is calculated on how well it can discern informal and formal outputs
from the generator. 

In [5]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [6]:
def discriminator_loss(reference, generated):
    reference_loss = cross_entropy(tf.ones_like(reference), reference)
    generated_loss = cross_entropy(tf.zeros_like(generated), generated)
    return reference_loss + generated_loss

### Generator
This loss only applies to how the BA net did in its efforts to trick the discriminator

In [7]:
def generator_loss(discriminator_results):
    return crossentropy(tf.ones_like(discriminator_results), discriminator_results)

### Optimizers

In [8]:
generator_optimizer = tf.keras.optimizers.Adam()
discriminator_optimizer = tf.keras.optimizers.Adam()

## Training Loop
This learns a sequence and then goes through usual GAN paradigm

In [9]:
BATCH_SIZE = 64

In [10]:
tf.constant([1,2,3]).numpy()

array([1, 2, 3], dtype=int32)

### Predict Funciton

In [11]:
def predict(inpt, trgt, train=True):
    target_tokenizer = context['target_tokenizer']
    
    # initialize seqs tensor
    gen_seqs = tf.expand_dims([target_tokenizer.word_index['<start>']] * BATCH_SIZE, 1)
    gen_seqs = gen_seqs.numpy()
    
    # This resets the hidden state of the LSTM for every epoch
    init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]
    
    ## Generate Sequences
    enc_output, dec_hidden_forward, dec_hidden_backward = encoder(inpt, init_state)

    # Get start token for every sequence in batch
    dec_input = tf.expand_dims([target_tokenizer.word_index['<start>']] * BATCH_SIZE, 1)

    for i in range(1, trgt.shape[1]):
        # dec_hidden shape: (batch_size, decoder_units)
        # dec_input shape: (batch_size, 1)
        predictions, dec_hidden_forward, _ = decoder(dec_input, 
                                                     dec_hidden_forward, 
                                                     dec_hidden_backward, 
                                                     enc_output)

        # Need to hold onto seqs for discriminator
        new_preds = tf.argmax(predictions, axis=1).numpy()
        gen_seqs = np.column_stack((gen_seqs, new_preds))
    
    if not train:
        reference_results = discriminator(trgt, True)
        generated_results = discriminator(seqs, True)
        
        # compute losses
        gen_loss = generator_loss(generated_results)
        disc_loss = discriminator_loss(reference_results, generated_results)
        
        return gen_loss, disc_loss
        
    return gen_seqs

### Training Step

In [12]:
def train_step(inpt, trgt):
    with tf.GradientTape() as gtape, tf.GradientTape() as dtape:
        seqs = predict(inpt, trgt)
        ## discriminator
        reference_results = discriminator(trgt, True)
        generated_results = discriminator(seqs, True)
        
        # compute losses
        gen_loss = generator_loss(generated_results)
        disc_loss = discriminator_loss(reference_results, generated_results)
    
    # gradients
    generator_variables = encoder.trainable_variables + decoder.trainable_variables
    generator_gradients = gtape.gradient(gen_loss, generator_variables)
    discriminator_gradients = dtape.gradient(disc_loss, discriminator.trainable_variables)
    
    # apply gradients
    generator_optimizer.apply_gradients(zip(generator_gradients, generator_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, 
                                                discriminator.trainable_variables))
    
    return gen_loss, disc_loss

### Training Loop

In [13]:
EPOCHS = 20

for epoch in range(EPOCHS):
    start = datetime.now()

    generator_loss = 0
    discriminator_loss = 0
    

    # This resets the hidden state of the LSTM for every epoch
    init_state = [tf.zeros((BATCH_SIZE, ENCODER_UNITS)) for _ in range(4)]

    for inpt, trgt in train.take(context['steps_per_epoch']):
        batch_gen_loss, batch_disc_loss = train_step(inpt, trgt)
        generator_loss += batch_gen_loss
        discriminator_loss += batch_disc_loss
    
    if_val, trgt_val = next(iter(val))
    val_gen_loss, val_disc_loss = predict(if_val, trgt_val, False)
    
    epoch_print = 'Epoch {} | Generator Loss {:.4f} | Discriminator Loss {:.4f}'
    epoch_val_print = 'Epoch {} | Generator Val Loss {:.4f} | Discriminator Val Loss {:.4f}'
    
    print(epoch_print.format(epoch + 1, generator_loss / BATCH_SIZE, 
                             discriminator_loss / BATCH_SIZE))
    
    print(epoch_val_print.format(epoch + 1, val_gen_loss, val_disc_loss))
    
    print('Time taken {}\n'.format(datetime.now() - start))

InternalError: Failed to call ThenRnnForward with model config: [rnn_mode, rnn_input_mode, rnn_direction_mode]: 2, 0, 0 , [num_layers, input_size, num_units, dir_count, max_seq_length, batch_size, cell_num_units]: [1, 2248, 1024, 1, 1, 64, 1024]  [Op:CudnnRNN]