In [1]:
import os
import sys

In [2]:
project_root_dir = os.path.relpath(os.path.join('..', '..'), os.curdir)
if project_root_dir not in sys.path:
    sys.path += [project_root_dir]

In [3]:
from src.data.dataload import *
from src.models.bcn_model import *

# Data<br>
Loading data

In [4]:
data = load_sst()
print(f'loading data {data.NAME} (sentence column: {data.SENTENCE}, target column: {data.TARGET})')
train, val, test = data.train_val_test
train

loading data sst (sentence column: sentence, target column: label)


Unnamed: 0,sentence,label
0,The Rock is destined to be the 21st Century 's...,3
1,The gorgeously elaborate continuation of `` Th...,4
2,Singer/composer Bryan Adams contributes a slew...,3
3,You 'd think by now America would have had eno...,2
4,Yet the act is still charming here .,3
...,...,...
8539,A real snooze .,0
8540,No surprises .,1
8541,We 've seen the hippie-turned-yuppie plot befo...,3
8542,Her fans walked out muttering words like `` ho...,0


# Model<br>
Fine-tuning and loading BCN

In [5]:
bcn = BCNModel()
bcn.load_model(data)
bcn.model

BiattentiveClassificationNetwork(
  (_text_field_embedder): BasicTextFieldEmbedder(
    (token_embedder_tokens): Embedding()
  )
  (_embedding_dropout): Dropout(p=0.25, inplace=False)
  (_pre_encode_feedforward): FeedForward(
    (_activations): ModuleList(
      (0): ReLU()
    )
    (_linear_layers): ModuleList(
      (0): Linear(in_features=300, out_features=300, bias=True)
    )
    (_dropout): ModuleList(
      (0): Dropout(p=0.25, inplace=False)
    )
  )
  (_encoder): LstmSeq2SeqEncoder(
    (_module): LSTM(300, 300, batch_first=True, bidirectional=True)
  )
  (_integrator): LstmSeq2SeqEncoder(
    (_module): LSTM(1800, 300, batch_first=True, bidirectional=True)
  )
  (_integrator_dropout): Dropout(p=0.1, inplace=False)
  (_self_attentive_pooling_projection): Linear(in_features=600, out_features=1, bias=True)
  (_output_layer): Maxout(
    (_linear_layers): ModuleList(
      (0): Linear(in_features=2400, out_features=4800, bias=True)
      (1): Linear(in_features=1200, out_featu

Prediction

In [6]:
print(f'Individual prediction for {data.NAME}')
bcn.predict(test.sentence[0])
print(bcn.predict(test[data.SENTENCE][0]))

Individual prediction for sst
     logits  class_probabilities label
0  0.620440             0.016874     1
1  4.414236             0.749643     1
2  2.922851             0.168715     1
3 -3.867865             0.000190     1
4  1.962503             0.064577     1


In [7]:
print(f'Batch prediction for {data.NAME}')
bcn.predict_batch_df(test[:100], input_col=data.SENTENCE)
print(bcn.predict_batch(test[data.SENTENCE][:100]))

Batch prediction for sst
                                               logits  \
0   [0.6204397678375244, 4.414236068725586, 2.9228...   
1   [1.6763371229171753, -0.020200669765472412, 0....   
2   [1.2694412469863892, 1.169205665588379, 0.8316...   
3   [1.9059851169586182, -0.21952968835830688, 0.4...   
4   [3.5735950469970703, -2.2221274375915527, 0.29...   
..                                                ...   
95  [1.1532161235809326, -0.2188836634159088, 0.24...   
96  [1.259469747543335, 1.6719794273376465, 1.1938...   
97  [3.3070015907287598, -1.3257590532302856, 0.48...   
98  [0.9012832641601562, 1.8801426887512207, 1.477...   
99  [3.3964028358459473, -0.5213455557823181, 0.84...   

                                  class_probabilities label  
0   [0.016874458640813828, 0.7496432662010193, 0.1...     1  
1   [0.5133780837059021, 0.09411099553108215, 0.15...     3  
2   [0.3514782786369324, 0.31795579195022583, 0.22...     3  
3   [0.5327665209770203, 0.063597142696380

In [8]:
print(f'Individual label prediction for {data.NAME}')
print(bcn.predict_label(test.sentence[0]))

Individual label prediction for sst
[1]


In [9]:
print(f'Batched label prediction for {data.NAME}')
bcn.predict_label_batch_df(test[:100], input_col=data.SENTENCE)
print(bcn.predict_label_batch(test.sentence[:100]))

Batched label prediction for sst
[1 3 3 3 4 1 3 1 3 4 3 3 4 1 3 1 4 3 4 4 4 3 3 3 3 3 3 3 1 0 4 1 1 3 1 3 3
 3 1 1 3 1 3 4 3 1 4 3 3 3 3 1 3 3 3 1 3 1 4 3 3 3 3 3 3 1 3 4 4 3 1 3 4 1
 3 1 3 4 3 3 3 3 3 4 3 3 3 3 3 1 3 3 1 3 1 3 1 4 1 3]
