## Binary Text Classification - IMDB
### (EmbeddingBag Layer, Linear Layer)

In [43]:
import torch
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn import EmbeddingBag
import torch.nn as nn
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data.dataset import random_split
from torch.optim.lr_scheduler import StepLR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Build Vocabulary

In [2]:
tokenizer = get_tokenizer('basic_english')
train_iter = IMDB(split='train')

def yield_tokens(train_iter):
    for _, text in train_iter:
        yield tokenizer(text)      
        
vocab = build_vocab_from_iterator(yield_tokens(train_iter), 
                                  min_freq=1,
                                  specials=["<unk>"])

vocab.set_default_index(vocab["<unk>"])

### Prepare Data

In [34]:
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: 0. if x=='neg'else 1.

BATCH_SIZE = 100
train_iter, test_iter = IMDB()

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for labels, texts in batch:
        label_list.append(label_pipeline(labels))
        processed_text = torch.tensor(text_pipeline(texts), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_tensor = torch.tensor(label_list)
    label_tensor = torch.unsqueeze(label_tensor, dim=1)
    offset_tensor = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_tensor = torch.cat(text_list)
    return label_tensor, text_tensor, offset_tensor


train_dataset, test_dataset = to_map_style_dataset(train_iter), to_map_style_dataset(test_iter)
num_test = int(len(test_dataset)*0.95)
split_test, split_valid = random_split(train_dataset, [num_test, len(test_dataset)-num_test])

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid, batch_size=BATCH_SIZE,
                             shuffle=False, collate_fn=collate_batch)
test_dataloader = DataLoader(split_test, batch_size=BATCH_SIZE,
                             shuffle=False, collate_fn=collate_batch)

### Define Classifier

In [60]:
class BinaryTextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(BinaryTextClassifier, self).__init__()
        self.embedding = EmbeddingBag(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, 1)
        self.sigmoid = nn.Sigmoid()
                
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        out = self.fc(embedded)
        out = self.sigmoid(out)
        return out

### Train Model

In [61]:
def train(dataloader, model):
    for labels, texts, offsets in dataloader:
        optimizer.zero_grad()
        outputs = model(texts, offsets)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

            
def evaluate(dataloader, model):
    n_samples, n_accurates = 0, 0
    with torch.no_grad():
        for labels, texts, offsets in dataloader:
            outputs = model(texts, offsets)
            n_accurates += (torch.round(outputs)==labels).sum().item() 
            n_samples += labels.size(0)
    return n_accurates/n_samples

In [63]:
# Define some hyperparameters
LR = 0.6
N_EPOCHS = 50
vocab_size = len(vocab)
embed_dim = 100
classifier = BinaryTextClassifier(vocab_size, embed_dim)

# Criterion, Optimizer, learning rate scheduler
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(classifier.parameters(), lr=LR)
scheduler = StepLR(optimizer, step_size=1, gamma=0.9)

for epoch in range(1, N_EPOCHS + 1):
    train(train_dataloader, classifier)
    train_acc = evaluate(train_dataloader, classifier)
    valid_acc = evaluate(valid_dataloader, classifier)
    scheduler.step()
    print(f"| Epoch: {epoch}/{N_EPOCHS} | train_accuracy : {train_acc: .3f} | val_accuracy : {valid_acc: .3f}")

# Test with test set
accu_test = evaluate(test_dataloader, classifier)
print("="*60)
print(f"Test Accuracy: {accu_test: .3f}")

| Epoch: 1/50 | train_accuracy :  0.657 | val_accuracy :  0.650
| Epoch: 2/50 | train_accuracy :  0.672 | val_accuracy :  0.661
| Epoch: 3/50 | train_accuracy :  0.689 | val_accuracy :  0.688
| Epoch: 4/50 | train_accuracy :  0.698 | val_accuracy :  0.698
| Epoch: 5/50 | train_accuracy :  0.707 | val_accuracy :  0.708
| Epoch: 6/50 | train_accuracy :  0.713 | val_accuracy :  0.716
| Epoch: 7/50 | train_accuracy :  0.719 | val_accuracy :  0.722
| Epoch: 8/50 | train_accuracy :  0.722 | val_accuracy :  0.730
| Epoch: 9/50 | train_accuracy :  0.728 | val_accuracy :  0.734
| Epoch: 10/50 | train_accuracy :  0.731 | val_accuracy :  0.735
| Epoch: 11/50 | train_accuracy :  0.733 | val_accuracy :  0.738
| Epoch: 12/50 | train_accuracy :  0.736 | val_accuracy :  0.740
| Epoch: 13/50 | train_accuracy :  0.739 | val_accuracy :  0.743
| Epoch: 14/50 | train_accuracy :  0.740 | val_accuracy :  0.745
| Epoch: 15/50 | train_accuracy :  0.743 | val_accuracy :  0.747
| Epoch: 16/50 | train_accuracy : 