Skip to content

Commit

Permalink
Fix WT2, restructure splits, and add WMT14 (#138)
Browse files Browse the repository at this point in the history
* add nonoptional dependencies to setup.py

* fix WT2 and restructure Dataset.splits and Dataset.download

* add WMT14 dataset
  • Loading branch information
jekbradbury committed Oct 16, 2017
1 parent ee595c4 commit 247598d
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 62 deletions.
12 changes: 7 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# NLP tools
nltk
spacy
git+git://github.com/jekbradbury/revtok.git

# Progress bars on iterators
tqdm

# Downloading data and other files
requests

# Optional NLP tools
nltk
spacy
git+git://github.com/jekbradbury/revtok.git

# Required for tests only:

# Style-checking for PEP8
flake8

Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
long_description=long_description,
license='BSD',

install_requires=[
'tqdm', 'requests'
],

# Package info
packages=find_packages(exclude=('test',)),

Expand Down
30 changes: 20 additions & 10 deletions torchtext/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import tarfile

import torch.utils.data
from six.moves import urllib

from .example import Example
from ..utils import download_from_url


class Dataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -42,11 +42,14 @@ def __init__(self, examples, fields, filter_pred=None):
self.fields = dict(fields)

@classmethod
def splits(cls, path, train=None, validation=None, test=None, **kwargs):
def splits(cls, path=None, root='.data', train=None, validation=None,
test=None, **kwargs):
"""Create Dataset objects for multiple splits of a dataset.
Arguments:
path (str): Common prefix of the splits' file paths.
path (str): Common prefix of the splits' file paths, or None to use
the result of cls.download(root).
root (str): Root dataset storage directory. Default is '.data'.
train (str): Suffix to add to path for the train set, or None for no
train set. Default is None.
validation (str): Suffix to add to path for the validation set, or None
Expand All @@ -60,10 +63,14 @@ def splits(cls, path, train=None, validation=None, test=None, **kwargs):
split_datasets (tuple(Dataset)): Datasets for train, validation, and
test splits in that order, if provided.
"""
train_data = None if train is None else cls(path + train, **kwargs)
val_data = None if validation is None else cls(path + validation,
**kwargs)
test_data = None if test is None else cls(path + test, **kwargs)
if path is None:
path = cls.download(root)
train_data = None if train is None else cls(
os.path.join(path, train), **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)

Expand Down Expand Up @@ -93,7 +100,7 @@ def download(cls, root, check=None):
root (str): Folder to download data to.
check (str or None): Folder whose existence indicates
that the dataset has already been downloaded, or
None to check the existence of root.
None to check the existence of root/{cls.name}.
Returns:
dataset_path (str): Path to extracted dataset.
Expand All @@ -102,13 +109,16 @@ def download(cls, root, check=None):
check = path if check is None else check
if not os.path.isdir(check):
for url in cls.urls:
filename = os.path.basename(url)
if isinstance(url, tuple):
url, filename = url
else:
filename = os.path.basename(url)
zpath = os.path.join(path, filename)
if not os.path.isfile(zpath):
if not os.path.exists(os.path.dirname(zpath)):
os.makedirs(os.path.dirname(zpath))
print('downloading {}'.format(filename))
urllib.request.urlretrieve(url, zpath)
download_from_url(url, zpath)
ext = os.path.splitext(filename)[-1]
if ext == '.zip':
with zipfile.ZipFile(zpath, 'r') as zfile:
Expand Down
5 changes: 3 additions & 2 deletions torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .language_modeling import LanguageModelingDataset, WikiText2
from .language_modeling import LanguageModelingDataset, WikiText2 # NOQA
from .snli import SNLI
from .sst import SST
from .translation import TranslationDataset, Multi30k, IWSLT
from .translation import TranslationDataset, Multi30k, IWSLT, WMT14 # NOQA
from .trec import TREC
from .imdb import IMDB

Expand All @@ -12,6 +12,7 @@
'TranslationDataset',
'Multi30k',
'IWSLT',
'WMT14'
'WikiText2',
'TREC',
'IMDB']
13 changes: 4 additions & 9 deletions torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,15 @@ def splits(cls, text_field, label_field, root='.data',
Arguments:
text_field: The field that will be used for the sentence.
label_field: The field that will be used for label data.
root: The root directory that contains the IMDB dataset subdirectory
root: Root dataset storage directory. Default is '.data'.
train: The directory that contains the training examples
test: The directory that contains the test examples
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download(root)

train_data = None if train is None else cls(
os.path.join(path, train), text_field, label_field, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), text_field, label_field, **kwargs)
return tuple(d for d in (train_data, test_data)
if d is not None)
return super(IMDB, cls).splits(
root=root, text_field=text_field, label_field=label_field,
train=train, validation=None, test=test, **kwargs)

