-
Notifications
You must be signed in to change notification settings - Fork 814
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix WT2, restructure splits, and add WMT14 (#138)
* add nonoptional dependencies to setup.py * fix WT2 and restructure Dataset.splits and Dataset.download * add WMT14 dataset
- Loading branch information
1 parent
ee595c4
commit 247598d
Showing
9 changed files
with
152 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,9 +39,37 @@ def __init__(self, path, exts, fields, **kwargs): | |
|
||
super(TranslationDataset, self).__init__(examples, fields, **kwargs) | ||
|
||
@classmethod | ||
def splits(cls, exts, fields, root='.data', | ||
train='train', validation='val', test='test', **kwargs): | ||
"""Create dataset objects for splits of a TranslationDataset. | ||
Arguments: | ||
root: Root dataset storage directory. Default is '.data'. | ||
exts: A tuple containing the extension to path for each language. | ||
fields: A tuple containing the fields that will be used for data | ||
in each language. | ||
train: The prefix of the train data. Default: 'train'. | ||
validation: The prefix of the validation data. Default: 'val'. | ||
test: The prefix of the test data. Default: 'test'. | ||
Remaining keyword arguments: Passed to the splits method of | ||
Dataset. | ||
""" | ||
path = cls.download(root) | ||
|
||
train_data = None if train is None else cls( | ||
os.path.join(path, train), exts, fields, **kwargs) | ||
val_data = None if validation is None else cls( | ||
os.path.join(path, validation), exts, fields, **kwargs) | ||
test_data = None if test is None else cls( | ||
os.path.join(path, test), exts, fields, **kwargs) | ||
return tuple(d for d in (train_data, val_data, test_data) | ||
if d is not None) | ||
|
||
|
||
class Multi30k(TranslationDataset, data.Dataset): | ||
"""Defines a dataset for the multi-modal WMT 2016 task""" | ||
class Multi30k(TranslationDataset): | ||
"""The small-dataset WMT 2016 multimodal task, also known as Flickr30k""" | ||
|
||
urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz', | ||
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz', | ||
|
@@ -51,12 +79,12 @@ class Multi30k(TranslationDataset, data.Dataset): | |
|
||
@classmethod | ||
def splits(cls, exts, fields, root='.data', | ||
train='train', val='val', test='test', **kwargs): | ||
train='train', validation='val', test='test', **kwargs): | ||
"""Create dataset objects for splits of the Multi30k dataset. | ||
Arguments: | ||
root: directory containing Multi30k data | ||
root: Root dataset storage directory. Default is '.data'. | ||
exts: A tuple containing the extension to path for each language. | ||
fields: A tuple containing the fields that will be used for data | ||
in each language. | ||
|
@@ -66,34 +94,26 @@ def splits(cls, exts, fields, root='.data', | |
Remaining keyword arguments: Passed to the splits method of | ||
Dataset. | ||
""" | ||
path = cls.download(root) | ||
return super(Multi30k, cls).splits( | ||
exts, fields, root, train, validation, test, **kwargs) | ||
|
||
train_data = None if train is None else cls( | ||
os.path.join(path, train), exts, fields, **kwargs) | ||
val_data = None if val is None else cls( | ||
os.path.join(path, val), exts, fields, **kwargs) | ||
test_data = None if test is None else cls( | ||
os.path.join(path, test), exts, fields, **kwargs) | ||
return tuple(d for d in (train_data, val_data, test_data) | ||
if d is not None) | ||
|
||
|
||
class IWSLT(TranslationDataset, data.Dataset): | ||
"""Defines a dataset for the IWSLT 2016 task""" | ||
class IWSLT(TranslationDataset): | ||
"""The IWSLT 2016 TED talk translation task""" | ||
|
||
base_url = 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz' | ||
name = 'iwslt' | ||
base_dirname = '{}-{}' | ||
|
||
@classmethod | ||
def splits(cls, exts, fields, root='.data', | ||
train='train', val='IWSLT16.TED.tst2013', | ||
train='train', validation='IWSLT16.TED.tst2013', | ||
test='IWSLT16.TED.tst2014', **kwargs): | ||
"""Create dataset objects for splits of the IWSLT dataset. | ||
Arguments: | ||
root: directory containing Multi30k data | ||
root: Root dataset storage directory. Default is '.data'. | ||
exts: A tuple containing the extension to path for each language. | ||
fields: A tuple containing the fields that will be used for data | ||
in each language. | ||
|
@@ -109,7 +129,7 @@ def splits(cls, exts, fields, root='.data', | |
path = cls.download(root, check=check) | ||
|
||
train = '.'.join([train, cls.dirname]) | ||
val = '.'.join([val, cls.dirname]) | ||
validation = '.'.join([validation, cls.dirname]) | ||
if test is not None: | ||
test = '.'.join([test, cls.dirname]) | ||
|
||
|
@@ -118,8 +138,8 @@ def splits(cls, exts, fields, root='.data', | |
|
||
train_data = None if train is None else cls( | ||
os.path.join(path, train), exts, fields, **kwargs) | ||
val_data = None if val is None else cls( | ||
os.path.join(path, val), exts, fields, **kwargs) | ||
val_data = None if validation is None else cls( | ||
os.path.join(path, validation), exts, fields, **kwargs) | ||
test_data = None if test is None else cls( | ||
os.path.join(path, test), exts, fields, **kwargs) | ||
return tuple(d for d in (train_data, val_data, test_data) | ||
|
@@ -146,3 +166,41 @@ def clean(path): | |
for l in fd_orig: | ||
if not any(tag in l for tag in xml_tags): | ||
fd_txt.write(l.strip() + '\n') | ||
|
||
|
||
class WMT14(TranslationDataset): | ||
"""The WMT 2014 English-German dataset, as preprocessed by Google Brain. | ||
Though this download contains test sets from 2015 and 2016, the train set | ||
differs slightly from WMT 2015 and 2016 and significantly from WMT 2017.""" | ||
|
||
urls = [('https://drive.google.com/uc?export=download&' | ||
'id=0B_bZck-ksdkpM25jRUN2X2UxMm8', 'wmt16_en_de.tar.gz')] | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
jekbradbury
Author
Contributor
|
||
name = 'wmt14' | ||
dirname = '' | ||
|
||
@classmethod | ||
def splits(cls, exts, fields, root='.data', | ||
train='train.tok.clean.bpe.32000', | ||
validation='newstest2013.tok.bpe.32000', | ||
test='newstest2014.tok.bpe.32000', **kwargs): | ||
"""Create dataset objects for splits of the WMT 2014 dataset. | ||
Arguments: | ||
root: Root dataset storage directory. Default is '.data'. | ||
exts: A tuple containing the extensions for each language. Must be | ||
either ('.en', '.de') or the reverse. | ||
fields: A tuple containing the fields that will be used for data | ||
in each language. | ||
train: The prefix of the train data. Default: | ||
'train.tok.clean.bpe.32000'. | ||
validation: The prefix of the validation data. Default: | ||
'newstest2013.tok.bpe.32000'. | ||
test: The prefix of the test data. Default: | ||
'newstest2014.tok.bpe.32000'. | ||
Remaining keyword arguments: Passed to the splits method of | ||
Dataset. | ||
""" | ||
return super(WMT14, cls).splits( | ||
exts, fields, root, train, validation, test, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The commit message said the new dataset is WMT14, but the name here is "wmt16_en_de.tar.gz". Is there something wrong? @jekbradbury