In [1]:
default_path = " "

#-------------------------------save/load--------------------------------------#
pickle_path = default_path + "pickles/"
log_file = default_path + "logs.txt"
csv_file = default_path + "logs.csv"

log_file_handler = open(log_file,"a")
csv_file_handler = open(csv_file,"a")

file1 = open(log_file , "w+")
file2 = open(csv_file , "w+")

import pickle

def save(obj , filename):
  print("saving {} ..".format(filename))
  with open(filename, 'wb') as handle:
      pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
      
def load(filename):
  print("loading {} ..".format(filename))
  with open(filename, 'rb') as handle:
    return pickle.load(handle)
#-----------------------------------------------------------------------------------#  

In [17]:
import glob
import random
import struct
import csv
from tensorflow.core.example import example_pb2
import operator
import numpy as np
np.random.seed(123)
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'

PAD_TOKEN = '[PAD]'
UNKNOWN_TOKEN = '[UNK]' 
START_DECODING = '[START]' 
STOP_DECODING = '[STOP]' 

class Vocab(object):
  
  def __init__(self, vocab_file, max_size):
    self._word_to_id = {}
    self._id_to_word = {}
    self._count = 0 
    for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
      self._word_to_id[w.lower()] = self._count
      self._id_to_word[self._count] = w.lower()
      self._count += 1

    with open(vocab_file, 'r',encoding='utf-8') as vocab_f:
      for line in vocab_f:
        pieces = line.split()
        if len(pieces) != 2:
          print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line)
          continue
        w = pieces[0].lower()
        if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
          raise Exception('<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
        if w in self._word_to_id:
          print("Duplicate:",w)
          continue
          raise Exception('Duplicated word in vocabulary file: %s' % w)
        self._word_to_id[w] = self._count
        self._id_to_word[self._count] = w
        self._count += 1
        if max_size != 0 and self._count >= max_size:
          print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
          break

    print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]))

  def word2id(self, word):
    if word not in self._word_to_id:
      return self._word_to_id[UNKNOWN_TOKEN.lower()]
    return self._word_to_id[word.lower()]

  def id2word(self, word_id):
    if word_id not in self._id_to_word:
      raise ValueError('Id not found in vocab: %d' % word_id)
    return self._id_to_word[word_id]

  def size(self):
        return self._count

  def write_metadata(self, fpath):
    
    print("Writing word embedding metadata file to %s..." % (fpath))
    with open(fpath, "w") as f:
      fieldnames = ['word']
      writer = csv.DictWriter(f,delimiter=str(u"\t").encode('utf-8') , fieldnames=fieldnames)
      for i in range(self.size()):
        writer.writerow({"word": self._id_to_word[i]})

  def LoadWordEmbedding(self, w2v_file, word_dim):
    self.wordDict = {}
    self.word_dim = word_dim

    self.wordDict[UNKNOWN_TOKEN] = np.zeros(self.word_dim,dtype=np.float32)
    self.wordDict[PAD_TOKEN] = np.random.uniform(-1,1,self.word_dim)
    self.wordDict[START_DECODING] = np.random.uniform(-1,1,self.word_dim)
    self.wordDict[STOP_DECODING] = np.random.uniform(-1,1,self.word_dim)
    with open(w2v_file) as wf:
      for line in wf:
        info = line.strip().split()
        word = info[0]
        coef = np.asarray(info[1:], dtype='float32')
        self.wordDict[word] = coef
        assert self.word_dim == len(coef)

    self.MakeWordEmbedding()

  def MakeWordEmbedding(self):
    sorted_x = sorted(self._word_to_id.items(), key=operator.itemgetter(1))
    self._wordEmbedding = np.zeros((self.size(), self.word_dim),dtype=np.float32) # replace unknown words with UNKNOWN_TOKEN embedding (zero vector)
    for word,i in sorted_x:
      if word in self.wordDict:
        self._wordEmbedding[i,:] = self.wordDict[word.lower()]
    print('Word Embedding Reading done.')

  def getWordEmbedding(self):
    return self._wordEmbedding

def example_generator(data_path, single_pass):
  
  while True:
    import glob
    filelist = glob.glob(data_path)
    
    if single_pass:
      filelist = sorted(filelist)
    else:
      random.shuffle(filelist)
    for f in filelist:
      reader = open(f, 'rb')
      while True:
        len_bytes = reader.read(8)
        if not len_bytes: break # finished reading this file
        str_len = struct.unpack('q', len_bytes)[0]
        example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
        yield example_pb2.Example.FromString(example_str)
    if single_pass:
      print("example_generator completed reading all datafiles. No more data.")
      break


def article2ids(article_words, vocab):
  
  ids = []
  oovs = []
  unk_id = vocab.word2id(UNKNOWN_TOKEN)
  for w in article_words:
    i = vocab.word2id(w)
    if i == unk_id: # If w is OOV
      if w not in oovs: 
        oovs.append(w)
      oov_num = oovs.index(w) 
      ids.append(vocab.size() + oov_num) 
    else:
      ids.append(i)
  return ids, oovs


def abstract2ids(abstract_words, vocab, article_oovs):
  
  ids = []
  unk_id = vocab.word2id(UNKNOWN_TOKEN)
  for w in abstract_words:
    i = vocab.word2id(w)
    if i == unk_id: 
      if w in article_oovs: 
        vocab_idx = vocab.size() + article_oovs.index(w) 
        ids.append(vocab_idx)
      else: 
        ids.append(unk_id) 
    else:
      ids.append(i)
  return ids


def outputids2words(id_list, vocab, article_oovs):
  
  words = []
  for i in id_list:
    try:
      w = vocab.id2word(i) # might be [UNK]
    except ValueError as e: # w is OOV
      assert article_oovs is not None, "Error: model produced a word ID that isn't in the vocabulary. This should not happen in baseline (no pointer-generator) mode"
      article_oov_idx = i - vocab.size()
      try:
        w = article_oovs[article_oov_idx]
      except ValueError as e: # i doesn't correspond to an article oov
        raise ValueError('Error: model produced word ID %i which corresponds to article OOV %i but this example only has %i article OOVs' % (i, article_oov_idx, len(article_oovs)))
    words.append(w)
  return words


def abstract2sents(abstract):
  
  abstract = abstract.encode("utf-8").decode(encoding="utf-8", errors="strict")
  cur = 0
  sents = []
  while True:
    try:
      # print("SENCENCE TYPE:",type(SENTENCE_START))
      # print("CUR TYPE:", type(cur))
      # print("ABSTRACT", type(abstract))
      start_p = abstract.index(SENTENCE_START,cur)
      end_p = abstract.index(SENTENCE_END, start_p + 1)
      cur = end_p + len(SENTENCE_END)
      sents.append(abstract[start_p+len(SENTENCE_START):end_p])
    except ValueError as e: # no more sentences
      return sents


def show_art_oovs(article, vocab):
  
  unk_token = vocab.word2id(UNKNOWN_TOKEN)
  words = article.split(' ')
  words = [("__%s__" % w) if vocab.word2id(w)==unk_token else w for w in words]
  out_str = ' '.join(words)
  return out_str


def show_abs_oovs(abstract, vocab, article_oovs):
  
  unk_token = vocab.word2id(UNKNOWN_TOKEN)
  words = abstract.split(' ')
  new_words = []
  for w in words:
    if vocab.word2id(w) == unk_token: # w is oov
      if article_oovs is None: # baseline mode
        new_words.append("__%s__" % w)
      else: # pointer-generator mode
        if w in article_oovs:
          new_words.append("__%s__" % w)
        else:
          new_words.append("!!__%s__!!" % w)
    else: # w is in-vocab word
      new_words.append(w)
  out_str = ' '.join(new_words)
  return out_str

In [3]:

try:
  import queue
except:
  import Queue as queue
from random import shuffle
from random import seed
seed(123)
from threading import Thread
import time
import numpy as np
import tensorflow as tf
#import data

FLAGS = tf.app.flags.FLAGS

class Example(object):
  """Class representing a train/val/test example for text summarization."""

  def __init__(self, article, abstract_sentences, vocab, hps):
    
    self.hps = hps

    start_decoding = vocab.word2id(START_DECODING)
    stop_decoding = vocab.word2id(STOP_DECODING)

    article_words = article.split()
    if len(article_words) > hps.max_enc_steps:
      article_words = article_words[:hps.max_enc_steps]
    self.enc_len = len(article_words) 
    self.enc_input = [vocab.word2id(w) for w in article_words] 
    
    
    abstract = ' '.join(abstract_sentences) # string
    abstract_words = abstract.split() 
    abs_ids = [vocab.word2id(w) for w in abstract_words] 

    self.dec_input, self.target = self.get_dec_inp_targ_seqs(abs_ids, hps.max_dec_steps, start_decoding, stop_decoding)
    self.dec_len = len(self.dec_input)

    if hps.pointer_gen:
      self.enc_input_extend_vocab, self.article_oovs = article2ids(article_words, vocab)

      abs_ids_extend_vocab = abstract2ids(abstract_words, vocab, self.article_oovs)

      _, self.target = self.get_dec_inp_targ_seqs(abs_ids_extend_vocab, hps.max_dec_steps, start_decoding, stop_decoding)

    self.original_article = article
    self.original_abstract = abstract
    self.original_abstract_sents = abstract_sentences


  def get_dec_inp_targ_seqs(self, sequence, max_len, start_id, stop_id):
    
    inp = [start_id] + sequence[:]
    target = sequence[:]
    if len(inp) > max_len: # truncate
      inp = inp[:max_len]
      target = target[:max_len] # no end_token
    else: # no truncation
      target.append(stop_id) # end token
    assert len(inp) == len(target)
    return inp, target


  def pad_decoder_inp_targ(self, max_len, pad_id):
    while len(self.dec_input) < max_len:
      self.dec_input.append(pad_id)
    while len(self.target) < max_len:
      self.target.append(pad_id)


  def pad_encoder_input(self, max_len, pad_id):
    while len(self.enc_input) < max_len:
      self.enc_input.append(pad_id)
    if self.hps.pointer_gen:
      while len(self.enc_input_extend_vocab) < max_len:
        self.enc_input_extend_vocab.append(pad_id)


class Batch(object):

  def __init__(self, example_list, hps, vocab):
    
    self.pad_id = vocab.word2id(PAD_TOKEN) # id of the PAD token used to pad sequences
    self.init_encoder_seq(example_list, hps) # initialize the input to the encoder
    self.init_decoder_seq(example_list, hps) # initialize the input and targets for the decoder
    self.store_orig_strings(example_list) # store the original strings

  def init_encoder_seq(self, example_list, hps):
    
    max_enc_seq_len = max([ex.enc_len for ex in example_list])

    for ex in example_list:
      ex.pad_encoder_input(max_enc_seq_len, self.pad_id)

    self.enc_batch = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32)
    self.enc_lens = np.zeros((hps.batch_size), dtype=np.int32)
    self.enc_padding_mask = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.float32)

    for i, ex in enumerate(example_list):
      self.enc_batch[i, :] = ex.enc_input[:]
      self.enc_lens[i] = ex.enc_len
      for j in range(ex.enc_len):
        self.enc_padding_mask[i][j] = 1

    if hps.pointer_gen:
      # Determine the max number of in-article OOVs in this batch
      self.max_art_oovs = max([len(ex.article_oovs) for ex in example_list])
      # Store the in-article OOVs themselves
      self.art_oovs = [ex.article_oovs for ex in example_list]
      # Store the version of the enc_batch that uses the article OOV ids
      self.enc_batch_extend_vocab = np.zeros((hps.batch_size, max_enc_seq_len), dtype=np.int32)
      for i, ex in enumerate(example_list):
        self.enc_batch_extend_vocab[i, :] = ex.enc_input_extend_vocab[:]

  def init_decoder_seq(self, example_list, hps):
    
    for ex in example_list:
      ex.pad_decoder_inp_targ(hps.max_dec_steps, self.pad_id)

    self.dec_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32)
    self.target_batch = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.int32)
    self.dec_padding_mask = np.zeros((hps.batch_size, hps.max_dec_steps), dtype=np.float32)

    for i, ex in enumerate(example_list):
      self.dec_batch[i, :] = ex.dec_input[:]
      self.target_batch[i, :] = ex.target[:]
      for j in range(ex.dec_len):
        self.dec_padding_mask[i][j] = 1

  def store_orig_strings(self, example_list):
    self.original_articles = [ex.original_article for ex in example_list] # list of lists
    self.original_abstracts = [ex.original_abstract for ex in example_list] # list of lists
    self.original_abstracts_sents = [ex.original_abstract_sents for ex in example_list] # list of list of lists


class Batcher(object):
  
  BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold

  def __init__(self, data_path, vocab, hps, single_pass, decode_after):
    
    self._data_path = data_path
    self._vocab = vocab
    self._hps = hps
    self._single_pass = single_pass
    self._decode_after = decode_after

    self._batch_queue = queue.Queue(self.BATCH_QUEUE_MAX)
    self._example_queue = queue.Queue(self.BATCH_QUEUE_MAX * self._hps.batch_size)

    if single_pass:
      self._num_example_q_threads = 1 # just one thread, so we read through the dataset just once
      self._num_batch_q_threads = 1  # just one thread to batch examples
      self._bucketing_cache_size = 1 # only load one batch's worth of examples before bucketing; this essentially means no bucketing
      self._finished_reading = False # this will tell us when we're finished reading the dataset
    else:
      self._num_example_q_threads = FLAGS.example_queue_threads # num threads to fill example queue
      self._num_batch_q_threads = FLAGS.batch_queue_threads  # num threads to fill batch queue
      self._bucketing_cache_size = FLAGS.bucketing_cache_size # how many batches-worth of examples to load into cache before bucketing

    self._example_q_threads = []
    for _ in range(self._num_example_q_threads):
      self._example_q_threads.append(Thread(target=self.fill_example_queue))
      self._example_q_threads[-1].daemon = True
      self._example_q_threads[-1].start()
    self._batch_q_threads = []
    for _ in range(self._num_batch_q_threads):
      self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
      self._batch_q_threads[-1].daemon = True
      self._batch_q_threads[-1].start()

    # Start a thread that watches the other threads and restarts them if they're dead
    if not single_pass: # We don't want a watcher in single_pass mode because the threads shouldn't run forever
      self._watch_thread = Thread(target=self.watch_threads)
      self._watch_thread.daemon = True
      self._watch_thread.start()

  def next_batch(self):
    if self._batch_queue.qsize() == 0:
      tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize())
      if self._single_pass and self._finished_reading:
        tf.logging.info("Finished reading dataset in single_pass mode.")
        return None

    batch = self._batch_queue.get() # get the next Batch
    return batch

  def fill_example_queue(self):
    """Reads data from file and processes into Examples which are then placed into the example queue."""

    input_gen = self.text_generator(example_generator(self._data_path, self._single_pass))

    while True:
      try:
        (article, abstract) = next(input_gen)
        article = article.decode("utf-8")
        abstract = abstract.decode("utf-8")
        
      except StopIteration: 
        tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
        if self._single_pass:
          tf.logging.info("single_pass mode is on, so we've finished reading dataset. This thread is stopping.")
          self._finished_reading = True
          break
        else:
          raise Exception("single_pass mode is off but the example generator is out of data; error.")

      abstract_sentences = [sent.strip() for sent in abstract2sents(abstract)] # Use the <s> and </s> tags in abstract to get a list of sentences.
      example = Example(article, abstract_sentences, self._vocab, self._hps) # Process into an Example.
      self._example_queue.put(example) # place the Example in the example queue.

  def fill_batch_queue(self):
    
    while True:
      if self._hps.mode != 'decode':
        inputs = []
        for _ in range(self._hps.batch_size * self._bucketing_cache_size):
          inputs.append(self._example_queue.get())
        inputs = sorted(inputs, key=lambda inp: inp.enc_len) # sort by length of encoder sequence

        batches = []
        for i in range(0, len(inputs), self._hps.batch_size):
          batches.append(inputs[i:i + self._hps.batch_size])
        if not self._single_pass:
          shuffle(batches)
        for b in batches:  # each b is a list of Example objects
          self._batch_queue.put(Batch(b, self._hps, self._vocab))

      else: # beam search decode mode
        ex = self._example_queue.get()
        b = [ex for _ in range(self._hps.batch_size)]
        self._batch_queue.put(Batch(b, self._hps, self._vocab))

  def watch_threads(self):
    
    while True:
      time.sleep(60)
      for idx,t in enumerate(self._example_q_threads):
        if not t.is_alive(): # if the thread is dead
          tf.logging.error('Found example queue thread dead. Restarting.')
          new_t = Thread(target=self.fill_example_queue)
          self._example_q_threads[idx] = new_t
          new_t.daemon = True
          new_t.start()
      for idx,t in enumerate(self._batch_q_threads):
        if not t.is_alive(): # if the thread is dead
          tf.logging.error('Found batch queue thread dead. Restarting.')
          new_t = Thread(target=self.fill_batch_queue)
          self._batch_q_threads[idx] = new_t
          new_t.daemon = True
          new_t.start()

  def text_generator(self, example_generator):
    
    cnt = 0
    while True:
      e = next(example_generator) # e is a tf.Example
      try:
        article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files
        abstract_text = e.features.feature['abstract'].bytes_list.value[0] # the abstract text was saved under the key 'abstract' in the data files
      except ValueError:
        tf.logging.error('Failed to get article or abstract from example')
        continue
      if len(article_text)==0:
        tf.logging.warning('Found an example with empty article text. Skipping it.')
      else:
        if self._single_pass and cnt < self._decode_after: #skip already decoded docs
          cnt +=1
          continue
        yield (article_text, abstract_text)

In [4]:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import itertools
import numpy as np

#pylint: disable=C0103


def _get_ngrams(n, text):
  
  ngram_set = set()
  text_length = len(text)
  max_index_ngram_start = text_length - n
  for i in range(max_index_ngram_start + 1):
    ngram_set.add(tuple(text[i:i + n]))
  return ngram_set

def _split_into_words(sentences):
  return list(itertools.chain(*[_.split(" ") for _ in sentences]))

def _get_word_ngrams(n, sentences):
  
  assert len(sentences) > 0
  assert n > 0

  words = _split_into_words(sentences)
  return _get_ngrams(n, words)

def _len_lcs(x, y):
  
  table = _lcs(x, y)
  n, m = len(x), len(y)
  return table[n, m]

def _lcs(x, y):
  
  n, m = len(x), len(y)
  table = dict()
  for i in range(n + 1):
    for j in range(m + 1):
      if i == 0 or j == 0:
        table[i, j] = 0
      elif x[i - 1] == y[j - 1]:
        table[i, j] = table[i - 1, j - 1] + 1
      else:
        table[i, j] = max(table[i - 1, j], table[i, j - 1])
  return table

def _recon_lcs(x, y):
  
  i, j = len(x), len(y)
  table = _lcs(x, y)

  def _recon(i, j):
    """private recon calculation"""
    if i == 0 or j == 0:
      return []
    elif x[i - 1] == y[j - 1]:
      return _recon(i - 1, j - 1) + [(x[i - 1], i)]
    elif table[i - 1, j] > table[i, j - 1]:
      return _recon(i - 1, j)
    else:
      return _recon(i, j - 1)

  recon_tuple = tuple(map(lambda x: x[0], _recon(i, j)))
  return recon_tuple

def rouge_n(evaluated_sentences, reference_sentences, n=2):
  
  if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
    raise ValueError("Collections must contain at least 1 sentence.")

  evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences)
  reference_ngrams = _get_word_ngrams(n, reference_sentences)
  reference_count = len(reference_ngrams)
  evaluated_count = len(evaluated_ngrams)

  overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
  overlapping_count = len(overlapping_ngrams)

  if evaluated_count == 0:
    precision = 0.0
  else:
    precision = overlapping_count / evaluated_count

  if reference_count == 0:
    recall = 0.0
  else:
    recall = overlapping_count / reference_count

  f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))

  return f1_score, precision, recall

def _f_p_r_lcs(llcs, m, n):
  
  r_lcs = llcs / m
  p_lcs = llcs / n
  beta = p_lcs / (r_lcs + 1e-12)
  num = (1 + (beta**2)) * r_lcs * p_lcs
  denom = r_lcs + ((beta**2) * p_lcs)
  f_lcs = num / (denom + 1e-12)
  return f_lcs, p_lcs, r_lcs


def rouge_l_sentence_level(evaluated_sentences, reference_sentences):
  
  if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
    raise ValueError("Collections must contain at least 1 sentence.")
  reference_words = _split_into_words(reference_sentences)
  evaluated_words = _split_into_words(evaluated_sentences)
  m = len(reference_words)
  n = len(evaluated_words)
  lcs = _len_lcs(evaluated_words, reference_words)
  return _f_p_r_lcs(lcs, m, n)


def _union_lcs(evaluated_sentences, reference_sentence):
  
  if len(evaluated_sentences) <= 0:
    raise ValueError("Collections must contain at least 1 sentence.")

  lcs_union = set()
  reference_words = _split_into_words([reference_sentence])
  combined_lcs_length = 0
  for eval_s in evaluated_sentences:
    evaluated_words = _split_into_words([eval_s])
    lcs = set(_recon_lcs(reference_words, evaluated_words))
    combined_lcs_length += len(lcs)
    lcs_union = lcs_union.union(lcs)

  union_lcs_count = len(lcs_union)
  union_lcs_value = union_lcs_count / combined_lcs_length
  return union_lcs_value


def rouge_l_summary_level(evaluated_sentences, reference_sentences):
  
  if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0:
    raise ValueError("Collections must contain at least 1 sentence.")

  m = len(_split_into_words(reference_sentences))

  n = len(_split_into_words(evaluated_sentences))

  union_lcs_sum_across_all_references = 0
  for ref_s in reference_sentences:
    union_lcs_sum_across_all_references += _union_lcs(evaluated_sentences,
                                                      ref_s)
  return _f_p_r_lcs(union_lcs_sum_across_all_references, m, n)


