Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions tensor2tensor/data_generators/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from __future__ import division
from __future__ import print_function

import hashlib
import io
import os
import tarfile
import hashlib

# Dependency imports

Expand All @@ -46,7 +47,7 @@
# Train/Dev/Test Splits for summarization data
_TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
_TEST_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt"


# End-of-sentence marker.
Expand Down Expand Up @@ -128,9 +129,7 @@ def generate_hash(inp):

return filelist


def example_generator(tmp_dir, is_training, sum_token):
"""Generate examples."""
def example_generator(all_files, urls_path, sum_token):
def fix_run_on_sents(line):
if u"@highlight" in line:
return line
Expand All @@ -140,7 +139,6 @@ def fix_run_on_sents(line):
return line
return line + u"."

all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
filelist = example_splits(urls_path, all_files)
story_summary_split_token = u" <summary> " if sum_token else " "

Expand Down Expand Up @@ -170,13 +168,29 @@ def fix_run_on_sents(line):

yield " ".join(story) + story_summary_split_token + " ".join(summary)


def _story_summary_split(story):
split_str = u" <summary> "
split_str_len = len(split_str)
split_pos = story.find(split_str)
return story[:split_pos], story[split_pos+split_str_len:] # story, summary

def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training):
def write_to_file(all_files, urls_path, data_dir, filename):
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
fstory.write(story+"\n")
fsummary.write(summary+"\n")

filename = "cnndm.train" if is_training else "cnndm.dev"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, urls_path, data_dir, filename)

if not is_training:
test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS)
filename = "cnndm.test"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, test_urls_path, data_dir, filename)

@registry.register_problem
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
Expand Down Expand Up @@ -219,10 +233,12 @@ def use_train_shards_for_dev(self):
return False

def generator(self, data_dir, tmp_dir, is_training):
all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
encoder = generator_utils.get_or_generate_vocab_inner(
data_dir, self.vocab_file, self.targeted_vocab_size,
example_generator(tmp_dir, is_training, sum_token=False))
for example in example_generator(tmp_dir, is_training, sum_token=True):
example_generator(all_files, urls_path, sum_token=False))
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training)
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
encoded_summary = encoder.encode(summary) + [EOS]
encoded_story = encoder.encode(story) + [EOS]
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def log_decode_results(inputs,

decoded_targets = None
if identity_output:
decoded_outputs = " ".join(map(str, outputs.flatten()))
decoded_outputs = "".join(map(str, outputs.flatten()))
if targets is not None:
decoded_targets = " ".join(map(str, targets.flatten()))
decoded_targets = "".join(map(str, targets.flatten()))
else:
decoded_outputs = targets_vocab.decode(_save_until_eos(outputs, is_image))
if targets is not None:
Expand Down
16 changes: 16 additions & 0 deletions tensor2tensor/utils/get_cnndm_rouge.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

# Path to moses dir
mosesdecoder=$1

# Path to file containing gold summaries, one per line
targets_file=$2
# Path to file containing model generated summaries, one per line
decodes_file=$3

# Tokenize.
perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $targets_file > $targets_file.tok
perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $decodes_file > $decodes_file.tok

# Get rouge scores
python get_rouge.py --decodes_filename $decodes_file.tok --targets_filename $targets_file.tok
88 changes: 88 additions & 0 deletions tensor2tensor/utils/get_rouge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Computing rouge scores using pyrouge."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import shutil
from tempfile import mkdtemp
from pprint import pprint

# Dependency imports
from pyrouge import Rouge155

import numpy as np
import tensorflow as tf

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string("decodes_filename", None, "File containing model generated summaries tokenized")
tf.flags.DEFINE_string("targets_filename", None, "File containing model target summaries tokenized")

def write_to_file(filename, data):
data = ".\n".join(data.split(". "))
with open(filename, "w") as fp:
fp.write(data)

def prep_data(decode_dir, target_dir):
with open(FLAGS.decodes_filename, "rb") as fdecodes, open(FLAGS.targets_filename, "rb") as ftargets:
for i, (d, t) in enumerate(zip(fdecodes, ftargets)):
write_to_file(os.path.join(decode_dir, "rouge.%06d.txt" % (i+1)), d)
write_to_file(os.path.join(target_dir, "rouge.A.%06d.txt" % (i+1)), t)

if (i+1 % 1000) == 0:
tf.logging.into("Written %d examples to file" % i)

def main(_):
rouge = Rouge155()
rouge.log.setLevel(logging.ERROR)
rouge.system_filename_pattern = "rouge.(\d+).txt"
rouge.model_filename_pattern = "rouge.[A-Z].#ID#.txt"

tf.logging.set_verbosity(tf.logging.INFO)

tmpdir = mkdtemp()
tf.logging.info("tmpdir: %s" % tmpdir)
# system = decodes/predictions
system_dir = os.path.join(tmpdir, 'system')
# model = targets/gold
model_dir = os.path.join(tmpdir, 'model')
os.mkdir(system_dir)
os.mkdir(model_dir)

rouge.system_dir = system_dir
rouge.model_dir = model_dir

prep_data(rouge.system_dir, rouge.model_dir)

rouge_scores = rouge.convert_and_evaluate()
rouge_scores = rouge.output_to_dict(rouge_scores)
for prefix in ["rouge_1", "rouge_2", "rouge_l"]:
for suffix in ["f_score", "precision", "recall"]:
key = "_".join([prefix, suffix])
tf.logging.info("%s: %.4f" % (key, rouge_scores[key]))

# clean up after pyrouge
shutil.rmtree(tmpdir)
shutil.rmtree(rouge._config_dir)
shutil.rmtree(os.path.split(rouge._system_dir)[0])

if __name__=='__main__':
tf.app.run()