Skip to content

updated text sentiment tutorial #1563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 17 additions & 91 deletions beginner_source/text_sentiment_ngrams_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,32 +49,35 @@
#
# We have revisited the very basic components of the torchtext library, including vocab, word vectors, tokenizer. Those are the basic data processing building blocks for raw text string.
#
# Here is an example for typical NLP data processing with tokenizer and vocabulary. The first step is to build a vocabulary with the raw training dataset. Users can have a customized vocab by setting up arguments in the constructor of the Vocab class. For example, the minimum frequency ``min_freq`` for the tokens to be included.
# Here is an example for typical NLP data processing with tokenizer and vocabulary. The first step is to build a vocabulary with the raw training dataset. Here we use built in
# factory function `build_vocab_from_iterator` which accepts iterator that yield list or iterator of tokens. Users can also pass any special symbols to be added to the
# vocabulary.


from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')
counter = Counter()
for (label, line) in train_iter:
counter.update(tokenizer(line))
vocab = Vocab(counter, min_freq=1)

def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

######################################################################
# The vocabulary block converts a list of tokens into integers.
#
# ::
#
# [vocab[token] for token in ['here', 'is', 'an', 'example']]
# >>> [476, 22, 31, 5298]
# vocab(['here', 'is', 'an', 'example'])
# >>> [475, 21, 30, 5286]
#
# Prepare the text processing pipeline with the tokenizer and vocabulary. The text and label pipelines will be used to process the raw data strings from the dataset iterators.

text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1


Expand Down Expand Up @@ -246,6 +249,7 @@ def evaluate(dataloader):


from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5 # learning rate
Expand All @@ -256,8 +260,8 @@ def evaluate(dataloader):
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = list(train_iter)
test_dataset = list(test_iter)
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
random_split(train_dataset, [num_train, len(train_dataset) - num_train])
Expand Down Expand Up @@ -285,72 +289,6 @@ def evaluate(dataloader):
print('-' * 59)


######################################################################
# Running the model on GPU with the following printout:
#
# ::
#
# | epoch 1 | 500/ 1782 batches | accuracy 0.684
# | epoch 1 | 1000/ 1782 batches | accuracy 0.852
# | epoch 1 | 1500/ 1782 batches | accuracy 0.877
# -----------------------------------------------------------
# | end of epoch 1 | time: 8.33s | valid accuracy 0.867
# -----------------------------------------------------------
# | epoch 2 | 500/ 1782 batches | accuracy 0.895
# | epoch 2 | 1000/ 1782 batches | accuracy 0.900
# | epoch 2 | 1500/ 1782 batches | accuracy 0.903
# -----------------------------------------------------------
# | end of epoch 2 | time: 8.18s | valid accuracy 0.890
# -----------------------------------------------------------
# | epoch 3 | 500/ 1782 batches | accuracy 0.914
# | epoch 3 | 1000/ 1782 batches | accuracy 0.914
# | epoch 3 | 1500/ 1782 batches | accuracy 0.916
# -----------------------------------------------------------
# | end of epoch 3 | time: 8.20s | valid accuracy 0.897
# -----------------------------------------------------------
# | epoch 4 | 500/ 1782 batches | accuracy 0.926
# | epoch 4 | 1000/ 1782 batches | accuracy 0.924
# | epoch 4 | 1500/ 1782 batches | accuracy 0.921
# -----------------------------------------------------------
# | end of epoch 4 | time: 8.18s | valid accuracy 0.895
# -----------------------------------------------------------
# | epoch 5 | 500/ 1782 batches | accuracy 0.938
# | epoch 5 | 1000/ 1782 batches | accuracy 0.935
# | epoch 5 | 1500/ 1782 batches | accuracy 0.937
# -----------------------------------------------------------
# | end of epoch 5 | time: 8.16s | valid accuracy 0.902
# -----------------------------------------------------------
# | epoch 6 | 500/ 1782 batches | accuracy 0.939
# | epoch 6 | 1000/ 1782 batches | accuracy 0.939
# | epoch 6 | 1500/ 1782 batches | accuracy 0.938
# -----------------------------------------------------------
# | end of epoch 6 | time: 8.16s | valid accuracy 0.906
# -----------------------------------------------------------
# | epoch 7 | 500/ 1782 batches | accuracy 0.941
# | epoch 7 | 1000/ 1782 batches | accuracy 0.939
# | epoch 7 | 1500/ 1782 batches | accuracy 0.939
# -----------------------------------------------------------
# | end of epoch 7 | time: 8.19s | valid accuracy 0.903
# -----------------------------------------------------------
# | epoch 8 | 500/ 1782 batches | accuracy 0.942
# | epoch 8 | 1000/ 1782 batches | accuracy 0.941
# | epoch 8 | 1500/ 1782 batches | accuracy 0.942
# -----------------------------------------------------------
# | end of epoch 8 | time: 8.16s | valid accuracy 0.904
# -----------------------------------------------------------
# | epoch 9 | 500/ 1782 batches | accuracy 0.942
# | epoch 9 | 1000/ 1782 batches | accuracy 0.941
# | epoch 9 | 1500/ 1782 batches | accuracy 0.942
# -----------------------------------------------------------
# end of epoch 9 | time: 8.16s | valid accuracy 0.904
# -----------------------------------------------------------
# | epoch 10 | 500/ 1782 batches | accuracy 0.940
# | epoch 10 | 1000/ 1782 batches | accuracy 0.942
# | epoch 10 | 1500/ 1782 batches | accuracy 0.942
# -----------------------------------------------------------
# | end of epoch 10 | time: 8.15s | valid accuracy 0.904
# -----------------------------------------------------------


######################################################################
# Evaluate the model with test dataset
Expand All @@ -366,12 +304,7 @@ def evaluate(dataloader):
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

################################################
#
# ::
#
# test accuracy 0.906
#



######################################################################
Expand Down Expand Up @@ -409,10 +342,3 @@ def predict(text, text_pipeline):

print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])


################################################
#
# ::
#
# This is a Sports news
#