def rouge(hypotheses, references):
  
  rouge_1 = [
      rouge_n([hyp], [ref], 1) for hyp, ref in zip(hypotheses, references)
  ]
  rouge_1_f, rouge_1_p, rouge_1_r = map(np.mean, zip(*rouge_1))

  rouge_2 = [
      rouge_n([hyp], [ref], 2) for hyp, ref in zip(hypotheses, references)
  ]
  rouge_2_f, rouge_2_p, rouge_2_r = map(np.mean, zip(*rouge_2))

  rouge_l = [
      rouge_l_sentence_level([hyp], [ref])
      for hyp, ref in zip(hypotheses, references)
  ]
  rouge_l_f, rouge_l_p, rouge_l_r = map(np.mean, zip(*rouge_l))

  return {
      "rouge_1/f_score": rouge_1_f,
      "rouge_1/r_score": rouge_1_r,
      "rouge_1/p_score": rouge_1_p,
      "rouge_2/f_score": rouge_2_f,
      "rouge_2/r_score": rouge_2_r,
      "rouge_2/p_score": rouge_2_p,
      "rouge_l/f_score": rouge_l_f,
      "rouge_l/r_score": rouge_l_r,
      "rouge_l/p_score": rouge_l_p,
  }

In [5]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np

import tensorflow as tf


def _len_lcs(x, y):
  
  table = _lcs(x, y)
  n, m = len(x), len(y)
  return table[n, m]


def _lcs(x, y):
  
  n, m = len(x), len(y)
  table = dict()
  for i in range(n + 1):
    for j in range(m + 1):
      if i == 0 or j == 0:
        table[i, j] = 0
      elif x[i - 1] == y[j - 1]:
        table[i, j] = table[i - 1, j - 1] + 1
      else:
        table[i, j] = max(table[i - 1, j], table[i, j - 1])
  return table


def _f_lcs(llcs, m, n):
  
  r_lcs = llcs / m
  p_lcs = llcs / n
  beta = p_lcs / (r_lcs + 1e-12)
  num = (1 + (beta**2)) * r_lcs * p_lcs
  denom = r_lcs + ((beta**2) * p_lcs)
  f_lcs = num / (denom + 1e-12)
  return f_lcs


def rouge_l_sentence_level(eval_sentences, ref_sentences):
  

  f1_scores = []
  for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
    m = len(ref_sentence)
    n = len(eval_sentence)
    lcs = _len_lcs(eval_sentence, ref_sentence)
    f1_scores.append(_f_lcs(lcs, m, n))
  return np.array(f1_scores).astype(np.float32)


def rouge_l_fscore(hypothesis, references, **unused_kwargs):
  
  rouge_l_f_score = tf.py_func(rouge_l_sentence_level, (hypothesis, references), [tf.float32])
  return rouge_l_f_score


def _get_ngrams(n, text):
  
  ngram_set = set()
  text_length = len(text)
  max_index_ngram_start = text_length - n
  for i in range(max_index_ngram_start + 1):
    ngram_set.add(tuple(text[i:i + n]))
  return ngram_set


def rouge_n(eval_sentences, ref_sentences, n=2):
  

  f1_scores = []
  for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
    eval_ngrams = _get_ngrams(n, eval_sentence)
    ref_ngrams = _get_ngrams(n, ref_sentence)
    ref_count = len(ref_ngrams)
    eval_count = len(eval_ngrams)

    # Gets the overlapping ngrams between evaluated and reference
    overlapping_ngrams = eval_ngrams.intersection(ref_ngrams)
    overlapping_count = len(overlapping_ngrams)

    # Handle edge case. This isn't mathematically correct, but it's good enough
    if eval_count == 0:
      precision = 0.0
    else:
      precision = overlapping_count / eval_count

    if ref_count == 0:
      recall = 0.0
    else:
      recall = overlapping_count / ref_count

    f1_scores.append(2.0 * ((precision * recall) / (precision + recall + 1e-8)))

  return np.array(f1_scores).astype(np.float32)


def rouge_2_fscore(predictions, labels, **unused_kwargs):
  

  rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), [tf.float32])
  return rouge_2_f_score, tf.constant(1.0)

In [6]:
import tensorflow as tf
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import bernoulli
#from rouge_tensor import rouge_l_fscore

FLAGS = tf.app.flags.FLAGS

def print_shape(str, var):
  tf.logging.info('shape of {}: {}'.format(str, [k for k in var.get_shape()]))


def _calc_final_dist(_hps, v_size, _max_art_oovs, _enc_batch_extend_vocab, p_gen, vocab_dist, attn_dist):
  
  with tf.variable_scope('final_distribution'):
    vocab_dist = p_gen * vocab_dist
    attn_dist = (1-p_gen) * attn_dist

    extended_vsize = v_size + _max_art_oovs # the maximum (over the batch) size of the extended vocabulary
    extra_zeros = tf.zeros((_hps.batch_size, _max_art_oovs))
    vocab_dists_extended = tf.concat(axis=1, values=[vocab_dist, extra_zeros]) # list length max_dec_steps of shape (batch_size, extended_vsize)

    batch_nums = tf.range(0, limit=_hps.batch_size) # shape (batch_size)
    batch_nums = tf.expand_dims(batch_nums, 1) # shape (batch_size, 1)
    attn_len = tf.shape(_enc_batch_extend_vocab)[1] # number of states we attend over
    batch_nums = tf.tile(batch_nums, [1, attn_len]) # shape (batch_size, attn_len)
    indices = tf.stack( (batch_nums, _enc_batch_extend_vocab), axis=2) # shape (batch_size, enc_t, 2)
    shape = [_hps.batch_size, extended_vsize]
    attn_dists_projected = tf.scatter_nd(indices, attn_dist, shape) # list length max_dec_steps (batch_size, extended_vsize)

    final_dist = vocab_dists_extended + attn_dists_projected
    final_dist +=1e-15 # for cases where we have zero in the final dist, especially for oov words
    dist_sums = tf.reduce_sum(final_dist, axis=1)
    final_dist = final_dist / tf.reshape(dist_sums, [-1, 1]) # re-normalize

  return final_dist

def attention_decoder(_hps, 
  v_size, 
  _max_art_oovs, 
  _enc_batch_extend_vocab, 
  emb_dec_inputs,
  target_batch,
  _dec_in_state, 
  _enc_states, 
  enc_padding_mask, 
  dec_padding_mask, 
  cell, 
  embedding, 
  sampling_probability,
  alpha,
  unk_id,
  initial_state_attention=False,
  pointer_gen=True, 
  use_coverage=False, 
  prev_coverage=None, 
  prev_decoder_outputs=[], 
  prev_encoder_es = []):
  
  with variable_scope.variable_scope("attention_decoder") as scope:
    batch_size = _enc_states.get_shape()[0] # if this line fails, it's because the batch size isn't defined
    attn_size = _enc_states.get_shape()[2] # if this line fails, it's because the attention length isn't defined
    emb_size = emb_dec_inputs[0].get_shape()[1] # if this line fails, it's because the embedding isn't defined
    decoder_attn_size = _dec_in_state.c.get_shape()[1]
    tf.logging.info("batch_size %i, attn_size: %i, emb_size: %i", batch_size, attn_size, emb_size)
    # Reshape _enc_states (need to insert a dim)
    _enc_states = tf.expand_dims(_enc_states, axis=2) # now is shape (batch_size, max_enc_steps, 1, attn_size)

    attention_vec_size = attn_size

    if _hps.matrix_attention:
      w_attn = variable_scope.get_variable("w_attn", [attention_vec_size, attention_vec_size])
      if _hps.intradecoder:
        w_dec_attn = variable_scope.get_variable("w_dec_attn", [decoder_attn_size, decoder_attn_size])
    else:
      W_h = variable_scope.get_variable("W_h", [1, 1, attn_size, attention_vec_size])
      v = variable_scope.get_variable("v", [attention_vec_size])
      encoder_features = nn_ops.conv2d(_enc_states, W_h, [1, 1, 1, 1], "SAME") # shape (batch_size,max_enc_steps,1,attention_vec_size)
    if _hps.intradecoder:
      W_h_d = variable_scope.get_variable("W_h_d", [1, 1, decoder_attn_size, decoder_attn_size])
      v_d = variable_scope.get_variable("v_d", [decoder_attn_size])

    if use_coverage:
      with variable_scope.variable_scope("coverage"):
        w_c = variable_scope.get_variable("w_c", [1, 1, 1, attention_vec_size])

    if prev_coverage is not None: # for beam search mode with coverage
      prev_coverage = tf.expand_dims(tf.expand_dims(prev_coverage,2),3)

    def attention(decoder_state, temporal_e, coverage=None):
      
      with variable_scope.variable_scope("Attention"):
        # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper)
        decoder_features = linear(decoder_state, attention_vec_size, True) # shape (batch_size, attention_vec_size)
        decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_vec_size)

        # We can't have coverage with matrix attention
        if not _hps.matrix_attention and use_coverage and coverage is not None: # non-first step of coverage
          # Multiply coverage vector by w_c to get coverage_features.
          coverage_features = nn_ops.conv2d(coverage, w_c, [1, 1, 1, 1], "SAME") # c has shape (batch_size, max_enc_steps, 1, attention_vec_size)
          # Calculate v^T tanh(W_h h_i + W_s s_t + w_c c_i^t + b_attn)
          e_not_masked = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features + coverage_features), [2, 3])  # shape (batch_size,max_enc_steps)
          masked_e = nn_ops.softmax(e_not_masked) * enc_padding_mask # (batch_size, max_enc_steps)
          masked_sums = tf.reduce_sum(masked_e, axis=1) # shape (batch_size)
          masked_e = masked_e / tf.reshape(masked_sums, [-1, 1])
          # Equation 3 in 
          if _hps.use_temporal_attention:
            try:
              len_temporal_e = temporal_e.get_shape()[0]
            except:
              len_temporal_e = 0
            if len_temporal_e==0:
              attn_dist = masked_e
            else:
              masked_sums = tf.reduce_sum(temporal_e,axis=0)+1e-10 # if it's zero due to masking we set it to a small value
              attn_dist = masked_e / masked_sums # (batch_size, max_enc_steps)
          else:
            attn_dist = masked_e
          masked_attn_sums = tf.reduce_sum(attn_dist, axis=1)
          attn_dist = attn_dist / tf.reshape(masked_attn_sums, [-1, 1]) # re-normalize
          # Update coverage vector
          coverage += array_ops.reshape(attn_dist, [batch_size, -1, 1, 1])
        else:
          if _hps.matrix_attention:
            # Calculate h_d * W_attn * h_i, equation 2 in https://arxiv.org/pdf/1705.04304.pdf
            _dec_attn = tf.unstack(tf.matmul(tf.squeeze(decoder_features,axis=[1,2]),w_attn),axis=0) # batch_size * (attention_vec_size)
            _enc_states_lst = tf.unstack(tf.squeeze(_enc_states,axis=2),axis=0) # batch_size * (max_enc_steps, attention_vec_size)

            e_not_masked = tf.squeeze(tf.stack([tf.matmul(tf.reshape(_dec,[1,-1]), tf.transpose(_enc)) for _dec, _enc in zip(_dec_attn,_enc_states_lst)]),axis=1) # (batch_size, max_enc_steps)
            masked_e = tf.exp(e_not_masked * enc_padding_mask) # (batch_size, max_enc_steps)
          else:
            # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
            e_not_masked = math_ops.reduce_sum(v * math_ops.tanh(encoder_features + decoder_features), [2, 3]) # calculate e, (batch_size, max_enc_steps)
            masked_e = nn_ops.softmax(e_not_masked) * enc_padding_mask # (batch_size, max_enc_steps)
            masked_sums = tf.reduce_sum(masked_e, axis=1) # shape (batch_size)
            masked_e = masked_e / tf.reshape(masked_sums, [-1, 1])
          if _hps.use_temporal_attention:
            try:
              len_temporal_e = temporal_e.get_shape()[0]
            except:
              len_temporal_e = 0
            if len_temporal_e==0:
              attn_dist = masked_e
            else:
              masked_sums = tf.reduce_sum(temporal_e,axis=0)+1e-10 # if it's zero due to masking we set it to a small value
              attn_dist = masked_e / masked_sums # (batch_size, max_enc_steps)
          else:
            attn_dist = masked_e
          masked_attn_sums = tf.reduce_sum(attn_dist, axis=1)
          attn_dist = attn_dist / tf.reshape(masked_attn_sums, [-1, 1]) # re-normalize

          if use_coverage: # first step of training
            coverage = tf.expand_dims(tf.expand_dims(attn_dist,2),2) # initialize coverage

        context_vector = math_ops.reduce_sum(array_ops.reshape(attn_dist, [batch_size, -1, 1, 1]) * _enc_states, [1, 2]) # shape (batch_size, attn_size).
        context_vector = array_ops.reshape(context_vector, [-1, attn_size])

      return context_vector, attn_dist, coverage, masked_e

    def intra_decoder_attention(decoder_state, outputs):
      
      attention_dec_vec_size = attn_dec_size = decoder_state.c.get_shape()[1] # hidden_dim
      try:
        len_dec_states = outputs.get_shape()[0]
      except:
        len_dec_states = 0
      attention_dec_vec_size = attn_dec_size = decoder_state.c.get_shape()[1] # hidden_dim
      _decoder_states = tf.expand_dims(tf.reshape(outputs,[batch_size,-1,attn_dec_size]), axis=2) # now is shape (batch_size,len(decoder_states), 1, attn_size)
      _prev_decoder_features = nn_ops.conv2d(_decoder_states, W_h_d, [1, 1, 1, 1], "SAME") # shape (batch_size,len(decoder_states),1,attention_vec_size)
      with variable_scope.variable_scope("DecoderAttention"):
        # Pass the decoder state through a linear layer (this is W_s s_t + b_attn in the paper)
        try:
          decoder_features = linear(decoder_state, attention_dec_vec_size, True) # shape (batch_size, attention_vec_size)
          decoder_features = tf.expand_dims(tf.expand_dims(decoder_features, 1), 1) # reshape to (batch_size, 1, 1, attention_dec_vec_size)
          # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
          if _hps.matrix_attention:
            # Calculate h_d * W_attn * h_d, equation 6 in https://arxiv.org/pdf/1705.04304.pdf
            _dec_attn = tf.matmul(tf.squeeze(decoder_features),w_dec_attn) # (batch_size, decoder_attn_size)
            _dec_states_lst = tf.unstack(tf.reshape(_prev_decoder_features,[batch_size,-1,decoder_attn_size])) # batch_size * (len(decoder_states), decoder_attn_size)
            e_not_masked = tf.reshape(tf.stack([tf.matmul(_dec_attn, tf.transpose(k)) for k in _dec_states_lst]),[batch_size,-1]) # (batch_size, len(decoder_states))
            masked_e = tf.exp(e_not_masked * dec_padding_mask[:,:len_dec_states]) # (batch_size, len(decoder_states))
          else:
            # Calculate v^T tanh(W_h h_i + W_s s_t + b_attn)
            e_not_masked = math_ops.reduce_sum(v_d * math_ops.tanh(_prev_decoder_features + decoder_features), [2, 3]) # calculate e, (batch_size,len(decoder_states))
            masked_e = nn_ops.softmax(e_not_masked) * dec_padding_mask[:,:len_dec_states] # (batch_size,len(decoder_states))
          if len_dec_states <= 1:
            masked_e = array_ops.ones([batch_size,1]) # first step is filled with equal values
          masked_sums = tf.reshape(tf.reduce_sum(masked_e,axis=1),[-1,1]) # (batch_size,1), # if it's zero due to masking we set it to a small value
          decoder_attn_dist = masked_e / masked_sums # (batch_size,len(decoder_states))
          context_decoder_vector = math_ops.reduce_sum(array_ops.reshape(decoder_attn_dist, [batch_size, -1, 1, 1]) * _decoder_states, [1, 2]) # (batch_size, attn_size)
          context_decoder_vector = array_ops.reshape(context_decoder_vector, [-1, attn_dec_size]) # (batch_size, attn_size)
        except:
          return array_ops.zeros([batch_size, decoder_attn_size]), array_ops.zeros([batch_size, 0])
      return context_decoder_vector, decoder_attn_dist

    outputs = []
    temporal_e = []
    attn_dists = []
    vocab_scores = []
    vocab_dists = []
    final_dists = []
    p_gens = []
    samples = [] # this holds the words chosen by sampling based on the final distribution for each decoding step, list of max_dec_steps of (batch_size, 1)
    greedy_search_samples = [] # this holds the words chosen by greedy search (taking the max) on the final distribution for each decoding step, list of max_dec_steps of (batch_size, 1)
    sampling_rewards = [] # list of size max_dec_steps (batch_size, k)
    greedy_rewards = [] # list of size max_dec_steps (batch_size, k)
    state = _dec_in_state
    coverage = prev_coverage # initialize coverage to None or whatever was passed in
    context_vector = array_ops.zeros([batch_size, attn_size])
    context_decoder_vector = array_ops.zeros([batch_size, decoder_attn_size])
    context_vector.set_shape([None, attn_size])  # Ensure the second shape of attention vectors is set.
    if initial_state_attention: # true in decode mode
      # Re-calculate the context vector from the previous step so that we can pass it through a linear layer with this step's input to get a modified version of the input
      context_vector, _, coverage, _ = attention(_dec_in_state, tf.stack(prev_encoder_es,axis=0), coverage) # in decode mode, this is what updates the coverage vector
      if _hps.intradecoder:
        context_decoder_vector, _ = intra_decoder_attention(_dec_in_state, tf.stack(prev_decoder_outputs,axis=0))
    for i, inp in enumerate(emb_dec_inputs):
      tf.logging.info("Adding attention_decoder timestep %i of %i", i, len(emb_dec_inputs))
      
      
      if i > 0:
        variable_scope.get_variable_scope().reuse_variables()

      if _hps.mode in ['train','eval'] and _hps.scheduled_sampling and i > 0: # start scheduled sampling after we received the first decoder's output
        # modify the input to next decoder using scheduled sampling
        if FLAGS.scheduled_sampling_final_dist:
          inp = scheduled_sampling(_hps, sampling_probability, final_dist, embedding, inp, alpha)
        else:
          inp = scheduled_sampling_vocab_dist(_hps, sampling_probability, vocab_dist, embedding, inp, alpha)

      emb_dim = inp.get_shape().with_rank(2)[1]
      if emb_dim is None:
        raise ValueError("Could not infer input size from input: %s" % inp.name)

      x = linear([inp] + [context_vector], emb_dim, True)
      cell_output, state = cell(x, state)

      if i == 0 and initial_state_attention:  # always true in decode mode
        with variable_scope.variable_scope(variable_scope.get_variable_scope()):#, reuse=True): # you need this because you've already run the initial attention(...) call
          context_vector, attn_dist, _, masked_e = attention(state, tf.stack(prev_encoder_es,axis=0), coverage) # don't allow coverage to update
          if _hps.intradecoder:
            context_decoder_vector, _ = intra_decoder_attention(state, tf.stack(prev_decoder_outputs,axis=0))
      else:
        context_vector, attn_dist, coverage, masked_e = attention(state, tf.stack(temporal_e,axis=0), coverage)
        if _hps.intradecoder:
          context_decoder_vector, _ = intra_decoder_attention(state, tf.stack(outputs,axis=0))
      attn_dists.append(attn_dist)
      temporal_e.append(masked_e)

      with variable_scope.variable_scope("combined_context"):
        if _hps.intradecoder:
          context_vector = linear([context_vector] + [context_decoder_vector], attn_size, False)
      # Calculate p_gen
      if pointer_gen:
        with tf.variable_scope('calculate_pgen'):
          p_gen = linear([context_vector, state.c, state.h, x], 1, True) # Tensor shape (batch_size, 1)
          p_gen = tf.sigmoid(p_gen)
          p_gens.append(p_gen)

      with variable_scope.variable_scope("AttnOutputProjection"):
        output = linear([cell_output] + [context_vector], cell.output_size, True)
      outputs.append(output)

      with tf.variable_scope('output_projection'):
        if i > 0:
          tf.get_variable_scope().reuse_variables()
        trunc_norm_init = tf.truncated_normal_initializer(stddev=_hps.trunc_norm_init_std)
        w_out = tf.get_variable('w', [_hps.dec_hidden_dim, v_size], dtype=tf.float32, initializer=trunc_norm_init)
        #w_t_out = tf.transpose(w)
        v_out = tf.get_variable('v', [v_size], dtype=tf.float32, initializer=trunc_norm_init)
        if i > 0:
          tf.get_variable_scope().reuse_variables()
        if FLAGS.share_decoder_weights: # Eq. 13 in https://arxiv.org/pdf/1705.04304.pdf
          w_out = tf.transpose(
            math_ops.tanh(linear([embedding] + [tf.transpose(w_out)], _hps.dec_hidden_dim, bias=False)))
        score = tf.nn.xw_plus_b(output, w_out, v_out)
        if _hps.scheduled_sampling and not _hps.greedy_scheduled_sampling:
          # Gumbel reparametrization trick: https://arxiv.org/abs/1704.06970
          U = tf.random_uniform(score.get_shape(),10e-12,(1-10e-12)) # add a small number to avoid log(0)
          G = -tf.log(-tf.log(U))
          score = score + G
        vocab_scores.append(score) # apply the linear layer
        vocab_dist = tf.nn.softmax(score)
        vocab_dists.append(vocab_dist) # The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays. The words are in the order they appear in the vocabulary file.

      if _hps.pointer_gen:
        final_dist = _calc_final_dist(_hps, v_size, _max_art_oovs, _enc_batch_extend_vocab, p_gen, vocab_dist,
                                      attn_dist)
      else: # final distribution is just vocabulary distribution
        final_dist = vocab_dist
      final_dists.append(final_dist)

      one_hot_k_samples = tf.distributions.Multinomial(total_count=1., probs=final_dist).sample(
        _hps.k)  # sample k times according to https://arxiv.org/pdf/1705.04304.pdf, size (k, batch_size, extended_vsize)
      k_argmax = tf.argmax(one_hot_k_samples, axis=2, output_type=tf.int32) # (k, batch_size)
      k_sample = tf.transpose(k_argmax) # shape (batch_size, k)
      greedy_search_prob, greedy_search_sample = tf.nn.top_k(final_dist, k=_hps.k) # (batch_size, k)
      greedy_search_samples.append(greedy_search_sample)
      samples.append(k_sample)
      if FLAGS.use_discounted_rewards:
        _sampling_rewards = []
        _greedy_rewards = []
        for _ in range(_hps.k):
          rl_fscore = tf.reshape(rouge_l_fscore(tf.transpose(tf.stack(samples)[:, :, _]), target_batch),
                                 [-1, 1])  # shape (batch_size, 1)
          _sampling_rewards.append(tf.reshape(rl_fscore, [-1, 1]))
          rl_fscore = tf.reshape(rouge_l_fscore(tf.transpose(tf.stack(greedy_search_samples)[:, :, _]), target_batch),
                                 [-1, 1])  # shape (batch_size, 1)
          _greedy_rewards.append(tf.reshape(rl_fscore, [-1, 1]))
        sampling_rewards.append(tf.squeeze(tf.stack(_sampling_rewards, axis=1), axis = -1)) # (batch_size, k)
        greedy_rewards.append(tf.squeeze(tf.stack(_greedy_rewards, axis=1), axis = -1))  # (batch_size, k)

    if FLAGS.use_discounted_rewards:
      sampling_rewards = tf.stack(sampling_rewards)
      greedy_rewards = tf.stack(greedy_rewards)
    else:
      _sampling_rewards = []
      _greedy_rewards = []
      for _ in range(_hps.k):
        rl_fscore = rouge_l_fscore(tf.transpose(tf.stack(samples)[:, :, _]), target_batch) # shape (batch_size, 1)
        _sampling_rewards.append(tf.reshape(rl_fscore, [-1, 1]))
        rl_fscore = rouge_l_fscore(tf.transpose(tf.stack(greedy_search_samples)[:, :, _]), target_batch)  # shape (batch_size, 1)
        _greedy_rewards.append(tf.reshape(rl_fscore, [-1, 1]))
      sampling_rewards = tf.squeeze(tf.stack(_sampling_rewards, axis=1), axis=-1) # (batch_size, k)
      greedy_rewards = tf.squeeze(tf.stack(_greedy_rewards, axis=1), axis=-1) # (batch_size, k)
    if coverage is not None:
      coverage = array_ops.reshape(coverage, [batch_size, -1])

  return (
  outputs, state, attn_dists, p_gens, coverage, vocab_scores, final_dists, samples, greedy_search_samples, temporal_e,
  sampling_rewards, greedy_rewards)

