Skip to content

Commit

Permalink
Multi30k and Dataset download refactor (#116)
Browse files Browse the repository at this point in the history
* adding Multi30k wrapper

* abstracting tar.gz compression

* removing cls.filename

* refactor datasets for downloading

* bug in sst tree examples
  • Loading branch information
bmccann authored and jekbradbury committed Sep 14, 2017
1 parent 4f22159 commit 6f930eb
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 102 deletions.
34 changes: 30 additions & 4 deletions test/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,38 @@ def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))]


DE = data.Field(tokenize=tokenize_de)
EN = data.Field(tokenize=tokenize_en)

train, val, test = datasets.Multi30k.splits(exts=('.de', '.en'), fields=(DE, EN))

print(train.fields)
print(len(train))
print(vars(train[0]))
print(vars(train[100]))

DE.build_vocab(train.src, min_freq=3)
EN.build_vocab(train.trg, max_size=50000)

train_iter, val_iter = data.BucketIterator.splits(
(train, val), batch_size=3, device=0)

print(DE.vocab.freqs.most_common(10))
print(len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print(len(EN.vocab))

batch = next(iter(train_iter))
print(batch.src)
print(batch.trg)


DE = data.Field(tokenize=tokenize_de)
EN = data.Field(tokenize=tokenize_en)

train, val = datasets.TranslationDataset.splits(
path='~/iwslt2016/de-en/', train='train.tags.de-en',
validation='IWSLT16.TED.tst2013.de-en', exts=('.de', '.en'),
path='.data/multi30k/', train='train',
validation='val', exts=('.de', '.en'),
fields=(DE, EN))

print(train.fields)
Expand All @@ -38,9 +64,9 @@ def tokenize_en(text):
(train, val), batch_size=3, device=0)

print(DE.vocab.freqs.most_common(10))
print(DE.vocab.size)
print(len(DE.vocab))
print(EN.vocab.freqs.most_common(10))
print(EN.vocab.size)
print(len(EN.vocab))

