Skip to content

Commit

Permalink
Added RawField that represents any datatype. (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylegao91 authored and jekbradbury committed Oct 16, 2017
1 parent 247598d commit b579fbe
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 14 deletions.
28 changes: 26 additions & 2 deletions test/data/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,35 @@
from numpy.testing import assert_allclose
import torch
import torchtext.data as data
import pytest

from ..common.torchtext_test_case import TorchtextTestCase, verify_numericalized_example


class TestField(TorchtextTestCase):
def test_process(self):
raw_field = data.RawField()
field = data.Field(sequential=True, use_vocab=False, batch_first=True)

# Test tensor-like batch data which is accepted by both RawField and Field
batch = [[1, 2, 3], [2, 3, 4]]
batch_tensor = torch.LongTensor(batch)

raw_field_processed = raw_field.process(batch)
field_processed = field.process(batch, device=-1, train=False)

assert raw_field_processed == batch
assert field_processed.data.equal(batch_tensor)

# Test non-tensor data which is only accepted by RawField
any_obj = [object() for _ in range(5)]

raw_field_processed = raw_field.process(any_obj)
assert any_obj == raw_field_processed

with pytest.raises(TypeError):
field.process(any_obj)

def test_preprocess(self):
# Default case.
field = data.Field()
Expand Down Expand Up @@ -329,10 +353,10 @@ def test_numerical_features_no_vocab(self):

# Test with postprocessing applied
int_field = data.Field(sequential=False, use_vocab=False,
postprocessing=lambda arr, _: [x + 1 for x in arr])
postprocessing=lambda arr, _, __: [x + 1 for x in arr])
float_field = data.Field(sequential=False, use_vocab=False,
tensor_type=torch.FloatTensor,
postprocessing=lambda arr, _: [x * 0.5 for x in arr])
postprocessing=lambda arr, _, __: [x * 0.5 for x in arr])
tsv_fields = [("int", int_field), ("float", float_field), ("string", None)]
tsv_dataset = data.TabularDataset(
path=self.test_numerical_features_dataset_path, format="tsv",
Expand Down
4 changes: 2 additions & 2 deletions torchtext/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .batch import Batch
from .dataset import Dataset, TabularDataset
from .example import Example
from .field import Field, ReversibleField, SubwordField
from .field import RawField, Field, ReversibleField, SubwordField
from .iterator import (batch, BucketIterator, Iterator, BPTTIterator,
pool)
from .pipeline import Pipeline
Expand All @@ -10,7 +10,7 @@
__all__ = ["Batch",
"Dataset", "TabularDataset", "ZipDataset",
"Example",
"Field", "ReversibleField", "SubwordField",
"RawField", "Field", "ReversibleField", "SubwordField",
"batch", "BucketIterator", "Iterator", "BPTTIterator",
"pool",
"Pipeline",
Expand Down
5 changes: 2 additions & 3 deletions torchtext/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def __init__(self, data=None, dataset=None, device=None, train=True):
self.train = train
for (name, field) in dataset.fields.items():
if field is not None:
setattr(self, name, field.numericalize(
field.pad(x.__dict__[name] for x in data),
device=device, train=train))
batch = [x.__dict__[name] for x in data]
setattr(self, name, field.process(batch, device=device, train=train))

@classmethod
def fromvars(cls, dataset, batch_size, train=True, **kwargs):
Expand Down
75 changes: 68 additions & 7 deletions torchtext/data/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,59 @@
from ..vocab import Vocab, SubwordVocab


class Field(object):
"""Defines a datatype together with instructions for converting to Tensor.
class RawField(object):
""" Defines a general datatype.
Every dataset consists of one or more types of data. For instance, a text
classification dataset contains sentences and their classes, while a
machine translation dataset contains paired examples of text in two
languages. Each of these types of data is represented by a Field object,
which holds a Vocab object that defines the set of possible values for
elements of the field and their corresponding numerical representations.
languages. Each of these types of data is represented by an RawField object.
An RawField object does not assume any property of the data type and
it holds parameters relating to how a datatype should be processed.
Attributes:
preprocessing: The Pipeline that will be applied to examples
using this field before creating an example.
Default: None.
postprocessing: A Pipeline that will be applied to a list of examples
using this field before assigning to a batch.
Function signature: (batch(list)) -> object
Default: None.
"""

def __init__(self, preprocessing=None, postprocessing=None):
self.preprocessing = preprocessing
self.postprocessing = postprocessing

def preprocess(self, x):
""" Preprocess an example if the `preprocessing` Pipeline is provided. """
if self.preprocessing is not None:
return self.preprocessing(x)
else:
return x

def process(self, batch, *args, **kargs):
""" Process a list of examples to create a batch.
Postprocess the batch with user-provided Pipeline.
Args:
batch (list(object)): A list of object from a batch of examples.
Returns:
data (object): Processed object given the input and custom
postprocessing Pipeline.
"""
if self.postprocessing is not None:
batch = self.postprocessing(batch)
return batch


class Field(RawField):
"""Defines a datatype together with instructions for converting to Tensor.
Field class models common text processing datatypes that can be represented
by tensors. It holds a Vocab object that defines the set of possible values
for elements of the field and their corresponding numerical representations.
The Field object also holds other parameters relating to how a datatype
should be numericalized, such as a tokenization method and the kind of
Tensor that should be produced.
Expand All @@ -46,7 +90,9 @@ class Field(object):
Default: None.
postprocessing: A Pipeline that will be applied to examples using
this field after numericalizing but before the numbers are turned
into a Tensor. Default: None.
into a Tensor. The pipeline function takes the batch as a list,
the field's Vocab, and train (a bool).
Default: None.
lower: Whether to lowercase the text in this field. Default: False.
tokenize: The function used to tokenize strings using this field into
sequential examples. If "spacy", the SpaCy English tokenizer is
Expand Down Expand Up @@ -123,6 +169,21 @@ def preprocess(self, x):
else:
return x

def process(self, batch, device, train):
""" Process a list of examples to create a torch.Tensor.
Pad, numericalize, and postprocess a batch and create a tensor.
Args:
batch (list(object)): A list of object from a batch of examples.
Returns:
data (torch.autograd.Varaible): Processed object given the input
and custom postprocessing Pipeline.
"""
padded = self.pad(batch)
tensor = self.numericalize(padded, device=device, train=train)
return tensor

def pad(self, minibatch):
"""Pad a batch of examples using this field.
Expand Down Expand Up @@ -232,7 +293,7 @@ def numericalize(self, arr, device=None, train=True):
if not self.sequential:
arr = [numericalization_func(x) for x in arr]
if self.postprocessing is not None:
arr = self.postprocessing(arr, train)
arr = self.postprocessing(arr, None, train)

arr = self.tensor_type(arr)
if self.sequential and not self.batch_first:
Expand Down

0 comments on commit b579fbe

Please sign in to comment.