def scheduled_sampling(hps, sampling_probability, output, embedding, inp, alpha = 0):
  vocab_size = embedding.get_shape()[0]

  def soft_argmax(alpha, _output):
    new_oov_scores = tf.reshape(_output[:, 0] + tf.reduce_sum(_output[:, vocab_size:], axis=1),
                                [-1, 1])  # add score for all OOV to the UNK score
    _output = tf.concat([new_oov_scores, _output[:, 1:vocab_size]], axis=1) # select only the vocab_size outputs
    _output = _output / tf.reshape(tf.reduce_sum(output, axis=1), [-1, 1]) # re-normalize scores

    one_hot_scores = tf.nn.softmax((alpha * _output))
    return one_hot_scores

  def soft_top_k(alpha, _output, K):
    copy = tf.identity(_output)
    p = []
    arg_top_k = []
    for k in range(K):
      sargmax = soft_argmax(alpha, copy)
      copy = (1-sargmax)* copy
      p.append(tf.reduce_sum(sargmax * _output, axis=1))
      arg_top_k.append(sargmax)

    return tf.stack(p, axis=1), tf.stack(arg_top_k)

  with variable_scope.variable_scope("ScheduledEmbedding"):
    # Return -1s where we did not sample, and sample_ids elsewhere
    select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
    select_sample = select_sampler.sample(sample_shape=hps.batch_size)
    sample_id_sampler = categorical.Categorical(probs=output) # equals to argmax{ Multinomial(output, total_count=1) }, our greedy search selection
    sample_ids = array_ops.where(
            select_sample,
            sample_id_sampler.sample(seed=123),
            gen_array_ops.fill([hps.batch_size], -1))

    where_sampling = math_ops.cast(
        array_ops.where(sample_ids > -1), tf.int32)
    where_not_sampling = math_ops.cast(
        array_ops.where(sample_ids <= -1), tf.int32)

    if hps.greedy_scheduled_sampling:
      sample_ids = tf.argmax(output, axis=1, output_type=tf.int32)

    sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)

    cond = tf.less(sample_ids_sampling, vocab_size) # replace oov with unk
    sample_ids_sampling = tf.cast(cond, tf.int32) * sample_ids_sampling
    inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling)

    if hps.E2EBackProp:
      if hps.hard_argmax:
        greedy_search_prob, greedy_search_sample = tf.nn.top_k(output, k=hps.k) # (batch_size, k)
        greedy_search_prob_normalized = greedy_search_prob/tf.reshape(tf.reduce_sum(greedy_search_prob,axis=1),[-1,1])

        cond = tf.less(greedy_search_sample, vocab_size) # replace oov with unk
        greedy_search_sample = tf.cast(cond, tf.int32) * greedy_search_sample

        greedy_embedding = tf.nn.embedding_lookup(embedding, greedy_search_sample)
        normalized_embedding = tf.multiply(tf.reshape(greedy_search_prob_normalized,[hps.batch_size,hps.k,1]), greedy_embedding)
        e2e_embedding = tf.reduce_mean(normalized_embedding,axis=1)
      else:
        e = []
        greedy_search_prob, greedy_search_sample = soft_top_k(alpha, output,
                                                              K=hps.k)  # (batch_size, k), (k, batch_size, vocab_size)
        greedy_search_prob_normalized = greedy_search_prob / tf.reshape(tf.reduce_sum(greedy_search_prob, axis=1),
                                                                        [-1, 1])

        for _ in range(hps.k):
          a_k = greedy_search_sample[_]
          e_k = tf.matmul(tf.reshape(greedy_search_prob_normalized[:,_],[-1,1]) * a_k, embedding)
          e.append(e_k)
        e2e_embedding = tf.reduce_sum(e, axis=0) # (batch_size, emb_dim)
      sampled_next_inputs = array_ops.gather_nd(e2e_embedding, where_sampling)
    else:
      if hps.hard_argmax:
        sampled_next_inputs = tf.nn.embedding_lookup(embedding, sample_ids_sampling)
      else: # using soft armax (greedy) proposed in: https://arxiv.org/abs/1704.06970
        #alpha_exp = tf.exp(alpha * (output_not_extended + G)) # (batch_size, vocab_size)
        #one_hot_scores = alpha_exp / tf.reduce_sum(alpha_exp, axis=1) #(batch_size, vocab_size)
        one_hot_scores = soft_argmax(alpha, output) #(batch_size, vocab_size)
        soft_argmax_embedding = tf.matmul(one_hot_scores, embedding) #(batch_size, emb_size)
        sampled_next_inputs = array_ops.gather_nd(soft_argmax_embedding, where_sampling)

    base_shape = array_ops.shape(inp)
    result1 = array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape)
    result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape)
    return result1 + result2

def scheduled_sampling_vocab_dist(hps, sampling_probability, output, embedding, inp, alpha = 0):
  
  def soft_argmax(alpha, output):
    
    one_hot_scores = tf.nn.softmax(alpha * output)
    return one_hot_scores

  def soft_top_k(alpha, output, K):
    copy = tf.identity(output)
    p = []
    arg_top_k = []
    for k in range(K):
      sargmax = soft_argmax(alpha, copy)
      copy = (1-sargmax)* copy
      p.append(tf.reduce_sum(sargmax * output, axis=1))
      arg_top_k.append(sargmax)

    return tf.stack(p, axis=1), tf.stack(arg_top_k)

  with variable_scope.variable_scope("ScheduledEmbedding"):
    
    select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
    select_sample = select_sampler.sample(sample_shape=hps.batch_size)
    sample_id_sampler = categorical.Categorical(probs=output) # equals to argmax{ Multinomial(output, total_count=1) }, our greedy search selection
    sample_ids = array_ops.where(
            select_sample,
            sample_id_sampler.sample(seed=123),
            gen_array_ops.fill([hps.batch_size], -1))

    where_sampling = math_ops.cast(
        array_ops.where(sample_ids > -1), tf.int32)
    where_not_sampling = math_ops.cast(
        array_ops.where(sample_ids <= -1), tf.int32)

    if hps.greedy_scheduled_sampling:
      sample_ids = tf.argmax(output, axis=1, output_type=tf.int32)

    sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
    inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling)

    if hps.E2EBackProp:
      if hps.hard_argmax:
        greedy_search_prob, greedy_search_sample = tf.nn.top_k(output, k=hps.k) # (batch_size, k)
        greedy_search_prob_normalized = greedy_search_prob/tf.reshape(tf.reduce_sum(greedy_search_prob,axis=1),[-1,1])
        greedy_embedding = tf.nn.embedding_lookup(embedding, greedy_search_sample)
        normalized_embedding = tf.multiply(tf.reshape(greedy_search_prob_normalized,[hps.batch_size,hps.k,1]), greedy_embedding)
        e2e_embedding = tf.reduce_mean(normalized_embedding,axis=1)
      else:
        e = []
        greedy_search_prob, greedy_search_sample = soft_top_k(alpha, output,
                                                              K=hps.k)  # (batch_size, k), (k, batch_size, vocab_size)
        greedy_search_prob_normalized = greedy_search_prob / tf.reshape(tf.reduce_sum(greedy_search_prob, axis=1),
                                                                        [-1, 1])

        for _ in range(hps.k):
          a_k = greedy_search_sample[_]
          e_k = tf.matmul(tf.reshape(greedy_search_prob_normalized[:,_],[-1,1]) * a_k, embedding)
          e.append(e_k)
        e2e_embedding = tf.reduce_sum(e, axis=0) # (batch_size, emb_dim)
      sampled_next_inputs = array_ops.gather_nd(e2e_embedding, where_sampling)
    else:
      if hps.hard_argmax:
        sampled_next_inputs = tf.nn.embedding_lookup(embedding, sample_ids_sampling)
      else: 
        one_hot_scores = soft_argmax(alpha, output) #(batch_size, vocab_size)
        soft_argmax_embedding = tf.matmul(one_hot_scores, embedding) #(batch_size, emb_size)
        sampled_next_inputs = array_ops.gather_nd(soft_argmax_embedding, where_sampling)

    base_shape = array_ops.shape(inp)
    result1 = array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape)
    result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape)
    return result1 + result2

def linear(args, output_size, bias, bias_start=0.0, scope=None):
  
  if args is None or (isinstance(args, (list, tuple)) and not args):
    raise ValueError("`args` must be specified")
  if not isinstance(args, (list, tuple)):
    args = [args]
  
  total_arg_size = 0
  shapes = [a.get_shape().as_list() for a in args]
  for shape in shapes:
    if len(shape) != 2:
      raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
    if not shape[1]:
      raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
    else:
      total_arg_size += shape[1]

  # Now the computation.
  with tf.variable_scope(scope or "Linear" , reuse=tf.AUTO_REUSE):
    matrix = tf.get_variable("Matrix", [total_arg_size, output_size])
    if len(args) == 1:
      res = tf.matmul(args[0], matrix)
    else:
      res = tf.matmul(tf.concat(axis=1, values=args), matrix)
    if not bias:
      return res
    bias_term = tf.get_variable(
        "Bias", [output_size], initializer=tf.constant_initializer(bias_start))
  return res + bias_term

In [7]:
import os
import time
import numpy as np
import tensorflow as tf
#from attention_decoder import attention_decoder
##from tensorflow.contrib.tensorboard.plugins import projector
from nltk.translate.bleu_score import sentence_bleu
#from rouge import rouge
#from rouge_tensor import rouge_l_fscore
#import data
#from replay_buffer import Transition

FLAGS = tf.app.flags.FLAGS

