Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RawField that represents any datatype. #147

Merged
merged 4 commits into from
Oct 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions test/data/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,35 @@
from numpy.testing import assert_allclose
import torch
import torchtext.data as data
import pytest

from ..common.torchtext_test_case import TorchtextTestCase, verify_numericalized_example


class TestField(TorchtextTestCase):
def test_process(self):
raw_field = data.RawField()
field = data.Field(sequential=True, use_vocab=False, batch_first=True)

# Test tensor-like batch data which is accepted by both RawField and Field
batch = [[1, 2, 3], [2, 3, 4]]
batch_tensor = torch.LongTensor(batch)

raw_field_processed = raw_field.process(batch)
field_processed = field.process(batch, device=-1, train=False)

assert raw_field_processed == batch
assert field_processed.data.equal(batch_tensor)

# Test non-tensor data which is only accepted by RawField
any_obj = [object() for _ in range(5)]

raw_field_processed = raw_field.process(any_obj)
assert any_obj == raw_field_processed

with pytest.raises(TypeError):
field.process(any_obj)

def test_preprocess(self):
# Default case.
field = data.Field()
Expand Down Expand Up @@ -329,10 +353,10 @@ def test_numerical_features_no_vocab(self):

# Test with postprocessing applied
int_field = data.Field(sequential=False, use_vocab=False,
postprocessing=lambda arr, _: [x + 1 for x in arr])
postprocessing=lambda arr, _, __: [x + 1 for x in arr])
float_field = data.Field(sequential=False, use_vocab=False,
tensor_type=torch.FloatTensor,
postprocessing=lambda arr, _: [x * 0.5 for x in arr])
postprocessing=lambda arr, _, __: [x * 0.5 for x in arr])
tsv_fields = [("int", int_field), ("float", float_field), ("string", None)]
tsv_dataset = data.TabularDataset(
path=self.test_numerical_features_dataset_path, format="tsv",
Expand Down
4 changes: 2 additions & 2 deletions torchtext/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .batch import Batch
from .dataset import Dataset, TabularDataset
from .example import Example
from .field import Field, ReversibleField, SubwordField
from .field import RawField, Field, ReversibleField, SubwordField
from .iterator import (batch, BucketIterator, Iterator, BPTTIterator,
pool)
from .pipeline import Pipeline
Expand All @@ -10,7 +10,7 @@
__all__ = ["Batch",
"Dataset", "TabularDataset", "ZipDataset",
"Example",
"Field", "ReversibleField", "SubwordField",
"RawField", "Field", "ReversibleField", "SubwordField",
"batch", "BucketIterator", "Iterator", "BPTTIterator",
"pool",
"Pipeline",
Expand Down
5 changes: 2 additions & 3 deletions torchtext/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def __init__(self, data=None, dataset=None, device=None, train=True):
self.train = train
for (name, field) in dataset.fields.items():
if field is not None:
setattr(self, name, field.numericalize(
field.pad(x.__dict__[name] for x in data),
device=device, train=train))
batch = [x.__dict__[name] for x in data]
setattr(self, name, field.process(batch, device=device, train=train))

@classmethod
def fromvars(cls, dataset, batch_size, train=True, **kwargs):
Expand Down
75 changes: 68 additions & 7 deletions torchtext/data/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,59 @@
from ..vocab import Vocab, SubwordVocab


class Field(object):
"""Defines a datatype together with instructions for converting to Tensor.
class RawField(object):
""" Defines a general datatype.

Every dataset consists of one or more types of data. For instance, a text
classification dataset contains sentences and their classes, while a
machine translation dataset contains paired examples of text in two
languages. Each of these types of data is represented by a Field object,
which holds a Vocab object that defines the set of possible values for
elements of the field and their corresponding numerical representations.
languages. Each of these types of data is represented by an RawField object.
An RawField object does not assume any property of the data type and
it holds parameters relating to how a datatype should be processed.

Attributes:
preprocessing: The Pipeline that will be applied to examples
using this field before creating an example.
Default: None.
postprocessing: A Pipeline that will be applied to a list of examples
using this field before assigning to a batch.
Function signature: (batch(list)) -> object
Default: None.
"""

def __init__(self, preprocessing=None, postprocessing=None):
self.preprocessing = preprocessing
self.postprocessing = postprocessing

def preprocess(self, x):
""" Preprocess an example if the `preprocessing` Pipeline is provided. """
if self.preprocessing is not None:
return self.preprocessing(x)
else:
return x

def process(self, batch, *args, **kargs):
""" Process a list of examples to create a batch.

Postprocess the batch with user-provided Pipeline.

Args:
batch (list(object)): A list of object from a batch of examples.
Returns:
data (object): Processed object given the input and custom
postprocessing Pipeline.
"""
if self.postprocessing is not None:
batch = self.postprocessing(batch)
return batch


class Field(RawField):
"""Defines a datatype together with instructions for converting to Tensor.

Field class models common text processing datatypes that can be represented
by tensors. It holds a Vocab object that defines the set of possible values
for elements of the field and their corresponding numerical representations.
The Field object also holds other parameters relating to how a datatype
should be numericalized, such as a tokenization method and the kind of
Tensor that should be produced.
Expand All @@ -46,7 +90,9 @@ class Field(object):
Default: None.
postprocessing: A Pipeline that will be applied to examples using
this field after numericalizing but before the numbers are turned
into a Tensor. Default: None.
into a Tensor. The pipeline function takes the batch as a list,
the field's Vocab, and train (a bool).
Default: None.
lower: Whether to lowercase the text in this field. Default: False.
tokenize: The function used to tokenize strings using this field into
sequential examples. If "spacy", the SpaCy English tokenizer is
Expand Down Expand Up @@ -123,6 +169,21 @@ def preprocess(self, x):
else:
return x

def process(self, batch, device, train):
""" Process a list of examples to create a torch.Tensor.

Pad, numericalize, and postprocess a batch and create a tensor.

Args:
batch (list(object)): A list of object from a batch of examples.
Returns:
data (torch.autograd.Varaible): Processed object given the input
and custom postprocessing Pipeline.
"""
padded = self.pad(batch)
tensor = self.numericalize(padded, device=device, train=train)
return tensor

def pad(self, minibatch):
"""Pad a batch of examples using this field.

Expand Down Expand Up @@ -232,7 +293,7 @@ def numericalize(self, arr, device=None, train=True):
if not self.sequential:
arr = [numericalization_func(x) for x in arr]
if self.postprocessing is not None:
arr = self.postprocessing(arr, train)
arr = self.postprocessing(arr, None, train)

arr = self.tensor_type(arr)
if self.sequential and not self.batch_first:
Expand Down