Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline bugfixes, docstrings, and tests #113

Merged
merged 9 commits into from
Sep 11, 2017
54 changes: 54 additions & 0 deletions test/data/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import six
import torchtext.data as data

from ..common.torchtext_test_case import TorchtextTestCase


class TestPipeline(TorchtextTestCase):
@staticmethod
def repeat_n(x, n=3):
"""
Given a sequence, repeat it n times.
"""
return x * n

def test_pipeline(self):
id_pipeline = data.Pipeline()
assert id_pipeline("Test STring") == "Test STring"
assert id_pipeline("ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T") == "ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T"
assert id_pipeline(["1241", "Some String"]) == ["1241", "Some String"]

pipeline = data.Pipeline(six.text_type.lower)
assert pipeline("Test STring") == "test string"
assert pipeline("ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T") == "ᑌᑎiᑕoᗪᕮ_tᕮ᙭t"
assert pipeline(["1241", "Some String"]) == ["1241", "some string"]

args_pipeline = data.Pipeline(TestPipeline.repeat_n)
assert args_pipeline("test", 5) == "testtesttesttesttest"
assert args_pipeline(["ele1", "ele2"], 2) == ["ele1ele1", "ele2ele2"]

def test_composition(self):
id_pipeline = data.Pipeline()
pipeline = data.Pipeline(TestPipeline.repeat_n)
pipeline.add_before(id_pipeline)
pipeline.add_after(id_pipeline)
pipeline.add_before(six.text_type.lower)
pipeline.add_after(six.text_type.capitalize)

other_pipeline = data.Pipeline(six.text_type.swapcase)
other_pipeline.add_before(pipeline)

# Assert pipeline gives proper results after composition
# (test that we aren't modfifying pipes member)
assert pipeline("teST") == "Testtesttest"
assert pipeline(["ElE1", "eLe2"]) == ["Ele1ele1ele1", "Ele2ele2ele2"]

# Assert pipeline that we added to gives proper results
assert other_pipeline("teST") == "tESTTESTTEST"
assert other_pipeline(["ElE1", "eLe2"]) == ["eLE1ELE1ELE1", "eLE2ELE2ELE2"]

def test_exceptions(self):
with self.assertRaises(ValueError):
data.Pipeline("Not Callable")
68 changes: 61 additions & 7 deletions torchtext/data/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,85 @@
class Pipeline(object):
"""Defines a pipeline for transforming sequence data."""
"""Defines a pipeline for transforming sequence data.

The input is assumed to be utf-8 encoded `str` (Python 3) or
`unicode` (Python 2).

Attributes:
convert_token: The function to apply to input sequence data.
pipes: The Pipelines that will be applid to input sequence
data in order.
"""
def __init__(self, convert_token=None):
if convert_token is not None:
"""Create a pipeline.

Arguments:
convert_token: The function to apply to input sequence data.
If None, the identity function is used. Default: None
"""
if convert_token is None:
self.convert_token = Pipeline.identity
elif callable(convert_token):
self.convert_token = convert_token
else:
self.convert_token = lambda x: x
raise ValueError("Pipeline input convert_token {} is not None "
"or callable".format(convert_token))
self.pipes = [self]

def __call__(self, x, *args):
"""Apply the the current Pipeline(s) to an input.

Arguments:
x: The input to process with the Pipeline(s).
Positional arguments: Forwarded to the `call` function
of the Pipeline(s).
"""
for pipe in self.pipes:
x = pipe.call(x)
x = pipe.call(x, *args)
return x

def call(self, x, *args):
"""Apply _only_ the convert_token function of the current pipeline
to the input. If the input is a list, a list with the results of
applying the `convert_token` function to all input elements is
returned.

Arguments:
x: The input to apply the convert_token function to.
Positional arguments: Forwarded to the `convert_token` function
of the current Pipeline.
"""
if isinstance(x, list):
return [self(tok, *args) for tok in x]
return [self.convert_token(tok, *args) for tok in x]
return self.convert_token(x, *args)

def add_before(self, pipeline):
"""Add `pipeline` before this processing pipeline."""
"""Add a Pipeline to be applied before this processing pipeline.

Arguments:
pipeline: The Pipeline or callable to apply before this
Pipeline.
"""
if not isinstance(pipeline, Pipeline):
pipeline = Pipeline(pipeline)
self.pipes = pipeline.pipes[:] + self.pipes[:]
return self

def add_after(self, pipeline):
"""Add `pipeline` after this processing pipeline."""
"""Add a Pipeline to be applied after this processing pipeline.

Arguments:
pipeline: The Pipeline or callable to apply after this
Pipeline.
"""
if not isinstance(pipeline, Pipeline):
pipeline = Pipeline(pipeline)
self.pipes = self.pipes[:] + pipeline.pipes[:]
return self

@staticmethod
def identity(x):
"""Return a copy of the input.

This is here for serialization compatibility with pickle.
"""
return x