batch = next(iter(train_iter))
print(batch.src)
Expand Down
2 changes: 1 addition & 1 deletion torchtext/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .batch import Batch
from .dataset import Dataset, TabularDataset, ZipDataset
from .dataset import Dataset, TabularDataset
from .example import Example
from .field import Field
from .iterator import (batch, BucketIterator, Iterator, BPTTIterator,
Expand Down
48 changes: 24 additions & 24 deletions torchtext/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import os
import zipfile
import tarfile

import torch.utils.data
from six.moves import urllib
Expand Down Expand Up @@ -78,6 +79,29 @@ def __getattr__(self, attr):
for x in self.examples:
yield getattr(x, attr)

@classmethod
def download(cls, root):
path = os.path.join(root, cls.name)
if not os.path.isdir(path):
for url in cls.urls:
filename = os.path.basename(url)
zpath = os.path.join(path, filename)
if not os.path.isfile(zpath):
if not os.path.exists(os.path.dirname(zpath)):
os.makedirs(os.path.dirname(zpath))
print('downloading {}'.format(filename))
urllib.request.urlretrieve(url, zpath)
ext = os.path.splitext(filename)[-1]
if ext == '.zip':
with zipfile.ZipFile(zpath, 'r') as zfile:
print('extracting')
zfile.extractall(path)
elif ext in ['.gz', '.tgz']:
with tarfile.open(zpath, 'r:gz') as tar:
dirs = [member for member in tar.getmembers()]
tar.extractall(path=path, members=dirs)
return os.path.join(path, cls.dirname)


class TabularDataset(Dataset):
"""Defines a Dataset of columns stored in CSV, TSV, or JSON format."""
Expand Down Expand Up @@ -113,27 +137,3 @@ def __init__(self, path, format, fields, **kwargs):
fields.append(field)

super(TabularDataset, self).__init__(examples, fields, **kwargs)


class ZipDataset(Dataset):
"""Defines a Dataset loaded from a downloadable zip archive.
Attributes:
url: URL where the zip archive can be downloaded.
filename: Filename of the downloaded zip archive.
dirname: Name of the top-level directory within the zip archive that
contains the data files.
"""

@classmethod
def download_or_unzip(cls, root):
path = os.path.join(root, cls.dirname)
if not os.path.isdir(path):
zpath = os.path.join(root, cls.filename)
if not os.path.isfile(zpath):
print('downloading')
urllib.request.urlretrieve(cls.url, zpath)
with zipfile.ZipFile(zpath, 'r') as zfile:
print('extracting')
zfile.extractall(root)
return os.path.join(path, '')
4 changes: 2 additions & 2 deletions torchtext/data/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,5 @@ def fromtree(cls, data, fields, subtrees=False):
tree = Tree.fromstring(data)
if subtrees:
return [cls.fromlist(
[t.leaves(), t.label()], fields) for t in tree.subtrees()]
return cls.fromlist([tree.leaves(), tree.label()], fields)
[' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields)
3 changes: 2 additions & 1 deletion torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .language_modeling import LanguageModelingDataset, WikiText2
from .snli import SNLI
from .sst import SST
from .translation import TranslationDataset
from .translation import TranslationDataset, Multi30k
from .trec import TREC
from .imdb import IMDB

Expand All @@ -10,6 +10,7 @@
'SNLI',
'SST',
'TranslationDataset',
'Multi30k',
'WikiText2',
'TREC',
'IMDB']
31 changes: 7 additions & 24 deletions torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
import tarfile
from six.moves import urllib
import glob

from .. import data


class IMDB(data.Dataset):

url = 'http://ai.stanford.edu/~amaas/data/sentiment/'
filename = 'aclImdb_v1.tar.gz'
urls = ['http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz']
name = 'imdb'
dirname = 'aclImdb'

@staticmethod
Expand All @@ -20,7 +18,7 @@ def __init__(self, path, text_field, label_field, **kwargs):
"""Create an IMDB dataset instance given a path and fields.
Arguments:
path: Path to the datasets highest level directory
path: Path to the dataset's highest level directory
text_field: The field that will be used for text data.
label_field: The field that will be used for label data.
Remaining keyword arguments: Passed to the constructor of
Expand All @@ -38,22 +36,7 @@ def __init__(self, path, text_field, label_field, **kwargs):
super(IMDB, self).__init__(examples, fields, **kwargs)

@classmethod
def download(cls, root):
path = os.path.join(root, cls.dirname)
if not os.path.isdir(path):
fpath = os.path.join(path, cls.filename)
if not os.path.isfile(fpath):
if not os.path.exists(os.path.dirname(fpath)):
os.makedirs(os.path.dirname(fpath))
print('downloading {}'.format(cls.filename))
urllib.request.urlretrieve(os.path.join(cls.url, cls.filename), fpath)
with tarfile.open(fpath, 'r:gz') as tar:
dirs = [member for member in tar.getmembers()]
tar.extractall(path=root, members=dirs)
return os.path.join(path, '')

@classmethod
def splits(cls, text_field, label_field, root='.',
def splits(cls, text_field, label_field, root='.data',
train='train', test='test', **kwargs):
"""Create dataset objects for splits of the IMDB dataset.
Expand All @@ -69,14 +52,14 @@ def splits(cls, text_field, label_field, root='.',
path = cls.download(root)

train_data = None if train is None else cls(
path + train, text_field, label_field, **kwargs)
os.path.join(path, train), text_field, label_field, **kwargs)
test_data = None if test is None else cls(
path + test, text_field, label_field, **kwargs)
os.path.join(path, test), text_field, label_field, **kwargs)
return tuple(d for d in (train_data, test_data)
if d is not None)

@classmethod
def iters(cls, batch_size=32, device=0, root='.', vectors=None, **kwargs):
def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs):
"""Creater iterator objects for splits of the IMDB dataset.
Arguments:
Expand Down
14 changes: 7 additions & 7 deletions torchtext/datasets/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def __init__(self, path, text_field, newline_eos=True, **kwargs):
examples, fields, **kwargs)


class WikiText2(LanguageModelingDataset, data.ZipDataset):
class WikiText2(LanguageModelingDataset):

url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
filename = 'wikitext-2-v1.zip'
dirname = 'wikitext-2'
urls = ['https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip']
name = 'wikitext-2'
dirname = ''

@classmethod
def splits(cls, text_field, root='.', train='wiki.train.tokens',
def splits(cls, text_field, root='.data', train='wiki.train.tokens',
validation='wiki.valid.tokens', test='wiki.test.tokens'):
"""Create dataset objects for splits of the WikiText-2 dataset.
Expand All @@ -52,13 +52,13 @@ def splits(cls, text_field, root='.', train='wiki.train.tokens',
test: The filename of the test data, or None to not load the test
set. Default: 'test.tokens'.
"""
path = cls.download_or_unzip(root)
path = cls.download(root)
return super(WikiText2, cls).splits(
path, train, validation, test,
text_field=text_field)

@classmethod
def iters(cls, batch_size=32, bptt_len=35, device=0, root='.', wv_dir='.',
def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data', wv_dir='.',
wv_type=None, wv_dim='300d', **kwargs):
"""Create iterator objects for splits of the WikiText-2 dataset.
Expand Down
13 changes: 7 additions & 6 deletions torchtext/datasets/snli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ def __init__(self, eos_token='<pad>', lower=False):
list(reversed(p)) for p in parse])


