In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext import data, datasets

In [2]:
BATCH_SIZE = 64
EPOCHS = 10
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
LR = 0.001

In [3]:
TEXT = data.Field(sequential=True, batch_first=True, lower=True)
LABEL = data.Field(sequential=False, batch_first=True)
trainset, testset = datasets.IMDB.splits(TEXT, LABEL)

In [4]:
print(repr(TEXT))
print(repr(LABEL))

<torchtext.data.field.Field object at 0x0000019276340348>
<torchtext.data.field.Field object at 0x0000019276340308>


In [5]:
trainset, valset = trainset.split(split_ratio=0.8)
train_iter, val_iter, test_iter = data.BucketIterator.splits(
    (trainset, valset, testset), batch_size=BATCH_SIZE,
    shuffle=True, repeat=False
)

In [6]:
TEXT.build_vocab(trainset, min_freq=5)
LABEL.build_vocab(trainset)

In [7]:
vocab_size = len(TEXT.vocab)
n_classes = 2

In [8]:
print(f"Trainset: {len(trainset)}, Validationset: {len(valset)}, Testset: {len(testset)}, VocabSize: {vocab_size}")

Trainset: 20000, Validationset: 5000, Testset: 25000, VocabSize: 39899


In [9]:
class BasicGRU(nn.Module):
    def __init__(self, n_layers, hidden_dim, n_vocab, embed_dim, n_classes, dropout_p=0.2):
        super().__init__()
        self.n_layers = n_layers
        self.embed = nn.Embedding(n_vocab, embed_dim)
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout_p)
        self.gru = nn.GRU(embed_dim, self.hidden_dim,
                          num_layers=self.n_layers,
                          batch_first=True)
        self.out = nn.Linear(self.hidden_dim, n_classes)
    
    def forward(self, x):
        x = self.embed(x)
        h_0 = self._init_state(batch_size=x.size(0))
        x, _ = self.gru(x, h_0)
        h_t = x[:, -1, :]
        self.dropout(h_t)
        logit = self.out(h_t)
        return logit
    
    def _init_state(self, batch_size = 1):
        weight = next(self.parameters()).data
        return weight.new(self.n_layers, batch_size, self.hidden_dim).zero_()


In [10]:
def train(model, optimizer, train_iter):
    model.train()
    for batch in train_iter:
        x, y = batch.text.to(DEVICE), batch.label.to(DEVICE)
        y.data.sub_(1)
        optimizer.zero_grad()
        
        logit = model(x)
        loss = F.cross_entropy(logit, y)
        loss.backward()
        optimizer.step()

def evaluate(model, val_iter):
    model.eval()
    corrects, total_loss = 0, 0
    for batch in val_iter:
        x, y = batch.text.to(DEVICE), batch.label.to(DEVICE)
        y.data.sub_(1)
        
        logit = model(x)
        loss = F.cross_entropy(logit, y, reduction="sum")
        total_loss += loss.item()
        corrects += (logit.max(1)[1].view(y.size()).data == y.data).sum()
    size = len(val_iter.dataset)
    avg_loss = total_loss / size
    avg_acc = 100.0 * corrects / size
    return avg_loss, avg_acc
    

In [11]:
model = BasicGRU(n_layers=1, hidden_dim=256, n_vocab=vocab_size, embed_dim=128, n_classes=n_classes, dropout_p=0.5).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [12]:
best_val_loss = None
for e in range(1, EPOCHS+1):
    train(model, optimizer, train_iter)
    val_loss, val_acc = evaluate(model, val_iter)
    
    print(f"Epoch {e}: Val loss: {val_loss:5.2f} | Val acc: {val_acc:5.2f}")
    
    if best_val_loss is None or val_loss < best_val_loss:
        best_val_loss = val_loss
        if not os.path.isdir("snapshot"):
            os.makedirs("snapshot")
        torch.save(model.state_dict(), "./snapshot/imdb.pt")

Epoch 1: Val loss:  0.70 | Val acc: 49.54
Epoch 2: Val loss:  0.70 | Val acc: 51.72
Epoch 3: Val loss:  0.83 | Val acc: 51.22
Epoch 4: Val loss:  0.52 | Val acc: 76.74
Epoch 5: Val loss:  0.37 | Val acc: 84.66
Epoch 6: Val loss:  0.32 | Val acc: 86.56
Epoch 7: Val loss:  0.35 | Val acc: 86.06
Epoch 8: Val loss:  0.39 | Val acc: 86.26
Epoch 9: Val loss:  0.42 | Val acc: 85.38
Epoch 10: Val loss:  0.40 | Val acc: 86.70


In [13]:
model.load_state_dict(torch.load("./snapshot/imdb.pt"))
test_loss, test_acc = evaluate(model, test_iter)
print(f"Test loss: {test_loss:5.2f} | Test acc: {test_acc:5.2f}")

Test loss:  0.33 | Test acc: 86.57
