Skip to content

Commit

Permalink
Merge pull request #670 from ufal/wordpieces
Browse files Browse the repository at this point in the history
Wordpiece implementation
  • Loading branch information
jlibovicky committed Mar 14, 2018
2 parents 07ec99c + 86ea5a1 commit b95b8de
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 1 deletion.
122 changes: 122 additions & 0 deletions neuralmonkey/processors/wordpiece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Loose reimplementation of the t2t tokenizer.
Original code:
https://github.com/tensorflow/tensor2tensor/blob/v1.5.5/tensor2tensor/data_generators/tokenizer.py
Provides a WordpiecePreprocessor, a higher order function which takes a
vocabulary object and returns a preprocessor, and a WordpiecePostprocessor.
Note that the latter is not a higher order function and can be used directly
without making a new section in the configuration.
"""
from typing import List, Callable, Set
import re

from typeguard import check_argument_types
from neuralmonkey.vocabulary import Vocabulary


UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")


def escape_token(token: str, alphabet: Set[str]) -> str:
"""Escapes the token in the t2t fasion.
Underscores are regarded as an end of a token, so they must be escaped.
Additionally, they/we escape also the OOA (out-of-alphabet) characters
using their unicode code.
"""
esc_token = token.replace("\\", "\\\\") # replace 1 backslash with 2
esc_token = esc_token.replace("_", "\\u") # replace underscore with "\u"

# replace OOA symbol `s` with \1234; where 1234 is `ord(s)`
characters = [c if c in alphabet and c != "\n" else "\\{};".format(ord(c))
for c in token] # not sure about the "\n"-part

return "".join(characters) + "_"


def unescape_token(escaped_token: str) -> str:
"""Inverse function for escape_token."""

# Ends with underscore -> remove it
token = escaped_token
token = token[:-1] if token.endswith("_") else token

def match(m):
if m.group(1) is None:
return "_" if m.group(0) == "\\u" else "\\"

try:
return chr(int(m.group(1)))
except (ValueError, OverflowError):
return u"\u3013" # Unicode for undefined character.

# The substitution works because of the left-to-right nature of matching
return UNESCAPE_REGEX.sub(match, token)


def wordpiece_encode(sentence: List[str], vocabulary: Vocabulary) -> List[str]:
"""Convert tokens to subtokens using a vocabulary of subtokens.
A greedy implementation, as in t2t referenced above.
We search for the longest subtoken available in the vocabulary from left to
right.
"""
tokens = []
for token in sentence:
esc_token = escape_token(token, vocabulary.alphabet)

subtokens = []
current_subtoken_start = 0
token_len = len(esc_token)

while current_subtoken_start < len(esc_token):

# TODO: they optimize this by ranging from
# min(token_len, max_subtoken_len + start)
# this can be achieved by saving the len of longest word in vocab
for end in range(token_len, current_subtoken_start, -1):
subtoken = esc_token[current_subtoken_start:end]

if subtoken in vocabulary.word_to_index:
subtokens.append(subtoken)
current_subtoken_start = end
break
else: # executed if the loop is not exited by the break statement
raise AssertionError(
"No token substring found in the vocab ({})."
.format(esc_token[current_subtoken_start:]))

# TODO: they also optimize this by caching the segmentation of the
# escaped tokens.
tokens.extend(subtokens)
return tokens


def wordpiece_decode(sentence: List[str]) -> List[str]:
"""Postprocess the wordpieces into a sentence.
First, retokenize the sentence - join and split around underscores.
Second, unescape tokens throwing away any empty tokens encountered.
"""
retokenized = "".join(sentence).split("_")
unescaped = [unescape_token(tok) for tok in retokenized if tok]
return [tok for tok in unescaped if tok]


def wordpiece_decode_batch(sentences: List[List[str]]) -> List[List[str]]:
return [wordpiece_decode(s) for s in sentences]


def get_wordpiece_preprocessor(
vocabulary: Vocabulary) -> Callable[[List[str]], List[str]]:
check_argument_types()
return lambda s: wordpiece_encode(s, vocabulary)


# pylint: disable=invalid-name
# Syntactic sugar for configuration
WordpiecePreprocessor = get_wordpiece_preprocessor
WordpiecePostprocessor = wordpiece_decode_batch
48 changes: 48 additions & 0 deletions neuralmonkey/readers/plain_text_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import csv
import io
import sys
import unicodedata

from neuralmonkey.logging import warn

Expand Down Expand Up @@ -40,6 +41,52 @@ def reader(files: List[str]) -> Iterable[List[str]]:
return reader


def t2t_tokenized_text_reader(encoding: str = "utf-8") -> PlainTextFileReader:
"""Get a tokenizing reader for plain text.
Tokenization is inspired by the tensor2tensor tokenizer:
https://github.com/tensorflow/tensor2tensor/blob/v1.5.5/tensor2tensor/data_generators/text_encoder.py
The text is split to groups of consecutive alphanumeric or non-alphanumeric
tokens, dropping single spaces inside the text. Basically the goal here is
to preserve the whitespace around weird characters and whitespace on weird
positions (beginning and end of the text).
"""
alphanumeric_charset = set(
chr(i) for i in range(sys.maxunicode)
if (unicodedata.category(chr(i)).startswith("L") or
unicodedata.category(chr(i)).startswith("N")))

def reader(files: List[str]) -> Iterable[List[str]]:
lines = string_reader(encoding)
for line in lines(files):
if not line:
yield []

tokens = []
is_alnum = [ch in alphanumeric_charset for ch in line]
current_token_start = 0

for pos in range(1, len(line)):
# Boundary of alnum and non-alnum character groups
if is_alnum[pos] != is_alnum[pos - 1]:
token = line[current_token_start:pos]

# Drop single space if it's not on the beginning
if token != " " or current_token_start == 0:
tokens.append(token)

current_token_start = pos

# Add a final token (even if it's a single space)
final_token = line[current_token_start:]
tokens.append(final_token)

yield tokens

return reader


def column_separated_reader(
column: int, delimiter: str = "\t", quotechar: str = None,
encoding: str = "utf-8") -> PlainTextFileReader:
Expand Down Expand Up @@ -87,4 +134,5 @@ def tsv_reader(column: int):

# pylint: disable=invalid-name
UtfPlainTextReader = tokenized_text_reader()
T2TReader = t2t_tokenized_text_reader()
# pylint: enable=invalid-name
23 changes: 23 additions & 0 deletions neuralmonkey/tests/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from neuralmonkey.readers.string_vector_reader import get_string_vector_reader
from neuralmonkey.readers.plain_text_reader import T2TReader

STRING_INTS = """
1 2 3
Expand Down Expand Up @@ -96,5 +97,27 @@ def tearDown(self):
self.tmpfile_ints_fine.close()


