<a href="https://colab.research.google.com/github/zhangguanheng66/tutorials/blob/sentiment_analysis/Torchtext_with_text_classification_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%shell

# pip uninstall torch torchtext
rm -r /usr/local/lib/python3.6/dist-packages/torch*
pip install --pre torch torchtext -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html
# pip install --pre torch==1.8.0.dev20201008+cu101 torchtext==0.8.0.dev20201008 -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html

In [50]:
import torch
import torchtext

# Prototype pipeline with the new torchtext library

In this tutorial, we will show how to use the new torchtext library to build the dataset for the text classification analysis. In the nightly release of torchtext libraries, we provide a few prototype building blocks for data processing. With the new torchtext library, you will have the flexibility to

*   Access to the raw data as an iterator
*   Build data processing pipeline to convert the raw text strings into `torch.Tensor` that can be used to train the model
*   Shuffle and iterate the data with [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader)


## Step 1: Access to the raw dataset iterators
----------------------------

For some advanced users, they prefer to work on the raw data strings with their custom data process pipeline. The new torchtext library provides a few raw dataset iterators, which yield the raw text strings. For example, the AG_NEWS dataset iterators yield the raw data as a tuple of label and text.

In [51]:
from torchtext.experimental.datasets.raw import AG_NEWS
train_iter, = AG_NEWS(data_select=('train'))



```
next(iter(train_iter))
>>> (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - 
Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green 
again.")

next(iter(train_iter))
>>> (3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private 
investment firm Carlyle Group,\\which has a reputation for making well-timed 
and occasionally\\controversial plays in the defense industry, has quietly 
placed\\its bets on another part of the market.')

next(iter(train_iter))
>>> (3, "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring 
crude prices plus worries\\about the economy and the outlook for earnings are 
expected to\\hang over the stock market next week during the depth of 
the\\summer doldrums.")
```



## Step 2: Prepare data processing pipelines
----------------------------
We have revisited the very basic components of the torchtext library, including vocab, word vectors, tokenizer backed by regular expression, and sentencepiece. Those are the basic data processing building blocks for raw text string.

### 2.1 Tokenizer-vocabulary data processing pipeline

Here is an example for typical NLP data processing with tokenizer and vocabulary.

The first step is to build a vocabulary with the raw training dataset. We provide a function `build_vocab_from_iterator` to build the vocabulary from a text iterator.

In [52]:
from torchtext.experimental.vocab import build_vocab_from_iterator
from torchtext.experimental.transforms import basic_english_normalize
tokenizer = basic_english_normalize()
train_iter, = AG_NEWS(data_select=('train',))
vocab = build_vocab_from_iterator(iter(tokenizer(line) for label, line in train_iter))



The vocabulary block converts a list of tokens into integers.
```
vocab(['here', 'is', 'an', 'example'])
>>> [475, 21, 30, 5286]
```

Prepare data pipeline with the tokenizer and vocabulary. The pipelines will be used for the raw data strings from the dataset iterators.

In [53]:
def generate_text_pipeline(tokenizer, vocab):
  def _forward(text):
    return vocab(tokenizer(text))
  return _forward
text_pipeline = generate_text_pipeline(basic_english_normalize(), vocab)
label_pipeline = lambda x: int(x) - 1

The text pipeline converts a text string into a list of integers based on the lookup defined in the vocab. The label pipeline converts the label into integers. For example,

```
text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5286]
label_pipeline('10')
>>> 9
```

### 2.2 SentencePiece data processing pipeline

---



SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation systems where the vocabulary size is predetermined prior to the neural model training. For sentencepiece transforms in torchtext, both subword units (e.g., byte-pair-encoding (BPE) ) and unigram language model are supported. We provide a few pretrained SentencePiece models and they are accessable from `PRETRAINED_SP_MODEL`. Here is an example to apply SentencePiece transform to build the dataset.

In [54]:
from torchtext.experimental.transforms import (
    PRETRAINED_SP_MODEL,
    sentencepiece_processor,
    load_sp_model,
)
from torchtext.utils import download_from_url
spm_filepath = download_from_url(PRETRAINED_SP_MODEL['text_unigram_25000'])
spm_transform = sentencepiece_processor(spm_filepath)
sp_model = load_sp_model(spm_filepath)

The sentecepiece processor converts a text string into a list of integers. You can use the `decode` method to convert a list of integers back to the original string.

```
spm_transform('here is the an example')
>>> [130, 46, 9, 76, 1798]
spm_transform.decode([6468, 17151, 4024, 8246, 16887, 87, 23985, 12, 581, 15120])
>>> 'torchtext sentencepiece processor can encode and decode'
```



## Step 3: Generate data batch and iterator¶

