In [1]:
import argparse
import logging
import time

import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import DATASETS
from torchtext.prototype.transforms import load_sp_model, PRETRAINED_SP_MODEL, SentencePieceTokenizer
from torchtext.utils import download_from_url
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torchtext.vocab import GloVe, FastText

### Information
- torchtext repo: https://github.com/pytorch/text/tree/main/torchtext
- torchtext documentation: https://pytorch.org/text/stable/index.html

### Constants

In [2]:
DATASET = "AG_NEWS"
DATA_DIR = ".data"
DEVICE = "cpu"
EMBED_DIM = 300
LR = 4.0
BATCH_SIZE = 16
NUM_EPOCHS = 5
PADDING_VALUE = 0
PADDING_IDX = PADDING_VALUE

### Get the tokenizer
- Different models tolenize in different ways. 
    - Word2Vec / GloVe does words (WordLevel).
    - BERT does WordPiece.
    - The original transformer did BytePairEncoding.
    - FastText uses n-grams.


In [3]:
basic_english_tokenizer = get_tokenizer("basic_english")

In [4]:
sp_model_path = download_from_url(PRETRAINED_SP_MODEL["text_unigram_15000"], root=".data")
sp_model = load_sp_model(sp_model_path)
sentence_piece_tokenizer = SentencePieceTokenizer(sp_model)

In [5]:
basic_english_tokenizer("This is some text ...")

['this', 'is', 'some', 'text', '.', '.', '.']

In [6]:
sentence_piece_tokenizer("This is some text ...")

['▁This', '▁is', '▁some', '▁text', '▁...']

In [7]:
# Needed later.
TOKENIZER = basic_english_tokenizer

### Get the data and get the vocabulary.

In [8]:
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield TOKENIZER(text)

In [9]:
train_iter = DATASETS[DATASET](root=DATA_DIR, split="train")
VOCAB = build_vocab_from_iterator(yield_tokens(train_iter), specials=('<pad>', '<unk>'))

# Make the default index the same as that of the unk_token.
VOCAB.set_default_index(VOCAB['<unk>'])

### Get GloVe vectors

Information about pretrained vectors: 
- https://pytorch.org/text/stable/_modules/torchtext/vocab/vectors.html#GloVe
- https://github.com/pytorch/text/blob/e3799a6eecef451f6e66c9c20b6432c5f078697f/torchtext/vocab/vectors.py#L263

In [22]:
GLOVE = GloVe(name='840B', dim=300)
FASTTEXT = FastText()

NameError: name 'FastText' is not defined

If the embeddings are not in the token space, a zero vector will be returned.

In [11]:
GLOVE.get_vecs_by_tokens(TOKENIZER("Hello, How are you?"), lower_case_backup=True).shape

torch.Size([6, 300])

In [12]:
GLOVE.get_vecs_by_tokens(TOKENIZER("<pad> <unk> the man Man ahsdhashdahsdhash"), lower_case_backup=True)

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.2720, -0.0620, -0.1884,  ...,  0.1302, -0.1832,  0.1323],
        [-0.1731,  0.2066,  0.0165,  ...,  0.1666, -0.3834, -0.0738],
        [-0.1731,  0.2066,  0.0165,  ...,  0.1666, -0.3834, -0.0738],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

### Helper functions

These functions tokenize the string input and then map each token to the integer representation in the vocabulary.

In [13]:
def text_pipeline(x):
    return VOCAB(TOKENIZER(x))

def label_pipeline(x):
    return int(x) - 1

Nice link on collate_fn and DataLoader in PyTorch: https://python.plainenglish.io/understanding-collate-fn-in-pytorch-f9d1742647d3

In [14]:
def collate_batch(batch):
    label_list, text_list = [], []
    for (_label, _text) in batch:
        # Get the label from {1, 2, 3, 4} to {0, 1, 2, 3}
        label_list.append(label_pipeline(_label))
                
        # Return a list of ints.
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text.clone().detach())
    
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = pad_sequence(
        text_list,
        batch_first=True,
        padding_value=PADDING_VALUE)
            
    return label_list.to(DEVICE), text_list.to(DEVICE)

