In [1]:
from data import Data
from model import SummaryModel
import argparse

import tensorflow as tf
tf.compat.v1.logging.set_verbosity('ERROR')

parser = argparse.ArgumentParser(description = 'Train/Test summarization model', formatter_class = argparse.ArgumentDefaultsHelpFormatter)

# Import Setting
parser.add_argument("--doc_file", type = str, default = './data/doc.p', help = 'path to document file')
parser.add_argument("--vocab_file", type = str, default = './data/vocab.p', help = 'path to vocabulary file')
parser.add_argument("--emb_file", type = str, default = './data/emb.p', help = 'path to embedding file')
parser.add_argument("--src_time", type = int, default = 1000, help = 'maximal # of time steps in source text')
parser.add_argument("--sum_time", type = int, default = 120, help = 'maximal # of time steps in summary')
parser.add_argument("--max_oov_bucket", type = int, default = 280, help = 'maximal # of out-of-vocabulary word in one summary')
parser.add_argument("--train_ratio", type = float, default = 0.9, help = 'ratio of training data')
parser.add_argument("--seed", type = int, default = 888, help = 'seed for spliting data')

# Saving Setting
parser.add_argument("--log", type = str, default = './log/', help = 'logging directory')
parser.add_argument("--save", type = str, default = './model/', help = 'model saving directory')
parser.add_argument("--load_pretrain", type = bool, default = False, help = 'whether load from old version pre-trained model')
parser.add_argument("--checkpoint", type = str, help = 'path to checkpoint point')
parser.add_argument("--autosearch", type = bool, default = False, help = "[NOT AVAILABLE] Set 'True' if searching for latest checkpoint")
parser.add_argument("--save_interval", type = int, default = 2000, help = "Save interval for training")

# Hyperparameter Setting
parser.add_argument("--batch_size", type = int, default = 16, help = 'number of samples in one batch')
parser.add_argument("--gen_lr", type = float, default = 1e-3, help = 'learning rate for generator')
parser.add_argument("--dis_lr", type = float, default = 1e-3, help = 'learning rate for discriminator')
parser.add_argument("--cov_weight", type = float, default = 1., help = 'weight for coverage loss');

params = vars(parser.parse_args([]))

# params['load_pretrain'] = True
# params['checkpoint'] = './model/pointer_cov_supervised-1250' # Uncomment when requiring reloading model

model = SummaryModel(**params)
data = Data(**params)


  self.enc_fw_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='encoder_forward_cell')
  self.enc_bw_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='encoder_backward_cell')
  self.dec_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, state_is_tuple=False, name='decoder_cell')
  self.bas_enc_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='bas_enc_unit')
  self.bas_dec_unit = tf.compat.v1.nn.rnn_cell.LSTMCell(self.num_unit, name='bas_dec_unit')
2022-12-08 22:55:49.709902: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
2022-12-08 22:55:49.835433: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


### Train Model
#### Train without coverage

In [None]:
train_max_epoch = 10
print (f'Start from step {model.sess.run(model.gen_global_step)}')
for i in range(train_max_epoch):
    print (f'Train Epoch {i}')
    train_data = data.get_next_epoch()
    model.train_one_epoch(train_data, data.n_train_batch, coverage_on = False)

Start from step 0
Train Epoch 0


  0%|          | 0/1983 [00:00<?, ?it/s]

#### Train with coverage

In [None]:
train_max_epoch = 2
print (f'Start from step {model.sess.run(model.gen_global_step)}')
for i in range(train_max_epoch):
    print (f'Train Epoch {i}')
    train_data = data.get_next_epoch()
    model.train_one_epoch(train_data, data.n_train_batch, coverage_on = True, model_name = 'with_coverage')

#### Train with GAN (Pretrain Discriminator)

In [None]:
train_max_epoch = 2
print (f'Start from step {model.sess.run(model.gen_global_step_2)}')
for i in range(train_max_epoch):
    print (f'Train Epoch {i}')
    train_data = data.get_next_epoch()
    model.train_one_epoch_pre_dis(train_data, data.n_train_batch, coverage_on = True)

#### Train with GAN (Adversarial Training)

In [None]:
train_max_epoch = 12
print (f'Start from step {model.sess.run(model.gen_global_step_2)}')
for i in range(train_max_epoch):
    print (f'Train Epoch {i}')
    train_data = data.get_next_epoch()
    model.train_one_epoch_unsup(train_data, data.n_train_batch, coverage_on = True)

In [None]:
model.sess.run(model.real_reward, feed_dict = feed_dict)

In [None]:
model._save_model(4, 'cov_after_4_epoch')

In [None]:
from rouge import Rouge
rouge = Rouge()

