# Adversarial Training on Pointer Generator
## Introduction
    The beginning of the introduction. 

## Table Of Contents:
* [Load Data & Initialize Model](#load-initialize)
* [Example Generation](#example-generation)
* [Train Pointer Generator](#train-global-1)
    * [Without Coverage](#train-global-1-sub-1)
        * [Generate Tokens](#gen-global-1-sub-1)
        * [Rouge Evaluation](#rouge-global-1-sub-1)
    * [With Coverage](#train-global-1-sub-2)
        * [Generate Tokens](#gen-global-1-sub-2)
        * [Rouge Evaluation](#rouge-global-1-sub-2)
* [Train Generative Adversarial Network](#train-global-2)
    * [Pretrain Discriminator](#train-global-2-sub-1)
        * [Generate Tokens](#gen-global-2-sub-1)
        * [Rouge Evaluation](#rouge-global-2-sub-1)
    * [Adversarial Training](#train-global-2-sub-2)
        * [Generate Tokens](#gen-global-2-sub-2)
        * [Rouge Evaluation](#rouge-global-2-sub-2)
* [Analysis & Conclusion](#analysis-conclusion)
* [Limitations & Future Work](#limit-future)


## Load Data & Initialize Model <a class="anchor" id="load-initialize"></a>

In [1]:
import numpy as np
from data import Data
from model import SummaryModel
import argparse

import tensorflow as tf

tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity('ERROR')

parser = argparse.ArgumentParser(description = 'Train/Test summarization model', formatter_class = argparse.ArgumentDefaultsHelpFormatter)

# Import Setting
parser.add_argument("--doc_file", type = str, default = './data/doc.p', help = 'path to document file')
parser.add_argument("--vocab_file", type = str, default = './data/vocab.p', help = 'path to vocabulary file')
parser.add_argument("--emb_file", type = str, default = './data/emb.p', help = 'path to embedding file')
parser.add_argument("--src_time", type = int, default = 200, help = 'maximal # of time steps in source text')
parser.add_argument("--sum_time", type = int, default = 50, help = 'maximal # of time steps in summary')
parser.add_argument("--max_oov_bucket", type = int, default = 280, help = 'maximal # of out-of-vocabulary word in one summary')
parser.add_argument("--train_ratio", type = float, default = 0.8, help = 'ratio of training data')
parser.add_argument("--seed", type = int, default = 888, help = 'seed for spliting data')

# Saving Setting
parser.add_argument("--log", type = str, default = './log/', help = 'logging directory')
parser.add_argument("--save", type = str, default = './model/', help = 'model saving directory')
parser.add_argument("--checkpoint", type = str, help = 'path to checkpoint point')
parser.add_argument("--autosearch", type = bool, default = False, help = "[NOT AVAILABLE] Set 'True' if searching for latest checkpoint")
parser.add_argument("--save_interval", type = int, default = 1900, help = "Save interval for training")

# Hyperparameter Setting
parser.add_argument("--batch_size", type = int, default = 16, help = 'number of samples in one batch')
parser.add_argument("--gen_lr", type = float, default = 1e-3, help = 'learning rate for generator')
parser.add_argument("--dis_lr", type = float, default = 1e-3, help = 'learning rate for discriminator')
parser.add_argument("--cov_weight", type = float, default = 1e-3, help = 'learning rate for coverage')

params = vars(parser.parse_args([]))

params['load_pretrain'] = True
# 1900 no coverage
# 
params['checkpoint'] = './model/pointer_cov_supervised-1900' # Uncomment when requiring reloading model

model = SummaryModel(**params)
data = Data(**params)


  self.enc_fw_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='encoder_forward_cell')
  self.enc_bw_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='encoder_backward_cell')
  self.dec_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, state_is_tuple=False, name='decoder_cell')
  self.dis_enc_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='dis_enc_unit')
  self.dis_dec_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='dis_dec_unit')
  self.bas_enc_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='bas_enc_unit')
  self.bas_dec_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='bas_dec_unit')


Restore Model from ./model/pointer_cov_supervised-1900


## Example Generation <a class="anchor" id="example-generation"></a>

In [2]:
train_data = data.get_next_epoch()
test_data = data.get_next_epoch_test()
src, ref, gen, tokens, scores, attens, gt_attens = None, None, None, None, None, None, None
for feed_dict in train_data:
    real, fake, real_len, fake_len = model.sess.run(
        [model.real_reward, model.fake_reward, model.sum_len, model.tokens_len], feed_dict=feed_dict)
    print(np.mean(real[1, 0:int(real_len[1])]))
    print(np.mean(fake[1, 0:int(fake_len[1])]))
    break

for feed_dict in test_data:
    tokens, scores, attens = model.beam_search(feed_dict)
    src, ref, gen = data.id2word(feed_dict, tokens)
    gt_attens = model.sess.run(model.atten_dist, feed_dict = feed_dict)
#     print(src, ref, gen, gt_attens)
    x = 0
    print ("".join(src[x]).replace("(OOV)",""), end = '\n\n')
    print ("".join(ref[x]).replace("(OOV)",""), end = '\n\n')
    print (gen)
#     for i in range(len(src)):
#         print ("".join(gen[x][i]))
    break

0.48081478
0.47486126
daw aung san suu kyi appeared video link yangon tell audience hong kong would seek broaden audience country since released house arrest<END>

Dissident Plans a More Active Role in Myanmar<END>

[['A', ' (OOV)', 'S', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'e', 'a', 'd', 'e', 'r', ' (OOV)', 'i', 'n', ' (OOV)', 'A', 'f', 'g', 'h', 'a', 'n', 'i', 's', 't', 'e', 'r'], ['A', ' (OOV)', 'S', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd'], ['A', ' (OOV)', 'S', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'R', 'a', 's', 'e', ' (OOV)', 'a', 'n', 'd', ' (OOV)', 'T'], ['A', ' (OOV)', '

## Train Pointer Generator<a class="anchor" id="train-global-1"></a>
### Train without coverage<a class="anchor" id="train-global-1-sub-1"></a>

In [3]:
train_max_epoch = 1
print (f'Start from step {model.sess.run(model.gen_global_step)}')
for i in range(train_max_epoch):
    print (f'Train Epoch {i}')
    train_data = data.get_next_epoch()
    model.train_one_epoch(train_data, data.n_train_batch, coverage_on = False)

Start from step 4
Train Epoch 0


  0%|          | 0/1908 [00:00<?, ?it/s]

#### Generate tokens <a class="anchor" id="gen-global-1-sub-1"></a>

In [93]:
def generate_top_k_tokens(coverage, top_k=1):
    test_data = data.get_next_epoch_test()
    src = [[] for i in range(top_k)]
    ref = [[] for i in range(top_k)]
    gen = [[] for i in range(top_k)]
    for feed_dict in test_data:
        tokens, scores, attens = model.beam_search(feed_dict, coverage_on = coverage, top_k = top_k)
        if top_k == 1:
            tokens = [tokens]
            scores = [scores]
        for i in range(top_k):
            src[i], ref[i], gen[i] = data.id2word(feed_dict, tokens[i])
#         feed_dict['coverage_on:0'] = coverage
#         gt_attens = model.sess.run(model.atten_dist, feed_dict = feed_dict)
        break
    return src, ref, gen, scores

def clean_generated_tokens(src, ref, gen):
    src_str_list = ["".join(src[0][i]).replace("(OOV)", "") for i in range(len(src[0]))]
    ref_str_list = ["".join(ref[0][i]).replace("(OOV)", "") for i in range(len(ref[0]))]
    gen_str_list = []
    for i in range(len(test1_gen)):
        gen_str_list.append(["".join(gen[i][j]).replace("(OOV)","") for j in range(len(gen[i]))])
    return src_str_list, ref_str_list, gen_str_list

def print_cleaned_tokens(src, ref, gen, scores):
    for i in range(len(gen)):
        print ("Batch: " + str(i))
        for j in range(len(gen[i])):
            print("\tGeneration: " + str(j))
            print ("\t\tAbstract: " + src[j])
            print ("\t\tTitle: "+ ref[j])
            print("\t\tGenerated: " + gen[i][j])
            print("\t\tScore: " + str(scores[i][j]))

In [74]:
test1_src, test1_ref, test1_gen, test1_scores = generate_top_k_tokens(coverage=False, top_k=1)
test1_src_list, test1_ref_list, test1_gen_list = clean_generated_tokens(test1_src, test1_ref, test1_gen)
print_cleaned_tokens(test1_src_list, test1_ref_list, test1_gen_list, test1_scores)

Batch: 0
	Generation: 0
		Abstract: nuclear powered american submarine collided navy warship early friday strait hormuz narrow passage much world oil must pas<END>
		Title: 2 Navy Vessels Collide in Strait of Hormuz<END>
		Generated: A Sase and Rase and Rase and Rase and Rase of 
		Score: -1.1149016847001745
	Generation: 1
		Abstract: barack obama said israel interest find peace palestinian said israel able defend<END>
		Title: 2 Navy Vessels Collide in Strait of Hormuz<END>
		Generated: A Sase and Rase and Rase and Rase and Rase and T
		Score: -1.0473748031927614
	Generation: 2
		Abstract: ruling african national congress steamrolling big election victory third ballot counted<END>
		Title: 2 Navy Vessels Collide in Strait of Hormuz<END>
		Generated: A Sase and Rase and Rase and Rase and Rase and
		Score: -1.11961932892495
	Generation: 3
		Abstract: ultra orthodox israeli protest plan include community military draft arguing study torah important defending israel carrying weapon army<E

#### Rouge Evaluation<a class="anchor" id="rouge-global-1-sub-1"></a>

In [79]:
from rouge import Rouge
from tqdm import tqdm_notebook
rouge = Rouge()

def generate_evaluation_sets(coverage):
    refs = []
    gens = []
    cnt = 0
    test_n = min(data.n_test_batch, 50)
    test_data = data.get_next_epoch_test()
    for feed_dict in tqdm_notebook(test_data, total = test_n):
        tokens, scores, attens = model.beam_search(feed_dict, coverage_on = coverage)
        # sample_tokens = model.sess.run(model.tokens, feed_dict = feed_dict)
        src, ref, gen = data.id2word(feed_dict, tokens)
        for i in range(len(ref)):
            refs.append(" ".join(ref[i][:-1]))
            gens.append(" ".join(gen[i][:-1]))
        cnt += 1
        if cnt > test_n:
            break
    new_gens = []
    new_refs = []
    for i in range(len(gens)):
        if not (gens[i] == ""):
            new_gens.append(gens[i])
            new_refs.append(refs[i])
    return new_refs, new_gens


def rouge_evaluation(ref, gen):
    rouge_score = rouge.get_scores(gen, ref)
    r1, r2, rl = 0., 0., 0.
    for score in rouge_score:
        r1 = r1 + score['rouge-1']['f']
        r2 = r2 + score['rouge-2']['f']
        rl = rl + score['rouge-l']['f']
    r1 /= len(rouge_score)
    r2 /= len(rouge_score)
    rl /= len(rouge_score)
    print (r1, r2, rl)
    return r1, r2, rl

In [80]:
test1_eval_refs, test1_eval_gens = generate_evaluation_sets(coverage=False)
test1_r1, test1_r2, test1_rl = rouge_evaluation(test1_eval_refs, test1_eval_gens)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for feed_dict in tqdm_notebook(test_data, total = test_n):


  0%|          | 0/50 [00:00<?, ?it/s]

0.5283459926872293 0.14496090360978436 0.4170432907200313


### Train with coverage<a class="anchor" id="train-global-1-sub-2"></a>

In [82]:
train_max_epoch = 1
print (f'Start from step {model.sess.run(model.gen_global_step)}')
for i in range(train_max_epoch):
    print (f'Train Epoch {i}')
    train_data = data.get_next_epoch()
    model.train_one_epoch(train_data, data.n_train_batch, coverage_on = True, model_name = 'with_coverage')

Start from step 1899
Train Epoch 0


  0%|          | 0/1908 [00:00<?, ?it/s]

#### Generate tokens<a class="anchor" id="gen-global-1-sub-2"></a>

In [85]:
test2_src, test2_ref, test2_gen, test2_scores = generate_top_k_tokens(coverage=True, top_k=1)
test2_src_list, test2_ref_list, test2_gen_list = clean_generated_tokens(test2_src, test2_ref, test2_gen)
print_cleaned_tokens(test2_src_list, test2_ref_list, test2_gen_list, test2_scores)

Batch: 0
	Generation: 0
		Abstract: republican senator voted border security deal come state large house republican delegation<END>
		Title: Prospects in the House<END>
		Generated: Afghanistan:
		Score: -0.8208288046029898
	Generation: 1
		Abstract: danish celebrity chef claus meyer opening restaurant la paz hoping start bolivian food movement rediscover local ingredient<END>
		Title: Prospects in the House<END>
		Generated: Afghanistan:
		Score: -0.8298874634962815
	Generation: 2
		Abstract: see latest chart map coronavirus case death hospitalization harris county georgia<END>
		Title: Prospects in the House<END>
		Generated: Afabama Povid Case and Risk Trackers of
		Score: -0.5524458885192871
	Generation: 3
		Abstract: detention journalist unspecified security offense first time hamas arrested foreigner since took gaza 2007<END>
		Title: Prospects in the House<END>
		Generated: Aftermath on the
		Score: -0.7901968114516315
	Generation: 4
		Abstract: elderly farmer demonstration new 

#### Rouge Evaluation<a class="anchor" id="rouge-global-1-sub-2"></a>

In [86]:
test2_eval_refs, test2_eval_gens = generate_evaluation_sets(coverage=True)
test2_r1, test2_r2, test2_rl = rouge_evaluation(test2_eval_refs, test2_eval_gens)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for feed_dict in tqdm_notebook(test_data, total = test_n):


  0%|          | 0/50 [00:00<?, ?it/s]

0.49432528823714433 0.1006288352997728 0.3716033682313441


## Train GAN<a class="anchor" id="train-global-2"></a>
### Pretrain Discriminator<a class="anchor" id="train-global-2-sub-1"></a>

In [87]:
train_max_epoch = 1
print (f'Start from step {model.sess.run(model.gen_global_step_2)}')
for i in range(train_max_epoch):
    print (f'Train Epoch {i}')
    train_data = data.get_next_epoch()
    model.train_one_epoch_pre_dis(train_data, data.n_train_batch, coverage_on = True)

Start from step 0
Train Epoch 0


  0%|          | 0/1908 [00:00<?, ?it/s]

#### Generate tokens<a class="anchor" id="gen-global-2-sub-1"></a>

In [88]:
test3_src, test3_ref, test3_gen, test3_scores = generate_top_k_tokens(coverage=True, top_k=1)
test3_src_list, test3_ref_list, test3_gen_list = clean_generated_tokens(test3_src, test3_ref, test3_gen)
print_cleaned_tokens(test3_src_list, test3_ref_list, test3_gen_list, test3_scores)

Batch: 0
	Generation: 0
		Abstract: many greek taking hard look country day seeing mismanagement corruption<END>
		Title: Digging Deep and Seeing Greece’s Flaws<END>
		Generated: Aftermath on the
		Score: -0.8156286127427045
	Generation: 1
		Abstract: gen sarath fonseka led military campaign ended two decade civil war ran unsuccessfully president freed nearly two year prison<END>
		Title: Digging Deep and Seeing Greece’s Flaws<END>
		Generated: Aftermath of ther
		Score: -0.8385739856296115
	Generation: 2
		Abstract: see latest chart map coronavirus case death hospitalization ames area<END>
		Title: Digging Deep and Seeing Greece’s Flaws<END>
		Generated: Afabama Covid Case and Risk Trackers
		Score: -0.5203059428447002
	Generation: 3
		Abstract: everyone want democracy sympathizes protest crashing across middle east<END>
		Title: Digging Deep and Seeing Greece’s Flaws<END>
		Generated: Aftermath on the
		Score: -0.7647354462567497
	Generation: 4
		Abstract: president trump prime minis

#### Rouge Evaluation<a class="anchor" id="rouge-global-2-sub-1"></a>

In [89]:
test3_eval_refs, test3_eval_gens = generate_evaluation_sets(coverage=True)
test3_r1, test3_r2, test3_rl = rouge_evaluation(test3_eval_refs, test3_eval_gens)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for feed_dict in tqdm_notebook(test_data, total = test_n):


  0%|          | 0/50 [00:00<?, ?it/s]

0.49311720341328225 0.09652547784773684 0.3691956659410668


### Adversarial Training<a class="anchor" id="train-global-2-sub-2"></a>

In [91]:
train_max_epoch = 1                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
print (f'Start from step {model.sess.run(model.gen_global_step_2)}')
for i in range(train_max_epoch):
    print (f'Train Epoch {i}')                                                
    train_data = data.get_next_epoch()                                        
    model.train_one_epoch_unsup(train_data, data.n_train_batch, coverage_on = True)

Start from step 14
Train Epoch 0


  0%|          | 0/1908 [00:00<?, ?it/s]

Restart Done!
Restart Done!
Restart Done!
Restart Done!
Restart Done!
Restart Done!
Restart Done!
Restart Done!


#### Generate Tokens<a class="anchor" id="gen-global-2-sub-2"></a>

In [94]:
test4_src, test4_ref, test4_gen, test4_scores = generate_top_k_tokens(coverage=True, top_k=1)
test4_src_list, test4_ref_list, test4_gen_list = clean_generated_tokens(test4_src, test4_ref, test4_gen)
print_cleaned_tokens(test4_src_list, test4_ref_list, test4_gen_list, test4_scores)

Batch: 0
	Generation: 0
		Abstract: call foreigner pause fighting talk rebel increasingly angered government colombo<END>
		Title: Sri Lanka Bars Swede Over Stand on War<END>
		Generated: Tndia Sattler th the
		Score: -0.9276477268763951
	Generation: 1
		Generated: Attacks on the
		Score: -0.8075149536132813
	Generation: 2
		Abstract: white house friday released 2015 joint tax return president michelle obama document suggests planning future beyond mr obama white house tenure<END>
		Title: The Obamas’ 2015 Tax Return<END>
		Generated: Attacks on the
		Score: -0.7968966801961263
	Generation: 3
		Abstract: tahrir square birthplace uprising place many felt like dream egypt like<END>
		Title: Birthplace of Uprising Welcomes Its Success<END>
		Generated: Attacks on the
		Score: -0.7942705790201823
	Generation: 4
		Abstract: child terrified noise finding food challenge rarely power many people yemen beyond dream end fighting<END>
		Title: The Many Miseries of Yemeni Families<END>
		Generated

#### Rouge Evaluation<a class="anchor" id="rouge-global-2-sub-2"></a>

In [95]:
test4_eval_refs, test4_eval_gens = generate_evaluation_sets(coverage=True)
test4_r1, test4_r2, test4_rl = rouge_evaluation(test4_eval_refs, test4_eval_gens)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for feed_dict in tqdm_notebook(test_data, total = test_n):


  0%|          | 0/50 [00:00<?, ?it/s]

0.5139641795130951 0.137870443942698 0.3837935762467773


In [96]:
test_data = data.get_next_epoch_test()
src, ref, gen, tokens, scores, attens, gt_attens = None, None, None, None, None, None, None
for feed_dict in train_data:
    real, fake, real_len, fake_len = model.sess.run(
        [model.real_reward, model.fake_reward, model.sum_len, model.tokens_len], feed_dict=feed_dict)
    print(np.mean(real[1, 0:int(real_len[1])]))
    print(np.mean(fake[1, 0:int(fake_len[1])]))
    break

## Analysis & Conclusion<a class="anchor" id="analysis-conclusion"></a>



You should consider upgrading via the 'A:\SHANE_STUFF\Text-Generation-GAN\venv\Scripts\python.exe -m pip install --upgrade pip' command.





## Limitations & Future Work<a class="anchor" id="limit-future"></a>