class TestT2TReader(unittest.TestCase):

def setUp(self):
self.reader = T2TReader

def test_reader(self):
text = "Ich bin der čermák -=- - !!! alfonso "
gold_tokens = ["Ich", "bin", " ", "der", "čermák", " -=- - !!! ",
"alfonso", " "]

tmpfile = _make_file(text)

read = []
for line in self.reader([tmpfile.name]):
read.append(line)

tmpfile.close()

self.assertEqual(len(read), 1)
self.assertSequenceEqual(read[0], gold_tokens)


if __name__ == "__main__":
unittest.main()
92 changes: 92 additions & 0 deletions neuralmonkey/tests/test_wordpiece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python3.5
import unittest

from neuralmonkey.vocabulary import Vocabulary
from neuralmonkey.processors.wordpiece import (
WordpiecePreprocessor, WordpiecePostprocessor)

CORPUS = [
"the colorless ideas slept furiously",
"pooh slept all night",
"working class hero is something to be",
"I am the working class walrus",
"walrus for president"
]

TOKENIZED_CORPUS = [[a + "_" for a in s.split()] for s in CORPUS]

# Create list of characters required to process the CORPUS with wordpieces
CORPUS_CHARS = [x for c in set("".join(CORPUS)) for x in [c, c + "_"]]
ESCAPE_CHARS = "\\_u0987654321;"
C_CARON = "\\269;"
A_ACUTE = "225"


class TestWordpieces(unittest.TestCase):

@classmethod
def setUpClass(cls):
vocabulary = Vocabulary()

for c in CORPUS_CHARS + list(ESCAPE_CHARS):
vocabulary.add_word(c)

for sent in TOKENIZED_CORPUS:
vocabulary.add_tokenized_text(sent)

vocabulary.add_word(C_CARON)
vocabulary.add_word(A_ACUTE)

cls.preprocessor = WordpiecePreprocessor(vocabulary)
cls.postprocessor = WordpiecePostprocessor