class SummarizationModel(object):
  

  def __init__(self, hps, vocab):
    self._hps = hps
    self._vocab = vocab

  def reward_function(self, reference, summary, measure='rouge_l/f_score'):
    
    if 'rouge' in measure:
      return rouge([summary],[reference])[measure]
    else:
      return sentence_bleu([reference.split()],summary.split(),weights=(0.25,0.25,0.25,0.25))

  def variable_summaries(self, var_name, var):
    
    with tf.name_scope('summaries_{}'.format(var_name)):
      mean = tf.reduce_mean(var)
      tf.summary.scalar('mean', mean)
      with tf.name_scope('stddev'):
        stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
      tf.summary.scalar('stddev', stddev)
      tf.summary.scalar('max', tf.reduce_max(var))
      tf.summary.scalar('min', tf.reduce_min(var))
      tf.summary.histogram('histogram', var)

  def _add_placeholders(self):
    
    hps = self._hps

    # encoder part
    self._enc_batch = tf.placeholder(tf.int32, [hps.batch_size, None], name='enc_batch')
    self._enc_lens = tf.placeholder(tf.int32, [hps.batch_size], name='enc_lens')
    self._enc_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, None], name='enc_padding_mask')
    self._eta = tf.placeholder(tf.float32, None, name='eta')
    if FLAGS.embedding:
      self.embedding_place = tf.placeholder(tf.float32, [self._vocab.size(), hps.emb_dim])
    if FLAGS.pointer_gen:
      self._enc_batch_extend_vocab = tf.placeholder(tf.int32, [hps.batch_size, None], name='enc_batch_extend_vocab')
      self._max_art_oovs = tf.placeholder(tf.int32, [], name='max_art_oovs')
    if FLAGS.ac_training: # added by yaserkl@vt.edu for the purpose of calculating rouge loss
      self._q_estimates = tf.placeholder(tf.float32, [self._hps.batch_size,self._hps.k,self._hps.max_dec_steps, None], name='q_estimates')
    if FLAGS.scheduled_sampling:
      self._sampling_probability = tf.placeholder(tf.float32, None, name='sampling_probability')
      self._alpha = tf.placeholder(tf.float32, None, name='alpha')

    # decoder part
    self._dec_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='dec_batch')
    self._target_batch = tf.placeholder(tf.int32, [hps.batch_size, hps.max_dec_steps], name='target_batch')
    self._dec_padding_mask = tf.placeholder(tf.float32, [hps.batch_size, hps.max_dec_steps], name='dec_padding_mask')

    if hps.mode == "decode":
      if hps.coverage:
        self.prev_coverage = tf.placeholder(tf.float32, [hps.batch_size, None], name='prev_coverage')
      if hps.intradecoder:
        self.prev_decoder_outputs = tf.placeholder(tf.float32, [None, hps.batch_size, hps.dec_hidden_dim], name='prev_decoder_outputs')
      if hps.use_temporal_attention:
        self.prev_encoder_es = tf.placeholder(tf.float32, [None, hps.batch_size, None], name='prev_encoder_es')

  def _make_feed_dict(self, batch, just_enc=False):
    
    feed_dict = {}
    feed_dict[self._enc_batch] = batch.enc_batch
    feed_dict[self._enc_lens] = batch.enc_lens
    feed_dict[self._enc_padding_mask] = batch.enc_padding_mask
    if FLAGS.pointer_gen:
      feed_dict[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab
      feed_dict[self._max_art_oovs] = batch.max_art_oovs
    if not just_enc:
      feed_dict[self._dec_batch] = batch.dec_batch
      feed_dict[self._target_batch] = batch.target_batch
      feed_dict[self._dec_padding_mask] = batch.dec_padding_mask
    return feed_dict

  def _add_encoder(self, emb_enc_inputs, seq_len):
    
    with tf.variable_scope('encoder'):
      cell_fw = tf.contrib.rnn.LSTMCell(self._hps.enc_hidden_dim, initializer=self.rand_unif_init, state_is_tuple=True)
      cell_bw = tf.contrib.rnn.LSTMCell(self._hps.enc_hidden_dim, initializer=self.rand_unif_init, state_is_tuple=True)
      (encoder_outputs, (fw_st, bw_st)) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, emb_enc_inputs, dtype=tf.float32, sequence_length=seq_len, swap_memory=True)
      encoder_outputs = tf.concat(axis=2, values=encoder_outputs) # concatenate the forwards and backwards states
    return encoder_outputs, fw_st, bw_st

  def _reduce_states(self, fw_st, bw_st):
    
    enc_hidden_dim = self._hps.enc_hidden_dim
    dec_hidden_dim = self._hps.dec_hidden_dim

    with tf.variable_scope('reduce_final_st'):

      # Define weights and biases to reduce the cell and reduce the state
      w_reduce_c = tf.get_variable('w_reduce_c', [enc_hidden_dim * 2, dec_hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
      w_reduce_h = tf.get_variable('w_reduce_h', [enc_hidden_dim * 2, dec_hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
      bias_reduce_c = tf.get_variable('bias_reduce_c', [dec_hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
      bias_reduce_h = tf.get_variable('bias_reduce_h', [dec_hidden_dim], dtype=tf.float32, initializer=self.trunc_norm_init)

      # Apply linear layer
      old_c = tf.concat(axis=1, values=[fw_st.c, bw_st.c]) # Concatenation of fw and bw cell
      old_h = tf.concat(axis=1, values=[fw_st.h, bw_st.h]) # Concatenation of fw and bw state
      new_c = tf.nn.relu(tf.matmul(old_c, w_reduce_c) + bias_reduce_c) # Get new cell from old cell
      new_h = tf.nn.relu(tf.matmul(old_h, w_reduce_h) + bias_reduce_h) # Get new state from old state
      return tf.contrib.rnn.LSTMStateTuple(new_c, new_h) # Return new cell and state

  def _add_decoder(self, emb_dec_inputs, embedding):
    
    hps = self._hps
    cell = tf.contrib.rnn.LSTMCell(hps.dec_hidden_dim, state_is_tuple=True, initializer=self.rand_unif_init)

    prev_coverage = self.prev_coverage if (hps.mode=="decode" and hps.coverage) else None # In decode mode, we run attention_decoder one step at a time and so need to pass in the previous step's coverage vector each time
    prev_decoder_outputs = self.prev_decoder_outputs if (hps.intradecoder and hps.mode=="decode") else tf.stack([],axis=0)
    prev_encoder_es = self.prev_encoder_es if (hps.use_temporal_attention and hps.mode=="decode") else tf.stack([],axis=0)
    return attention_decoder(hps,
      self._vocab.size(),
      self._max_art_oovs,
      self._enc_batch_extend_vocab,
      emb_dec_inputs,
      self._target_batch,
      self._dec_in_state,
      self._enc_states,
      self._enc_padding_mask,
      self._dec_padding_mask,
      cell,
      embedding,
      self._sampling_probability if FLAGS.scheduled_sampling else 0,
      self._alpha if FLAGS.E2EBackProp else 0,
      self._vocab.word2id(UNKNOWN_TOKEN),
      initial_state_attention=(hps.mode=="decode"),
      pointer_gen=hps.pointer_gen,
      use_coverage=hps.coverage,
      prev_coverage=prev_coverage,
      prev_decoder_outputs=prev_decoder_outputs,
      prev_encoder_es = prev_encoder_es)

  def _add_emb_vis(self, embedding_var):
    
    train_dir = os.path.join(FLAGS.log_root, "train")
    vocab_metadata_path = os.path.join(train_dir, "vocab_metadata.tsv")
    self._vocab.write_metadata(vocab_metadata_path) # write metadata file
    summary_writer = tf.summary.FileWriter(train_dir)
    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    embedding.tensor_name = embedding_var.name
    embedding.metadata_path = vocab_metadata_path
    projector.visualize_embeddings(summary_writer, config)

  def discount_rewards(self, r):
    
    discounted_r = []
    running_add = tf.constant(0, tf.float32)
    for t in reversed(range(0, len(r))):
      running_add = running_add * self._hps.gamma + r[t] # rd_t = r_t + gamma * r_{t+1}
      discounted_r.append(running_add)
    discounted_r = tf.stack(discounted_r[::-1]) # (max_dec_step, batch_size, k)
    normalized_discounted_r = tf.nn.l2_normalize(discounted_r, axis=0)
    return tf.unstack(normalized_discounted_r) # list of max_dec_step * (batch_size, k)

  def intermediate_rewards(self, r):
    
    intermediate_r = []
    intermediate_r.append(r[0])
    for t in range(1, len(r)):
      intermediate_r.append(r[t]-r[t-1])
    return intermediate_r # list of max_dec_step * (batch_size, k)

  def _add_seq2seq(self):
    
    hps = self._hps
    vsize = self._vocab.size() # size of the vocabulary

    with tf.variable_scope('seq2seq'):
      # Some initializers
      self.rand_unif_init = tf.random_uniform_initializer(-hps.rand_unif_init_mag, hps.rand_unif_init_mag, seed=123)
      self.trunc_norm_init = tf.truncated_normal_initializer(stddev=hps.trunc_norm_init_std)

      # Add embedding matrix (shared by the encoder and decoder inputs)
      with tf.variable_scope('embedding'):
        if FLAGS.embedding:
          embedding = tf.Variable(self.embedding_place)
        else:
          embedding = tf.get_variable('embedding', [vsize, hps.emb_dim], dtype=tf.float32, initializer=self.trunc_norm_init)
        ##if hps.mode=="train": self._add_emb_vis(embedding) # add to tensorboard
        emb_enc_inputs = tf.nn.embedding_lookup(embedding, self._enc_batch) # tensor with shape (batch_size, max_enc_steps, emb_size)
        emb_dec_inputs = [tf.nn.embedding_lookup(embedding, x) for x in tf.unstack(self._dec_batch, axis=1)] # list length max_dec_steps containing shape (batch_size, emb_size)

      # Add the encoder.
      enc_outputs, fw_st, bw_st = self._add_encoder(emb_enc_inputs, self._enc_lens)
      self._enc_states = enc_outputs

      # Our encoder is bidirectional and our decoder is unidirectional so we need to reduce the final encoder hidden state to the right size to be the initial decoder hidden state
      self._dec_in_state = self._reduce_states(fw_st, bw_st)

      # Add the decoder.
      with tf.variable_scope('decoder'):
        (self.decoder_outputs, self._dec_out_state, self.attn_dists, self.p_gens, self.coverage, self.vocab_scores,
         self.final_dists, self.samples, self.greedy_search_samples, self.temporal_es,
         self.sampling_rewards, self.greedy_rewards) = self._add_decoder(emb_dec_inputs, embedding)

      if FLAGS.use_discounted_rewards and hps.rl_training and hps.mode in ['train', 'eval']:
        
        self.sampling_discounted_rewards = tf.stack(self.discount_rewards(tf.unstack(self.sampling_rewards))) # list of max_dec_steps * (batch_size, k)
        self.greedy_discounted_rewards = tf.stack(self.discount_rewards(tf.unstack(self.greedy_rewards))) # list of max_dec_steps * (batch_size, k)
      elif FLAGS.use_intermediate_rewards and hps.rl_training and hps.mode in ['train', 'eval']:
        
        self.sampling_discounted_rewards = tf.stack(self.intermediate_rewards(tf.unstack(self.sampling_rewards))) # list of max_dec_steps * (batch_size, k)
        self.greedy_discounted_rewards = tf.stack(self.intermediate_rewards(tf.unstack(self.greedy_rewards))) # list of max_dec_steps * (batch_size, k)
      elif hps.ac_training and hps.mode in ['train', 'eval']:
        # Get the sampled and greedy sentence from model output
        self.sampled_sentences = tf.transpose(tf.stack(self.samples), perm=[1,2,0]) # (batch_size, k, <=max_dec_steps) word indices
        self.greedy_search_sentences = tf.transpose(tf.stack(self.greedy_search_samples), perm=[1,2,0]) # (batch_size, k, <=max_dec_steps) word indices

    if hps.mode == "decode":
      # We run decode beam search mode one decoder step at a time
      assert len(self.final_dists)==1 # final_dists is a singleton list containing shape (batch_size, extended_vsize)
      self.final_dists = self.final_dists[0]
      topk_probs, self._topk_ids = tf.nn.top_k(self.final_dists, hps.batch_size*2) # take the k largest probs. note batch_size=beam_size in decode mode
      self._topk_log_probs = tf.log(topk_probs)

  def _add_shared_loss_op(self):
    
    with tf.variable_scope('shared_loss'):
      
      loss_per_step = [] # will be list length max_dec_steps containing shape (batch_size)
      batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size)
      for dec_step, dist in enumerate(self.final_dists):
        targets = self._target_batch[:,dec_step] # The indices of the target words. shape (batch_size)
        indices = tf.stack( (batch_nums, targets), axis=1) # shape (batch_size, 2)
        gold_probs = tf.gather_nd(dist, indices) # shape (batch_size). prob of correct words on this step
        losses = -tf.log(gold_probs)
        loss_per_step.append(losses)
      self._pgen_loss = _mask_and_avg(loss_per_step, self._dec_padding_mask)
      self.variable_summaries('pgen_loss', self._pgen_loss)
      # Adding Q-Estimation to CE loss in Actor-Critic Model
      if self._hps.ac_training:
        
        loss_per_step = [] # will be list length k each containing a list of shape <=max_dec_steps which each has the shape (batch_size)
        q_loss_per_step = [] # will be list length k each containing a list of shape <=max_dec_steps which each has the shape (batch_size)
        batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size)
        unstacked_q = tf.unstack(self._q_estimates, axis =1) # list of k with size (batch_size, <=max_dec_steps, vsize_extended)
        for sample_id in range(self._hps.k):
          loss_per_sample = [] # length <=max_dec_steps of batch_sizes
          q_loss_per_sample = [] # length <=max_dec_steps of batch_sizes
          q_val_per_sample = tf.unstack(unstacked_q[sample_id], axis =1) # list of <=max_dec_step (batch_size, vsize_extended)
          for dec_step, (dist, q_value) in enumerate(zip(self.final_dists, q_val_per_sample)):
            targets = tf.squeeze(self.samples[dec_step][:,sample_id]) # The indices of the sampled words. shape (batch_size)
            indices = tf.stack((batch_nums, targets), axis=1) # shape (batch_size, 2)
            gold_probs = tf.gather_nd(dist, indices) # shape (batch_size). prob of correct words on this step
            losses = -tf.log(gold_probs)
            dist_q_val = -tf.log(dist) * q_value
            q_losses = tf.gather_nd(dist_q_val, indices) # shape (batch_size). prob of correct words on this step
            loss_per_sample.append(losses)
            q_loss_per_sample.append(q_losses)
          loss_per_step.append(loss_per_sample)
          q_loss_per_step.append(q_loss_per_sample)
        with tf.variable_scope('reinforce_loss'):
          #### this is the actual loss
          self._rl_avg_logprobs = tf.reduce_mean([_mask_and_avg(loss_per_sample, self._dec_padding_mask) for loss_per_sample in loss_per_step])
          self._rl_loss = tf.reduce_mean([_mask_and_avg(q_loss_per_sample, self._dec_padding_mask) for q_loss_per_sample in q_loss_per_step])
          # Eq. 34 in https://arxiv.org/pdf/1805.09461.pdf
          self._reinforce_shared_loss = self._eta * self._rl_loss + (tf.constant(1.,dtype=tf.float32) - self._eta) * self._pgen_loss # equation 16 in https://arxiv.org/pdf/1705.04304.pdf
          #### the following is only for monitoring purposes
          self.variable_summaries('reinforce_avg_logprobs', self._rl_avg_logprobs)
          self.variable_summaries('reinforce_loss', self._rl_loss)
          self.variable_summaries('reinforce_shared_loss', self._reinforce_shared_loss)

      if self._hps.rl_training:
        loss_per_step = [] # will be list length max_dec_steps*k containing shape (batch_size)
        rl_loss_per_step = [] # will be list length max_dec_steps*k containing shape (batch_size)
        batch_nums = tf.range(0, limit=self._hps.batch_size) # shape (batch_size)
        self._sampled_rouges = []
        self._greedy_rouges = []
        self._reward_diff = []
        for _ in range(self._hps.k):
          if FLAGS.use_discounted_rewards or FLAGS.use_intermediate_rewards:
            self._sampled_rouges.append(self.sampling_discounted_rewards[:, :, _]) # shape (max_enc_steps, batch_size)
            self._greedy_rouges.append(self.greedy_discounted_rewards[:, :, _]) # shape (max_enc_steps, batch_size)
          else:
            # use the reward of last step, since we use the reward of the whole sentence in this case
            self._sampled_rouges.append(self.sampling_rewards[:, _]) # shape (batch_size)
            self._greedy_rouges.append(self.greedy_rewards[:, _]) # shape (batch_size)
          if FLAGS.self_critic:
            self._reward_diff.append(self._greedy_rouges[_]-self._sampled_rouges[_])
          else:
            self._reward_diff.append(self._sampled_rouges[_])
        for dec_step, dist in enumerate(self.final_dists):
          _targets = self.samples[dec_step] # The indices of the sampled words. shape (batch_size, k)
          for _k, targets in enumerate(tf.unstack(_targets,axis=1)): # list of k samples of size (batch_size)
            indices = tf.stack( (batch_nums, targets), axis=1) # shape (batch_size, 2)
            gold_probs = tf.gather_nd(dist, indices) # shape (batch_size). prob of correct words on this step
            losses = -tf.log(gold_probs)
            loss_per_step.append(losses)
            if FLAGS.use_discounted_rewards or FLAGS.use_intermediate_rewards:
              rl_losses = -tf.log(gold_probs) * self._reward_diff[_k][dec_step, :]  # positive values
            else:
              rl_losses = -tf.log(gold_probs) * self._reward_diff[_k] # positive values
            rl_loss_per_step.append(rl_losses)

        # new size: (k, max_dec_steps, batch_size)
        rl_loss_per_step = tf.unstack(
          tf.transpose(tf.reshape(rl_loss_per_step, [-1, self._hps.k, self._hps.batch_size]),perm=[1,0,2]))
        loss_per_step = tf.unstack(
          tf.transpose(tf.reshape(loss_per_step, [-1, self._hps.k, self._hps.batch_size]), perm=[1, 0, 2]))

        if FLAGS.use_intermediate_rewards:
          self._sampled_rouges = tf.reduce_sum(self._sampled_rouges, axis=1) # shape (k, batch_size)
          self._greedy_rouges = tf.reduce_sum(self._greedy_rouges, axis=1) # shape (k, batch_size)
          self._reward_diff = tf.reduce_sum(self._reward_diff, axis=1) # shape (k, batch_size)

        with tf.variable_scope('reinforce_loss'):
          self._rl_avg_logprobs = []
          self._rl_loss = []

          for _k in range(self._hps.k):
            self._rl_avg_logprobs.append(_mask_and_avg(tf.unstack(loss_per_step[_k]), self._dec_padding_mask))
            self._rl_loss.append(_mask_and_avg(tf.unstack(tf.reshape(rl_loss_per_step[_k], [self._hps.max_dec_steps, self._hps.batch_size])), self._dec_padding_mask))

          self._rl_avg_logprobs = tf.reduce_mean(self._rl_avg_logprobs)
          self._rl_loss = tf.reduce_mean(self._rl_loss)
          self._reinforce_shared_loss = self._eta * self._rl_loss + (tf.constant(1.,dtype=tf.float32) - self._eta) * self._pgen_loss
          self.variable_summaries('reinforce_avg_logprobs', self._rl_avg_logprobs)
          self.variable_summaries('reinforce_loss', self._rl_loss)
          self.variable_summaries('reinforce_sampled_r_value', tf.reduce_mean(self._sampled_rouges))
          self.variable_summaries('reinforce_greedy_r_value', tf.reduce_mean(self._greedy_rouges))
          self.variable_summaries('reinforce_r_diff', tf.reduce_mean(self._reward_diff))
          self.variable_summaries('reinforce_shared_loss', self._reinforce_shared_loss)

      if self._hps.coverage:
        with tf.variable_scope('coverage_loss'):
          self._coverage_loss = _coverage_loss(self.attn_dists, self._dec_padding_mask)
          self.variable_summaries('coverage_loss', self._coverage_loss)
        if self._hps.rl_training or self._hps.ac_training:
          with tf.variable_scope('reinforce_loss'):
            self._reinforce_cov_total_loss = self._reinforce_shared_loss + self._hps.cov_loss_wt * self._coverage_loss
            self.variable_summaries('reinforce_coverage_loss', self._reinforce_cov_total_loss)
        if self._hps.pointer_gen:
          self._pointer_cov_total_loss = self._pgen_loss + self._hps.cov_loss_wt * self._coverage_loss
          self.variable_summaries('pointer_coverage_loss', self._pointer_cov_total_loss)

  def _add_shared_train_op(self):
    
    if self._hps.rl_training or self._hps.ac_training:
      loss_to_minimize = self._reinforce_shared_loss
      if self._hps.coverage:
        loss_to_minimize = self._reinforce_cov_total_loss
    else:
      loss_to_minimize = self._pgen_loss
      if self._hps.coverage:
        loss_to_minimize = self._pointer_cov_total_loss

    tvars = tf.trainable_variables()
    gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)

    with tf.device("/gpu:{}".format(self._hps.gpu_num)):
      grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm)

    tf.summary.scalar('global_norm', global_norm)

    optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc)
    with tf.device("/gpu:{}".format(self._hps.gpu_num)):
      self._shared_train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step, name='train_step')

  def build_graph(self):
    tf.logging.info('Building graph...')
    t0 = time.time()
    self.global_step = tf.Variable(0, name='global_step', trainable=False)
    self._add_placeholders()
    with tf.device("/gpu:{}".format(self._hps.gpu_num)):
      self._add_seq2seq()
      if self._hps.mode in ['train', 'eval']:
        self._add_shared_loss_op()
      if self._hps.mode == 'train':
        self._add_shared_train_op()
      self._summaries = tf.summary.merge_all()
    t1 = time.time()
    tf.logging.info('Time to build graph: %i seconds', t1 - t0)

  def collect_dqn_transitions(self, sess, batch, step, max_art_oovs):
    

    feed_dict = self._make_feed_dict(batch)
    if self._hps.fixed_eta:
      feed_dict[self._eta] = self._hps.eta
    else:
      feed_dict[self._eta] = min(step * self._hps.eta,1.)
    if self._hps.scheduled_sampling:
      if self._hps.fixed_sampling_probability:
        feed_dict[self._sampling_probability] = self._hps.sampling_probability
      else:
        feed_dict[self._sampling_probability] = min(step * self._hps.sampling_probability,1.) # linear decay function
      ranges = [np.exp(float(step) * self._hps.alpha),np.finfo(np.float64).max] # to avoid overflow
      feed_dict[self._alpha] = np.log(ranges[np.argmin(ranges)]) # linear decay function

    vsize_extended = self._vocab.size() + max_art_oovs
    if self._hps.calculate_true_q:
      self.advantages = np.zeros((self._hps.batch_size, self._hps.k, self._hps.max_dec_steps, vsize_extended),dtype=np.float32) # (batch_size, k, <=max_dec_steps,vocab_size)
      self.q_values = np.zeros((self._hps.batch_size, self._hps.k, self._hps.max_dec_steps, vsize_extended),dtype=np.float32) # (batch_size, k, <=max_dec_steps,vocab_size)
      self.r_values = np.zeros((self._hps.batch_size, self._hps.k, self._hps.max_dec_steps, vsize_extended),dtype=np.float32) # (batch_size, k, <=max_dec_steps,vocab_size)
      self.v_values = np.zeros((self._hps.batch_size, self._hps.k, self._hps.max_dec_steps),dtype=np.float32) # (batch_size, k, <=max_dec_steps)
    else:
      self.r_values = np.zeros((self._hps.batch_size, self._hps.k, self._hps.max_dec_steps),dtype=np.float32) # (batch_size, k, <=max_dec_steps)
    to_return = {
      'sampled_sentences': self.sampled_sentences,
      'greedy_search_sentences': self.greedy_search_sentences,
      'decoder_outputs': self.decoder_outputs,
    }
    ret_dict = sess.run(to_return, feed_dict)


    _t = time.time()
    for _n, (sampled_sentence, greedy_search_sentence, target_sentence) in enumerate(zip(ret_dict['sampled_sentences'],ret_dict['greedy_search_sentences'], batch.target_batch)): # run batch_size time
      _gts = target_sentence
      for i in range(self._hps.k):
        _ss = greedy_search_sentence[i] # reward is calculated over the best action through greedy search
        if self._hps.calculate_true_q:
          # Collect Reward, Q, V, and A only when we want to train DDQN using true Q-estimation
          A, Q, V, R = self.caluclate_advantage_function(_ss, _gts, vsize_extended)
          self.r_values[_n,i,:,:] = R
          self.advantages[_n,i,:,:] = A
          self.q_values[_n,i,:,:] = Q
          self.v_values[_n,i,:] = V
        else:
          # if using DDQN estimates, we only need to calculate the reward and later on estimate Q values.
          self.r_values[_n, i, :] = self.caluclate_single_reward(_ss, _gts) # len max_dec_steps
    tf.logging.info('seconds for dqn collection: {}'.format(time.time()-_t))
    trasitions = self.prepare_dqn_transitions(self._hps, ret_dict['decoder_outputs'], ret_dict['greedy_search_sentences'], vsize_extended)
    return trasitions

  def caluclate_advantage_function(self, _ss, _gts, vsize_extended):
    

    R = np.zeros((self._hps.max_dec_steps, vsize_extended)) # shape (max_dec_steps, vocab_size)
    Q = np.zeros((self._hps.max_dec_steps, vsize_extended)) # shape (max_dec_steps, vocab_size)
    for t in range(self._hps.max_dec_steps,0,-1):
      R[t-1][:] = self.reward(t, _ss, _gts, vsize_extended)
      try:
        Q[t-1][:] = R[t-1][:] + self._hps.gamma * Q[t,:].max()
      except:
        Q[t-1][:] = R[t-1][:]

    V = np.reshape(np.mean(Q,axis=1),[-1,1])

    A = Q - V
    return A, Q, np.squeeze(V), R

  def caluclate_single_reward(self, _ss, _gts):
    

    return [self.calc_reward(t, _ss,_gts) for t in range(1,self._hps.max_dec_steps+1)]

  def prepare_dqn_transitions(self, hps, decoder_states, greedy_samples, vsize_extended):
    
    decoder_states = np.transpose(np.stack(decoder_states),[1,0,2]) # now of shape (batch_size, <=max_dec_steps, hidden_dim)
    greedy_samples = np.stack(greedy_samples) # now of shape (batch_size, <=max_dec_steps)

    dec_length = decoder_states.shape[1]
    hidden_dim = decoder_states.shape[-1]

    # modifying decoder state tensor to shape (batch_size, k, <=max_dec_steps, hidden_dim)
    _decoder_states = np.expand_dims(decoder_states, 1)
    _decoder_states = np.concatenate([_decoder_states] * hps.k, axis=1) # shape (batch_size, k, <=max_dec_steps, hidden_dim)
    features = _decoder_states # shape (batch_size, k, <=max_dec_steps, hidden_dim)

    q_func = lambda i,k,t: self.q_values[i,k,t] # (vsize_extended)
    zero_func = lambda i, k, t: np.zeros((vsize_extended))
    raction_func = lambda i,k,t,action: self.r_values[i,k,t,action]
    r_func = lambda i,k,t,action: self.r_values[i,k,t]

    if self._hps.calculate_true_q:
      pass_q_func = q_func
      pass_r_func = raction_func
    else:
      pass_q_func = zero_func
      pass_r_func = r_func

    transitions = [] # (h_t, w_t, h_{t+1}, r_t, q_t, done)
    for i in range(self._hps.batch_size):
      for k in range(self._hps.k):
        for t in range(self._hps.max_dec_steps):
          action = greedy_samples[i,k,t]
          done = (t==(self._hps.max_dec_steps-1) or action==3) # 3 is the id for [STOP] in our vocabularity to stop decoding
          state = features[i, k, t]
          if done:
            state_prime = np.zeros((features.shape[-1]))
            action_prime = 3 # 3 is the id for [STOP] in our vocabularity to stop decoding
          else:
            state_prime = features[i,k,t+1]
            action_prime = greedy_samples[i,k,t+1]
          transitions.append(Transition(state, action, state_prime, action_prime, pass_r_func(i,k,t,action), pass_q_func(i,k,t), done))

    return transitions

  def calc_reward(self, _ss, _gts): # optimizing based on ROUGE-L
    

    summary = ' '.join([str(k) for k in _ss])
    reference = ' '.join([str(k) for k in _gts])
    reward = self.reward_function(reference, summary, self._hps.reward_function)
    return reward

  def reward(self, t, _ss, _gts, vsize_extended): # shape (vocab_size)
    """ A wrapper for calculating the reward. """
    first_case = np.append(_ss[0:t],[0]) # our special character is '[UNK]' which has the id of 0 in our vocabulary
    special_reward = self.calc_reward(first_case, _gts)
    reward = [special_reward for _ in range(vsize_extended)]
    # change the ground-truth reward
    second_case = np.append(_ss[0:t],[_gts[t-1]])
    reward[_gts[t-1]] = self.calc_reward(second_case, _gts)

    return reward

  def run_train_steps(self, sess, batch, step, q_estimates=None):
    
    feed_dict = self._make_feed_dict(batch)
    if self._hps.ac_training or self._hps.rl_training:
      if self._hps.fixed_eta:
        feed_dict[self._eta] = self._hps.eta
      else:
        feed_dict[self._eta] = min(step * self._hps.eta, 1.)
    if self._hps.scheduled_sampling:
      if self._hps.fixed_sampling_probability:
        feed_dict[self._sampling_probability] = self._hps.sampling_probability
      else:
        feed_dict[self._sampling_probability] = min(step * self._hps.sampling_probability, 1.) # linear decay function
      ranges = [np.exp(float(step) * self._hps.alpha), np.finfo(np.float64).max] # to avoid overflow
      feed_dict[self._alpha] = np.log(ranges[np.argmin(ranges)]) # linear decay function
    if self._hps.ac_training:
      self.q_estimates = q_estimates
      feed_dict[self._q_estimates]= self.q_estimates
    to_return = {
        'train_op': self._shared_train_op,
        'summaries': self._summaries,
        'pgen_loss': self._pgen_loss,
        'global_step': self.global_step,
        'decoder_outputs': self.decoder_outputs
    }

    if self._hps.rl_training:
      to_return['sampled_sentence_r_values'] = self._sampled_rouges
      to_return['greedy_sentence_r_values'] = self._greedy_rouges

    if self._hps.coverage:
      to_return['coverage_loss'] = self._coverage_loss
      if self._hps.rl_training or self._hps.ac_training:
        to_return['reinforce_cov_total_loss']= self._reinforce_cov_total_loss
      if self._hps.pointer_gen:
        to_return['pointer_cov_total_loss'] = self._pointer_cov_total_loss
    if self._hps.rl_training or self._hps.ac_training:
      to_return['shared_loss']= self._reinforce_shared_loss
      to_return['rl_loss']= self._rl_loss
      to_return['rl_avg_logprobs']= self._rl_avg_logprobs

    # We feed the collected reward and feed it back to model to update the loss
    return sess.run(to_return, feed_dict)

  def run_eval_step(self, sess, batch, step, q_estimates=None):
    
    feed_dict = self._make_feed_dict(batch)
    if self._hps.ac_training or self._hps.rl_training:
      if self._hps.fixed_eta:
        feed_dict[self._eta] = self._hps.eta
      else:
        feed_dict[self._eta] = min(step * self._hps.eta,1.)
    if self._hps.scheduled_sampling:
      if self._hps.fixed_sampling_probability:
        feed_dict[self._sampling_probability] = self._hps.sampling_probability
      else:
        feed_dict[self._sampling_probability] = min(step * self._hps.sampling_probability,1.) # linear decay function
      ranges = [np.exp(float(step) * self._hps.alpha),np.finfo(np.float64).max] # to avoid overflow
      feed_dict[self._alpha] = np.log(ranges[np.argmin(ranges)]) # linear decay function
    if self._hps.ac_training:
      self.q_estimates = q_estimates
      feed_dict[self._q_estimates]= self.q_estimates
    to_return = {
        'summaries': self._summaries,
        'pgen_loss': self._pgen_loss,
        'global_step': self.global_step,
        'decoder_outputs': self.decoder_outputs
    }

    if self._hps.rl_training:
      to_return['sampled_sentence_r_values'] = self._sampled_rouges
      to_return['greedy_sentence_r_values'] = self._greedy_rouges

    if self._hps.coverage:
      to_return['coverage_loss'] = self._coverage_loss
      if self._hps.rl_training or self._hps.ac_training:
        to_return['reinforce_cov_total_loss']= self._reinforce_cov_total_loss
      if self._hps.pointer_gen:
        to_return['pointer_cov_total_loss'] = self._pointer_cov_total_loss
    if self._hps.rl_training or self._hps.ac_training:
      to_return['shared_loss']= self._reinforce_shared_loss
      to_return['rl_loss']= self._rl_loss
      to_return['rl_avg_logprobs']= self._rl_avg_logprobs

    # We feed the collected reward and feed it back to model to update the loss
    return sess.run(to_return, feed_dict)

  def run_encoder(self, sess, batch):
    
    feed_dict = self._make_feed_dict(batch, just_enc=True) # feed the batch into the placeholders
    (enc_states, dec_in_state, global_step) = sess.run([self._enc_states, self._dec_in_state, self.global_step], feed_dict) # run the encoder

    # dec_in_state is LSTMStateTuple shape ([batch_size,hidden_dim],[batch_size,hidden_dim])
    # Given that the batch is a single example repeated, dec_in_state is identical across the batch so we just take the top row.
    dec_in_state = tf.contrib.rnn.LSTMStateTuple(dec_in_state.c[0], dec_in_state.h[0])
    return enc_states, dec_in_state

  def decode_onestep(self, sess, batch, latest_tokens, enc_states, dec_init_states, prev_coverage, prev_decoder_outputs, prev_encoder_es):
    

    beam_size = len(dec_init_states)

    # Turn dec_init_states (a list of LSTMStateTuples) into a single LSTMStateTuple for the batch
    cells = [np.expand_dims(state.c, axis=0) for state in dec_init_states]
    hiddens = [np.expand_dims(state.h, axis=0) for state in dec_init_states]
    new_c = np.concatenate(cells, axis=0)  # shape [batch_size,hidden_dim]
    new_h = np.concatenate(hiddens, axis=0)  # shape [batch_size,hidden_dim]
    new_dec_in_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)

    feed = {
        self._enc_states: enc_states,
        self._enc_padding_mask: batch.enc_padding_mask,
        self._dec_in_state: new_dec_in_state,
        self._dec_batch: np.transpose(np.array([latest_tokens])),
        self._dec_padding_mask: np.ones((beam_size,1),dtype=np.float32)
    }

    to_return = {
      "ids": self._topk_ids,
      "probs": self._topk_log_probs,
      "states": self._dec_out_state,
      "attn_dists": self.attn_dists,
      "final_dists": self.final_dists
    }

    if FLAGS.pointer_gen:
      feed[self._enc_batch_extend_vocab] = batch.enc_batch_extend_vocab
      feed[self._max_art_oovs] = batch.max_art_oovs
      to_return['p_gens'] = self.p_gens

    if self._hps.coverage:
      feed[self.prev_coverage] = np.stack(prev_coverage, axis=0)
      to_return['coverage'] = self.coverage

    if FLAGS.ac_training or FLAGS.intradecoder:
      to_return['output']=self.decoder_outputs
    if FLAGS.intradecoder:
      feed[self.prev_decoder_outputs]= prev_decoder_outputs
    if FLAGS.use_temporal_attention:
      to_return['temporal_e'] = self.temporal_es
      feed[self.prev_encoder_es] = prev_encoder_es

    results = sess.run(to_return, feed_dict=feed) # run the decoder step

    new_states = [tf.contrib.rnn.LSTMStateTuple(results['states'].c[i, :], results['states'].h[i, :]) for i in range(beam_size)]

    assert len(results['attn_dists'])==1
    attn_dists = results['attn_dists'][0].tolist()
    final_dists = results['final_dists'][0].tolist()

    if FLAGS.pointer_gen:
      assert len(results['p_gens'])==1
      p_gens = results['p_gens'][0].tolist()
    else:
      p_gens = [None for _ in range(beam_size)]

    if FLAGS.ac_training or FLAGS.intradecoder:
      output = results['output'][0] # used for calculating the intradecoder at later steps and for calcualting q-estimate in Actor-Critic training.
    else:
      output = None
    if FLAGS.use_temporal_attention:
      temporal_e = results['temporal_e'][0] # used for calculating the attention at later steps
    else:
      temporal_e = None

    if FLAGS.coverage:
      new_coverage = results['coverage'].tolist()
      assert len(new_coverage) == beam_size
    else:
      new_coverage = [None for _ in range(beam_size)]

    return results['ids'], results['probs'], new_states, attn_dists, final_dists, p_gens, new_coverage, output, temporal_e

