In [1]:
import os
import sys
project_root_dir = os.path.relpath(os.path.join('..', '..'), os.curdir)
if project_root_dir not in sys.path:
    sys.path += [project_root_dir]
from src.data.dataload import *
from src.models.bertmodel import *
from pprint import pprint

# Data<br>
Loading data

In [2]:
data = load_sst()
print(f'loading data {data.NAME} (sentence column: {data.SENTENCE}, target column: {data.TARGET})')
train, val, test = data.train_val_test
train

loading data sst (sentence column: sentence, target column: label)


Unnamed: 0,sentence,label
0,The Rock is destined to be the 21st Century 's...,3
1,The gorgeously elaborate continuation of `` Th...,4
2,Singer/composer Bryan Adams contributes a slew...,3
3,You 'd think by now America would have had eno...,2
4,Yet the act is still charming here .,3
...,...,...
8539,A real snooze .,0
8540,No surprises .,1
8541,We 've seen the hippie-turned-yuppie plot befo...,3
8542,Her fans walked out muttering words like `` ho...,0


# Model<br>
Fine-tuning and loading BCN

In [3]:
bert = BertModel()
bert.load_model(data)
bert.model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

Prediction

In [4]:
print(f'Individual prediction for {data.NAME}')
bert.predict(test.sentence[0])
print(bert.predict(test[data.SENTENCE][0]))

100%|██████████| 1/1 [00:00<00:00, 25.11it/s]
100%|██████████| 1/1 [00:00<00:00, 65.59it/s]

Individual prediction for sst
     logits  class_probabilities  label
0 -0.986764             0.040054      2
1  1.235412             0.369594      2
2  1.532937             0.497667      2
3 -0.238388             0.084657      2
4 -2.594063             0.008028      2





In [5]:
print(f'Batch prediction for {data.NAME}')
bert.predict_batch_df(test[:100], input_col=data.SENTENCE)
print(bert.predict_batch(test[data.SENTENCE][:100]))

  0%|          | 0/4 [00:00<?, ?it/s]

Batch prediction for sst


100%|██████████| 4/4 [00:00<00:00,  4.54it/s]
100%|██████████| 4/4 [00:00<00:00,  4.70it/s]

                                               logits  \
0   [-0.9867643713951111, 1.2354116439819336, 1.53...   
1   [-1.7541735172271729, -0.8248212337493896, 0.3...   
2   [-2.8735034465789795, -1.9821761846542358, -0....   
3   [-2.72558331489563, -1.7797045707702637, 0.189...   
4   [-2.49272084236145, -2.249992847442627, -0.629...   
..                                                ...   
95  [-2.1240527629852295, -0.612809419631958, 0.46...   
96  [-0.2899676263332367, 1.4170408248901367, 1.39...   
97  [-2.727703809738159, -2.3031370639801025, -0.2...   
98  [-2.707550048828125, -1.5871951580047607, 0.29...   
99  [-2.750370740890503, -2.013284921646118, -0.08...   

                                  class_probabilities  label  
0   [0.04005402699112892, 0.3695940375328064, 0.49...      2  
1   [0.027608176693320274, 0.06992786377668381, 0....      3  
2   [0.0026741919573396444, 0.006520651280879974, ...      3  
3   [0.00502944178879261, 0.012951194308698177, 0....      3  





In [6]:
print(f'Individual label prediction for {data.NAME}')
print(bert.predict_label(test.sentence[0]))
print(f'Batched label prediction for {data.NAME}')
bert.predict_label_batch_df(test[:100], input_col=data.SENTENCE)
print(bert.predict_label_batch(test.sentence[:100]))

100%|██████████| 1/1 [00:00<00:00, 63.44it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

Individual label prediction for sst
[2]
Batched label prediction for sst


100%|██████████| 4/4 [00:00<00:00,  4.70it/s]
100%|██████████| 4/4 [00:00<00:00,  4.73it/s]

[2 3 3 3 4 1 1 3 3 4 3 3 4 2 3 3 4 3 3 4 4 3 3 3 3 3 3 1 3 3 4 3 1 3 1 3 2
 3 1 1 3 1 3 4 3 2 4 4 3 3 3 3 3 3 3 3 3 3 3 4 3 3 3 1 3 3 3 4 4 3 1 4 4 3
 3 2 4 3 3 2 3 3 3 3 4 3 3 3 3 3 2 3 1 1 3 3 1 3 3 3]





In [7]:
print(f'Individual proba prediction for {data.NAME}')
print(bert.predict_proba(test.sentence[0]))
print(f'Batched label prediction for {data.NAME}')
bert.predict_proba_batch_df(test[:10], input_col=data.SENTENCE)
print(bert.predict_proba_batch(test.sentence[:10]))

100%|██████████| 1/1 [00:00<00:00, 65.54it/s]
100%|██████████| 1/1 [00:00<00:00, 70.40it/s]
100%|██████████| 1/1 [00:00<00:00, 67.19it/s]


Individual proba prediction for sst
0    0.040054
1    0.369594
2    0.497667
3    0.084657
4    0.008028
Name: class_probabilities, dtype: float64
Batched label prediction for sst
[[0.04005403 0.36959404 0.49766704 0.08465688 0.00802796]
 [0.02760818 0.06992786 0.23198433 0.56439441 0.10608514]
 [0.00267419 0.00652065 0.04303093 0.68525881 0.26251534]
 [0.00502944 0.01295119 0.09282301 0.74466848 0.14452784]
 [0.00272763 0.00347697 0.01757069 0.48324239 0.49298233]
 [0.08925641 0.36362293 0.35126275 0.15365337 0.04220459]
 [0.0871759  0.6077832  0.26014316 0.03931797 0.00557976]
 [0.00996889 0.07069953 0.35954407 0.51221049 0.04757705]
 [0.00326375 0.00511813 0.03281545 0.65560311 0.30319953]
 [0.00429091 0.00317768 0.01184258 0.25880146 0.72188741]]


In [8]:
print(f'Individual logits prediction for {data.NAME}')
print(bert.predict_logits(test.sentence[0]))
print(f'Batched logits prediction for {data.NAME}')
bert.predict_logits_batch_df(test[:10], input_col=data.SENTENCE)
print(bert.predict_logits_batch(test.sentence[:10]))

100%|██████████| 1/1 [00:00<00:00, 63.84it/s]
100%|██████████| 1/1 [00:00<00:00, 67.55it/s]
100%|██████████| 1/1 [00:00<00:00, 67.18it/s]

Individual logits prediction for sst
0   -0.986764
1    1.235412
2    1.532937
3   -0.238388
4   -2.594063
Name: logits, dtype: float64
Batched logits prediction for sst





[[-0.98676437  1.23541164  1.53293765 -0.23838727 -2.59406304]
 [-1.75417352 -0.82482123  0.37438425  1.26346779 -0.40804356]
 [-2.87350345 -1.98217618 -0.09523159  2.67264605  1.71315897]
 [-2.72558331 -1.77970457  0.18980189  2.27204657  0.63257962]
 [-2.49272084 -2.24999285 -0.62992013  2.68436575  2.70432067]
 [-0.74530673  0.65929747  0.62471455 -0.20212078 -1.49429095]
 [-0.15909223  1.78279817  0.93421203 -0.9553383  -2.90787506]
 [-2.47936702 -0.52039778  1.10600019  1.45989907 -0.91648608]
 [-2.6689229  -2.21900988 -0.3609007   2.63375568  1.86259115]
 [-1.92020428 -2.2205534  -0.90500206  2.17935753  3.20516539]]
