In [1]:
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe

import torch
import torch.nn as nn
import torch.nn.functional as F
from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.handlers import ModelCheckpoint, EarlyStopping
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
TEXT = data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = data.Field(sequential=False)

train, test = datasets.IMDB.splits(TEXT, LABEL)
train_iter, test_iter = data.BucketIterator.splits((train, test), batch_size=8, device="cuda:0")

TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300))
LABEL.build_vocab(train)

In [4]:
class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, kernel_sizes, num_filters, num_classes, d_prob):
        super(TextCNN, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.kernel_sizes = kernel_sizes
        self.num_filters = num_filters
        self.num_classes = num_classes
        self.d_prob = d_prob
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.conv = nn.ModuleList([nn.Conv1d(in_channels=embedding_dim,
                                             out_channels=num_filters,
                                             kernel_size=k, stride=1) for k in kernel_sizes])
        self.dropout = nn.Dropout(d_prob)
        self.fc = nn.Linear(len(kernel_sizes) * num_filters, num_classes)

    def forward(self, x):
        batch_size, sequence_length = x.shape
        x = self.embedding(x).transpose(1, 2)
        x = [F.relu(conv(x)) for conv in self.conv]
        x = [F.max_pool1d(c, c.size(-1)).squeeze(dim=-1) for c in x]
        x = torch.cat(x, dim=1)
        x = self.fc(self.dropout(x))
        if self.num_classes > 1:
            return F.log_softmax(x, dim=1)
        else:
            return torch.sigmoid(x).squeeze()

In [5]:
vocab_size, embedding_dim = TEXT.vocab.vectors.shape

model = TextCNN(vocab_size=vocab_size,
                embedding_dim=embedding_dim,
                kernel_sizes=[3, 4, 5],
                num_filters=100,
                num_classes=1, d_prob=0.3)

mode = 'static'

if 'static' in mode:
    model.embedding.weight.data.copy_(TEXT.vocab.vectors)
    if 'non' not in mode:
        model.embedding.weight.data.requires_grad = False
else:
    pass
    
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()

In [6]:
def process_function(engine, batch):
    model.to(device)
    model.train()
    optimizer.zero_grad()
    x, y = batch.text[0], batch.label.float() - 1
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

def eval_function(engine, batch):
    model.to(device)
    model.eval()
    with torch.no_grad():
        x, y = batch.text[0], batch.label.float() - 1
        y_pred = model(x)
        return y_pred, y

def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

In [7]:
trainer = Engine(process_function)
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
pbar = ProgressBar(persist=True)
pbar.attach(trainer, ['loss'])

evaluator = Engine(eval_function)
Accuracy(output_transform=thresholded_output_transform).attach(evaluator, 'accuracy')
Loss(criterion).attach(evaluator, 'bce')

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    evaluator.run(train_iter)
    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_bce = metrics['bce']
    pbar.log_message(
        "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
        .format(engine.state.epoch, avg_accuracy, avg_bce))

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(test_iter)
    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_bce = metrics['bce']
    pbar.log_message(
        "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
        .format(engine.state.epoch, avg_accuracy, avg_bce))
    pbar.n = pbar.last_print_n = 0

In [8]:
trainer.run(train_iter, max_epochs=10)

Epoch [1/10]: [3125/3125] 100%|██████████, loss=3.16e-01 [01:19<00:00]


Training Results - Epoch: 1  Avg accuracy: 0.96 Avg loss: 0.12


Epoch [2/10]: [5/3125]   0%|          , loss=1.10e-01 [00:00<01:05]

Validation Results - Epoch: 1  Avg accuracy: 0.89 Avg loss: 0.27


Epoch [2/10]: [3125/3125] 100%|██████████, loss=1.39e-01 [01:19<00:00]


Training Results - Epoch: 2  Avg accuracy: 1.00 Avg loss: 0.02


Epoch [3/10]: [6/3125]   0%|          , loss=2.45e-02 [00:00<01:02]

Validation Results - Epoch: 2  Avg accuracy: 0.87 Avg loss: 0.35


Epoch [3/10]: [3125/3125] 100%|██████████, loss=3.84e-02 [01:19<00:00]


Training Results - Epoch: 3  Avg accuracy: 1.00 Avg loss: 0.01


Epoch [4/10]: [5/3125]   0%|          , loss=5.95e-05 [00:00<01:06]

Validation Results - Epoch: 3  Avg accuracy: 0.86 Avg loss: 0.58


Epoch [4/10]: [3125/3125] 100%|██████████, loss=4.00e-03 [01:19<00:00]


Training Results - Epoch: 4  Avg accuracy: 1.00 Avg loss: 0.00


Epoch [5/10]: [6/3125]   0%|          , loss=2.17e-04 [00:00<01:03]

Validation Results - Epoch: 4  Avg accuracy: 0.85 Avg loss: 0.78


Epoch [5/10]: [218/3125]   7%|▋         , loss=1.03e-03 [00:05<01:13]

KeyboardInterrupt: 

In [None]:
def eval_f(batch):
    model.to(device)
    model.eval()
    with torch.no_grad():
        x, y = batch.text[0], batch.label.float()
        y_pred = model(x)
        return y_pred, y

In [None]:
eval_f(batch)