@classmethod
def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs):
Expand Down
11 changes: 5 additions & 6 deletions torchtext/datasets/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class WikiText2(LanguageModelingDataset):

urls = ['https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip']
name = 'wikitext-2'
dirname = ''
dirname = 'wikitext-2'

@classmethod
def splits(cls, text_field, root='.data', train='wiki.train.tokens',
Expand All @@ -46,15 +46,14 @@ def splits(cls, text_field, root='.data', train='wiki.train.tokens',
root: The root directory that the dataset's zip archive will be
expanded into; therefore the directory in whose wikitext-2
subdirectory the data files will be stored.
train: The filename of the train data. Default: 'train.tokens'.
train: The filename of the train data. Default: 'wiki.train.tokens'.
validation: The filename of the validation data, or None to not
load the validation set. Default: 'valid.tokens'.
load the validation set. Default: 'wiki.valid.tokens'.
test: The filename of the test data, or None to not load the test
set. Default: 'test.tokens'.
set. Default: 'wiki.test.tokens'.
"""
path = cls.download(root)
return super(WikiText2, cls).splits(
path, train, validation, test,
root=root, train=train, validation=validation, test=test,
text_field=text_field)

@classmethod
Expand Down
100 changes: 79 additions & 21 deletions torchtext/datasets/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,37 @@ def __init__(self, path, exts, fields, **kwargs):

super(TranslationDataset, self).__init__(examples, fields, **kwargs)

@classmethod
def splits(cls, exts, fields, root='.data',
train='train', validation='val', test='test', **kwargs):
"""Create dataset objects for splits of a TranslationDataset.
Arguments:
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
test: The prefix of the test data. Default: 'test'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download(root)

train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), exts, fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)


class Multi30k(TranslationDataset, data.Dataset):
"""Defines a dataset for the multi-modal WMT 2016 task"""
class Multi30k(TranslationDataset):
"""The small-dataset WMT 2016 multimodal task, also known as Flickr30k"""

urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
Expand All @@ -51,12 +79,12 @@ class Multi30k(TranslationDataset, data.Dataset):

