# Train a sentiment classifier using torchtext and BERT using skorch

This notebook here is based on [another notebook](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/6%20-%20Transformers%20for%20Sentiment%20Analysis.ipynb). Please check there for more details.

<table align="left"><td>
<a target="_blank" href="https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/torchtext_bert.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank" href="https://github.com/skorch-dev/skorch/blob/master/notebooks/torchtext_bert.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/torchtext_bert.ipynb), we recommend you enable a free GPU by going:

> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**

If you are running in colab, you should install the dependencies and download the dataset by running the following cell:

## Install packages

In [1]:
! [ ! -z "$COLAB_GPU" ] && pip install torch torchtext transformers skorch

## Imports

In [0]:
import random

In [0]:
import torch
import torchtext
from torch import nn
from torchtext.data import Field, LabelField
from torchtext.data import BucketIterator
from torchtext.datasets import IMDB
from transformers import BertTokenizer
from transformers import BertModel
from skorch import NeuralNetClassifier
from skorch.callbacks import Freezer
from skorch.callbacks import ProgressBar

## Constants

In [0]:
SEED = 0
MAX_SEQ_LEN = 512  # discard everything after this many tokens, for speed

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Load data

When running this notebook for the first time, loading data and the pretrained model will take a couple of minutes.

In [0]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [0]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:MAX_SEQ_LEN - 2]
    return tokens

In [0]:
TEXT = Field(
    batch_first=True,
    use_vocab=False,
    tokenize=tokenize_and_cut,
    preprocessing=tokenizer.convert_tokens_to_ids,
    init_token=tokenizer.cls_token_id,
    eos_token=tokenizer.sep_token_id,
    pad_token=tokenizer.pad_token_id,
    unk_token=tokenizer.unk_token_id,
)

In [0]:
LABEL = LabelField(dtype=torch.int64)

In [9]:
%%time
# make splits for data
ds_train, ds_test = IMDB.splits(TEXT, LABEL)

CPU times: user 3min 19s, sys: 1.34 s, total: 3min 20s
Wall time: 3min 20s


In [0]:
LABEL.build_vocab(ds_train)

In [0]:
bert = BertModel.from_pretrained('bert-base-uncased')

## Model definition

In [0]:
class BERTGRUSentiment(nn.Module):
    def __init__(
            self,
            bert,
            hidden_dim,
            output_dim,
            n_layers,
            bidirectional,
            dropout
    ):
        super().__init__()

        self.bert = bert
        embedding_dim = bert.config.to_dict()['hidden_size']
        self.rnn = nn.GRU(
            embedding_dim,
            hidden_dim,
            num_layers=n_layers,
            bidirectional=bidirectional,
            batch_first=True,
            dropout=0 if n_layers < 2 else dropout,
        )

        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.sm = nn.Softmax(dim=-1)

    def forward(self, text):
        # text = [batch size, sent len]

        with torch.no_grad():
            embedded = self.bert(text)[0]
        # embedded = [batch size, sent len, emb dim]

        _, hidden = self.rnn(embedded)
        # hidden = [n layers * n directions, batch size, emb dim]

        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        else:
            hidden = self.dropout(hidden[-1, :, :])
        # hidden = [batch size, hid dim]

        output = self.out(hidden)
        # output = [batch size, out dim]

        return self.sm(output)

In [0]:
# model hyper-parameters
HIDDEN_DIM = 256
OUTPUT_DIM = 2
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25

## Custom code

In [0]:
class SkorchBucketIterator(BucketIterator):
    def __iter__(self):
        for batch in super().__iter__():
            # We make a small modification: Instead of just returning batch
            # we return batch.text and batch.label, corresponding to X and y
            yield batch.text, batch.label.long()

In [0]:
def my_split(dataset, y, seed=SEED):
    # use 70% of the training data for skorch-interval validation
    return dataset.split(random_state=random.seed(seed))

## Define and train neural net

In [0]:
net = NeuralNetClassifier(
    module=BERTGRUSentiment,
    module__bert=bert,
    module__hidden_dim=HIDDEN_DIM,
    module__output_dim=OUTPUT_DIM,
    module__n_layers=N_LAYERS,
    module__bidirectional=BIDIRECTIONAL,
    module__dropout=DROPOUT,

    optimizer=torch.optim.Adam,

    iterator_train=SkorchBucketIterator,
    iterator_valid=SkorchBucketIterator,
    train_split=my_split,

    callbacks=[
        # don't update the pretrained bert model parameters
        Freezer(['bert*']),
        # each epoch takes many minutes on colab, uncomment the
        # next line to see a progress bar
        # ProgressBar(batches_per_epoch=len(ds_train) // 128 + 1),
    ],

    device='cuda',
)

In [17]:
# we can set y=None because the labels are contained inside the dataset
net.fit(ds_train, y=None)

  epoch    train_loss    valid_acc    valid_loss        dur
-------  ------------  -----------  ------------  ---------
      1        [36m0.7713[0m       [32m0.8348[0m        [35m0.3604[0m  1145.7862
      2        [36m0.3791[0m       [32m0.8864[0m        [35m0.2884[0m  1150.6875
      3        [36m0.3496[0m       [32m0.8881[0m        [35m0.2822[0m  1148.2985
      4        0.3603       [32m0.8900[0m        [35m0.2790[0m  1147.4423
      5        [36m0.3465[0m       [32m0.8977[0m        [35m0.2656[0m  1146.9673
      6        0.3468       0.8861        0.2969  1149.5763
      7        [36m0.3455[0m       0.8787        0.2821  1148.0665
      8        [36m0.3348[0m       0.8929        0.2704  1147.7167
      9        0.3455       [32m0.8981[0m        0.2776  1148.3576
     10        [36m0.3329[0m       [32m0.8988[0m        0.2670  1150.3590


<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=BERTGRUSentiment(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_feature