Skip to content

Commit

Permalink
Redesigned SST2Dataset to subclass IterableDataset (#1421)
Browse files Browse the repository at this point in the history
* Updated SST2Dataset to subclass IterableDataset. Updated SST2 functional call to return SST2Dataset object

* Updated get_datapipe to be private, passed class parameters directly into get_datapipe function

Co-authored-by: nayef211 <n63ahmed@edu.uwaterloo.ca>
  • Loading branch information
Nayef211 and nayef211 committed Oct 22, 2021
1 parent 0153ead commit bcc1455
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
12 changes: 8 additions & 4 deletions test/experimental/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@ class TestDataset(TorchtextTestCase):
@skipIfNoModule("torchdata")
def test_sst2_dataset(self):
split = ("train", "dev", "test")
train_dp, dev_dp, test_dp = sst2.SST2(split=split)
train_dataset, dev_dataset, test_dataset = sst2.SST2(split=split)

# verify datasets objects are instances of SST2Dataset
for dataset in (train_dataset, dev_dataset, test_dataset):
self.assertTrue(isinstance(dataset, sst2.SST2Dataset))

# verify hashes of first line in dataset
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8")
json.dumps(next(iter(train_dataset)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["train"],
)
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8")
json.dumps(next(iter(dev_dataset)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["dev"],
)
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8")
json.dumps(next(iter(test_dataset)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["test"],
)
22 changes: 12 additions & 10 deletions torchtext/experimental/datasets/sst2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import os

from torch.utils.data.dataset import IterableDataset
from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_add_docstring_header,
Expand Down Expand Up @@ -50,10 +51,10 @@
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev", "test"))
def SST2(root, split):
return SST2Dataset(root, split).get_datapipe()
return SST2Dataset(root, split)


class SST2Dataset:
class SST2Dataset(IterableDataset):
"""The SST2 dataset uses torchdata datapipes end-2-end.
To avoid download at every epoch, we cache the data on-disk
We do sanity check on dowloaded and extracted data
Expand All @@ -67,26 +68,27 @@ def __init__(self, root, split):
"how to install the package."
)

self.root = root
self.split = split
self._dp = self._get_datapipe(root, split)

def get_datapipe(self):
def __iter__(self):
for data in self._dp:
yield data

def _get_datapipe(self, root, split):
# cache data on-disk
cache_dp = IterableWrapper([URL]).on_disk_cache(
HttpReader,
op_map=lambda x: (x[0], x[1].read()),
filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)),
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
)

# do sanity check
check_cache_dp = cache_dp.check_hash(
{os.path.join(self.root, "SST-2.zip"): MD5}, "md5"
{os.path.join(root, "SST-2.zip"): MD5}, "md5"
)

# extract data from zip
extracted_files = check_cache_dp.read_from_zip().filter(
lambda x: self.split in x[0]
)
extracted_files = check_cache_dp.read_from_zip().filter(lambda x: split in x[0])

# Parse CSV file and yield data samples
return extracted_files.parse_csv(skip_lines=1, delimiter="\t").map(
Expand Down

0 comments on commit bcc1455

Please sign in to comment.