Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Vocab for supervised learning datasets #567

Merged
merged 69 commits into from
Jul 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
c30ec38
add new APIs to build dataset.
Jul 9, 2019
ffd1f49
Add new datasets for text classification.
Jul 11, 2019
3b7b0e2
Add docs and examples.
Jul 9, 2019
5a31dd3
Split text_normalize out of preprocess function.
Jul 11, 2019
5efa58e
Add docs and test case.
Jul 11, 2019
844242a
Update README file.
Jul 12, 2019
b373de9
revise generate_iters() function.
Jul 22, 2019
6d5cb03
Remove TextDataset class.
Jul 22, 2019
3f0c523
Remove generate_iterators() API
Jul 22, 2019
2f20914
remove unnecessary library loading
Jul 22, 2019
57d0d03
Re-name build_vocab to build_dictionary
Jul 22, 2019
4cf4099
change build_vocab to build_dictionary.
Jul 22, 2019
c8ec403
convert two functions to the interanls.
Jul 22, 2019
0568a04
Change the API of _load_text_classification_data() function.
Jul 22, 2019
78673a5
use a static list for url.
Jul 22, 2019
58e3bac
use logging.info as print.
Jul 22, 2019
81e5a31
combine download and extract_archive
Jul 22, 2019
e05d7fe
Merge branch 'master' into new_pattern
cpuhrsch Jul 23, 2019
7ffb267
Merge branch 'new_supervised_learning_dataset' into new_pattern
Jul 23, 2019
e138fa8
examples
Jul 23, 2019
1e9f0e1
remove more
Jul 23, 2019
c746d86
less
Jul 23, 2019
fea3bad
split
Jul 24, 2019
5c90fbc
ordered dict
Jul 24, 2019
ba23ae1
Merge remote-tracking branch 'upstream/master' into tutorial
Jul 24, 2019
3df4dc1
rename
Jul 24, 2019
ea639c2
Simplifications
Jul 24, 2019
193a670
clean more
Jul 24, 2019
285a515
more efficient dictionary building
Jul 24, 2019
fc1fcc1
Merge branch 'master' into tutorial
Jul 24, 2019
3e27dcd
Reduce code
Jul 24, 2019
4678478
tar and extraction
Jul 24, 2019
2a18586
Merge branch 'additionalstuff' into tutorial
Jul 24, 2019
ee9894f
rebase
Jul 24, 2019
197c70d
remove legacy
Jul 24, 2019
bc2369f
more logging and args
Jul 25, 2019
0e81889
more
Jul 25, 2019
75fd515
small changes
Jul 25, 2019
e7ea6c2
More small changes
Jul 25, 2019
accf587
Update docs
Jul 25, 2019
5506a2e
bring back examples
Jul 25, 2019
2c8c4bf
bring back examples
Jul 25, 2019
28b0976
small fix
Jul 25, 2019
93f2f18
class to function
Jul 25, 2019
2809221
Use DataLoader
Jul 25, 2019
ae540b0
format
Jul 25, 2019
5520b51
Small test fix
Jul 25, 2019
9013bbd
Merge branch 'additionalstuff' into splitmore
Jul 25, 2019
2313ffb
More logging and nits
Jul 25, 2019
ff9132d
Use io.open
Jul 25, 2019
3a0db45
Merge branch 'additionalstuff' into splitmore
Jul 25, 2019
eab2708
build vocab from iterator
Jul 25, 2019
331bf79
flake8
Jul 25, 2019
90f8bbb
Merge branch 'additionalstuff' into splitmore
Jul 25, 2019
2460eaf
flake8
Jul 25, 2019
c6dc406
Merge branch 'splitmore' into vocab
Jul 25, 2019
b0edd62
flake8
Jul 25, 2019
8d7f014
more iterators
Jul 25, 2019
221b064
remove one include
Jul 25, 2019
73772e2
flake8
Jul 25, 2019
6dcc052
Merge branch 'additionalstuff' into splitmore
Jul 25, 2019
b4e5067
flake8
Jul 25, 2019
63b8e2e
Merge branch 'master' into splitmore
Jul 26, 2019
bc0c3be
Merge branch 'splitmore' into vocab
Jul 26, 2019
c1dc6ae
Merge remote-tracking branch 'upstream/master' into splitmore
Jul 26, 2019
7256ffc
Deal with tests
Jul 26, 2019
94fd475
Merge branch 'splitmore' into vocab
Jul 26, 2019
7af6a2d
Formatting
Jul 26, 2019
ed5a99e
Merge remote-tracking branch 'upstream/master' into vocab
Jul 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/text_classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ def init_weights(self):
self.fc.bias.data.zero_()

def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
return self.fc(self.embedding(text, offsets))
72 changes: 38 additions & 34 deletions examples/text_classification/train.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,38 @@
import os
import logging
import random
import argparse

import torch

from torchtext.datasets.text_classification import AG_NEWS
from torchtext.datasets import text_classification
from torch.utils.data import DataLoader

from model import TextSentiment


def generate_offsets(data_batch):
offsets = [0]
for entry in data_batch:
offsets.append(offsets[-1] + len(entry))
offsets = torch.tensor(offsets[:-1])
return offsets
def generate_batch(batch):

