In [1]:
import os

import numpy as np
import torch
from transformers import \
    BertForSequenceClassification, \
    BertTokenizer

In [2]:
os.chdir('../..')

In [3]:
from src.data.dataload import load_sst, load_agnews
from src.models.bert_utils import PreTrainedBERT

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## SST

In [5]:
sst = load_sst()

In [6]:
train_sst, val_sst, test_sst = sst.train_val_test

In [7]:
print(train_sst.shape)
train_sst.head()

(8544, 2)


Unnamed: 0,sentence,label
0,The Rock is destined to be the 21st Century 's...,3
1,The gorgeously elaborate continuation of `` Th...,4
2,Singer/composer Bryan Adams contributes a slew...,3
3,You 'd think by now America would have had eno...,2
4,Yet the act is still charming here .,3


In [8]:
bert_sst = PreTrainedBERT(device=device, dataset='sst')

In [9]:
bert_sst.tokenizer

PreTrainedTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [10]:
logits, probs = bert_sst.predict(
    sentence_array=train_sst.head(10)['sentence']
)

100%|██████████| 1/1 [00:01<00:00,  1.29s/it]


In [11]:
logits

array([[-2.1788163 , -2.1071174 , -0.62480575,  2.2194335 ,  2.7322145 ],
       [-1.7536162 , -2.044077  , -1.0482408 ,  1.7289542 ,  3.2294014 ],
       [-3.0301542 , -1.2249986 ,  0.51971126,  2.2774696 ,  0.24015394],
       [-0.5998924 ,  0.95011234,  0.81817615, -0.26388645, -2.0391357 ],
       [-3.2010188 , -1.7842047 ,  0.50896597,  2.8780212 ,  1.0130019 ],
       [-2.6354063 , -2.338395  , -0.34518975,  2.6789994 ,  2.1916533 ],
       [-2.2030933 , -2.1841588 , -0.58300143,  1.9689634 ,  2.2782674 ],
       [-2.9441814 , -1.8061488 ,  0.03783898,  2.6421356 ,  1.1196613 ],
       [-2.1810088 , -0.6658555 ,  0.05744424,  1.4511299 ,  0.8648205 ],
       [-1.4431202 ,  0.7792491 ,  1.176192  ,  0.20800184, -1.9727913 ]],
      dtype=float32)

In [12]:
probs

array([[0.00446643, 0.00479843, 0.02112802, 0.36315843, 0.6064487 ],
       [0.00548758, 0.00410426, 0.01111018, 0.17858365, 0.8007144 ],
       [0.00370262, 0.02251533, 0.12888317, 0.7474479 , 0.09745093],
       [0.08713014, 0.41051298, 0.3597723 , 0.12192532, 0.02065917],
       [0.00181747, 0.00749518, 0.0742495 , 0.793526  , 0.1229118 ],
       [0.0029385 , 0.00395472, 0.02902376, 0.59723115, 0.36685184],
       [0.00623923, 0.00635849, 0.03153029, 0.40460595, 0.5512661 ],
       [0.002867  , 0.00894681, 0.05655904, 0.7647751 , 0.16685206],
       [0.01355933, 0.06169656, 0.1271704 , 0.51245534, 0.2851183 ],
       [0.03360545, 0.31015047, 0.46127784, 0.17517936, 0.01978684]],
      dtype=float32)

In [13]:
np.testing.assert_array_equal(
    logits.argmax(axis=1),
    probs.argmax(axis=1)
)

In [14]:
labels = logits.argmax(axis=1)
labels

array([4, 4, 3, 1, 3, 3, 4, 3, 3, 2])

## AG News

In [15]:
ag = load_agnews()

In [16]:
train_ag, val_ag, test_ag = ag.train_val_test

Using custom data configuration default
Reusing dataset ag_news (/Users/stevengeorge/.cache/huggingface/datasets/ag_news/default/0.0.0/17ec33e23df9e89565131f989e0fdf78b0cc4672337b582da83fc3c9f79fe34d)


In [17]:
print(train_ag.shape)
train_ag.head()

(108000, 3)


Unnamed: 0,sentence,label,title
0,"Reuters - Short-sellers, Wall Street's dwindli...",2,Wall St. Bears Claw Back Into the Black (Reuters)
1,Reuters - Private investment firm Carlyle Grou...,2,Carlyle Looks Toward Commercial Aerospace (Reu...
2,Reuters - Soaring crude prices plus worriesabo...,2,Oil and Economy Cloud Stocks' Outlook (Reuters)
3,Reuters - Authorities have halted oil exportfl...,2,Iraq Halts Oil Exports from Main Southern Pipe...
4,"AFP - Tearaway world oil prices, toppling reco...",2,"Oil prices soar to all-time record, posing new..."


In [18]:
bert_ag = PreTrainedBERT(device=device, dataset='agn')

In [19]:
bert_ag.tokenizer

PreTrainedTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [20]:
logits, probs = bert_ag.predict(
    sentence_array=train_ag.head(10)['sentence']
)

100%|██████████| 1/1 [00:07<00:00,  7.61s/it]


In [21]:
logits

array([[-0.43939516, -3.8122914 ,  4.701639  , -1.4323643 ],
       [-2.5380313 , -3.0589104 ,  5.5920086 , -0.8595964 ],
       [-1.8376439 , -3.129275  ,  5.5971174 , -1.6488357 ],
       [ 2.7521205 , -3.7608418 ,  2.64813   , -2.8582907 ],
       [-0.15808287, -3.7708144 ,  4.61434   , -1.863463  ],
       [-1.4470259 , -3.915394  ,  5.4003773 , -1.0820162 ],
       [-0.33370143, -3.6063373 ,  4.796135  , -1.9481618 ],
       [-1.5505748 , -3.378364  ,  5.5192003 , -1.669473  ],
       [-1.9751476 , -3.1773708 ,  5.6033697 , -1.3229743 ],
       [-1.7016826 , -3.5161715 ,  5.6646147 , -1.500557  ]],
      dtype=float32)

In [22]:
probs

array([[5.8039259e-03, 1.9901767e-04, 9.9184680e-01, 2.1502092e-03],
       [2.9395448e-04, 1.7460846e-04, 9.9795663e-01, 1.5747633e-03],
       [5.8950600e-04, 1.6200921e-04, 9.9853647e-01, 7.1201060e-04],
       [5.2455509e-01, 7.7848008e-04, 4.7274676e-01, 1.9196430e-03],
       [8.3742291e-03, 2.2592035e-04, 9.8987818e-01, 1.5216254e-03],
       [1.0593691e-03, 8.9752932e-05, 9.9732482e-01, 1.5260509e-03],
       [5.8745299e-03, 2.2268214e-04, 9.9273384e-01, 1.1690197e-03],
       [8.4894529e-04, 1.3648380e-04, 9.9826080e-01, 7.5377681e-04],
       [5.1047822e-04, 1.5341162e-04, 9.9835616e-01, 9.7996951e-04],
       [6.3125265e-04, 1.0284442e-04, 9.9849403e-01, 7.7188207e-04]],
      dtype=float32)

In [23]:
np.testing.assert_array_equal(
    logits.argmax(axis=1),
    probs.argmax(axis=1)
)

In [24]:
labels = logits.argmax(axis=1)
labels

array([2, 2, 2, 0, 2, 2, 2, 2, 2, 2])