In [1]:
!pip install transformers==4.6.1



In [2]:
!curl https://s3.amazonaws.com/realworldnlpbook/data/stanfordSentimentTreebank/trees/dev.txt --output dev.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  274k  100  274k    0     0   804k      0 --:--:-- --:--:-- --:--:--  801k


In [3]:
!curl https://s3.amazonaws.com/realworldnlpbook/data/stanfordSentimentTreebank/trees/train.txt --output train.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2109k  100 2109k    0     0  4646k      0 --:--:-- --:--:-- --:--:-- 4646k


In [4]:
import re

import torch
from torch import nn, optim
from transformers import AutoTokenizer, AutoModel, AdamW, get_cosine_schedule_with_warmup

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
BERT_MODEL = 'bert-base-cased'

In [7]:
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

In [8]:
class BertClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super(BertClassifier, self).__init__()
        self.bert_model = AutoModel.from_pretrained(model_name)

        self.linear = nn.Linear(self.bert_model.config.hidden_size, num_labels)

        self.loss_function = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, token_type_ids, label=None):
        bert_out = self.bert_model(
          input_ids=input_ids,
          attention_mask=attention_mask,
          token_type_ids=token_type_ids)
        
        logits = self.linear(bert_out.pooler_output)

        loss = None
        if label is not None:
            loss = self.loss_function(logits, label)

        return loss, logits

In [9]:
token_ids = tokenizer.encode('The best movie ever!')

In [10]:
token_ids

[101, 1109, 1436, 2523, 1518, 106, 102]

In [11]:
tokenizer.decode(token_ids)

'[CLS] The best movie ever! [SEP]'

In [12]:
result = tokenizer(
    ['The best movie ever!', 'Aweful movie'],
    max_length=10,
    pad_to_max_length=True,
    truncation=True,
    return_tensors='pt')



In [13]:
result

{'input_ids': tensor([[ 101, 1109, 1436, 2523, 1518,  106,  102,    0,    0,    0],
        [ 101,  138, 7921, 2365, 2523,  102,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])}

In [14]:
result['input_ids']

tensor([[ 101, 1109, 1436, 2523, 1518,  106,  102,    0,    0,    0],
        [ 101,  138, 7921, 2365, 2523,  102,    0,    0,    0,    0]])

In [15]:
result['token_type_ids']

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [16]:
result['attention_mask']

tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])

In [17]:
def read_dataset(file_path, batch_size, tokenizer, max_length):
    batches = []
    with open(file_path) as f:
        texts = []
        labels = []
        for line in f:
            text = line.strip()
            label = int(text[1])
            text = re.sub('\)+', '', re.sub('\(\d ', '', text))
            text = text.replace('-LRB-', '(').replace('-RRB-', ')')
            
            texts.append(text)
            labels.append(label)

            if len(texts) == batch_size:
                batch = tokenizer(
                    texts,
                    max_length=max_length,
                    pad_to_max_length=True,
                    truncation=True,
                    return_tensors='pt')
                batch['label'] = torch.tensor(labels)
                batches.append(batch)
                
                texts = []
                labels = []
        
        if texts:
            batch = tokenizer(
                texts,
                max_length=max_length,
                pad_to_max_length=True,
                truncation=True,
                return_tensors='pt')
            batch['label'] = torch.tensor(labels)
            batches.append(batch)

        return batches

In [18]:
train_data = read_dataset('train.txt', batch_size=32, tokenizer=tokenizer, max_length=128)
dev_data = read_dataset('dev.txt', batch_size=32, tokenizer=tokenizer, max_length=128)

In [19]:
len(train_data), len(dev_data)

(267, 35)

In [20]:
def move_to(batch, device):
    for key in batch.keys():
        batch[key] = batch[key].to(device)

In [21]:
model = BertClassifier(model_name=BERT_MODEL, num_labels=5).to(device)

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
move_to(dev_data[0], device)
model(**dev_data[0])

