Skip to content

Commit

Permalink
[WIP] add new APIs to build dataset (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangguanheng66 authored and cpuhrsch committed Jul 23, 2019
1 parent 284a516 commit 1ebee35
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 3 deletions.
18 changes: 16 additions & 2 deletions torchtext/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,22 @@ def stratify(examples, strata_field):


def rationed_split(examples, train_ratio, test_ratio, val_ratio, rnd):
# Create a random permutation of examples, then split them
# by ratio x length slices for each of the train/test/dev? splits
"""Create a random permutation of examples, then split them by ratios
Arguments:
examples: a list of data
train_ratio, test_ratio, val_ratio: split fractions.
rnd: a random shuffler
Examples:
>>> examples = []
>>> train_ratio, test_ratio, val_ratio = 0.7, 0.2, 0.1
>>> rnd = torchtext.data.dataset.RandomShuffler(None)
>>> train_examples, test_examples, valid_examples = \
torchtext.data.dataset.rationed_split(examples, train_ratio,
test_ratio, val_ratio,
rnd)
"""
N = len(examples)
randperm = rnd(range(N))
train_len = int(round(train_ratio * N))
Expand Down
23 changes: 23 additions & 0 deletions torchtext/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,29 @@ def dtype_to_attr(dtype):
return dtype


def generate_ngrams(token_list, ngrams):
"""Generate a list of token up to ngrams.
Arguments:
token_list: A list of tokens
ngrams: the number of ngrams.
Examples:
>>> token_list = ['here', 'we', 'are']
>>> torchtext.data.utils.generate_ngrams(token_list, 2)
>>> ['here', 'here we', 'we', 'we are', 'are']
"""

re_list = []
for i in range(0, len(token_list)):
x = token_list[i]
re_list.append(x)
for j in range(i + 1, min(i + ngrams, len(token_list))):
x += ' ' + token_list[j]
re_list.append(x)
return re_list


class RandomShuffler(object):
"""Use random functions while keeping track of the random state to make it
reproducible and deterministic."""
Expand Down
43 changes: 42 additions & 1 deletion torchtext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import requests
import csv
from tqdm import tqdm
import os
import tarfile


def reporthook(t):
Expand All @@ -25,7 +27,18 @@ def inner(b=1, bsize=1, tsize=None):


def download_from_url(url, path):
"""Download file, with logic (from tensor2tensor) for Google Drive"""
"""Download file, with logic (from tensor2tensor) for Google Drive
Arguments:
url: the url for online Dataset
path: directory and filename for the downloaded dataset.
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> path = './validation.tar.gz'
>>> torchtext.utils.download_from_url(url, path)
"""

def process_response(r):
chunk_size = 16 * 1024
total_size = int(r.headers.get('Content-length', 0))
Expand Down Expand Up @@ -75,3 +88,31 @@ def unicode_csv_reader(unicode_csv_data, **kwargs):
def utf_8_encoder(unicode_csv_data):
for line in unicode_csv_data:
yield line.encode('utf-8')


def extract_archive(from_path, to_path=None, remove_finished=False):
"""Extract tar.gz archives.
Arguments:
from_path: the path where the tar.gz file is.
to_path: the path where the extracted files are.
remove_finished: remove the original tar.gz file. Default: False
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz'
>>> to_path = './'
>>> torchtext.utils.download_from_url(url, from_path)
>>> torchtext.utils.extract_archive(from_path, to_path)
"""
if to_path is None:
to_path = os.path.dirname(from_path)

if from_path.endswith(".tar.gz"):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))

if remove_finished:
os.remove(from_path)
30 changes: 30 additions & 0 deletions torchtext/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import tarfile

from .utils import reporthook
from collections import Counter, OrderedDict
from itertools import chain

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,6 +218,34 @@ def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_):
self.vectors[i] = unk_init(self.vectors[i])


def build_dictionary(dataset, field, data_name, **kwargs):
"""Construct the Vocab object for the field from a dataset.
Arguments:
dataset: Dataset with the iterable data.
field: Field object with the information of the special tokens.
data_name: The names of data used to build vocab (e.g. 'text', 'label').
It must be the attributes of dataset's examples.
Remaining keyword arguments: Passed to the constructor of Vocab.
Examples:
>>> field.vocab = build_vocab(dataset, field, 'text')
"""
counter = Counter()
for x in dataset:
x = getattr(x, data_name)
if not field.sequential:
x = [x]
try:
counter.update(x)
except TypeError:
counter.update(chain.from_iterable(x.text))
specials = list(OrderedDict.fromkeys(
tok for tok in [field.unk_token, field.pad_token, field.init_token,
field.eos_token] if tok is not None))
return Vocab(counter, specials=specials, **kwargs)


class SubwordVocab(Vocab):

def __init__(self, counter, max_size=None, specials=['<pad>'],
Expand Down

0 comments on commit 1ebee35

Please sign in to comment.