In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification # pip install transformers
from datasets import load_dataset # pip install datasets

  from .autonotebook import tqdm as notebook_tqdm


Download data to fine-tune the Bert model

In [2]:
# load the first 1K rows of the TREC dataset
trec = load_dataset('trec', split='train[:1000]')
trec

Found cached dataset trec (/Users/yunusskeete/.cache/huggingface/datasets/trec/default/2.0.0/f2469cab1b5fceec7249fda55360dfdbd92a7a5b545e91ea0f78ad108ffac1c2)


Dataset({
    features: ['text', 'coarse_label', 'fine_label'],
    num_rows: 1000
})

In [3]:
trec[0]

{'text': 'How did serfdom develop in and then leave Russia ?',
 'coarse_label': 2,
 'fine_label': 26}

Use the bert-base-uncased tokenizer from Hugging Face

In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# we have a small dataset so we can tokenize everything at once
# tokenize everything
tokens = tokenizer(
    trec['text'], max_length=512,
    truncation=True, padding='max_length'
)

In [5]:
tokens[:2]

[Encoding(num_tokens=512, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=512, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])]

In [6]:
tokens[0].ids[:10], tokens[0].ids[10:20], tokens[0].ids[-10:]

([101, 2129, 2106, 14262, 2546, 9527, 4503, 1999, 1998, 2059],
 [2681, 3607, 1029, 102, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

One-hot encode labels

In [7]:
import numpy as np

# initialize array to be used
labels = np.zeros(
    (len(trec), max(trec['coarse_label'])+1)
)
# one-hot encode
labels[np.arange(len(trec)), trec['coarse_label']] = 1
labels[:5]

array([[0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.]])

In [8]:
labels = torch.Tensor(labels)

Create the dataset object

In [9]:
class TrecDataset(torch.utils.data.Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, idx):
        input_ids = self.tokens[idx].ids
        attention_mask = self.tokens[idx].attention_mask
        labels = self.labels[idx]
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(attention_mask),
            'labels': torch.tensor(labels)
        }

    def __len__(self):
        return len(self.labels)

dataset = TrecDataset(tokens, labels)

Create the data loader

In [10]:
loader = torch.utils.data.DataLoader(
    dataset, batch_size=32
)

In [11]:
from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = max(trec['coarse_label'])+1
model = BertForSequenceClassification(config) 

Train

In [12]:
from transformers import AdamW
from tqdm.auto import tqdm

# activate training mode of model
model.train()

# initialize adam optimizer with weight decay
optim = AdamW(model.parameters(), lr=5e-5)



In [13]:
from time import time
from tqdm.auto import tqdm

loop_time = []

# setup loop (using tqdm for the progress bar)
loop = tqdm(loader, leave=True)
for batch in loop:
    t0 = time()
    # initialize calculated gradients (from prev step)
    optim.zero_grad()
    # train model on batch and return outputs (incl. loss)
    outputs = model(**batch)
    # extract loss
    loss = outputs[0]
    # calculate loss for every parameter that needs grad update
    loss.backward()
    # update parameters
    optim.step()
    loop_time.append(time()-t0)
    # print relevant info to progress bar
    loop.set_postfix(loss=loss.item())

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  'labels': torch.tensor(labels)
  6%|▋         | 2/32 [05:08<1:18:43, 157.46s/it, loss=0.452]

In [None]:
loop_time