(tensor(1.7918, device='cuda:0', grad_fn=<NllLossBackward>),
 tensor([[-1.3490e-01, -2.0231e-01,  1.0244e+00, -1.2546e-01,  5.7981e-02],
         [-2.0607e-02, -3.4415e-01,  9.9455e-01, -2.5073e-01, -1.3179e-02],
         [-1.9296e-02, -3.8154e-01,  1.0262e+00, -2.9049e-01, -6.5379e-02],
         [-1.4048e-01, -1.7977e-01,  1.0589e+00, -1.6484e-01,  8.2360e-02],
         [-4.8007e-02, -1.9766e-01,  9.7215e-01, -1.9844e-01,  1.0912e-01],
         [-8.7912e-02, -1.2186e-01,  1.0786e+00, -1.2343e-01,  2.1999e-02],
         [-3.2315e-02, -2.4036e-01,  1.0648e+00, -2.6812e-01, -9.6161e-03],
         [-9.7170e-02, -2.2085e-01,  9.9116e-01, -1.7828e-01,  4.6678e-02],
         [-1.3612e-01, -2.1868e-01,  9.5757e-01, -1.4523e-01,  1.0221e-01],
         [ 1.0864e-02, -3.7965e-01,  9.8214e-01, -2.5621e-01, -2.0941e-02],
         [-8.7822e-02, -2.9356e-01,  9.3817e-01, -1.9523e-01,  1.0290e-01],
         [-1.1786e-02, -3.7825e-01,  9.4710e-01, -2.7533e-01, -4.2856e-02],
         [-1.0005e-01, -1.8

In [23]:
epochs = 30
optimizer = AdamW(model.parameters(), lr=1e-5)
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=1000,
    num_training_steps=len(train_data) * epochs)

In [24]:
for epoch in range(epochs):
    print(f'epoch = {epoch}')
    
    model.train()

    losses = []
    total_instances = 0
    correct_instances = 0
    for batch in train_data:
        batch_size = batch['input_ids'].size(0)
        move_to(batch, device)

        optimizer.zero_grad()
        
        loss, logits = model(**batch)
        loss.backward()
        optimizer.step()
        scheduler.step()
    
        losses.append(loss)
        
        total_instances += batch_size
        correct_instances += torch.sum(torch.argmax(logits, dim=-1) == batch['label']).item()
    
    avr_loss = sum(losses) / len(losses)
    accuracy = correct_instances / total_instances
    print(f'train loss = {avr_loss}, accuracy = {accuracy}')
    
    losses = []
    total_instances = 0
    correct_instances = 0
    
    model.eval()
    for batch in dev_data:
        batch_size = batch['input_ids'].size(0)
        move_to(batch, device)

        with torch.no_grad():
            loss, logits = model(**batch)
        
        losses.append(loss)
        
        total_instances += batch_size
        correct_instances += torch.sum(torch.argmax(logits, dim=-1) == batch['label']).item()

    avr_loss = sum(losses) / len(losses)
    accuracy = correct_instances / total_instances
    
    print(f'dev loss = {avr_loss}, accuracy = {accuracy}')

epoch = 0
train loss = 1.6137793064117432, accuracy = 0.24367977528089887
dev loss = 1.660991907119751, accuracy = 0.259763851044505
epoch = 1
train loss = 1.4579780101776123, accuracy = 0.36739232209737827
dev loss = 1.6124693155288696, accuracy = 0.2851952770208901
epoch = 2
train loss = 1.2845360040664673, accuracy = 0.43469101123595505
dev loss = 1.3206464052200317, accuracy = 0.4150772025431426
epoch = 3
train loss = 1.0848054885864258, accuracy = 0.5149812734082397
dev loss = 1.3720569610595703, accuracy = 0.4223433242506812
epoch = 4
train loss = 0.9170249700546265, accuracy = 0.6067415730337079
dev loss = 1.2367522716522217, accuracy = 0.49227974568574023
epoch = 5
train loss = 0.7437633275985718, accuracy = 0.7036516853932584
dev loss = 1.3211069107055664, accuracy = 0.5131698455949137
epoch = 6
train loss = 0.5801069140434265, accuracy = 0.7814840823970037
dev loss = 1.469807744026184, accuracy = 0.4895549500454133
epoch = 7
train loss = 0.4505126178264618, accuracy = 0.83988