class SNLI(data.ZipDataset, data.TabularDataset):
class SNLI(data.TabularDataset):

url = 'http://nlp.stanford.edu/projects/snli/snli_1.0.zip'
filename = 'snli_1.0.zip'
urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip']
dirname = 'snli_1.0'
name = 'snli'

@staticmethod
def sort_key(ex):
return data.interleave_keys(
len(ex.premise), len(ex.hypothesis))

@classmethod
def splits(cls, text_field, label_field, parse_field=None, root='.',
def splits(cls, text_field, label_field, parse_field=None, root='.data',
train='train.jsonl', validation='dev.jsonl', test='test.jsonl'):
"""Create dataset objects for splits of the SNLI dataset.
Expand All @@ -57,7 +57,8 @@ def splits(cls, text_field, label_field, parse_field=None, root='.',
test: The filename of the test data, or None to not load the test
set. Default: 'test.jsonl'.
"""
path = cls.download_or_unzip(root)
path = cls.download(root)

if parse_field is None:
return super(SNLI, cls).splits(
os.path.join(path, 'snli_1.0_'), train, validation, test,
Expand All @@ -77,7 +78,7 @@ def splits(cls, text_field, label_field, parse_field=None, root='.',
filter_pred=lambda ex: ex.label != '-')

@classmethod
def iters(cls, batch_size=32, device=0, root='.',
def iters(cls, batch_size=32, device=0, root='.data',
vectors=None, trees=False, **kwargs):
"""Create iterator objects for splits of the SNLI dataset.
Expand Down
20 changes: 10 additions & 10 deletions torchtext/datasets/sst.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from .. import data


class SST(data.ZipDataset):
class SST(data.Dataset):

url = 'http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'
filename = 'trainDevTestTrees_PTB.zip'
urls = ['http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip']
dirname = 'trees'
name = 'sst'

@staticmethod
def sort_key(ex):
Expand All @@ -18,7 +18,7 @@ def __init__(self, path, text_field, label_field, subtrees=False,
"""Create an SST dataset instance given a path and fields.
Arguments:
path: Path to the data file.
path: Path to the data file
text_field: The field that will be used for text data.
label_field: The field that will be used for label data.
subtrees: Whether to include sentiment-tagged subphrases
Expand All @@ -44,7 +44,7 @@ def get_label_str(label):
super(SST, self).__init__(examples, fields, **kwargs)

@classmethod
def splits(cls, text_field, label_field, root='.',
def splits(cls, text_field, label_field, root='.data',
train='train.txt', validation='dev.txt', test='test.txt',
train_subtrees=False, **kwargs):
"""Create dataset objects for splits of the SST dataset.
Expand All @@ -65,20 +65,20 @@ def splits(cls, text_field, label_field, root='.',
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download_or_unzip(root)
path = cls.download(root)

train_data = None if train is None else cls(
path + train, text_field, label_field, subtrees=train_subtrees,
os.path.join(path, train), text_field, label_field, subtrees=train_subtrees,
**kwargs)
val_data = None if validation is None else cls(
path + validation, text_field, label_field, **kwargs)
os.path.join(path, validation), text_field, label_field, **kwargs)
test_data = None if test is None else cls(
path + test, text_field, label_field, **kwargs)
os.path.join(path, test), text_field, label_field, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)

@classmethod
def iters(cls, batch_size=32, device=0, root='.', vectors=None, **kwargs):
def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs):
"""Creater iterator objects for splits of the SST dataset.
Arguments:
Expand Down
38 changes: 38 additions & 0 deletions torchtext/datasets/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,41 @@ def __init__(self, path, exts, fields, **kwargs):
[src_line, trg_line], fields))

super(TranslationDataset, self).__init__(examples, fields, **kwargs)


class Multi30k(TranslationDataset, data.Dataset):
"""Defines a dataset for the multi-modal WMT 2017 task"""

urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
'https://staff.fnwi.uva.nl/d.elliott/wmt16/mmt16_task1_test.tgz']
name = 'multi30k'
dirname = ''

@classmethod
def splits(cls, exts, fields, root='.data',
train='train', val='val', test='test', **kwargs):
"""Create dataset objects for splits of the Multi30k dataset.
Arguments:
root: directory containing Multi30k data
exts: A tuple containing the extension to path for each language.
fields: A tuple containing the fields that will be used for data
in each language.
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
test: The prefix of the test data. Default: 'test'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download(root)

train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
val_data = None if val is None else cls(
os.path.join(path, val), exts, fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)
Loading

0 comments on commit 6f930eb

Please sign in to comment.