def test_preprocess_ok(self):
raw = "I am the walrus".split()
gold = "I_ am_ the_ walrus_".split()

preprocessed = TestWordpieces.preprocessor(raw)
self.assertSequenceEqual(preprocessed, gold)

def test_preprocess_split(self):
raw = "Ich bin der walrus".split()
gold = "I c h_ b i n_ d e r_ walrus_".split()

preprocessed = TestWordpieces.preprocessor(raw)
self.assertSequenceEqual(preprocessed, gold)

def test_preprocess_unk(self):
raw = "Ich bin der čermák".split()
gold = "I c h_ b i n_ d e r_ \\269; e r m \\ 225 ; k_".split()

preprocessed = TestWordpieces.preprocessor(raw)
self.assertSequenceEqual(preprocessed, gold)

def test_postprocess_ok(self):
output = "I_ am_ the_ walrus_".split()
gold = ["I am the walrus".split()]

postprocessed = TestWordpieces.postprocessor([output])
self.assertSequenceEqual(postprocessed, gold)

def test_postprocess_split(self):
output = "I c h_ b i n_ d e r_ walrus_".split()
gold = ["Ich bin der walrus".split()]

postprocessed = TestWordpieces.postprocessor([output])
self.assertSequenceEqual(postprocessed, gold)

def test_postprocess_unk(self):
output = "I c h_ b i n_ d e r_ \\269; e r m \\ 225 ; k_".split()
gold = ["Ich bin der čermák".split()]

postprocessed = TestWordpieces.postprocessor([output])
self.assertSequenceEqual(postprocessed, gold)

# TODO (#669): implement wordpiece generator
@unittest.skip("not implemented yet")
def test_make_wordpieces(self):
pass


if __name__ == "__main__":
unittest.main()
41 changes: 40 additions & 1 deletion neuralmonkey/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def from_wordlist(path: str,
Arguments:
path: The path to the wordlist file
encoding: The encoding of the merge file (defaults to UTF-8)
encoding: The encoding of the wordlist file (defaults to UTF-8)
contains_header: if the file have a header on first line
contains_frequencies: if the file contains frequencies in second column
Expand Down Expand Up @@ -111,6 +111,39 @@ def from_wordlist(path: str,
return vocabulary


def from_t2t_vocabulary(path: str,
encoding: str = "utf-8") -> "Vocabulary":
"""Load a vocabulary generated during tensor2tensor training.
Arguments:
path: The path to the vocabulary file.
encoding: The encoding of the vocabulary file (defaults to UTF-8).
Returns:
The new Vocabulary instantce.
"""
vocabulary = Vocabulary()

with open(path, encoding=encoding) as wordlist:
for line in wordlist:
line = line.strip()

# T2T vocab tends to wrap words in single quotes
if ((line.startswith("'") and line.endswith("'")) or
(line.startswith('"') and line.endswith('"'))):
line = line[1:-1]

if line in ["<pad>", "<EOS>"]:
continue

vocabulary.add_word(line)

log("Vocabulary form wordlist loaded, containing {} words"
.format(len(vocabulary)))
vocabulary.log_sample()
return vocabulary


def from_nematus_json(path: str, max_size: int = None,
pad_to_max_size: bool = False) -> "Vocabulary":
"""Load vocabulary from Nematus JSON format.
Expand Down Expand Up @@ -295,6 +328,7 @@ def __init__(self, tokenized_text: List[str] = None,
self.word_to_index = {} # type: Dict[str, int]
self.index_to_word = [] # type: List[str]
self.word_count = {} # type: Dict[str, int]
self.alphabet = {tok for tok in _SPECIAL_TOKENS}

# flag if the word count are in use
self.correct_counts = False
Expand Down Expand Up @@ -339,8 +373,13 @@ def add_word(self, word: str, occurences: int = 1) -> None:
self.word_to_index[word] = len(self.index_to_word)
self.index_to_word.append(word)
self.word_count[word] = 0
if word not in _SPECIAL_TOKENS:
self.add_characters(word)
self.word_count[word] += occurences

def add_characters(self, word: str) -> None:
self.alphabet |= {c for c in word}

def add_tokenized_text(self, tokenized_text: List[str]) -> None:
"""Add words from a list to the vocabulary.
Expand Down

0 comments on commit b95b8de

Please sign in to comment.