def _mask_and_avg(values, padding_mask):
  

  dec_lens = tf.reduce_sum(padding_mask, axis=1) # shape batch_size. float32
  values_per_step = [v * padding_mask[:,dec_step] for dec_step,v in enumerate(values)] # list of k
  values_per_ex = sum(values_per_step)/dec_lens # shape (batch_size); normalized value for each batch member
  return tf.reduce_mean(values_per_ex) # overall average

def _coverage_loss(attn_dists, padding_mask):
  
  coverage = tf.zeros_like(attn_dists[0]) # shape (batch_size, attn_length). Initial coverage is zero.
  covlosses = [] # Coverage loss per decoder timestep. Will be list length max_dec_steps containing shape (batch_size).
  for a in attn_dists:
    covloss = tf.reduce_sum(tf.minimum(a, coverage), [1]) # calculate the coverage loss for this step
    covlosses.append(covloss)
    coverage += a # update the coverage vector
  coverage_loss = _mask_and_avg(covlosses, padding_mask)
  return coverage_loss

In [8]:
try:
  import queue
except:
  import Queue as queue
from random import shuffle
from random import seed
seed(123)
from threading import Thread
import numpy as np
import time
import tensorflow as tf
try:
  import Queue as Q  # ver. < 3.0
except ImportError:
  import queue as Q
from sklearn.preprocessing import normalize

PriorityQueue = Q.PriorityQueue


class CustomQueue(PriorityQueue):
  
  def __init__(self, size):
    PriorityQueue.__init__(self, size)

  def clear(self):
    
    with self.mutex:
      unfinished = self.unfinished_tasks - len(self.queue)
      if unfinished <= 0:
        if unfinished < 0:
          raise ValueError('task_done() called too many times')
        self.all_tasks_done.notify_all()
      self.queue = self.queue[0:len(self.queue)/2]
      self.unfinished_tasks = unfinished + len(self.queue)
      self.not_full.notify_all()

  def isempty(self):
    with self.mutex:
      return len(self.queue) == 0

  def isfull(self):
    with self.mutex:
      return len(self.queue) == self.maxsize

class Transition(object):
  
  def __init__(self, state, action, state_prime, action_prime, reward, q_value, done):
    
    self.state = state # size: dqn_input_feature_len
    self.action = action # size: 1
    self.state_prime = state_prime # size: dqn_input_feature_len
    self.action_prime = action_prime
    self.reward = reward # size: vocab_size
    self.q_value = q_value # size: vocab_size
    self.done = done # true/false

  def __cmp__(self, item):
    
    return cmp(item.reward, self.reward) # bigger numbers have more priority

class ReplayBatch(object):
  

  def __init__(self, hps, example_list, dqn_batch_size, use_state_prime = False, max_art_oovs = 0):
    
    self._x = np.zeros((dqn_batch_size, hps.dqn_input_feature_len))
    self._y = np.zeros((dqn_batch_size, hps.vocab_size))
    self._y_extended = np.zeros((dqn_batch_size, hps.vocab_size + max_art_oovs))
    for i,e in enumerate(example_list):
      if use_state_prime:
        self._x[i,:]=e.state_prime
      else:
        self._x[i,:]=e.state
        self._y[i,:]=normalize([e.q_value[0:hps.vocab_size]], axis=1, norm='l1')
      if max_art_oovs == 0:
        self._y_extended[i,:] = normalize([e.q_value[0:hps.vocab_size]], axis=1, norm='l1')
      else:
        self._y_extended[i,:] = e.q_value

class ReplayBuffer(object):
  

  BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold

  def __init__(self, hps):
    self._hps = hps
    self._buffer = CustomQueue(self._hps.dqn_replay_buffer_size)

    self._batch_queue = queue.Queue(self.BATCH_QUEUE_MAX)
    self._example_queue = queue.Queue(self.BATCH_QUEUE_MAX * self._hps.dqn_batch_size)
    self._num_example_q_threads = 1 # num threads to fill example queue
    self._num_batch_q_threads = 1  # num threads to fill batch queue
    self._bucketing_cache_size = 100 # how many batches-worth of examples to load into cache before bucketing

    
    self._example_q_threads = []
    for _ in range(self._num_example_q_threads):
      self._example_q_threads.append(Thread(target=self.fill_example_queue))
      self._example_q_threads[-1].daemon = True
      self._example_q_threads[-1].start()
    self._batch_q_threads = []
    for _ in range(self._num_batch_q_threads):
      self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
      self._batch_q_threads[-1].daemon = True
      self._batch_q_threads[-1].start()

    
    self._watch_thread = Thread(target=self.watch_threads)
    self._watch_thread.daemon = True
    self._watch_thread.start()

  def next_batch(self):
    
    if self._batch_queue.qsize() == 0:
      tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize())
      return None

    batch = self._batch_queue.get() # get the next Batch
    return batch

  def create_batch(_hps, batch, batch_size, use_state_prime=False, max_art_oovs=0):
    

    return ReplayBatch(_hps, batch, batch_size, use_state_prime, max_art_oovs)

  def fill_example_queue(self):
    
    while True:
      try:
        input_gen = self._example_generator().next()
      except StopIteration: # if there are no more examples:
        tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
        raise Exception("single_pass mode is off but the example generator is out of data; error.")
      self._example_queue.put(input_gen) # place the pair in the example queue.

  def fill_batch_queue(self):
    while True:
      # Get bucketing_cache_size-many batches of Examples into a list, then sort
      inputs = []
      for _ in range(self._hps.dqn_batch_size * self._bucketing_cache_size):
        inputs.append(self._example_queue.get())

      self.add(inputs)

      batches = []
      for i in range(0, len(inputs), self._hps.dqn_batch_size):
        batches.append(inputs[i:i + self._hps.dqn_batch_size])
        shuffle(batches)
      for b in batches:  # each b is a list of Example objects
        self._batch_queue.put(ReplayBatch(self._hps, b, self._hps.dqn_batch_size))

  def watch_threads(self):
    while True:
      time.sleep(60)
      for idx,t in enumerate(self._example_q_threads):
        if not t.is_alive(): # if the thread is dead
          tf.logging.error('Found example queue thread dead. Restarting.')
          new_t = Thread(target=self.fill_example_queue)
          self._example_q_threads[idx] = new_t
          new_t.daemon = True
          new_t.start()
      for idx,t in enumerate(self._batch_q_threads):
        if not t.is_alive(): # if the thread is dead
          tf.logging.error('Found batch queue thread dead. Restarting.')
          new_t = Thread(target=self.fill_batch_queue)
          self._batch_q_threads[idx] = new_t
          new_t.daemon = True
          new_t.start()

  def add(self, items):
    
    for item in items:
      if not self._buffer.isfull():
        self._buffer.put_nowait(item)
      else:
        print('Replay Buffer is full, getting rid of unimportant transitions...')
        self._buffer.clear()
        self._buffer.put_nowait(item)
    print('ReplayBatch size: {}'.format(self._buffer.qsize()))
    print('ReplayBatch example queue size: {}'.format(self._example_queue.qsize()))
    print('ReplayBatch batch queue size: {}'.format(self._batch_queue.qsize()))

  def _buffer_len(self):
    return self._buffer.qsize()

  def _example_generator(self):
    while True:
      if not self._buffer.isempty():
        item = self._buffer.get_nowait()
        self._buffer.task_done()
        yield item

In [9]:
import tensorflow as tf
import numpy as np
#import data
#from replay_buffer import Transition, ReplayBuffer
from collections import Counter
from sklearn.preprocessing import normalize

FLAGS = tf.app.flags.FLAGS

class Hypothesis(object):
  

  def __init__(self, tokens, log_probs, state, decoder_output, encoder_mask, attn_dists, p_gens, coverage):
    
    self.tokens = tokens
    self.log_probs = log_probs
    self.state = state
    self.decoder_output = decoder_output
    self.encoder_mask = encoder_mask
    self.attn_dists = attn_dists
    self.p_gens = p_gens
    self.coverage = coverage

  def extend(self, token, log_prob, state, decoder_output, encoder_mask, attn_dist, p_gen, coverage):
    
    if FLAGS.avoid_trigrams and self._has_trigram(self.tokens + [token]):
        log_prob = -np.infty
    return Hypothesis(tokens = self.tokens + [token],
                      log_probs = self.log_probs + [log_prob],
                      state = state,
                      decoder_output= self.decoder_output + [decoder_output] if decoder_output is not None else [],
                      encoder_mask = self.encoder_mask + [encoder_mask] if encoder_mask is not None else [],
                      attn_dists = self.attn_dists + [attn_dist],
                      p_gens = self.p_gens + [p_gen],
                      coverage = coverage)

  def _find_ngrams(self, input_list, n):
      return zip(*[input_list[i:] for i in range(n)])

  def _has_trigram(self, tokens):
      tri_grams = self._find_ngrams(tokens, 3)
      cnt = Counter(tri_grams)
      return not all((cnt[g] == 1 for g in cnt))

  def latest_token(self):
    return self.tokens[-1]

  def log_prob(self):
    # the log probability of the hypothesis so far is the sum of the log probabilities of the tokens so far
    return sum(self.log_probs)

  def avg_log_prob(self):
    return self.log_prob / len(self.tokens)


def run_beam_search(sess, model, vocab, batch, dqn = None, dqn_sess = None, dqn_graph = None):
  
  enc_states, dec_in_state = model.run_encoder(sess, batch)
  
  hyps = [Hypothesis(tokens=[vocab.word2id(START_DECODING)],
                     log_probs=[0.0],
                     state=dec_in_state,
                     decoder_output = [np.zeros([FLAGS.dec_hidden_dim])],
                     encoder_mask = [np.zeros([batch.enc_batch.shape[1]])],
                     attn_dists=[],
                     p_gens=[],
                     coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length
                     ) for _ in range(FLAGS.beam_size)]
  results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token)

  steps = 0
  while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size:
    latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis
    latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
    states = [h.state for h in hyps] # list of current decoder states of the hypotheses
    prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None)
    decoder_outputs = np.array([h.decoder_output for h in hyps]).swapaxes(0, 1) # shape (?, batch_size, dec_hidden_dim)
    encoder_es = np.array([h.encoder_mask for h in hyps]).swapaxes(0, 1)  # shape (?, batch_size, enc_hidden_dim)
    # Run one step of the decoder to get the new info
    (topk_ids, topk_log_probs, new_states, attn_dists, final_dists, p_gens, new_coverage, decoder_output, encoder_e) = model.decode_onestep(sess=sess,
                        batch=batch,
                        latest_tokens=latest_tokens,
                        enc_states=enc_states,
                        dec_init_states=states,
                        prev_coverage=prev_coverage,
                        prev_decoder_outputs= decoder_outputs if FLAGS.intradecoder else tf.stack([], axis=0),
                        prev_encoder_es = encoder_es if FLAGS.use_temporal_attention else tf.stack([], axis=0))

    if FLAGS.ac_training:
      with dqn_graph.as_default():
        dqn_results = dqn.run_test_steps(dqn_sess, x=decoder_output)
        q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
        # we use the q_estimate of UNK token for all the OOV tokens
        q_estimates = np.concatenate([q_estimates,np.reshape(q_estimates[:,0],[-1,1])*np.ones((FLAGS.beam_size,batch.max_art_oovs))],axis=-1)
        # normalized q_estimate
        q_estimates = normalize(q_estimates, axis=1, norm='l1')
        combined_estimates = final_dists * q_estimates
        combined_estimates = normalize(combined_estimates, axis=1, norm='l1')
        # overwriting topk ids and probs
        topk_ids = np.argsort(combined_estimates,axis=-1)[:,-FLAGS.beam_size*2:][:,::-1]
        topk_probs = [combined_estimates[i,_] for i,_ in enumerate(topk_ids)]
        topk_log_probs = np.log(topk_probs)

    # Extend each hypothesis and collect them all in all_hyps
    all_hyps = []
    num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
    for i in range(num_orig_hyps):
      h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i]  # take the ith hypothesis and new decoder state info
      decoder_output_i = None
      encoder_mask_i = None
      if FLAGS.intradecoder:
        decoder_output_i = decoder_output[i]
      if FLAGS.use_temporal_attention:
        encoder_mask_i = encoder_e[i]
      for j in range(FLAGS.beam_size * 2):  # for each of the top 2*beam_size hyps:
        new_hyp = h.extend(token=topk_ids[i, j],
                           log_prob=topk_log_probs[i, j],
                           state=new_state,
                           decoder_output = decoder_output_i,
                           encoder_mask = encoder_mask_i,
                           attn_dist=attn_dist,
                           p_gen=p_gen,
                           coverage=new_coverage_i)
        all_hyps.append(new_hyp)

    # Filter and collect any hypotheses that have produced the end token.
    hyps = [] # will contain hypotheses for the next step
    for h in sort_hyps(all_hyps): # in order of most likely h
      if h.latest_token == vocab.word2id(STOP_DECODING): # if stop token is reached...
        # If this hypothesis is sufficiently long, put in results. Otherwise discard.
        if steps >= FLAGS.min_dec_steps:
          results.append(h)
      else: # hasn't reached stop token, so continue to extend this hypothesis
        hyps.append(h)
      if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
        break

    steps += 1

  

  if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
    results = hyps

  
  hyps_sorted = sort_hyps(results)

  
  return hyps_sorted[0]

def sort_hyps(hyps):
  
  return sorted(hyps, key=lambda h: h.avg_log_prob, reverse=True)

In [10]:
import tensorflow as tf
import time
import os
FLAGS = tf.app.flags.FLAGS

def get_config():
  
  config = tf.ConfigProto(allow_soft_placement=True)
  #config = tf.ConfigProto(log_device_placement=True)
  config.gpu_options.allow_growth=True
  return config

def load_ckpt(saver, sess, ckpt_dir="train"):
  
  while True:
    try:
      latest_filename = "checkpoint_best" if ckpt_dir=="eval" else None
      ckpt_dir = os.path.join(FLAGS.log_root, ckpt_dir)
      ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename)
      tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
      saver.restore(sess, ckpt_state.model_checkpoint_path)
      return ckpt_state.model_checkpoint_path
    except:
      tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10)
      time.sleep(10)

def load_dqn_ckpt(saver, sess):
  
  while True:
    try:
      ckpt_dir = os.path.join(FLAGS.log_root, "dqn", "train")
      ckpt_state = tf.train.get_checkpoint_state(ckpt_dir)
      tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
      saver.restore(sess, ckpt_state.model_checkpoint_path)
      return ckpt_state.model_checkpoint_path
    except:
      tf.logging.info("Failed to load checkpoint from %s. Sleeping for %i secs...", ckpt_dir, 10)
      time.sleep(10)

In [11]:
import os
import time
import tensorflow as tf
#import beam_search
#import data
import json
import pyrouge
#import util
import logging
#from unidecode import unidecode

FLAGS = tf.app.flags.FLAGS

SECS_UNTIL_NEW_CKPT = 60  # max number of seconds before loading new checkpoint

article = []
reference = []
summary = []