def generate_offsets(data_batch):
offsets = [0]
for entry in data_batch:
offsets.append(offsets[-1] + len(entry))
offsets = torch.tensor(offsets[:-1])
return offsets

def generate_batch(data, labels, i, batch_size):
data_batch = data[i:i + batch_size]
text = torch.cat(data_batch)
offsets = generate_offsets(data_batch)
cls = torch.tensor(labels[i:i + batch_size])
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
cls = torch.tensor([entry[0] for entry in batch])
text = [entry[1] for entry in batch]
offsets = generate_offsets(text)
text = torch.cat(text)
return text, offsets, cls


def train(lr_, num_epoch, data, labels):
def train(lr_, num_epoch, data_):
data = DataLoader(data_, batch_size=batch_size, shuffle=True,
collate_fn=generate_batch, num_workers=args.num_workers)
num_lines = num_epochs * len(data)
for epoch in range(num_epochs):
perm = list(range(len(data)))
random.shuffle(perm)
data = [data[i] for i in perm]
labels = [labels[i] for i in perm]

for i in range(0, len(data), batch_size):
text, offsets, cls = generate_batch(data, labels, i, batch_size)
for i, (text, offsets, cls) in enumerate(data):
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
output = model(text, offsets)
loss = criterion(output, cls)
loss.backward()
Expand All @@ -50,11 +46,12 @@ def train(lr_, num_epoch, data, labels):
print("")


def test(data, labels):
def test(data_):
data = DataLoader(data_, batch_size=batch_size, collate_fn=generate_batch)
total_accuracy = []
for i in range(0, len(data), batch_size):
for text, offsets, cls in data:
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
with torch.no_grad():
text, offsets, cls = generate_batch(data, labels, i, batch_size)
output = model(text, offsets)
accuracy = (output.argmax(1) == cls).float().mean().item()
total_accuracy.append(accuracy)
Expand All @@ -64,15 +61,18 @@ def test(data, labels):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Train a text classification model on AG_NEWS')
parser.add_argument('dataset', choices=text_classification.DATASETS)
parser.add_argument('--num-epochs', type=int, default=3)
parser.add_argument('--embed-dim', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--lr', type=float, default=64.0)
parser.add_argument('--ngrams', type=int, default=2)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--device', default='cpu')
parser.add_argument('--data', default='.data')
parser.add_argument('--save-model-path')
parser.add_argument('--save-dictionary-path')
parser.add_argument('--save-vocab-path')
parser.add_argument('--load-vocab-path')
parser.add_argument('--logging-level', default='WARNING')
args = parser.parse_args()

Expand All @@ -82,24 +82,28 @@ def test(data, labels):
lr = args.lr
device = args.device
data = args.data
vocab = args.load_vocab_path

logging.basicConfig(level=getattr(logging, args.logging_level))

if not os.path.exists(data):
print("Creating directory {}".format(data))
os.mkdir(data)

dataset = AG_NEWS(root=data, ngrams=args.ngrams)
model = TextSentiment(len(dataset.dictionary), embed_dim,
len(set(dataset.labels))).to(device)
train_dataset, test_dataset = text_classification.DATASETS[args.dataset](
root=data, ngrams=args.ngrams, vocab=vocab)

if args.save_vocab_path:
print("Saving vocab to {}".format(args.save_vocab_path))
torch.save(train_dataset.get_vocab, args.save_vocab_path)

model = TextSentiment(len(train_dataset.get_vocab()),
embed_dim, len(train_dataset.get_labels())).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)

train(lr, num_epochs, dataset.train_data, dataset.train_labels)
test(dataset.test_data, dataset.test_labels)
train(lr, num_epochs, train_dataset)
test(test_dataset)

if args.save_model_path:
print("Saving model to {}".format(args.save_model_path))
torch.save(model.to('cpu'), args.save_model_path)
if args.save_dictionary_path:
print("Saving dictionary to {}".format(args.save_dictionary_path))
torch.save(dataset.dictionary, args.save_dictionary_path)
5 changes: 5 additions & 0 deletions test/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,8 @@ def test_vectors_get_vecs(self):
for dim in ["50", "100", "200", "300"]:
conditional_remove(os.path.join(self.project_root, ".vector_cache",
"glove.6B.{}d.txt".format(dim)))

def test_has_unk(self):
c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2})
v = vocab.Vocab(c)
self.assertEqual(v['not_in_it'], 0)
12 changes: 5 additions & 7 deletions torchtext/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,25 @@ def dtype_to_attr(dtype):
return dtype


def generate_ngrams(token_list, ngrams):
"""Generate a list of token up to ngrams.
def ngrams_iterator(token_list, ngrams):
"""Return an iterator that yields the given tokens and their ngrams.

Arguments:
token_list: A list of tokens
ngrams: the number of ngrams.

Examples:
>>> token_list = ['here', 'we', 'are']
>>> torchtext.data.utils.generate_ngrams(token_list, 2)
>>> list(ngrams_iterator(token_list, 2))
>>> ['here', 'here we', 'we', 'we are', 'are']
"""

re_list = []
for i in range(0, len(token_list)):
x = token_list[i]
re_list.append(x)
yield x
for j in range(i + 1, min(i + ngrams, len(token_list))):
x += ' ' + token_list[j]
re_list.append(x)
return re_list
yield x


class RandomShuffler(object):
Expand Down