In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from torch.autograd import Variable

from IPython.display import clear_output
!pip install transformers
!pip install datasets
clear_output()

#Prepare Dataset

In [2]:
from transformers import AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
train_dataset = load_dataset('trec', split='train')
train_dataset = train_dataset.map(
    lambda e: tokenizer(e['text'], truncation=True, padding='max_length'),
    batched=True
    )
test_dataset = load_dataset('trec', split='test')
test_dataset = test_dataset.map(
    lambda e: tokenizer(e['text'], truncation=True, padding='max_length'),
    batched=True
    )
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=64)

Using custom data configuration default
Reusing dataset trec (/root/.cache/huggingface/datasets/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48)
Loading cached processed dataset at /root/.cache/huggingface/datasets/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-a128e73c7344c66a.arrow
Using custom data configuration default
Reusing dataset trec (/root/.cache/huggingface/datasets/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48)
Loading cached processed dataset at /root/.cache/huggingface/datasets/trec/default/1.1.0/1902c380fe66cc215f989888b1b35e8da7e79a3a97520f00dce753fd1f8f5c48/cache-628b05de24021c11.arrow


# Build model 

In [3]:
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, 
                 dropout, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.convs = nn.ModuleList([nn.Conv1d(in_channels=embedding_dim, out_channels=n_filters, kernel_size=fs) for fs in filter_sizes])
        self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):
        embs = self.embedding(text)
        embs = embs.permute(0, 2, 1)
        out = [F.relu(c(embs)) for c in self.convs]
        out_pool = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in out]
        cat = self.dropout(torch.cat(out_pool, dim=1))
        final = self.fc(cat)
        return final

In [4]:
#set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
model = TextCNN(vocab_size=tokenizer.vocab_size,
                embedding_dim=100,
                n_filters=8,
                filter_sizes=[3,4,5],
                output_dim=6,
                dropout=0.1,
                pad_idx=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
model = model.to(device)
criterion = criterion.to(device)

In [6]:
def trec_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    max_preds = preds.argmax(dim=1, keepdim=True) # get the index of the max probability
    correct = max_preds.squeeze(1).eq(y)
    correct = correct.detach().to('cpu')
    return correct.sum() / torch.FloatTensor([y.shape[0]])

In [7]:
def train(model, iterator, optimizer, criterion, tokenizer):
    global device
    epoch_loss = 0
    epoch_acc = 0
    model.train()
    half = len(iterator) // 2 
    for i, batch in enumerate(iterator):
        if i <= half:
            batch_ = torch.stack(batch['input_ids'], dim=0).permute(1, 0)
            batch_ = batch_.to(device)
            optimizer.zero_grad()
            predictions = model(batch_)
            loss = criterion(predictions, batch['label-coarse'].long().to(device))
            acc = trec_accuracy(predictions, batch['label-coarse'].to(device))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        else:
            break
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [8]:
def evaluate(model, iterator, criterion, tokenizer):
    global device
    epoch_loss = 0
    epoch_acc = 0
    model.eval()
    with torch.no_grad():
        for batch in iterator:
            batch_ = torch.stack(batch['input_ids'], dim=0).permute(1, 0)
            batch_ = batch_.to(device)
            predictions = model(batch_)
            loss = criterion(predictions, batch['label-coarse'].long().to(device))
            acc = trec_accuracy(predictions, batch['label-coarse'].to(device))
            epoch_loss += loss.item()
            epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [9]:
import time
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [10]:
N_EPOCHS = 50
best_valid_loss = float('inf')
losses = []
for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train(model, trainloader, optimizer, criterion, tokenizer)
    valid_loss, valid_acc = evaluate(model, testloader, criterion, tokenizer)
    losses.append(train_loss)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'textcnn_trec.pt')
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

    if epoch >= 3:
        if train_loss >= losses[-1] and train_loss >= losses[-2] and train_loss >= losses[-3]:
            print('Early stopping')
            break 

Epoch: 01 | Epoch Time: 0m 2s
	Train Loss: 0.807 | Train Acc: 16.84%
	 Val. Loss: 1.337 |  Val. Acc: 53.17%
Epoch: 02 | Epoch Time: 0m 2s
	Train Loss: 0.606 | Train Acc: 30.12%
	 Val. Loss: 1.031 |  Val. Acc: 68.24%
Epoch: 03 | Epoch Time: 0m 2s
	Train Loss: 0.480 | Train Acc: 35.30%
	 Val. Loss: 0.854 |  Val. Acc: 72.13%
Epoch: 04 | Epoch Time: 0m 2s
	Train Loss: 0.400 | Train Acc: 37.74%
	 Val. Loss: 0.756 |  Val. Acc: 76.13%
Epoch: 05 | Epoch Time: 0m 2s
	Train Loss: 0.347 | Train Acc: 39.34%
	 Val. Loss: 0.696 |  Val. Acc: 77.10%
Epoch: 06 | Epoch Time: 0m 2s
	Train Loss: 0.301 | Train Acc: 41.37%
	 Val. Loss: 0.657 |  Val. Acc: 77.00%
Epoch: 07 | Epoch Time: 0m 2s
	Train Loss: 0.262 | Train Acc: 42.82%
	 Val. Loss: 0.620 |  Val. Acc: 78.71%
Epoch: 08 | Epoch Time: 0m 2s
	Train Loss: 0.228 | Train Acc: 44.20%
	 Val. Loss: 0.597 |  Val. Acc: 78.56%
Epoch: 09 | Epoch Time: 0m 2s
	Train Loss: 0.187 | Train Acc: 45.97%
	 Val. Loss: 0.581 |  Val. Acc: 78.47%
Epoch: 10 | Epoch Time: 0m 2