class BeamSearchDecoder(object):
  

  def __init__(self, model, batcher, vocab, dqn = None):
    
    self._model = model
    self._model.build_graph()
    self._batcher = batcher
    self._vocab = vocab
    self._saver = tf.train.Saver() # we use this to load checkpoints for decoding
    self._sess = tf.Session(config=get_config())

    if FLAGS.ac_training:
      self._dqn = dqn
      self._dqn_graph = tf.Graph()
      with self._dqn_graph.as_default():
        self._dqn.build_graph()
        self._dqn_saver = tf.train.Saver() # we use this to load checkpoints for decoding
        self._dqn_sess = tf.Session(config=get_config())
        _ = load_dqn_ckpt(self._dqn_saver, self._dqn_sess)

    
    ckpt_path = load_ckpt(self._saver, self._sess, FLAGS.decode_from)

    if FLAGS.single_pass:
      
      ckpt_name = "{}-ckpt-".format(FLAGS.decode_from) + ckpt_path.split('-')[
        -1]  # this is something of the form "ckpt-123456"
      self._decode_dir = os.path.join(FLAGS.log_root, get_decode_dir_name(ckpt_name))
    else: 
      self._decode_dir = os.path.join(FLAGS.log_root, "decode")

    
    if not os.path.exists(self._decode_dir): os.mkdir(self._decode_dir)

    if FLAGS.single_pass:
      
      self._rouge_ref_dir = os.path.join(self._decode_dir, "reference")
      if not os.path.exists(self._rouge_ref_dir): os.mkdir(self._rouge_ref_dir)
      self._rouge_dec_dir = os.path.join(self._decode_dir, "decoded")
      if not os.path.exists(self._rouge_dec_dir): os.mkdir(self._rouge_dec_dir)

  def decode(self):
    
    t0 = time.time()
    counter = FLAGS.decode_after
    while True:
      tf.reset_default_graph()
      batch = self._batcher.next_batch()  # 1 example repeated across batch
      if batch is None: # finished decoding dataset in single_pass mode
        assert FLAGS.single_pass, "Dataset exhausted, but we are not in single_pass mode"
        tf.logging.info("Decoder has finished reading dataset for single_pass.")
        tf.logging.info("Output has been saved in %s and %s. Now starting ROUGE eval...", self._rouge_ref_dir, self._rouge_dec_dir)
        results_dict = rouge_eval(self._rouge_ref_dir, self._rouge_dec_dir)
        rouge_log(results_dict, self._decode_dir)
        return

      original_article = batch.original_articles[0]  # string
      original_abstract = batch.original_abstracts[0]  # string
      original_abstract_sents = batch.original_abstracts_sents[0]  # list of strings

      article_withunks = show_art_oovs(original_article, self._vocab) # string
      abstract_withunks = show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

      
      if FLAGS.ac_training:
        best_hyp = run_beam_search(self._sess, self._model, self._vocab, batch, self._dqn, self._dqn_sess, self._dqn_graph)
      else:
        best_hyp = run_beam_search(self._sess, self._model, self._vocab, batch)
      # Extract the output ids from the hypothesis and convert back to words
      output_ids = [int(t) for t in best_hyp.tokens[1:]]
      decoded_words = outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

      try:
        fst_stop_idx = decoded_words.index(STOP_DECODING) # index of the (first) [STOP] symbol
        decoded_words = decoded_words[:fst_stop_idx]
      except ValueError:
        decoded_words = decoded_words
      decoded_output = ' '.join(decoded_words) # single string

      if FLAGS.single_pass:#me when 
        self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
        
        article.append(article_withunks)
        reference.append(abstract_withunks)
        summary.append(decoded_output)
        
        counter += 1 # this is how many examples we've decoded
        if counter == 10 :
            tf.logging.info("Counter 10 stopped.")
            return
      else:
        print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
        self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

        t1 = time.time()
        if t1-t0 > SECS_UNTIL_NEW_CKPT:
          tf.logging.info('We\'ve been decoding with same checkpoint for %i seconds. Time to load new checkpoint', t1-t0)
          _ = load_ckpt(self._saver, self._sess, FLAGS.decode_from)
          t0 = time.time()

  def remove_non_ascii(self, text):
    
    return text #str(unidecode(text))

  def write_for_rouge(self, reference_sents, decoded_words, ex_index):
    
    decoded_sents = []
    while len(decoded_words) > 0:
      try:
        fst_period_idx = decoded_words.index(".")
      except ValueError: # there is text remaining that doesn't end in "."
        fst_period_idx = len(decoded_words)
      sent = decoded_words[:fst_period_idx+1] # sentence up to and including the period
      decoded_words = decoded_words[fst_period_idx+1:] # everything else
      decoded_sents.append(' '.join(sent))

    decoded_sents = [self.remove_non_ascii(make_html_safe(w)) for w in decoded_sents]
    reference_sents = [self.remove_non_ascii(make_html_safe(w)) for w in reference_sents]

    # Write to file
    ref_file = os.path.join(self._rouge_ref_dir, "%06d_reference.txt" % ex_index)
    decoded_file = os.path.join(self._rouge_dec_dir, "%06d_decoded.txt" % ex_index)

    with open(ref_file, "w",encoding='utf-8') as f:
      for idx,sent in enumerate(reference_sents):
        f.write(sent) if idx==len(reference_sents)-1 else f.write(sent+"\n")
    with open(decoded_file, "w",encoding='utf-8') as f:
      for idx,sent in enumerate(decoded_sents):
        f.write(sent) if idx==len(decoded_sents)-1 else f.write(sent+"\n")

    tf.logging.info("Wrote example %i to file" % ex_index)

  def write_for_attnvis(self, article, abstract, decoded_words, attn_dists, p_gens):
    
    article_lst = article.split() # list of words
    decoded_lst = decoded_words # list of decoded words
    to_write = {
        'article_lst': [make_html_safe(t) for t in article_lst],
        'decoded_lst': [make_html_safe(t) for t in decoded_lst],
        'abstract_str': make_html_safe(abstract),
        'attn_dists': attn_dists
    }
    if FLAGS.pointer_gen:
      to_write['p_gens'] = p_gens
    output_fname = os.path.join(self._decode_dir, 'attn_vis_data.json')
    with open(output_fname, 'w') as output_file:
      json.dump(to_write, output_file)
    tf.logging.info('Wrote visualization data to %s', output_fname)


def print_results(article, abstract, decoded_output):
  
  print("")
  tf.logging.info('ARTICLE:  %s', article)
  tf.logging.info('REFERENCE SUMMARY: %s', abstract)
  tf.logging.info('GENERATED SUMMARY: %s', decoded_output)
  print("")


def make_html_safe(s):
  try:
    
    s.replace("<", "&lt;")
    s.replace(">", "&gt;")
  except:
    pass
  return s


def rouge_eval(ref_dir, dec_dir):
  
  r = pyrouge.Rouge155()
  r.model_filename_pattern = '#ID#_reference.txt'
  r.system_filename_pattern = '(\d+)_decoded.txt'
  r.model_dir = ref_dir
  r.system_dir = dec_dir
  logging.getLogger('global').setLevel(logging.WARNING) # silence pyrouge logging
  rouge_results = r.convert_and_evaluate()
  return r.output_to_dict(rouge_results)


def rouge_log(results_dict, dir_to_write):
  
  log_str = ""
  for x in ["1","2","l"]:
    log_str += "\nROUGE-%s:\n" % x
    for y in ["f_score", "recall", "precision"]:
      key = "rouge_%s_%s" % (x,y)
      key_cb = key + "_cb"
      key_ce = key + "_ce"
      val = results_dict[key]
      val_cb = results_dict[key_cb]
      val_ce = results_dict[key_ce]
      log_str += "%s: %.4f with confidence interval (%.4f, %.4f)\n" % (key, val, val_cb, val_ce)
  tf.logging.info(log_str) # log to screen
  results_file = os.path.join(dir_to_write, "ROUGE_results.txt")
  tf.logging.info("Writing final ROUGE results to %s...", results_file)
  with open(results_file, "w") as f:
    f.write(log_str)

def get_decode_dir_name(ckpt_name):
  
  if "train" in FLAGS.data_path: dataset = "train"
  elif "val" in FLAGS.data_path: dataset = "val"
  elif "test" in FLAGS.data_path: dataset = "test"
  else: raise ValueError("FLAGS.data_path %s should contain one of train, val or test" % (FLAGS.data_path))
  dirname = "decode_%s_%s_%imaxenc_%ibeam_%imindec_%imaxdec" % (dataset, FLAGS.decode_from, FLAGS.max_enc_steps, FLAGS.beam_size, FLAGS.min_dec_steps, FLAGS.max_dec_steps)
  if ckpt_name is not None:
    dirname += "_%s" % ckpt_name
  return dirname

In [12]:
import tensorflow as tf
#import tensorlayer as tl
import numpy as np

class DQN(object):
    def __init__(self, hps, name_variable):
        self._hps = hps
        self._name_variable = name_variable

    def variable_summaries(self, var_name, var):
        
        with tf.name_scope('summaries_{}'.format(var_name)):
            mean = tf.reduce_mean(var)
            tf.summary.scalar('mean', mean)
            with tf.name_scope('stddev'):
                stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
            tf.summary.scalar('stddev', stddev)
            tf.summary.scalar('max', tf.reduce_max(var))
            tf.summary.scalar('min', tf.reduce_min(var))
            tf.summary.histogram('histogram', var)

    def _add_placeholders(self):
        
        self._x = tf.placeholder(tf.float32, [None, self._hps.dqn_input_feature_len], name='x') # size (dataset_len, input_feature_len)
        self._y = tf.placeholder(tf.float32, [None, self._hps.vocab_size], name='y') # size (dataset_len, 1)
        self._train_step = tf.placeholder(tf.int32, None,name='train_step')

    def _make_feed_dict(self, batch):
        feed_dict = {}
        feed_dict[self._x] = batch._x
        feed_dict[self._y] = batch._y
        return feed_dict

    def _add_tf_layers(self):
        

        h = tf.layers.dense(self._x, units = self._hps.dqn_input_feature_len, activation=tf.nn.relu, name='{}_input_layer'.format(self._name_variable))
        for i, layer in enumerate(self._hps.dqn_layers.split(',')):
            h = tf.layers.dense(h, units = int(layer), activation = tf.nn.relu, name='{}_h_{}'.format(self._name_variable, i))

        self.advantage_layer = tf.layers.dense(h, units = self._hps.vocab_size, activation = tf.nn.softmax, name='{}_advantage'.format(self._name_variable))
        if self._hps.dueling_net:
            # in dueling net, we have two extra output layers; one for value function estimation
            # and the other for advantage estimation, we then use the difference between these two layers
            # to calculate the q-estimation
            self_layer = tf.layers.dense(h, units = 1, activation = tf.identity, name='{}_value'.format(self._name_variable))
            normalized_al = self.advantage_layer-tf.reshape(tf.reduce_mean(self.advantage_layer,axis=1),[-1,1]) # equation 9 in https://arxiv.org/pdf/1511.06581.pdf
            value_extended = tf.concat([self_layer] * self._hps.vocab_size, axis=1)
            self.output = value_extended + normalized_al
        else:
            self.output = self.advantage_layer

    def _add_train_op(self):
        # In regression, the objective loss is Mean Squared Error (MSE).
        self.loss = tf.losses.mean_squared_error(labels = self._y, predictions = self.output)

        tvars = tf.trainable_variables()
        gradients = tf.gradients(self.loss, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)

        # Clip the gradients
        with tf.device("/gpu:{}".format(self._hps.dqn_gpu_num)):
            grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm)

        # Add a summary
        tf.summary.scalar('global_norm', global_norm)

        # Apply adagrad optimizer
        optimizer = tf.train.AdamOptimizer(self._hps.lr)
        with tf.device("/gpu:{}".format(self._hps.dqn_gpu_num)):
            self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step, name='train_step')

        self.variable_summaries('dqn_loss',self.loss)

    def _add_update_weights_op(self):
        
        self.model_trainables = tf.trainable_variables(scope='{}_relay_network'.format(self._name_variable)) # target variables
        self._new_trainables = [tf.placeholder(tf.float32, None,name='trainables_{}'.format(i)) for i in range(len(self.model_trainables))]
        self.assign_ops = []
        if self._hps.dqn_polyak_averaging: # target parameters are slowly updating using: \phi_target = \tau * \phi_target + (1-\tau) * \phi_target
            tau = (tf.cast(self._train_step,tf.float32) % self._hps.dqn_target_update)/float(self._hps.dqn_target_update)
            for i, mt in enumerate(self.model_trainables):
                nt = self._new_trainables[i]
                self.assign_ops.append(mt.assign(tau * mt + (1-tau) * nt))
        else:
          if self._train_step % self._hps.dqn_target_update == 0:
            for i, mt in enumerate(self.model_trainables):
                nt = self._new_trainables[i]
                self.assign_ops.append(mt.assign(nt))

    def build_graph(self):
        with tf.variable_scope('{}_relay_network'.format(self._name_variable)), tf.device("/gpu:{}".format(self._hps.dqn_gpu_num)):
            self.global_step = tf.Variable(0, name='global_step', trainable=False)
            self._add_placeholders()
            self._add_tf_layers()
            self._add_train_op()
            self._add_update_weights_op()
            self._summaries = tf.summary.merge_all()

    def run_train_steps(self, sess, batch):
        feed_dict = self._make_feed_dict(batch)
        to_return = {'train_op': self.train_op,
        'summaries': self._summaries,
        'loss': self.loss,
        'global_step': self.global_step}
        return sess.run(to_return, feed_dict)

    def run_test_steps(self, sess, x, y=None, return_loss=False, return_best_action=False):
        # when return_loss is True, the model will return the loss of the prediction
        # return_loss should be False, during estimation (decoding)
        feed_dict = {self._x:x}
        to_return = {'estimates': self.output}
        if return_loss:
            feed_dict.update({self._y:y})
            to_return.update({'loss': self.loss})
        output = sess.run(to_return, feed_dict)
        if return_best_action:
            output['best_action']=np.argmax(output['estimates'],axis=1)

        return output

    def run_update_weights(self, sess, train_step, weights):
        feed_dict = {self._train_step:train_step}
        for i, w in enumerate(weights):
            feed_dict.update({self._new_trainables[i]:w})
        _ = sess.run(self.assign_ops, feed_dict)

In [13]:
from xml.etree import ElementTree
from xml.dom import minidom
from functools import reduce

def prettify(elem):
    
    rough_string = ElementTree.tostring(elem, 'utf-8')
    reparsed = minidom.parseString(rough_string)
    return reparsed.toprettyxml(indent="  ")
  
from xml.etree.ElementTree import Element, SubElement, Comment


def zaksum(article , reference , summary_array , directory_path):
  top = Element('ZakSum')

  comment = Comment('Generated without scores')
  top.append(comment)

  i=0
  for summ in summary_array:
    example = SubElement(top, 'example')
    article_element   = SubElement(example, 'article')
    article_element.text = article[i]

    reference_element = SubElement(example, 'reference')
    reference_element.text = reference[i]

    summary_element   = SubElement(example, 'summary')
    summary_element.text = summ
    i+=1
    
  with open(directory_path + "result_rlfeaasf_9_1_2019_1_46am.xml", "w") as f:
    f.write(prettify(top))

In [14]:
import time
import os
import tensorflow as tf
from collections import namedtuple
import numpy as np
from glob import glob
from tensorflow.python import debug as tf_debug
#from replay_buffer import ReplayBuffer
#from dqn import DQN
from threading import Thread
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bernoulli
import logging


#FLAGS = tf.app.flags.FLAGS
#FLAGS.remove_flag_values(FLAGS.flag_values_dict()) 