### Get the data

In [15]:
train_iter = DATASETS[DATASET](root=DATA_DIR, split="train")
num_class = len(set([label for (label, _) in train_iter]))
# What are the classes?
print(f"The number of classes is {num_class} ...")

The number of classes is 4 ...


### Set up the model

In [16]:
# A more complicated model. We'll explore this after we learn word embeddings.
class TextClassificationModel(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim,
        num_class,
        initialize_with_glove = True,
        fine_tune_embeddings = True
    ):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.Embedding(
            vocab_size,
            embed_dim,
            padding_idx=PADDING_IDX
        )
        
        if initialize_with_glove:
            self.embedding.weight.requires_grad = False
            for i in range(vocab_size):
                token = VOCAB.lookup_token(i)
                
                self.embedding.weight[i, :] = GLOVE.get_vecs_by_tokens(
                    TOKENIZER(token), 
                    lower_case_backup=True
                )
            self.embedding.weight.requires_grad = True
                
        if not fine_tune_embeddings:
            self.embedding.weight.requires_grad = False
        
        
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text):
        embedded = self.embedding(text)
        embedded_sum = embedded.mean(axis=1).squeeze(1)
        out = self.fc(embedded_sum)
        return out

### Set up the 

In [17]:
criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
model = TextClassificationModel(
    len(VOCAB),
    EMBED_DIM,
    num_class,
    initialize_with_glove = True,
    fine_tune_embeddings = True
).to(DEVICE)

optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

### Set up the data

In [18]:
train_iter, test_iter = DATASETS[DATASET]()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

### Train the model

In [19]:
def train(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500

    for idx, (label, text) in enumerate(dataloader):
        optimizer.zero_grad()
        logits = model(text)
                
        # Get the loss.
        loss = criterion(input=logits, target=label)
        
        # Do back propagation.
        loss.backward()
        
        # Clip the gradients.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        
        # Do an optimization step.
        optimizer.step()
        total_acc += (logits.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(epoch, idx, len(dataloader), total_acc / total_count)
            )
            total_acc, total_count = 0, 0

In [20]:
def evaluate(dataloader, model):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            logits = model(text)
            total_acc += (logits.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [21]:
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader, model, optimizer, criterion, epoch)
    accu_val = evaluate(valid_dataloader, model)
    scheduler.step()
    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(epoch, time.time() - epoch_start_time, accu_val)
    )
    print("-" * 59)

print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader, model)
print("test accuracy {:8.3f}".format(accu_test))

| epoch   1 |   500/ 7125 batches | accuracy    0.428
| epoch   1 |  1000/ 7125 batches | accuracy    0.642
| epoch   1 |  1500/ 7125 batches | accuracy    0.741
| epoch   1 |  2000/ 7125 batches | accuracy    0.796
| epoch   1 |  2500/ 7125 batches | accuracy    0.811
| epoch   1 |  3000/ 7125 batches | accuracy    0.831
| epoch   1 |  3500/ 7125 batches | accuracy    0.852
| epoch   1 |  4000/ 7125 batches | accuracy    0.853
| epoch   1 |  4500/ 7125 batches | accuracy    0.859
| epoch   1 |  5000/ 7125 batches | accuracy    0.867
| epoch   1 |  5500/ 7125 batches | accuracy    0.875
| epoch   1 |  6000/ 7125 batches | accuracy    0.879
| epoch   1 |  6500/ 7125 batches | accuracy    0.868
| epoch   1 |  7000/ 7125 batches | accuracy    0.880
-----------------------------------------------------------
| end of epoch   1 | time: 113.27s | valid accuracy    0.897 
-----------------------------------------------------------
| epoch   2 |   500/ 7125 batches | accuracy    0.900
| epoch 