The PyTorch data loading utility is the [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader) class. It works with a map-style dataset that implements the `getitem()` and `len()` protocols, and represents a map from indices/keys to data samples. It also works with an iterable datasets with the shuffle argumnet of `False`. 

Before sending to the model, `collate_fn` function works on a batch of samples generated from DataLoader. The input to `collat_fn` is a batch of data with the batch size in `DataLoader`, and `collate_fn` processes them according to the data processing pipelines declared on Step 2. Pay attention here and make sure that collate_fn is declared as a top level def. This ensures that the function is available in each worker. 

In this example, the text entries in the original data batch input are packed into a list and concatenated as a single tensor for the input of `nn.EmbeddingBag`. The offset is a tensor of delimiters to represent the beginning index of the individual sequence in the text tensor. Label is a tensor saving the labels of indidividual text entries.

In [60]:
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
         label_list.append(label_pipeline(_label))
         processed_text = torch.tensor(spm_transform(_text), dtype=torch.int64)
         text_list.append(processed_text)
         offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)    

train_iter, = AG_NEWS(data_select=('train'))
dataloader = DataLoader(list(train_iter), batch_size=8, shuffle=True, collate_fn=collate_batch)

## Step 4: Model for text classification task
---

We use a simple model here for the text classification analysis. The model is composed of the [nn.EmbeddingBag](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#embeddingbag) layer plus a linear layer for the classification purpose. `nn.EmbeddingBag` computes the mean value of a “bag” of embeddings. Although the text entries here have different lengths, `nn.EmbeddingBag` module requires no padding here since the text lengths are saved in offsets.

In [65]:
from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

We build a model with the embedding dimension of 64. The AG_NEWS dataset has four labels and therefore the number of classes is four.

*   1 : World
*   2 : Sports
*   3 : Business
*   4 : Sci/Tec



In [66]:
from torchtext.experimental.datasets.text_classification import LABELS
num_class = len(LABELS['AG_NEWS'])
vocab_size = sp_model.GetPieceSize()
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)


## Step 5: Train and test the model
---

Then, we train and test the model with the text classification datasets

In [67]:
import time

def train(model, dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predited_label = model(text, offsets)
        loss = criterion(predited_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | '
                  'ms/batch {:5.2f} '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              scheduler.get_last_lr()[0],
                                              elapsed * 1000 / log_interval,
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(model, dataloader):
    model.eval()
    total_acc, total_count = 0, 0
    ans_pred_tokens_samples = []

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model(text, offsets)
            loss = criterion(predited_label, label)
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

[CrossEntropyLoss](https://pytorch.org/docs/stable/nn.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss) criterion combines `nn.LogSoftmax()` and `nn.NLLLoss()` in a single class. It is useful when training a classification problem with C classes. [SGD](https://pytorch.org/docs/stable/_modules/torch/optim/sgd.html) implements stochastic gradient descent method as optimizer. [StepLR](https://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html#StepLR) is used here to adjust the learning rate through epochs.

Here are a few hyperparameters used in the pipeline


*   The number of epoches - 10
*   The initial learning rate - 5.0
*   The batch size - 64
*   The maximum sequence length - 768

Since the original AG_NEWS has no valid dataset, we split the training dataset into train/valid sets with a split ratio of 0.95 (train) and 0.05 (valid). Here we use [torch.utils.data.dataset.random_split](https://pytorch.org/docs/stable/data.html?highlight=random_split#torch.utils.data.random_split) function in PyTorch core library.



In [68]:
from torch.utils.data.dataset import random_split
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training
  
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = list(train_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(list(test_iter), batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(model, train_dataloader)
    accu_val = evaluate(model, valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 89)
print('Checking the results of test dataset...')
accu_test = evaluate(model, test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

| epoch   1 |   500/ 1782 batches | lr 5.00000 | ms/batch 11.80 | accuracy    0.641
| epoch   1 |  1000/ 1782 batches | lr 5.00000 | ms/batch 11.80 | accuracy    0.841
| epoch   1 |  1500/ 1782 batches | lr 5.00000 | ms/batch 11.51 | accuracy    0.872
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 21.68s | valid accuracy    0.885 
-----------------------------------------------------------------------------------------
| epoch   2 |   500/ 1782 batches | lr 5.00000 | ms/batch 12.23 | accuracy    0.893
| epoch   2 |  1000/ 1782 batches | lr 5.00000 | ms/batch 12.09 | accuracy    0.897
| epoch   2 |  1500/ 1782 batches | lr 5.00000 | ms/batch 11.43 | accuracy    0.901
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 22.09s | valid accuracy    0.904 
-----------------------------------------------------------------------------------------
| epoch   3 | 