class Seq2Seq(object):

  def calc_running_avg_loss(self, loss, running_avg_loss, step, decay=0.99):
    
    if running_avg_loss == 0:  # on the first iteration just take the loss
      running_avg_loss = loss
    else:
      running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
    running_avg_loss = min(running_avg_loss, 12)  # clip
    loss_sum = tf.Summary()
    
    tag_name = 'running_avg_loss/decay=%f' % (decay)
    loss_sum.value.add(tag=tag_name, simple_value=running_avg_loss)
    self.summary_writer.add_summary(loss_sum, step)
    tf.logging.info('running_avg_loss: %f', running_avg_loss)
    return running_avg_loss

  def restore_best_model(self):
    
    tf.logging.info("Restoring bestmodel for training...")

    
    #sess = tf.InteractiveSession(config=get_config())#me
    sess = tf.Session(config=get_config())
    print("Initializing all variables...")
    sess.run(tf.initialize_all_variables())
    #tf.reset_default_graph() #me
    saver = tf.train.Saver([v for v in tf.all_variables() if "Adagrad" not in v.name])
    print("Restoring all non-adagrad variables from best model in eval dir...")
    curr_ckpt = load_ckpt(saver, sess, "eval")
    print("Restored %s." % curr_ckpt)

    new_model_name = curr_ckpt.split("/")[-1].replace("bestmodel", "model")
    new_fname = os.path.join(FLAGS.log_root, "train", new_model_name)
    print("Saving model to %s..." % (new_fname))
    log_file_handler.write("Saving model to %s..." % (new_fname)+"\n")

    new_saver = tf.train.Saver()
    new_saver.save(sess, new_fname)
    print("Saved.")
    exit()

  def restore_best_eval_model(self):
    
    best_loss = None
    best_step = None
    
    event_files = sorted(glob('{}/eval/events*'.format(FLAGS.log_root)))
    for ef in event_files:
      try:
        for e in tf.train.summary_iterator(ef):
          for v in e.summary.value:
            step = e.step
            if 'running_avg_loss/decay' in v.tag:
              running_avg_loss = v.simple_value
              if best_loss is None or running_avg_loss < best_loss:
                best_loss = running_avg_loss
                best_step = step
      except:
        continue
    tf.logging.info('resotring best loss from the current logs: {}\tstep: {}'.format(best_loss, best_step))
    return best_loss

  def convert_to_coverage_model(self):
    
    tf.logging.info("converting non-coverage model to coverage model..")

    
    sess = tf.Session(config=get_config())
    print("initializing everything...")
    sess.run(tf.global_variables_initializer())

    
    saver = tf.train.Saver([v for v in tf.global_variables() if "coverage" not in v.name and "Adagrad" not in v.name])
    print("restoring non-coverage variables...")
    curr_ckpt = load_ckpt(saver, sess)
    print("restored.")

    new_fname = curr_ckpt + '_cov_init'
    print("saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver() # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print("saved.")
    exit()

  def convert_to_reinforce_model(self):
    tf.logging.info("converting non-reinforce model to reinforce model..")

    sess = tf.Session(config=get_config())
    print("initializing everything...")
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver([v for v in tf.global_variables() if "reinforce" not in v.name and "Adagrad" not in v.name])
    print("restoring non-reinforce variables...")
    curr_ckpt = load_ckpt(saver, sess)
    print("restored.")

    # save this model and quit
    new_fname = curr_ckpt + '_rl_init'
    print("saving model to %s..." % (new_fname))
    new_saver = tf.train.Saver() # this one will save all variables that now exist
    new_saver.save(sess, new_fname)
    print("saved.")
    exit()

  def setup_training(self):
    log = logging.getLogger('tensorflow')
    log.setLevel(logging.DEBUG)

    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    fh = logging.FileHandler(default_path + 'tensorflow.log')
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    log.addHandler(fh)
    
    train_dir = os.path.join(FLAGS.log_root, "train")
    if not os.path.exists(train_dir): os.makedirs(train_dir)
    if FLAGS.ac_training:
      dqn_train_dir = os.path.join(FLAGS.log_root, "dqn", "train")
      if not os.path.exists(dqn_train_dir): os.makedirs(dqn_train_dir)
    #replaybuffer_pcl_path = os.path.join(FLAGS.log_root, "replaybuffer.pcl")
    #if not os.path.exists(dqn_target_train_dir): os.makedirs(dqn_target_train_dir)

    self.model.build_graph() # build the graph

    if FLAGS.convert_to_reinforce_model:
      assert (FLAGS.rl_training or FLAGS.ac_training), "To convert your pointer model to a reinforce model, run with convert_to_reinforce_model=True and either rl_training=True or ac_training=True"
      self.convert_to_reinforce_model()
    if FLAGS.convert_to_coverage_model:
      assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True"
      self.convert_to_coverage_model()
    if FLAGS.restore_best_model:
      self.restore_best_model()
    saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time

    # Loads pre-trained word-embedding. By default the model learns the embedding.
    if FLAGS.embedding:
      self.vocab.LoadWordEmbedding(FLAGS.embedding, FLAGS.emb_dim)
      word_vector = self.vocab.getWordEmbedding()

    self.sv = tf.train.Supervisor(logdir=train_dir,
                       is_chief=True,
                       saver=saver,
                       summary_op=None,
                       save_summaries_secs=60, # save summaries for tensorboard every 60 secs
                       save_model_secs=60, # checkpoint every 60 secs
                       global_step=self.model.global_step,
                       init_feed_dict= {self.model.embedding_place:word_vector} if FLAGS.embedding else None
                       )
    self.summary_writer = self.sv.summary_writer
    self.sess = self.sv.prepare_or_wait_for_session(config=get_config())
    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      self.dqn_graph = tf.Graph()
      with self.dqn_graph.as_default():
        self.dqn.build_graph() 
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))

        self.dqn_target.build_graph() 
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))

        dqn_saver = tf.train.Saver(max_to_keep=3) # keep 3 checkpoints at a time
        self.dqn_sv = tf.train.Supervisor(logdir=dqn_train_dir,
                           is_chief=True,
                           saver=dqn_saver,
                           summary_op=None,
                           save_summaries_secs=60, # save summaries for tensorboard every 60 secs
                           save_model_secs=60, # checkpoint every 60 secs
                           global_step=self.dqn.global_step,
                           )
        self.dqn_summary_writer = self.dqn_sv.summary_writer
        self.dqn_sess = self.dqn_sv.prepare_or_wait_for_session(config=get_config())
      
      self.replay_buffer = ReplayBuffer(self.dqn_hps)
    tf.logging.info("Preparing or waiting for session...")
    tf.logging.info("Created session.")
    try:
      self.run_training() # this is an infinite loop until interrupted
    except (KeyboardInterrupt, SystemExit):
      tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
      self.sv.stop()
      if FLAGS.ac_training:
        self.dqn_sv.stop()

  def run_training(self):
    tf.logging.info("Starting run_training")
    log_file_handler.write("starting run_training..\n")
    
    if FLAGS.debug: 
      self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
      self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

    self.train_step = 0
    if FLAGS.ac_training:
      
      tf.logging.info('Starting DQN training thread...')
      self.dqn_train_step = 0
      self.thrd_dqn_training = Thread(target=self.dqn_training)
      self.thrd_dqn_training.daemon = True
      self.thrd_dqn_training.start()

      watcher = Thread(target=self.watch_threads)
      watcher.daemon = True
      watcher.start()
    
    tf.logging.info('Starting Seq2Seq training...')
    log_file_handler.write("Starting Seq2Seq training...\n")
    while True: 
      batch = self.batcher.next_batch()
      t0=time.time()
      if FLAGS.ac_training:
        
        transitions = self.model.collect_dqn_transitions(self.sess, batch, self.train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
        tf.logging.info('Q-values collection time: {}'.format(time.time()-t0))
        
        with self.dqn_graph.as_default():
          batch_len = len(transitions)
          
          b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = False, max_art_oovs = batch.max_art_oovs)
          
          b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
          
          dqn_results = self.dqn.run_test_steps(sess=self.dqn_sess, x= b._x, return_best_action=True)
          q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
          dqn_best_action = dqn_results['best_action']
          #dqn_q_estimate_loss = dqn_results['loss']

          dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x= b_prime._x)
          q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

          q_estimates = np.concatenate([q_estimates,
            np.reshape(q_estimates[:,0],[-1,1])*np.ones((len(transitions),batch.max_art_oovs))],axis=-1)
          for i, tr in enumerate(transitions):
            if tr.done:
              q_estimates[i][tr.action] = tr.reward
            else:
              q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
          if FLAGS.dqn_scheduled_sampling:
            q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
          if not FLAGS.calculate_true_q:
            # when we are not training DDQN based on true Q-values,
            # we need to update Q-values in our transitions based on the q_estimates we collected from DQN current network.
            for trans, q_val in zip(transitions,q_estimates):
              trans.q_values = q_val # each have the size vocab_extended
          q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
        
        self.replay_buffer.add(transitions)
        
        if FLAGS.dqn_pretrain:
          tf.logging.info('RUNNNING DQN PRETRAIN: Adding data to relplay buffer only...')
          continue
        
        results = self.model.run_train_steps(self.sess, batch, self.train_step, q_estimates)
      else:
          results = self.model.run_train_steps(self.sess, batch, self.train_step)
      t1=time.time()
      
      summaries = results['summaries'] 
      self.train_step = results['global_step'] 
      tf.logging.info('seconds for training step {}: {}'.format(self.train_step, t1-t0))

      printer_helper = {}
      printer_helper['pgen_loss']= results['pgen_loss']
      if FLAGS.coverage:
        printer_helper['coverage_loss'] = results['coverage_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
        else:
          printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
      if FLAGS.rl_training or FLAGS.ac_training:
        printer_helper['shared_loss'] = results['shared_loss']
        printer_helper['rl_loss'] = results['rl_loss']
        printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
      if FLAGS.rl_training:
        printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
        printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
        printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r']
      if FLAGS.ac_training:
        printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss)>0 else 0

      for (k,v) in printer_helper.items():
        if not np.isfinite(v):
          raise Exception("{} is not finite. Stopping.".format(k))
        tf.logging.info('{}: {}\t'.format(k,v))
      tf.logging.info('-------------------------------------------')

      self.summary_writer.add_summary(summaries, self.train_step) 
      if self.train_step % 100 == 0: 
        self.summary_writer.flush()
      if FLAGS.ac_training:
        self.dqn_summary_writer.flush()
      if self.train_step > FLAGS.max_iter: break

  def dqn_training(self):
    
    try:
      while True:
        if self.dqn_train_step == FLAGS.dqn_pretrain_steps: raise SystemExit()
        _t = time.time()
        self.avg_dqn_loss = []
        avg_dqn_target_loss = []
        
        dqn_batch = self.replay_buffer.next_batch()
        if dqn_batch is None:
          tf.logging.info('replay buffer not loaded enough yet...')
          time.sleep(60)
          continue
        
        dqn_results = self.dqn.run_train_steps(self.dqn_sess, dqn_batch)
        
        dqn_target_results = self.dqn_target.run_test_steps(self.dqn_sess, x=dqn_batch._x, y=dqn_batch._y, return_loss=True)
        self.dqn_train_step = dqn_results['global_step']
        self.dqn_summary_writer.add_summary(dqn_results['summaries'], self.dqn_train_step) # write the summaries
        self.avg_dqn_loss.append(dqn_results['loss'])
        avg_dqn_target_loss.append(dqn_target_results['loss'])
        self.dqn_train_step = self.dqn_train_step + 1
        tf.logging.info('seconds for training dqn model: {}'.format(time.time()-_t))
        
        with self.dqn_graph.as_default():
          current_model_weights = self.dqn_sess.run([self.dqn.model_trainables])[0] # get weights of current model
          self.dqn_target.run_update_weights(self.dqn_sess, self.dqn_train_step, current_model_weights) # update target model weights with current model weights
        tf.logging.info('DQN loss at step {}: {}'.format(self.dqn_train_step, np.mean(self.avg_dqn_loss)))
        tf.logging.info('DQN Target loss at step {}: {}'.format(self.dqn_train_step, np.mean(avg_dqn_target_loss)))
        
        time.sleep(FLAGS.dqn_sleep_time)
    except (KeyboardInterrupt, SystemExit):
      tf.logging.info("Caught keyboard interrupt on worker. Stopping supervisor...")
      self.sv.stop()
      self.dqn_sv.stop()

  def watch_threads(self):
    
    while True:
      time.sleep(60)
      if not self.thrd_dqn_training.is_alive(): # if the thread is dead
        tf.logging.error('Found DQN Learning thread dead. Restarting.')
        self.thrd_dqn_training = Thread(target=self.dqn_training)
        self.thrd_dqn_training.daemon = True
        self.thrd_dqn_training.start()

  def run_eval(self):
    
    self.model.build_graph() 
    saver = tf.train.Saver(max_to_keep=3) 
    sess = tf.Session(config=get_config())

    if FLAGS.embedding:
      sess.run(tf.global_variables_initializer(),feed_dict={self.model.embedding_place:self.word_vector})
    eval_dir = os.path.join(FLAGS.log_root, "eval") 
    bestmodel_save_path = os.path.join(eval_dir, 'bestmodel') 
    self.summary_writer = tf.summary.FileWriter(eval_dir)

    if FLAGS.ac_training:
      tf.logging.info('DDQN building graph')
      t1 = time.time()
      dqn_graph = tf.Graph()
      with dqn_graph.as_default():
        self.dqn.build_graph() 
        tf.logging.info('building current network took {} seconds'.format(time.time()-t1))
        self.dqn_target.build_graph() 
        tf.logging.info('building target network took {} seconds'.format(time.time()-t1))
        dqn_saver = tf.train.Saver(max_to_keep=3)
        dqn_sess = tf.Session(config=get_config())
      dqn_train_step = 0
      replay_buffer = ReplayBuffer(self.dqn_hps)

    running_avg_loss = 0 
    best_loss = self.restore_best_eval_model() 
    train_step = 0

    while True:
      _ = load_ckpt(saver, sess) # load a new checkpoint
      if FLAGS.ac_training:
        _ = load_dqn_ckpt(dqn_saver, dqn_sess) # load a new checkpoint
      processed_batch = 0
      avg_losses = []
      while processed_batch < 100*FLAGS.batch_size:
        processed_batch += FLAGS.batch_size
        batch = self.batcher.next_batch() # get the next batch
        if FLAGS.ac_training:
          t0 = time.time()
          transitions = self.model.collect_dqn_transitions(sess, batch, train_step, batch.max_art_oovs) # len(batch_size * k * max_dec_steps)
          tf.logging.info('Q values collection time: {}'.format(time.time()-t0))
          with dqn_graph.as_default():
            
            batch_len = len(transitions)
            b = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            b_prime = ReplayBuffer.create_batch(self.dqn_hps, transitions,len(transitions), use_state_prime = True, max_art_oovs = batch.max_art_oovs)
            dqn_results = self.dqn.run_test_steps(sess=dqn_sess, x= b._x, return_best_action=True)
            q_estimates = dqn_results['estimates'] # shape (len(transitions), vocab_size)
            dqn_best_action = dqn_results['best_action']

            tf.logging.info('running test step on dqn_target')
            dqn_target_results = self.dqn_target.run_test_steps(dqn_sess, x= b_prime._x)
            q_vals_new_t = dqn_target_results['estimates'] # shape (len(transitions), vocab_size)

            q_estimates = np.concatenate([q_estimates,np.zeros((len(transitions),batch.max_art_oovs))],axis=-1)

            tf.logging.info('fixing the action q-estimates')
            for i, tr in enumerate(transitions):
              if tr.done:
                q_estimates[i][tr.action] = tr.reward
              else:
                q_estimates[i][tr.action] = tr.reward + FLAGS.gamma * q_vals_new_t[i][dqn_best_action[i]]
            if FLAGS.dqn_scheduled_sampling:
              tf.logging.info('scheduled sampling on q-estimates')
              q_estimates = self.scheduled_sampling(batch_len, FLAGS.sampling_probability, b._y_extended, q_estimates)
            if not FLAGS.calculate_true_q:
              for trans, q_val in zip(transitions,q_estimates):
                trans.q_values = q_val # each have the size vocab_extended
            q_estimates = np.reshape(q_estimates, [FLAGS.batch_size, FLAGS.k, FLAGS.max_dec_steps, -1]) # shape (batch_size, k, max_dec_steps, vocab_size_extended)
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step, q_estimates)
          t1=time.time()
        else:
          tf.logging.info('run eval step on seq2seq model.')
          t0=time.time()
          results = self.model.run_eval_step(sess, batch, train_step)
          t1=time.time()

        tf.logging.info('experiment: {}'.format(FLAGS.exp_name))
        tf.logging.info('processed_batch: {}, seconds for batch: {}'.format(processed_batch, t1-t0))

        printer_helper = {}
        loss = printer_helper['pgen_loss']= results['pgen_loss']
        if FLAGS.coverage:
          printer_helper['coverage_loss'] = results['coverage_loss']
          if FLAGS.rl_training or FLAGS.ac_training:
            printer_helper['rl_cov_total_loss']= results['reinforce_cov_total_loss']
          loss = printer_helper['pointer_cov_total_loss'] = results['pointer_cov_total_loss']
        if FLAGS.rl_training or FLAGS.ac_training:
          printer_helper['shared_loss'] = results['shared_loss']
          printer_helper['rl_loss'] = results['rl_loss']
          printer_helper['rl_avg_logprobs'] = results['rl_avg_logprobs']
        if FLAGS.rl_training:
          printer_helper['sampled_r'] = np.mean(results['sampled_sentence_r_values'])
          printer_helper['greedy_r'] = np.mean(results['greedy_sentence_r_values'])
          printer_helper['r_diff'] = printer_helper['greedy_r'] - printer_helper['sampled_r']
        if FLAGS.ac_training:
          printer_helper['dqn_loss'] = np.mean(self.avg_dqn_loss) if len(self.avg_dqn_loss) > 0 else 0

        for (k,v) in printer_helper.items():
          if not np.isfinite(v):
            raise Exception("{} is not finite. Stopping.".format(k))
          tf.logging.info('{}: {}\t'.format(k,v))

        # add summaries
        summaries = results['summaries']
        train_step = results['global_step']
        self.summary_writer.add_summary(summaries, train_step)

        # calculate running avg loss
        avg_losses.append(self.calc_running_avg_loss(np.asscalar(loss), running_avg_loss, train_step))
        tf.logging.info('-------------------------------------------')

      running_avg_loss = np.mean(avg_losses)
      tf.logging.info('==========================================')
      tf.logging.info('best_loss: {}\trunning_avg_loss: {}\t'.format(best_loss, running_avg_loss))
      tf.logging.info('==========================================')

      if best_loss is None or running_avg_loss < best_loss:
        tf.logging.info('Found new best model with %.3f running_avg_loss. Saving to %s', running_avg_loss, bestmodel_save_path)
        saver.save(sess, bestmodel_save_path, global_step=train_step, latest_filename='checkpoint_best')
        best_loss = running_avg_loss

      # flush the summary writer every so often
      if train_step % 100 == 0:
        self.summary_writer.flush()
      #time.sleep(600) # run eval every 10 minute

  def main(self):
    tf.reset_default_graph()

    #FLAGS.log_root = os.path.join(FLAGS.log_root, FLAGS.exp_name)
    tf.logging.set_verbosity(tf.logging.INFO) # choose what level of logging you want
    tf.logging.info('Starting seq2seq_attention in %s mode...', (FLAGS.mode))

    #flags = getattr(FLAGS,"__flags")

    if not os.path.exists(FLAGS.log_root):
      if FLAGS.mode=="train":
        os.makedirs(FLAGS.log_root)
      else:
        raise Exception("Logdir %s doesn't exist. Run in train mode to create it." % (FLAGS.log_root))

    fw = open('{}/config.txt'.format(FLAGS.log_root), 'w')
    #for k, v in flags.items():
    #  fw.write('{}\t{}\n'.format(k, v))
    #fw.close()

    self.vocab = Vocab(FLAGS.vocab_path, FLAGS.vocab_size) # create a vocabulary

    if FLAGS.mode == 'decode':
      FLAGS.batch_size = FLAGS.beam_size

    # If single_pass=True, check we're in decode mode
    if FLAGS.single_pass and FLAGS.mode!='decode':
      raise Exception("The single_pass flag should only be True in decode mode")

    
    hparam_list = ['mode', 'lr', 'gpu_num',
    #'sampled_greedy_flag', 
    'gamma', 'eta', 
    'fixed_eta', 'reward_function', 'intradecoder', 
    'use_temporal_attention', 'ac_training','rl_training', 'matrix_attention', 'calculate_true_q',
    'enc_hidden_dim', 'dec_hidden_dim', 'k', 
    'scheduled_sampling', 'sampling_probability','fixed_sampling_probability',
    'alpha', 'hard_argmax', 'greedy_scheduled_sampling',
    'adagrad_init_acc', 'rand_unif_init_mag', 
    'trunc_norm_init_std', 'max_grad_norm', 
    'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps',
    'dqn_scheduled_sampling', 'dqn_sleep_time', 'E2EBackProp',
    'coverage', 'cov_loss_wt', 'pointer_gen']
    hps_dict = {}
    
    flag_members = [attr for attr in dir(FLAGS) if not callable(getattr(FLAGS, attr)) and not attr.startswith("__")]
    for m in flag_members:
        hps_dict[m] = getattr(FLAGS, m)
        
    if FLAGS.ac_training:
      hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)})
    self.hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)
    # creating all the required parameters for DDQN model.
    if FLAGS.ac_training:
      hparam_list = ['lr', 'dqn_gpu_num', 
      'dqn_layers', 
      'dqn_replay_buffer_size', 
      'dqn_batch_size', 
      'dqn_target_update',
      'dueling_net',
      'dqn_polyak_averaging',
      'dqn_sleep_time',
      'dqn_scheduled_sampling',
      'max_grad_norm']
      hps_dict = {}
      
      flag_members = [attr for attr in dir(FLAGS) if not callable(getattr(FLAGS, attr)) and not attr.startswith("__")]
      for m in flag_members:
        hps_dict[m] = getattr(FLAGS, m)
    
      hps_dict.update({'dqn_input_feature_len':(FLAGS.dec_hidden_dim)})
      hps_dict.update({'vocab_size':self.vocab.size()})
      self.dqn_hps = namedtuple("HParams", hps_dict.keys())(**hps_dict)

    self.batcher = Batcher(FLAGS.data_path, self.vocab, self.hps, single_pass=FLAGS.single_pass, decode_after=FLAGS.decode_after)

    tf.set_random_seed(111) # a seed value for randomness

    if self.hps.mode == 'train':
      print("creating model...")
      self.model = SummarizationModel(self.hps, self.vocab)
      if FLAGS.ac_training:
        # current DQN with paramters \Psi
        self.dqn = DQN(self.dqn_hps,'current')
        # target DQN with paramters \Psi^{\prime}
        self.dqn_target = DQN(self.dqn_hps,'target')
      self.setup_training()
    elif self.hps.mode == 'eval':
      self.model = SummarizationModel(self.hps, self.vocab)
      if FLAGS.ac_training:
        self.dqn = DQN(self.dqn_hps,'current')
        self.dqn_target = DQN(self.dqn_hps,'target')
      self.run_eval()
    elif self.hps.mode == 'decode':
      decode_model_hps = self.hps  # This will be the hyperparameters for the decoder model
      decode_model_hps = self.hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries
      model = SummarizationModel(decode_model_hps, self.vocab)
      if FLAGS.ac_training:
        # We need our target DDQN network for collecting Q-estimation at each decoder step.
        dqn_target = DQN(self.dqn_hps,'target')
      else:
        dqn_target = None
      decoder = BeamSearchDecoder(model, self.batcher, self.vocab, dqn = dqn_target)
      decoder.decode() # decode indefinitely (unless single_pass=True, in which case deocde the dataset exactly once)
    else:
      raise ValueError("The 'mode' flag must be one of train/eval/decode")

  def scheduled_sampling(self, batch_size, sampling_probability, true, estimate):
    with variable_scope.variable_scope("ScheduledEmbedding"):
      # Return -1s where we do not sample, and sample_ids elsewhere
      select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
      select_sample = select_sampler.sample(sample_shape=batch_size)
      sample_ids = array_ops.where(
                  select_sample,
                  tf.range(batch_size),
                  gen_array_ops.fill([batch_size], -1))
      where_sampling = math_ops.cast(
          array_ops.where(sample_ids > -1), tf.int32)
      where_not_sampling = math_ops.cast(
          array_ops.where(sample_ids <= -1), tf.int32)
      _estimate = array_ops.gather_nd(estimate, where_sampling)
      _true = array_ops.gather_nd(true, where_not_sampling)

      base_shape = array_ops.shape(true)
      result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape)
      result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape)
      result = result1 + result2
      return result1 + result2

def main():
  seq2seq = Seq2Seq()
  seq2seq.main()

In [None]:
class flags_:
  pass
FLAGS = flags_()

default_path = ""
data_path = ""

# Where to find data
FLAGS.data_path = data_path + 'news_finished_files/chunked/train_*'
FLAGS.vocab_path = data_path + 'news_finished_files/vocab'

FLAGS.mode = 'train'
FLAGS.single_pass = False
FLAGS.decode_after = 0
FLAGS.decode_from = 'train'

FLAGS.log_root = default_path +'logs_26_5'
FLAGS.exp_name = 'intradecoder-temporalattention-withpretraining4'

FLAGS.example_queue_threads = 4
FLAGS.batch_queue_threads   = 2
FLAGS.bucketing_cache_size  = 100

FLAGS.enc_hidden_dim= 256
FLAGS.dec_hidden_dim= 256
FLAGS.emb_dim= 128
FLAGS.batch_size=25
FLAGS.max_enc_steps= 400
FLAGS.max_dec_steps= 100 
FLAGS.beam_size= 4 
FLAGS.min_dec_steps= 35 
FLAGS.max_iter= 20000 
FLAGS.vocab_size= 50000 
FLAGS.lr= 0.15
FLAGS.adagrad_init_acc= 0.1 
FLAGS.rand_unif_init_mag= 0.02
FLAGS.trunc_norm_init_std= 1e-4
FLAGS.max_grad_norm= 2.0
FLAGS.embedding= False
FLAGS.gpu_num= 0

FLAGS.pointer_gen= True
FLAGS.avoid_trigrams= True
FLAGS.share_decoder_weights= False
FLAGS.rl_training= True
FLAGS.self_critic= True
FLAGS.use_discounted_rewards= False
FLAGS.use_intermediate_rewards= False
FLAGS.convert_to_reinforce_model= True
FLAGS.intradecoder= True
FLAGS.use_temporal_attention=  True
FLAGS.matrix_attention= False#, 
FLAGS.eta= 0
FLAGS.fixed_eta= False
FLAGS.gamma= 0.99#,
FLAGS.reward_function= 'rouge_l/f_score'

FLAGS.ac_training= False
FLAGS.dqn_scheduled_sampling= True
FLAGS.dqn_layers= '512,256,128'
FLAGS.dqn_replay_buffer_size= 100000
FLAGS.dqn_batch_size= 100 
FLAGS.dqn_target_update= 10000
FLAGS.dqn_sleep_time= 2
FLAGS.dqn_gpu_num= 0
FLAGS.dueling_net= True
FLAGS.dqn_polyak_averaging= True
FLAGS.calculate_true_q= False
FLAGS.dqn_pretrain= False
FLAGS.dqn_pretrain_steps= 10000
FLAGS.scheduled_sampling= False
FLAGS.decay_function= 'linear'#,'linear#, exponential#, inv_sigmoid') 
FLAGS.sampling_probability= 0
FLAGS.fixed_sampling_probability= False
FLAGS.hard_argmax= True
FLAGS.greedy_scheduled_sampling= False
FLAGS.E2EBackProp= False
FLAGS.alpha= 1
FLAGS.k= 1
FLAGS.scheduled_sampling_final_dist= True

FLAGS.coverage= True
FLAGS.cov_loss_wt= 1.0

FLAGS.convert_to_coverage_model= True
FLAGS.restore_best_model= False

FLAGS.debug= False


seq2seq = Seq2Seq()
seq2seq.main()

In [15]:
zaksum(article,reference ,summary ,"")

In [16]:
class flags_:
  pass
FLAGS = flags_()

default_path = ""
data_path = ""

# Where to find data
FLAGS.data_path = data_path + 'news_finished_files/chunked/train_*'
FLAGS.vocab_path = data_path + 'news_finished_files/vocab'

FLAGS.mode = 'decode'
FLAGS.single_pass = False
FLAGS.decode_after = 0
FLAGS.decode_from = 'train'

FLAGS.log_root = default_path +'logs_26_5'
FLAGS.exp_name = 'intradecoder-temporalattention-withpretraining4'

FLAGS.example_queue_threads = 4
FLAGS.batch_queue_threads   = 2
FLAGS.bucketing_cache_size  = 100

