Skip to content

Commit

Permalink
Merge branch 'master' of github.com:pytorch/text
Browse files Browse the repository at this point in the history
  • Loading branch information
mttk committed Sep 25, 2018
2 parents 3422c1d + 64ef022 commit f2a939a
Show file tree
Hide file tree
Showing 18 changed files with 237 additions and 44 deletions.
7 changes: 4 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Alternatively, you might want to use Moses tokenizer from `NLTK <http://nltk.org
Documentation
=============

Find the documentation `here <https://torchtext.readthedocs.io/en/latest/index.html>`.
Find the documentation `here <https://torchtext.readthedocs.io/en/latest/index.html>`_.

Data
====
Expand Down Expand Up @@ -118,9 +118,10 @@ The datasets module currently contains:
* Sentiment analysis: SST and IMDb
* Question classification: TREC
* Entailment: SNLI, MultiNLI
* Language modeling: abstract class + WikiText-2
* Language modeling: abstract class + WikiText-2, WikiText103, PennTreebank
* Machine translation: abstract class + Multi30k, IWSLT, WMT14
* Sequence tagging (e.g. POS/NER): abstract class + UDPOS
* Sequence tagging (e.g. POS/NER): abstract class + UDPOS, CoNLL2000Chunking
* Question answering: 20 QA bAbI tasks

Others are planned or a work in progress:

Expand Down
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl
torch
http://download.pytorch.org/whl/cu80/torch-0.4.0-cp35-cp35m-linux_x86_64.whl

# Progress bars on iterators
tqdm
sphinx_rtd_theme
Expand Down
50 changes: 50 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ SNLI
:members: splits, iters


MultiNLI
~~~~~~~~

.. autoclass:: MultiNLI
:members: splits, iters


Language Modeling
^^^^^^^^^^^^^^^^^
Expand All @@ -90,6 +96,19 @@ WikiText-2
:members: splits, iters


WikiText103
~~~~~~~~~~~

.. autoclass:: WikiText103
:members: splits, iters


PennTreebank
~~~~~~~~~~~~

.. autoclass:: PennTreebank
:members: splits, iters


Machine Translation
^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -117,3 +136,34 @@ WMT14

.. autoclass:: WMT14
:members: splits


Sequence Tagging
^^^^^^^^^^^^^^^^

Sequence tagging datasets are subclasses of ``SequenceTaggingDataset`` class.

.. autoclass:: SequenceTaggingDataset
:members: __init__


UDPOS
~~~~~

.. autoclass:: UDPOS
:members: splits

CoNLL2000Chunking
~~~~~~~~~~~~~~~~~

.. autoclass:: CoNLL2000Chunking
:members: splits

Question Answering
^^^^^^^^^^^^^^^^^^

BABI20
~~~~~~

.. autoclass:: BABI20
:members: __init__, splits, iters
11 changes: 11 additions & 0 deletions readthedocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
build:
image: latest

python:
version: 3.6
setup_py_install: true

# Don't build any extra formats
formats: []

requirements_file: docs/requirements.txt
25 changes: 25 additions & 0 deletions test/data/test_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import unicode_literals
import torch
import torchtext.data as data

from ..common.torchtext_test_case import TorchtextTestCase
Expand All @@ -17,3 +18,27 @@ def test_batch_with_missing_field(self):
("label", None)])
itr = data.Iterator(dst, batch_size=64)
str(next(itr.__iter__()))

def test_batch_iter(self):
self.write_test_numerical_features_dataset()
FLOAT = data.Field(use_vocab=False, sequential=False,
dtype=torch.float)
INT = data.Field(use_vocab=False, sequential=False, is_target=True)
TEXT = data.Field(sequential=False)

dst = data.TabularDataset(path=self.test_numerical_features_dataset_path,
format="tsv", skip_header=False,
fields=[("float", FLOAT),
("int", INT),
("text", TEXT)])
TEXT.build_vocab(dst)
itr = data.Iterator(dst, batch_size=2, device=-1, shuffle=False)
fld_order = [k for k, v in dst.fields.items() if
v is not None and not v.is_target]
batch = next(iter(itr))
(x1, x2), y = batch
x = (x1, x2)[fld_order.index("float")]
self.assertEquals(y.data[0], 1)
self.assertEquals(y.data[1], 12)
self.assertAlmostEqual(x.data[0], 0.1, places=4)
self.assertAlmostEqual(x.data[1], 0.5, places=4)
37 changes: 37 additions & 0 deletions test/data/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.nn import init

from ..common.torchtext_test_case import TorchtextTestCase, verify_numericalized_example
from ..common.test_markers import slow


class TestField(TorchtextTestCase):
Expand Down Expand Up @@ -331,6 +332,35 @@ def reverse_postprocess(arr, vocab):
reversed_test_example_data,
postprocessed_numericalized)

