In [1]:
import os

import torch
from transformers import \
    BertForSequenceClassification, \
    BertTokenizer

In [2]:
os.chdir('../..')

In [3]:
from src.data.dataload import load_sst, load_agnews
from src.models.bert_utils import \
    SST_MAX_LENGTH, \
    SST_BERT_HYPERPARAMETERS, \
    make_predictions, \
    AGN_MAX_LENGTH, \
    AGN_BERT_HYPERPARAMETERS

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## BERT

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [6]:
bert_sst = BertForSequenceClassification.from_pretrained("models/fine-tuned-bert-base-sst")

In [7]:
%%capture
bert_sst.to(device)

In [8]:
sst = load_sst()

In [9]:
train_sst, val_sst, test_sst = sst.train_val_test

In [10]:
train_sst.head()

Unnamed: 0,sentence,label
0,The Rock is destined to be the 21st Century 's...,3
1,The gorgeously elaborate continuation of `` Th...,4
2,Singer/composer Bryan Adams contributes a slew...,3
3,You 'd think by now America would have had eno...,2
4,Yet the act is still charming here .,3


In [11]:
predictions = make_predictions(
    df=train_sst.head(10), 
    model=bert_sst,
    tokenizer=tokenizer,
    sentence_col_name='sentence',
    device=device,
    max_length=SST_MAX_LENGTH,
    hyperparameter_dict=SST_BERT_HYPERPARAMETERS
)

100%|██████████| 1/1 [00:01<00:00,  1.28s/it]


In [12]:
predictions

array([4, 4, 3, 1, 3, 3, 4, 3, 3, 2])

## AG News

In [13]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [14]:
bert_ag = BertForSequenceClassification.from_pretrained("models/fine-tuned-bert-base-agn/")

In [15]:
%%capture
bert_ag.to(device)

In [16]:
ag = load_agnews()

In [17]:
train_ag, val_ag, test_ag = ag.train_val_test

Using custom data configuration default
Reusing dataset ag_news (/Users/stevengeorge/.cache/huggingface/datasets/ag_news/default/0.0.0/17ec33e23df9e89565131f989e0fdf78b0cc4672337b582da83fc3c9f79fe34d)


In [18]:
train_ag.head()

Unnamed: 0,sentence,label,title
0,"Reuters - Short-sellers, Wall Street's dwindli...",2,Wall St. Bears Claw Back Into the Black (Reuters)
1,Reuters - Private investment firm Carlyle Grou...,2,Carlyle Looks Toward Commercial Aerospace (Reu...
2,Reuters - Soaring crude prices plus worriesabo...,2,Oil and Economy Cloud Stocks' Outlook (Reuters)
3,Reuters - Authorities have halted oil exportfl...,2,Iraq Halts Oil Exports from Main Southern Pipe...
4,"AFP - Tearaway world oil prices, toppling reco...",2,"Oil prices soar to all-time record, posing new..."


In [20]:
predictions = make_predictions(
    df=train_ag.head(10), 
    model=bert_ag,
    tokenizer=tokenizer,
    sentence_col_name='sentence',
    device=device,
    max_length=AGN_MAX_LENGTH,
    hyperparameter_dict=AGN_BERT_HYPERPARAMETERS
)

100%|██████████| 1/1 [00:07<00:00,  7.63s/it]


In [21]:
predictions

array([2, 2, 2, 0, 2, 2, 2, 2, 2, 2])