Skip to content

Commit

Permalink
[BC Breaking] Remove unicode_csv_reader from raw translation datasets (
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangguanheng66 committed Nov 19, 2020
1 parent d7138b2 commit 633548a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
28 changes: 14 additions & 14 deletions test/data/test_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,42 +167,42 @@ def test_multi30k(self):
# smoke test to ensure multi30k works properly
train_dataset, valid_dataset, test_dataset = Multi30k()
self._helper_test_func(len(train_dataset), 29000, train_dataset[20],
([3, 443, 2530, 46, 17478, 7422, 7, 157, 9, 11, 5848, 2],
[4, 60, 529, 136, 1493, 9, 8, 279, 5, 2, 3748, 3]))
([4, 444, 2531, 47, 17480, 7423, 8, 158, 10, 12, 5849, 3, 2],
[5, 61, 530, 137, 1494, 10, 9, 280, 6, 2, 3749, 4, 3]))
self._helper_test_func(len(valid_dataset), 1014, valid_dataset[30],
([3, 178, 25, 84, 1003, 56, 18, 153, 2],
[4, 23, 31, 80, 46, 1347, 5, 2, 118, 3]))
([4, 179, 26, 85, 1005, 57, 19, 154, 3, 2],
[5, 24, 32, 81, 47, 1348, 6, 2, 119, 4, 3]))
self._helper_test_func(len(test_dataset), 1000, test_dataset[40],
([3, 25, 5, 11, 3914, 1536, 20, 63, 2],
[4, 31, 19, 2, 746, 344, 1914, 5, 45, 3]))
([4, 26, 6, 12, 3915, 1538, 21, 64, 3, 2],
[5, 32, 20, 2, 747, 345, 1915, 6, 46, 4, 3]))

de_vocab, en_vocab = train_dataset.get_vocab()
de_tokens_ids = [
de_vocab[token] for token in
'Zwei Männer verpacken Donuts in Kunststofffolie'.split()
]
self.assertEqual(de_tokens_ids, [19, 29, 18703, 4448, 5, 6240])
self.assertEqual(de_tokens_ids, [20, 30, 18705, 4448, 6, 6241])

en_tokens_ids = [
en_vocab[token] for token in
'Two young White males are outside near many bushes'.split()
]
self.assertEqual(en_tokens_ids,
[17, 23, 1167, 806, 15, 55, 82, 334, 1337])
[18, 24, 1168, 807, 16, 56, 83, 335, 1338])

# Add test for the subset of the standard datasets
train_iter, valid_iter = torchtext.experimental.datasets.raw.Multi30k(data_select=('train', 'valid'))
self._helper_test_func(len(train_iter), 29000, ' '.join(next(iter(train_iter))),
' '.join(['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
'Two young White males are outside near many bushes.']))
' '.join(['Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n',
'Two young, White males are outside near many bushes.\n']))
self._helper_test_func(len(valid_iter), 1014, ' '.join(next(iter(valid_iter))),
' '.join(['Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen',
'A group of men are loading cotton onto a truck']))
' '.join(['Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen\n',
'A group of men are loading cotton onto a truck\n']))
del train_iter, valid_iter
train_dataset, = Multi30k(data_select=('train'))
self._helper_test_func(len(train_dataset), 29000, train_dataset[20],
([3, 443, 2530, 46, 17478, 7422, 7, 157, 9, 11, 5848, 2],
[4, 60, 529, 136, 1493, 9, 8, 279, 5, 2, 3748, 3]))
([4, 444, 2531, 47, 17480, 7423, 8, 158, 10, 12, 5849, 3, 2],
[5, 61, 530, 137, 1494, 10, 9, 280, 6, 2, 3749, 4, 3]))

datafile = os.path.join(self.project_root, ".data", "train*")
conditional_remove(datafile)
Expand Down
10 changes: 4 additions & 6 deletions torchtext/experimental/datasets/raw/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import codecs
import xml.etree.ElementTree as ET
from collections import defaultdict
from torchtext.utils import (download_from_url, extract_archive,
unicode_csv_reader)
from torchtext.utils import (download_from_url, extract_archive)
from torchtext.experimental.datasets.raw.common import RawTextIterableDataset
from torchtext.experimental.datasets.raw.common import check_default_set

Expand Down Expand Up @@ -69,9 +68,8 @@

def _read_text_iterator(path):
with io.open(path, encoding="utf8") as f:
reader = unicode_csv_reader(f)
for row in reader:
yield " ".join(row)
for row in f:
yield row


def _clean_xml_file(f_xml):
Expand Down Expand Up @@ -516,7 +514,7 @@ def WMT14(train_filenames=('train.tok.clean.bpe.32000.de',
}
NUM_LINES = {
'Multi30k': {'train': 29000, 'valid': 1014, 'test': 1000},
'IWSLT': {'train': 173939, 'valid': 823, 'test': 1096},
'IWSLT': {'train': 196884, 'valid': 993, 'test': 1305},
'WMT14': {'train': 4500966, 'valid': 3000, 'test': 3003}
}
MD5 = {
Expand Down

0 comments on commit 633548a

Please sign in to comment.