@classmethod
def splits(cls, exts, fields, root='.data',
train='train', val='val', test='test', **kwargs):
train='train', validation='val', test='test', **kwargs):
"""Create dataset objects for splits of the Multi30k dataset.
Arguments:
root: directory containing Multi30k data
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
Expand All @@ -66,34 +94,26 @@ def splits(cls, exts, fields, root='.data',
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download(root)
return super(Multi30k, cls).splits(
exts, fields, root, train, validation, test, **kwargs)

train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
val_data = None if val is None else cls(
os.path.join(path, val), exts, fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)


class IWSLT(TranslationDataset, data.Dataset):
"""Defines a dataset for the IWSLT 2016 task"""
class IWSLT(TranslationDataset):
"""The IWSLT 2016 TED talk translation task"""

base_url = 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz'
name = 'iwslt'
base_dirname = '{}-{}'

@classmethod
def splits(cls, exts, fields, root='.data',
train='train', val='IWSLT16.TED.tst2013',
train='train', validation='IWSLT16.TED.tst2013',
test='IWSLT16.TED.tst2014', **kwargs):
"""Create dataset objects for splits of the IWSLT dataset.
Arguments:
root: directory containing Multi30k data
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
Expand All @@ -109,7 +129,7 @@ def splits(cls, exts, fields, root='.data',
path = cls.download(root, check=check)

train = '.'.join([train, cls.dirname])
val = '.'.join([val, cls.dirname])
validation = '.'.join([validation, cls.dirname])
if test is not None:
test = '.'.join([test, cls.dirname])

Expand All @@ -118,8 +138,8 @@ def splits(cls, exts, fields, root='.data',

train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
val_data = None if val is None else cls(
os.path.join(path, val), exts, fields, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), exts, fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
Expand All @@ -146,3 +166,41 @@ def clean(path):
for l in fd_orig:
if not any(tag in l for tag in xml_tags):
fd_txt.write(l.strip() + '\n')


class WMT14(TranslationDataset):
"""The WMT 2014 English-German dataset, as preprocessed by Google Brain.
Though this download contains test sets from 2015 and 2016, the train set
differs slightly from WMT 2015 and 2016 and significantly from WMT 2017."""

urls = [('https://drive.google.com/uc?export=download&'
'id=0B_bZck-ksdkpM25jRUN2X2UxMm8', 'wmt16_en_de.tar.gz')]

This comment has been minimized.

Copy link
@IdiosyncraticDragon

IdiosyncraticDragon Nov 4, 2017

The commit message said the new dataset is WMT14, but the name here is "wmt16_en_de.tar.gz". Is there something wrong? @jekbradbury

This comment has been minimized.

Copy link
@jekbradbury

jekbradbury Nov 4, 2017

Author Contributor

This is a tgz produced by Google Brain that contains the WMT14 train data and the test sets through 2016 (the train data remained similar between those years but not identical). This allows direct comparison with all Google MT papers, which use this set and preprocessing setup.

This comment has been minimized.

Copy link
@IdiosyncraticDragon

IdiosyncraticDragon Nov 6, 2017

Ok, I see. Thank you for answering.

name = 'wmt14'
dirname = ''

@classmethod
def splits(cls, exts, fields, root='.data',
train='train.tok.clean.bpe.32000',
validation='newstest2013.tok.bpe.32000',
test='newstest2014.tok.bpe.32000', **kwargs):
"""Create dataset objects for splits of the WMT 2014 dataset.
Arguments:
root: Root dataset storage directory. Default is '.data'.
exts: A tuple containing the extensions for each language. Must be
either ('.en', '.de') or the reverse.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default:
'train.tok.clean.bpe.32000'.
validation: The prefix of the validation data. Default:
'newstest2013.tok.bpe.32000'.
test: The prefix of the test data. Default:
'newstest2014.tok.bpe.32000'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
return super(WMT14, cls).splits(
exts, fields, root, train, validation, test, **kwargs)
13 changes: 4 additions & 9 deletions torchtext/datasets/trec.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,16 @@ def splits(cls, text_field, label_field, root='.data',
Arguments:
text_field: The field that will be used for the sentence.
label_field: The field that will be used for label data.
root: The root directory that contains the trec dataset subdirectory
root: Root dataset storage directory. Default is '.data'.
train: The filename of the train data. Default: 'train_5500.label'.
test: The filename of the test data, or None to not load the test
set. Default: 'TREC_10.label'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download(root)

train_data = None if train is None else cls(
os.path.join(path, test), text_field, label_field, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), text_field, label_field, **kwargs)
return tuple(d for d in (train_data, test_data)
if d is not None)
return super(TREC, cls).splits(
root=root, text_field=text_field, label_field=label_field,
train=train, validation=None, test=test, **kwargs)

@classmethod
def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs):
Expand Down
26 changes: 26 additions & 0 deletions torchtext/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from six.moves import urllib
import requests


def reporthook(t):
"""https://github.com/tqdm/tqdm"""
Expand All @@ -17,3 +20,26 @@ def inner(b=1, bsize=1, tsize=None):
t.update((b - last_b[0]) * bsize)
last_b[0] = b
return inner


def download_from_url(url, path):
"""Download file, with logic (from tensor2tensor) for Google Drive"""
if 'drive.google.com' not in url:
return urllib.request.urlretrieve(url, path)
print('downloading from Google Drive; may take a few minutes')
confirm_token = None
session = requests.Session()
response = session.get(url, stream=True)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
confirm_token = v

if confirm_token:
url = url + "&confirm=" + confirm_token
response = session.get(url, stream=True)

chunk_size = 16 * 1024
with open(path, "wb") as f:
for chunk in response.iter_content(chunk_size):
if chunk:
f.write(chunk)

0 comments on commit 247598d

Please sign in to comment.