def test_numericalize_stop_words(self):
# Based on request from #354
self.write_test_ppid_dataset(data_format="tsv")
question_field = data.Field(sequential=True, batch_first=True,
stop_words=set(["do", "you"]))
tsv_fields = [("id", None), ("q1", question_field),
("q2", question_field), ("label", None)]
tsv_dataset = data.TabularDataset(
path=self.test_ppid_dataset_path, format="tsv",
fields=tsv_fields)
question_field.build_vocab(tsv_dataset)

test_example_data = question_field.pad(
[question_field.preprocess(x) for x in
[["When", "do", "you", "use", "シ",
"instead", "of", "し?"],
["What", "is", "2+2", "<pad>", "<pad>",
"<pad>", "<pad>", "<pad>"],
["Here", "is", "a", "sentence", "with",
"some", "oovs", "<pad>"]]]
)

# Test with batch_first
stopwords_removed_numericalized = question_field.numericalize(test_example_data)
verify_numericalized_example(question_field,
test_example_data,
stopwords_removed_numericalized,
batch_first=True)

def test_numerical_features_no_vocab(self):
self.write_test_numerical_features_dataset()
# Test basic usage
Expand Down Expand Up @@ -484,6 +514,9 @@ def test_build_vocab_from_dataset(self):
for c in expected:
assert c in CHARS.vocab.stoi

expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs

def test_build_vocab_from_iterable(self):
nesting_field = data.Field(unk_token="<cunk>", pad_token="<cpad>")
CHARS = data.NestedField(nesting_field)
Expand All @@ -497,6 +530,9 @@ def test_build_vocab_from_iterable(self):
for c in expected:
assert c in CHARS.vocab.stoi

expected_freqs = Counter({"a": 6, "b": 12, "c": 4})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs

def test_pad(self):
nesting_field = data.Field(tokenize=list, unk_token="<cunk>", pad_token="<cpad>",
init_token="<w>", eos_token="</w>")
Expand Down Expand Up @@ -749,6 +785,7 @@ def test_numericalize(self):
verify_numericalized_example(
field, example, numericalized_example, batch_first=True)

@slow
def test_build_vocab(self):
nesting_field = data.Field(tokenize=list, init_token="<w>", eos_token="</w>")

Expand Down
17 changes: 17 additions & 0 deletions test/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ def test_vocab_basic(self):
self.assertEqual(v.itos, expected_itos)
self.assertEqual(dict(v.stoi), expected_stoi)

def test_vocab_specials_first(self):
c = Counter("a a b b c c".split())

# add specials into vocabulary at first
v = vocab.Vocab(c, max_size=2, specials=['<pad>', '<eos>'])
expected_itos = ['<pad>', '<eos>', 'a', 'b']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.itos, expected_itos)
self.assertEqual(dict(v.stoi), expected_stoi)

# add specials into vocabulary at last
v = vocab.Vocab(c, max_size=2, specials=['<pad>', '<eos>'], specials_first=False)
expected_itos = ['a', 'b', '<pad>', '<eos>']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.itos, expected_itos)
self.assertEqual(dict(v.stoi), expected_stoi)

