##  Multiclass Text Classification 
### (EmbeddingBag, Linear Layer)

### 1. Prepare Dataset

In [1]:
import torch
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torchtext.data import functional
from torch.utils.data.dataset import random_split
from tqdm import tqdm
import sys

#### Build Vocabulary

In [12]:
tokenizer = get_tokenizer("basic_english")
train_iter = AG_NEWS(split='train')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)
        
vocab = build_vocab_from_iterator(yield_tokens(train_iter), 
                                  min_freq=1,
                                  specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

#### Build Dataloader

In [37]:
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

BATCH_SIZE = 64
train_iter, test_iter = AG_NEWS()

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for labels, texts in batch:
        label_list.append(label_pipeline(labels))
        processed_text = torch.tensor(text_pipeline(texts), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_tensor = torch.tensor(label_list)
    offset_tensor = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_tensor = torch.cat(text_list)
    return label_tensor, text_tensor, offset_tensor


train_dataset, test_dataset = functional.to_map_style_dataset(train_iter), \
                              functional.to_map_style_dataset(test_iter)
num_train = int(len(train_dataset)*0.95)
split_train, split_valid = random_split(train_dataset, [num_train, len(train_dataset)-num_train])

train_dataloader = DataLoader(split_train, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid, batch_size=BATCH_SIZE,
                             shuffle=False, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, collate_fn=collate_batch)

### 2. Define the Model

In [38]:
from torch import nn

class MulticlassModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(MulticlassModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)

    def forward(self, X, offsets):
        embedded = self.embedding(X, offsets)
        return self.fc(embedded)

### 3. Build and Train Model

In [39]:
def train(dataloader, model):
    model.train()
    for labels, texts, offsets in tqdm(dataloader, desc='training...', file=sys.stdout):
        optimizer.zero_grad()
        outputs = model(texts, offsets)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

            
def evaluate(dataloader, model):
    model.eval()
    n_samples, n_accurates = 0, 0
    with torch.no_grad():
        for labels, texts, offsets in dataloader:
            outputs = model(texts, offsets)
            outputs = torch.argmax(outputs, dim=1)
            n_accurates += (outputs==labels).sum().item() 
            n_samples += labels.size(0)
    return n_accurates/n_samples

In [40]:
# Hyperparameters
LR = 1
N_EPOCHS = 30
vocab_size = len(vocab)
embed_dim = 100
num_class = 4

classifier = MulticlassModel(vocab_size, embed_dim, num_class)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(classifier.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

for epoch in range(1, N_EPOCHS + 1):
    train(train_dataloader, classifier)
    train_acc = evaluate(train_dataloader, classifier)
    valid_acc = evaluate(valid_dataloader, classifier)
    scheduler.step()
    print(f"| Epoch: {epoch}/{N_EPOCHS} | train_accuracy : {train_acc: .3f} | val_accuracy : {valid_acc: .3f}")

# Test with test set
accu_test = evaluate(test_dataloader, classifier)
print("="*60)
print(f"Test Accuracy: {accu_test: .3f}")

training...: 100%|█████████████████████████████████████████████████████████████████| 1782/1782 [00:21<00:00, 83.62it/s]
| Epoch: 1/30 | train_accuracy :  0.747 | val_accuracy :  0.749
training...: 100%|█████████████████████████████████████████████████████████████████| 1782/1782 [00:21<00:00, 81.69it/s]
| Epoch: 2/30 | train_accuracy :  0.808 | val_accuracy :  0.807
training...: 100%|█████████████████████████████████████████████████████████████████| 1782/1782 [00:22<00:00, 80.74it/s]
| Epoch: 3/30 | train_accuracy :  0.834 | val_accuracy :  0.830
training...: 100%|█████████████████████████████████████████████████████████████████| 1782/1782 [00:22<00:00, 79.93it/s]
| Epoch: 4/30 | train_accuracy :  0.844 | val_accuracy :  0.838
training...: 100%|█████████████████████████████████████████████████████████████████| 1782/1782 [00:22<00:00, 80.95it/s]
| Epoch: 5/30 | train_accuracy :  0.855 | val_accuracy :  0.848
training...: 100%|██████████████████████████████████████████████████████████████

### Evaluate the model with test dataset

In [41]:
print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader, classifier)
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.
test accuracy    0.872


### Test on a random News

In [42]:
ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = classifier(text, torch.tensor([0]))
        return output.argmax(1).item() + 1

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."


print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])

This is a Sports news
