Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified tensor2tensor/bin/t2t-datagen
100644 → 100755
Empty file.
Empty file modified tensor2tensor/bin/t2t-trainer
100644 → 100755
Empty file.
7 changes: 4 additions & 3 deletions tensor2tensor/data_generators/generator_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,19 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
return vocab

# Use Tokenizer to count the word occurrences.
token_counts = defaultdict(int)
filepath = os.path.join(tmp_dir, source_filename)
with tf.gfile.GFile(filepath, mode="r") as source_file:
for line in source_file:
line = line.strip()
if line and "\t" in line:
parts = line.split("\t", maxsplit=1)
part = parts[index].strip()
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
token_counts[tok] += 1

vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
vocab_size, tokenizer.token_counts, 1,
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
vocab_size, token_counts, 1, 1e3)
vocab.store_to_file(vocab_filepath)
return vocab

Expand Down
48 changes: 19 additions & 29 deletions tensor2tensor/data_generators/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from __future__ import print_function

from collections import defaultdict
import re

# Dependency imports

Expand Down Expand Up @@ -225,6 +226,7 @@ class SubwordTextEncoder(TextEncoder):

def __init__(self, filename=None):
"""Initialize and read from a file, if provided."""
self._alphabet = set()
if filename is not None:
self._load_from_file(filename)
super(SubwordTextEncoder, self).__init__(num_reserved_ids=None)
Expand Down Expand Up @@ -503,6 +505,12 @@ def _escape_token(self, token):
ret += u"\\%d;" % ord(c)
return ret

# Regular expression for unescaping token strings
# '\u' is converted to '_'
# '\\' is converted to '\'
# '\213;' is converted to unichr(213)
_UNESCAPE_REGEX = re.compile(u'|'.join([r"\\u", r"\\\\", r"\\([0-9]+);"]))

def _unescape_token(self, escaped_token):
"""Inverse of _escape_token().

Expand All @@ -511,32 +519,14 @@ def _unescape_token(self, escaped_token):
Returns:
token: a unicode string
"""
ret = u""
escaped_token = escaped_token[:-1]
pos = 0
while pos < len(escaped_token):
c = escaped_token[pos]
if c == "\\":
pos += 1
if pos >= len(escaped_token):
break
c = escaped_token[pos]
if c == u"u":
ret += u"_"
pos += 1
elif c == "\\":
ret += u"\\"
pos += 1
else:
semicolon_pos = escaped_token.find(u";", pos)
if semicolon_pos == -1:
continue
try:
ret += unichr(int(escaped_token[pos:semicolon_pos]))
pos = semicolon_pos + 1
except (ValueError, OverflowError) as _:
pass
else:
ret += c
pos += 1
return ret
def match(m):
if m.group(1) is not None:
# Convert '\213;' to unichr(213)
try:
return unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return ""
# Convert '\u' to '_' and '\\' to '\'
return u"_" if m.group(0) == u"\\u" else u"\\"
# Cut off the trailing underscore and apply the regex substitution
return self._UNESCAPE_REGEX.sub(match, escaped_token[:-1])
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/tokenizer_test.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# coding=utf-8
"""Tests for tensor2tensor.data_generators.tokenizer."""

from __future__ import absolute_import
Expand Down
9 changes: 6 additions & 3 deletions tensor2tensor/utils/trainer_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def decode_from_dataset(estimator):
tf.logging.info("Performing local inference.")
infer_problems_data = get_datasets_for_mode(hparams.data_dir,
tf.contrib.learn.ModeKeys.INFER)

infer_input_fn = get_input_fn(
mode=tf.contrib.learn.ModeKeys.INFER,
hparams=hparams,
Expand Down Expand Up @@ -625,9 +626,11 @@ def log_fn(inputs,

# The function predict() returns an iterable over the network's
# predictions from the test input. We use it to log inputs and decodes.
for j, result in enumerate(result_iter):
inputs, targets, outputs = (result["inputs"], result["targets"],
result["outputs"])
inputs_iter = result_iter["inputs"]
targets_iter = result_iter["targets"]
outputs_iter = result_iter["outputs"]
for j, result in enumerate(zip(inputs_iter, targets_iter, outputs_iter)):
inputs, targets, outputs = result
if FLAGS.decode_return_beams:
output_beams = np.split(outputs, FLAGS.decode_beam_size, axis=0)
for k, beam in enumerate(output_beams):
Expand Down