def test_vocab_set_vectors(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5,
'test': 4, 'freq_too_low': 2})
Expand Down
19 changes: 19 additions & 0 deletions torchtext/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class Batch(object):
(which itself contains the dataset's Field objects).
train: Deprecated: this attribute is left for backwards compatibility,
however it is UNUSED as of the merger with pytorch 0.4.
input_fields: The names of the fields that are used as input for the model
target_fields: The names of the fields that are used as targets during
model training
Also stores the Variable for each column in the batch as an attribute.
"""
Expand All @@ -20,6 +23,10 @@ def __init__(self, data=None, dataset=None, device=None):
self.batch_size = len(data)
self.dataset = dataset
self.fields = dataset.fields.keys() # copy field names
self.input_fields = [k for k, v in dataset.fields.items() if
v is not None and not v.is_target]
self.target_fields = [k for k, v in dataset.fields.items() if
v is not None and v.is_target]

for (name, field) in dataset.fields.items():
if field is not None:
Expand Down Expand Up @@ -59,6 +66,18 @@ def __str__(self):
def __len__(self):
return self.batch_size

def _get_field_values(self, fields):
if len(fields) == 0:
return None
elif len(fields) == 1:
return getattr(self, fields[0])
else:
return tuple(getattr(self, f) for f in fields)

def __iter__(self):
yield self._get_field_values(self.input_fields)
yield self._get_field_values(self.target_fields)


def _short_str(tensor):
# unwrap variable to tensor
Expand Down
6 changes: 3 additions & 3 deletions torchtext/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def splits(cls, path=None, root='.data', train=None, validation=None,
Returns:
Tuple[Dataset]: Datasets for train, validation, and
test splits in that order, if provided.
test splits in that order, if provided.
"""
if path is None:
path = cls.download(root)
Expand Down Expand Up @@ -102,7 +102,7 @@ def split(self, split_ratio=0.7, stratified=False, strata_field='label',
Returns:
Tuple[Dataset]: Datasets for train, validation, and
test splits in that order, if the splits are provided.
test splits in that order, if the splits are provided.
"""
train_ratio, test_ratio, val_ratio = check_split_ratio(split_ratio)

Expand Down Expand Up @@ -266,7 +266,7 @@ def check_split_ratio(split_ratio):
if isinstance(split_ratio, float):
# Only the train set relative ratio is provided
# Assert in bounds, validation size is zero
assert split_ratio > 0. and split_ratio < 1., (
assert 0. < split_ratio < 1., (
"Split ratio {} not between 0 and 1".format(split_ratio))

test_ratio = 1. - split_ratio
Expand Down
29 changes: 23 additions & 6 deletions torchtext/data/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def process(self, batch, *args, **kwargs):
batch (list(object)): A list of object from a batch of examples.
Returns:
object: Processed object given the input and custom
postprocessing Pipeline.
postprocessing Pipeline.
"""
if self.postprocessing is not None:
batch = self.postprocessing(batch)
Expand Down Expand Up @@ -106,6 +106,9 @@ class Field(RawField):
unk_token: The string token used to represent OOV words. Default: "<unk>".
pad_first: Do the padding of the sequence at the beginning. Default: False.
truncate_first: Do the truncating of the sequence at the beginning. Default: False
stop_words: Tokens to discard during the preprocessing step. Default: None
is_target: Whether this field is a target variable.
Affects iteration over batches. Default: False
"""

vocab_cls = Vocab
Expand Down Expand Up @@ -134,7 +137,8 @@ def __init__(self, sequential=True, use_vocab=True, init_token=None,
preprocessing=None, postprocessing=None, lower=False,
tokenize=(lambda s: s.split()), include_lengths=False,
batch_first=False, pad_token="<pad>", unk_token="<unk>",
pad_first=False, truncate_first=False):
pad_first=False, truncate_first=False, stop_words=None,
is_target=False):
self.sequential = sequential
self.use_vocab = use_vocab
self.init_token = init_token
Expand All @@ -151,6 +155,15 @@ def __init__(self, sequential=True, use_vocab=True, init_token=None,
self.pad_token = pad_token if self.sequential else None
self.pad_first = pad_first
self.truncate_first = truncate_first
if stop_words is not None:
try:
self.stop_words = set(stop_words)
except TypeError:
raise ValueError("Stop words must be convertible to a set")
else:
self.stop_words = stop_words
self.stop_words = stop_words
self.is_target = is_target

def preprocess(self, x):
"""Load a single example using this field, tokenizing if necessary.
Expand All @@ -166,6 +179,8 @@ def preprocess(self, x):
x = self.tokenize(x.rstrip('\n'))
if self.lower:
x = Pipeline(six.text_type.lower)(x)
if self.sequential and self.use_vocab and self.stop_words is not None:
x = [w for w in x if w not in self.stop_words]
if self.preprocessing is not None:
return self.preprocessing(x)
else:
Expand All @@ -180,7 +195,7 @@ def process(self, batch, device=None):
batch (list(object)): A list of object from a batch of examples.
Returns:
torch.autograd.Variable: Processed object given the input
and custom postprocessing Pipeline.
and custom postprocessing Pipeline.
"""
padded = self.pad(batch)
tensor = self.numericalize(padded, device=device)
Expand Down Expand Up @@ -296,7 +311,7 @@ def numericalize(self, arr, device=None):
"Please raise an issue at "
"https://github.com/pytorch/text/issues".format(self.dtype))
numericalization_func = self.dtypes[self.dtype]
# It doesn't make sense to explictly coerce to a numeric type if
# It doesn't make sense to explicitly coerce to a numeric type if
# the data is sequential, since it's unclear how to coerce padding tokens
# to a numeric type.
if not self.sequential:
Expand Down Expand Up @@ -629,6 +644,7 @@ def build_vocab(self, *args, **kwargs):
self.nesting_field.build_vocab(*flattened, **kwargs)
super(NestedField, self).build_vocab()
self.vocab.extend(self.nesting_field.vocab)
self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
if old_vectors is not None:
self.vocab.load_vectors(old_vectors,
unk_init=old_unk_init, cache=old_vectors_cache)
Expand Down Expand Up @@ -660,8 +676,9 @@ def numericalize(self, arrs, device=None):

self.nesting_field.include_lengths = True
if self.include_lengths:
sentence_lengths = torch.LongTensor(sentence_lengths, device=device)
word_lengths = torch.LongTensor(word_lengths, device=device)
sentence_lengths = \
torch.tensor(sentence_lengths, dtype=self.dtype, device=device)
word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device)
return (padded_batch, sentence_lengths, word_lengths)
return padded_batch

Expand Down
Loading

0 comments on commit f2a939a

Please sign in to comment.