Skip to content

Commit

Permalink
Exposed parameters for csv reader to user through dataset (#432)
Browse files Browse the repository at this point in the history
* Exposed parameters for csv reader to user through dataset

* Added tests
  • Loading branch information
keitakurita authored and mttk committed Sep 29, 2018
1 parent 558fff6 commit 499e327
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
36 changes: 36 additions & 0 deletions test/data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import torchtext.data as data
import tempfile
import six

import pytest

Expand Down Expand Up @@ -206,6 +208,40 @@ def test_csv_file_no_header_one_col_multiple_fields(self):
# 6 Fields including None for ids
assert len(dataset.fields) == 6

def test_csv_dataset_quotechar(self):
# Based on issue #349
example_data = [("text", "label"),
('" hello world', "0"),
('goodbye " world', "1"),
('this is a pen " ', "0")]

with tempfile.NamedTemporaryFile(dir=self.test_dir) as f:
for example in example_data:
f.write(six.b("{}\n".format(",".join(example))))

TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
fields = {
"label": ("label", data.Field(use_vocab=False,
sequential=False)),
"text": ("text", TEXT)
}

f.seek(0)

dataset = data.TabularDataset(
path=f.name, format="csv",
skip_header=False, fields=fields,
csv_reader_params={"quotechar": None})

TEXT.build_vocab(dataset)

self.assertEqual(len(dataset), len(example_data) - 1)

for i, example in enumerate(dataset):
self.assertEqual(example.text,
example_data[i + 1][0].lower().split())
self.assertEqual(example.label, example_data[i + 1][1])

def test_dataset_split_arguments(self):
num_examples, num_labels = 30, 3
self.write_test_splitting_dataset(num_examples=num_examples,
Expand Down
12 changes: 9 additions & 3 deletions torchtext/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def filter_examples(self, field_names):
class TabularDataset(Dataset):
"""Defines a Dataset of columns stored in CSV, TSV, or JSON format."""

def __init__(self, path, format, fields, skip_header=False, **kwargs):
def __init__(self, path, format, fields, skip_header=False,
csv_reader_params={}, **kwargs):
"""Create a TabularDataset given a path, file format, and field list.
Arguments:
Expand All @@ -236,6 +237,11 @@ def __init__(self, path, format, fields, skip_header=False, **kwargs):
This allows the user to rename columns from their JSON/CSV/TSV key names
and also enables selecting a subset of columns to load.
skip_header (bool): Whether to skip the first line of the input file.
csv_reader_params(dict): Parameters to pass to the csv reader.
Only relevant when format is csv or tsv.
See
https://docs.python.org/3/library/csv.html#csv.reader
for more details.
"""
format = format.lower()
make_example = {
Expand All @@ -244,9 +250,9 @@ def __init__(self, path, format, fields, skip_header=False, **kwargs):

with io.open(os.path.expanduser(path), encoding="utf8") as f:
if format == 'csv':
reader = unicode_csv_reader(f)
reader = unicode_csv_reader(f, **csv_reader_params)
elif format == 'tsv':
reader = unicode_csv_reader(f, delimiter='\t')
reader = unicode_csv_reader(f, delimiter='\t', **csv_reader_params)
else:
reader = f

Expand Down

0 comments on commit 499e327

Please sign in to comment.