FLAGS.enc_hidden_dim= 256
FLAGS.dec_hidden_dim= 256
FLAGS.emb_dim= 128
FLAGS.batch_size=25
FLAGS.max_enc_steps= 400
FLAGS.max_dec_steps= 100 
FLAGS.beam_size= 4 
FLAGS.min_dec_steps= 35 
FLAGS.max_iter= 20000 
FLAGS.vocab_size= 50000 
FLAGS.lr= 0.15
FLAGS.adagrad_init_acc= 0.1 
FLAGS.rand_unif_init_mag= 0.02
FLAGS.trunc_norm_init_std= 1e-4
FLAGS.max_grad_norm= 2.0
FLAGS.embedding= False
FLAGS.gpu_num= 0

FLAGS.pointer_gen= True
FLAGS.avoid_trigrams= True
FLAGS.share_decoder_weights= False
FLAGS.rl_training= True
FLAGS.self_critic= True
FLAGS.use_discounted_rewards= False
FLAGS.use_intermediate_rewards= False
FLAGS.convert_to_reinforce_model= True
FLAGS.intradecoder= True
FLAGS.use_temporal_attention=  True
FLAGS.matrix_attention= False#, 
FLAGS.eta= 0
FLAGS.fixed_eta= False
FLAGS.gamma= 0.99#,
FLAGS.reward_function= 'rouge_l/f_score'

FLAGS.ac_training= False
FLAGS.dqn_scheduled_sampling= True
FLAGS.dqn_layers= '512,256,128'
FLAGS.dqn_replay_buffer_size= 100000
FLAGS.dqn_batch_size= 100 
FLAGS.dqn_target_update= 10000
FLAGS.dqn_sleep_time= 2
FLAGS.dqn_gpu_num= 0
FLAGS.dueling_net= True
FLAGS.dqn_polyak_averaging= True
FLAGS.calculate_true_q= False
FLAGS.dqn_pretrain= False
FLAGS.dqn_pretrain_steps= 10000
FLAGS.scheduled_sampling= False
FLAGS.decay_function= 'linear'#,'linear#, exponential#, inv_sigmoid') 
FLAGS.sampling_probability= 0
FLAGS.fixed_sampling_probability= False
FLAGS.hard_argmax= True
FLAGS.greedy_scheduled_sampling= False
FLAGS.E2EBackProp= False
FLAGS.alpha= 1
FLAGS.k= 1
FLAGS.scheduled_sampling_final_dist= True

FLAGS.coverage= True
FLAGS.cov_loss_wt= 1.0

FLAGS.convert_to_coverage_model= True
FLAGS.restore_best_model= False

FLAGS.debug= False


seq2seq = Seq2Seq()
seq2seq.main()

INFO:tensorflow:Starting seq2seq_attention in decode mode...
max_size of vocab was specified as 50000; we now have 50000 words. Stopping reading.
Finished constructing vocabulary of 50000 total words. Last word added: â¹958
INFO:tensorflow:Building graph...
Instructions for updating:
Colocations handled automatically by placer.


  data = yaml.load(f.read()) or {}



For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
Instructions for updating:
Please use `keras.layers.Bidirectional(keras.layers.RNN(cell))`, which is equivalent to this API
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
Instructions for updating:
Use tf.cast instead.
INFO:tensorflow:batch_size 4, attn_size: 512, emb_size: 128
INFO:tensorflow:Adding attention_decoder timestep 0 of 1
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
Instructions f


INFO:tensorflow:ARTICLE:  madras . high . court . examined . medical . report . __birthmark__ . scar . found . actor . dhanush . body . mentioned . elderly . couple . claiming . parents . however . report . stated . small . superficial . mole . another . large . mole . removed . hearing . case . deferred . march . 27 .
INFO:tensorflow:REFERENCE SUMMARY: hc . gets . medical . report . of . no . __birthmark__ . on . dhanush . s . body .
INFO:tensorflow:GENERATED SUMMARY: madras . high . court . report . report over . scar . dhanush . report in . found . dhanush s . s actor . madras . s birthmark . medical . examined . examined stated . s small . medical mole . mole . large . mole large . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  karnataka . government . urged . supreme . court . declare . late . tamil . nadu . cm . j . jayalalithaa . convict . disproportionate . assets . case . state . argued . apex . court . februa

INFO:tensorflow:REFERENCE SUMMARY: uk . bans . laptops . tablets . on . flights . from . 6 . nations .
INFO:tensorflow:GENERATED SUMMARY: uk . announced . electronic . ban . mobiles . laptops . laptops over . laptops in . baggage . flights . uk . flights several . s six . s cabin . s devices . s . devices . ban imposed . mobiles imposed . flights restrictions . originating . middle . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  amid . controversy . punjab . minister . navjot . singh . sidhu . tv . commitments . cricketer . turned . politician . said . nobody . business . 6 . pm . sidhu . indicated . appearing . tv . shows . amount . breach . office . profit . law . however . punjab . cm . amarinder . singh . sought . legal . opinion . sidhu . tv . commitments . despite . minister .
INFO:tensorflow:REFERENCE SUMMARY: nobody . s . business . what . i . do . after . 6 . pm . sidhu . on . tv . shows .
INFO:tensorflow:GENE

INFO:tensorflow:GENERATED SUMMARY: man . man . altaf . six . â¹500 . notes . notes in . bank . atm . baroda . atm s . s jamnagar . s serial . s misprinted . altaf side . s chaki . named . six others . clearly . currency . value . notes written . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  newly . restored . site . jesus . christ . tomb . church . holy . sepulchre . jerusalem . old . city . unveiled . public . wednesday . comes . underwent . restoration . nine . months . cost . 4 . million . â¹26 . crore . around . 50 . experts . national . technical . university . athens . involved . restoration .
INFO:tensorflow:REFERENCE SUMMARY: jesus . tomb . to . be . unveiled . to . public . after . 4mn . restoration .
INFO:tensorflow:GENERATED SUMMARY: newly . restored . site . jesus . tomb . church . jerusalem . holy . jerusalem 2 . s . church around . s old . s cost . s city . s newly . site 4 . s million . crore . crore a

INFO:tensorflow:REFERENCE SUMMARY: never . faced . any . camp . system . in . bollywood . diljit . dosanjh .
INFO:tensorflow:GENERATED SUMMARY: never . actor . faced . never . camp . system . system over . never s . bollywood . diljit . dosanjh . diljit said diljit . s respect . s love . never said never . s actor . never known . faced known . s never . never made . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  former . manchester . city . liverpool . striker . mario . balotelli . missed . opening . two . minutes . team . nice . ligue . 1 . match . nantes . struggled . loosen . shoelaces . balotelli . walking . start . match . went . towards . bench . assistance . helped . coach . loosened . laces .
INFO:tensorflow:REFERENCE SUMMARY: player . misses . two . mins . of . match . as . he . fails . to . undo . laces .
INFO:tensorflow:GENERATED SUMMARY: manchester . former . city . liverpool . mario . balotelli . mario in .

INFO:tensorflow:REFERENCE SUMMARY: sc . directs . k . taka . to . release . 2 . 000 . cusecs . water . to . tn .
INFO:tensorflow:GENERATED SUMMARY: supreme . supreme . karnataka . supply . 2 . 000 . cauvery . cusecs . water . tamil . dispute . tamil related . s compensation . s tamil . s nadu . s tuesday . court . karnataka seeking . s states . loss . property . cauvery [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  japanese . conglomerate . softbank . backed . deal . invest . 100 . million . android . co . founder . andy . rubin . smartphone . startup . named . essential . startup . reportedly . working . smartphone . compete . apple . iphone . 7 . fall . deal . comes . apple . committed . invest . 1 . billion . softbank . 100 . billion . fund . january . year .
INFO:tensorflow:REFERENCE SUMMARY: softbank . cancels . 100mn . funding . in . android . founder . s . startup .
INFO:tensorflow:GENERATED SUMMARY: softbank . 

INFO:tensorflow:GENERATED SUMMARY: uber . hailing . startup . uber . kalanick . travis . uber in . kalanick over . kalanick s . kalanick including . uber including . s ceo . uber according . s uber . uber number . s co . uber president . s kalanick . uber [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  managing . director . idea . himanshu . kapania . said . integrating . operations . vodafone . idea . cost . combined . entity . around . â¹13 . 400 . crore . however . said . synergies . leading . cost . saving . result . net . benefit . â¹65 . 000 . crore . merged . entity . idea . vodafone . confirmed . merger . monday . form . india . largest . telecom . firm .
INFO:tensorflow:REFERENCE SUMMARY: vodafone . idea . merger . to . cost . â¹13 . 400 . crore . says . idea . md .
INFO:tensorflow:GENERATED SUMMARY: integrating . leading . integrating . operations . vodafone . idea . cost . combined . idea i . 2 . crore . 40

INFO:tensorflow:REFERENCE SUMMARY: i . might . take . ashwin . s . advice . and . hit . him . on . head . says . starc .
INFO:tensorflow:GENERATED SUMMARY: australian . pacer . looking . bowl . bowl i . ashwin . ashwin in . australia . advice . starc . ashwin said ashwin . s second . s take . s pacer . mitchell . looking said bowl . looking tapping . bowl forehead . ashwin reply . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  two . member . probe . committee . pakistan . cricket . board . pcb . travel . united . kingdom . look . pakistani . opener . nasir . jamshed . alleged . involvement . pakistan . super . league . psl . spot . fixing . scandal . according . reports . jamshed . earlier . investigated . national . crime . agency . uk . passport . withheld . authorities . granted . bail .
INFO:tensorflow:REFERENCE SUMMARY: pcb . panel . to . probe . nasir . jamshed . in . uk . over . psl . spot . fixing .
INFO:tensorf

INFO:tensorflow:REFERENCE SUMMARY: killing . of . maoists . witnessed . 150 . increase . in . 2016 . centre .
INFO:tensorflow:GENERATED SUMMARY: maoists . number . killed . forces . 150 . jump . centre . 150 s . 150 in . s . maoists . security . security compared . s compared . maoists compared . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  cbi . registered . fir . unknown . persons . allegedly . using . national . emblem . name . prime . minister . office . dupe . people . pretext . giving . houses . poor . families . case . filed . basis . reference . prime . minister . office . attached . complaint . village . head . jharkhand .
INFO:tensorflow:REFERENCE SUMMARY: fir . filed . over . fraud . using . pm . s . office . name . to . dupe . people .
INFO:tensorflow:GENERATED SUMMARY: cbi . registered . fir . using . emblem . cbi . name . name in . name 2 . complaint . name filed . s case . s prime . fir pretext . s pers

INFO:tensorflow:GENERATED SUMMARY: instances . tuesday . instances . custodial . maharashtra . deaths . maharashtra i . little . states . 35 . 35 however . s compared . s states . s instances . instances said instances . s . bjp . custodial said singh . mumbai . police . maharashtra however . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  tibetan . religious . leader . dalai . lama . tuesday . said . global . terrorist . crisis . fast . reaching . nuclear . threshold . crosses . threshold . could . mean . entire . mankind . could . perish . added . dalai . lama . also . said . terrorism . crisis . can . not . resolved . unless . nations . come . together . single . platform .
INFO:tensorflow:REFERENCE SUMMARY: global . terrorism . fast . reaching . nuclear . threshold . dalai . lama .
INFO:tensorflow:GENERATED SUMMARY: terrorist . s . global . terrorist . crisis . crisis i . i . fast . nuclear . threshold . lama . dalai


INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  indian . captain . virat . kohli . accused . australian . players . trolling . indian . team . physio . patrick . farhart . australian . daily . called . kohli . donald . trump . world . sport . like . president . trump . kohli . decided . blame . media . means . trying . hide . egg . smeared . right . across . face . article . added .
INFO:tensorflow:REFERENCE SUMMARY: kohli . has . become . the . donald . trump . of . world . sport . aus . media .
INFO:tensorflow:GENERATED SUMMARY: indian . indian . virat . kohli . players . trolling . players over . physio . team . trump . daily . kohli pak . s decided . s australian . s kohli . s accused . virat decided . kohli media . s trying . hide . egg [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  actor . ranbir . kapoor . seen . playing . sanjay . dutt . biopic . said .

INFO:tensorflow:GENERATED SUMMARY: contract . actor . contract . killers . romantic . romance . help . guns . romance i . romance report . killer . nawazuddin . s actor . s many . s also . contract said contract . s . contract actor . killers played . romantic gangs . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  american . express . biggest . credit . card . issuer . customer . spending . paid . ceo . ken . chenault . 22 . million . nearly . â¹144 . crore . work . 2016 . 19 . increase . compensation . puts . par . goldman . sachs . ceo . lloyd . blankfein . notably . __amex__ . counts . warren . buffettâs . berkshire . hathaway . biggest . shareholder . climbed . 6 . 5 . 2016 . tumbling . 25 . previous . year .
INFO:tensorflow:REFERENCE SUMMARY: american . express . raises . ceo . s . pay . to . â¹144 . crore . in . a . year .
INFO:tensorflow:GENERATED SUMMARY: express . american . biggest . credit . card . issuer

INFO:tensorflow:REFERENCE SUMMARY: govt . proposes . ban . on . cash . transactions . above . â¹2 . lakh .
INFO:tensorflow:GENERATED SUMMARY: centre . centre . lower . limit . cash . â¹2 . transactions . lakh . â¹3 . lakh sc . s . lakh tweeted . money . centre earlier . s union . s tuesday . proposed . lower tweeted . limit tweeted tweeted . cash tweeted . lakh [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  railway . track . designed . run . 19 . storey . residential . building . city . chongqing . china . __liziba__ . station . located . building . sixth . eighth . floor . built . combat . lack . space . city . station . noise . insulation . system . trains . reportedly . run . rubber . tyres . air . suspension . units .
INFO:tensorflow:REFERENCE SUMMARY: china . has . a . train . track . passing . through . an . apartment .
INFO:tensorflow:GENERATED SUMMARY: railway . railway . run . 19 . storey . building . china 

INFO:tensorflow:REFERENCE SUMMARY: us . bans . devices . larger . than . mobiles . on . some . flights . reports .
INFO:tensorflow:GENERATED SUMMARY: united . united . travellers . 10 . 10 over . airports . saudi . north . middle . north s . north i . qatar . s east . s carrying . s devices . banned . banned travelling . to . egypt . turkey . saudi arabia . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  supreme . court . tuesday . pulled . centre . rbi . demonetisation . measures . taken . asking . give . window . citizens . deposit . old . currency . notes . december . 31 . apex . court . gave . central . government . reserve . bank . india . two . weeks . time . reply . notice .
INFO:tensorflow:REFERENCE SUMMARY: why . no . deposit . window . for . old . notes . after . dec . 31 . sc . to . govt .
INFO:tensorflow:GENERATED SUMMARY: supreme . court . pulled . centre . rbi . demonetisation . rbi over . demonetisation in

INFO:tensorflow:REFERENCE SUMMARY: isro . commissions . world . s . 3rd . largest . hypersonic . wind . tunnel .
INFO:tensorflow:GENERATED SUMMARY: isro . isro . commissioned . tunnel . tunnel over . tunnel in . tunnel 2 . tunnel s . tunnel creating . s third . isro aerospace . s tunnel . s hypersonic . s commissioned . isro atmosphere . s earth . tunnel high . shock . tunnel [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  sri . lanka . cricket . monday . declared . dead . local . newspaper . published . obituary . reading . affectionate . remembrance . sri . lankan . cricket . died . oval . 19 . march . 2017 . body . cremated . ashes . taken . bangladesh . similar . satirical . obituary . english . cricket . 1882 . gave . birth . ashes . series .
INFO:tensorflow:REFERENCE SUMMARY: newspaper . writes . ashes . like . obituary . for . sri . lanka . cricket .
INFO:tensorflow:GENERATED SUMMARY: sri . sri . declared . dead .


INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  actor . tom . cruise . training . stunt . upcoming . film . mission . impossible . 6 . year . revealed . producer . david . ellison . going . unbelievable . thing . tom . cruise . done . movie . added . david . earlier . tom . filmed . stunt . top . 2 . 717 . feet . tall . burj . khalifa . world . tallest . skyscraper .
INFO:tensorflow:REFERENCE SUMMARY: tom . cruise . training . for . mission . impossible . 6 . stunt . for . a . yr .
INFO:tensorflow:GENERATED SUMMARY: tom . actor . cruise . tom . stunt . mission . impossible . impossible in . 6 . 6 s . impossible revealed . s upcoming . s year . training . s . cruise revealed . tom revealed . producer . david . ellison . unbelievable . tom [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  iit . kanpur . annual . technological . entrepreneurial . festival . __techkri


INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  25 . year . old . man . arrested . ghazipur . posting . objectionable . picture . uttar . pradesh . chief . minister . yogi . adityanath . facebook . police . said . monday . badshah . abdul . razak . forged . identity . create . facebook . account . uploaded . image . tension . intensified . hindu . yuva . vahini . members . gathered . protest . act .
INFO:tensorflow:REFERENCE SUMMARY: man . arrested . for . posting . !!__âobjectionable__!! . !!__picâ__!! . of . up . cm . on . fb .
INFO:tensorflow:GENERATED SUMMARY: 25 . year . old . man . ghazipur . objectionable . objectionable over . objectionable s . picture . facebook . yogi . facebook said uttar . s pradesh . s minister . s . arrested . man said facebook . s image . tension . intensified . hindu . yogi [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  acco


INFO:tensorflow:ARTICLE:  madhya . pradesh . man . arrested . lodging . false . complaint . via . twitter . indian . railways . personnel . board . bhind . indore . __intercity__ . express . monday . tweet . railway . minister . suresh . prabhu . railway . protection . force . sub . inspector . boarded . train . look . issue . found . intoxicated . tweeted . fake . complaint .
INFO:tensorflow:REFERENCE SUMMARY: man . arrested . for . fake . tweet . complaint . to . railway . minister .
INFO:tensorflow:GENERATED SUMMARY: madhya . pradesh . arrested . man . false . complaint . twitter . railways . railways s . twitter tweeted . railways tweeted . s via . s indian . s railway . man protection . s protection . lodging . arrested inspector . train . look . twitter [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  two . sufi . clerics . hazrat . nizamuddin . dargah . delhi . gone . missing . pakistan . last . week . returned . 

INFO:tensorflow:GENERATED SUMMARY: supreme . supreme . karnataka . supply . 2 . 000 . cauvery . cusecs . water . tamil . dispute . tamil related . s compensation . s tamil . s nadu . s tuesday . court . karnataka seeking . s states . loss . property . cauvery [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  tibetan . religious . leader . dalai . lama . tuesday . said . global . terrorist . crisis . fast . reaching . nuclear . threshold . crosses . threshold . could . mean . entire . mankind . could . perish . added . dalai . lama . also . said . terrorism . crisis . can . not . resolved . unless . nations . come . together . single . platform .
INFO:tensorflow:REFERENCE SUMMARY: global . terrorism . fast . reaching . nuclear . threshold . dalai . lama .
INFO:tensorflow:GENERATED SUMMARY: terrorist . s . global . terrorist . crisis . crisis i . i . fast . nuclear . threshold . lama . dalai . lama said lama . s threshold . 


INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  indian . captain . virat . kohli . accused . australian . players . trolling . indian . team . physio . patrick . farhart . australian . daily . called . kohli . donald . trump . world . sport . like . president . trump . kohli . decided . blame . media . means . trying . hide . egg . smeared . right . across . face . article . added .
INFO:tensorflow:REFERENCE SUMMARY: kohli . has . become . the . donald . trump . of . world . sport . aus . media .
INFO:tensorflow:GENERATED SUMMARY: indian . indian . virat . kohli . players . trolling . players over . physio . team . trump . daily . kohli pak . s decided . s australian . s kohli . s accused . virat decided . kohli media . s trying . hide . egg [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  actor . ranbir . kapoor . seen . playing . sanjay . dutt . biopic . said .


INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  kitchen . gadget . company . __kitchenaid__ . accused . sexism . posting . ad . pink . appliances . aimed . women . ad . featured . words . __kitchenaid__ . women . alongside . pink . products . part . limited . edition . range . raise . money . breast . cancer . charity . __kitchenaid__ . however . stated . colour . used . symbol . hope .
INFO:tensorflow:REFERENCE SUMMARY: firm . accused . of . sexism . for . its . pink . kitchen . gadgets . ad .
INFO:tensorflow:GENERATED SUMMARY: kitchen . gadget . kitchenaid . sexism . ad . ad in . pink . appliances . ad 2 . s . women . ad however . s aimed . s pink . s kitchenaid . gadget edition . s raise . money . breast . cancer . charity . ad stated . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  singer . sukhwinder . singh . said . get . married . soon . jump . well . di

INFO:tensorflow:GENERATED SUMMARY: apple . major . ceo . apple . cook . cook in . billion . dollar . cook 2 . s . china . ofo . s tuesday . apple provides . s tim . s technology . apple 2016 . to . apple said apple . didi . chuxing . apple [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json

INFO:tensorflow:ARTICLE:  founders . 170 . startups . including . paytm . ola . written . letter . home . minister . rajnath . singh . stayzilla . row . co . founder . yogendra . vasupal . arrested . unpaid . dues . founders . opposed . arrest . called . free . fair . investigation . matter . notably . another . stayzilla . co . founder . received . threats . son . life .
INFO:tensorflow:REFERENCE SUMMARY: 170 . founders . write . open . letter . to . home . minister . on . stayzilla .
INFO:tensorflow:GENERATED SUMMARY: founders . founders . startups . paytm . ola . ola in . letter . rajnath . home . rajnath in . rajnath notably . s minister . s written . s home .

INFO:tensorflow:REFERENCE SUMMARY: 77 . cases . of . sedition . registered . in . 2014 . 2015 . centre .
INFO:tensorflow:GENERATED SUMMARY: 77 . many . cases . sedition . parts . parts i . 30 . 2014 . home . bihar . hansraj . home i . s country . s registered . s cases . cases said 77 . 77 . jharkhand . highest . 18 . 30 witnessed . [unk]

INFO:tensorflow:Wrote visualization data to logs_26_5\decode\attn_vis_data.json


KeyboardInterrupt: 