## Pre-requirements and presentation functions

## https://medium.com/swlh/a-simple-guide-on-using-bert-for-text-classification-bbf041ac8d04 (version used for binary classification in mlt)

In [19]:
# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"

# Common imports
import pandas as pd
import numpy as np
import os

# figure plotting
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "figures"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)

In [20]:
def plot_confusion_matrix(cm, classes, title, normalize=False, cmap=plt.cm.Blues):
    """
    See full source and example: 
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label') 
    plt.title(title)

In [64]:
df_fpb = pd.read_csv("./data/financial-phrase-bank-v1.0/Sentences_66Agree.txt", sep='@',encoding='latin-1', names=['Text','Rating'])

In [65]:
df_fpb.head()

Unnamed: 0,Text,Rating
0,"According to Gran , the company has no plans t...",neutral
1,Technopolis plans to develop in stages an area...,neutral
2,With the new production plant the company woul...,positive
3,According to the company 's updated strategy f...,positive
4,"For the last quarter of 2010 , Componenta 's n...",positive


In [66]:
len(df_fpb)

4217

In [67]:
"""Changed the getlabel function in binaryprocessor class to have 3 labels, negative, neutral, positive"""
df_fpb['Rating'] = df_fpb['Rating'].replace('negative',0)
df_fpb['Rating'] = df_fpb['Rating'].replace('neutral',1)
df_fpb['Rating'] = df_fpb['Rating'].replace('positive',2)

In [68]:
df_fpb = sklearn.utils.shuffle(df_fpb, random_state=42)


In [69]:
df_fpb

Unnamed: 0,Text,Rating
463,Tielinja generated net sales of 7.5 mln euro $...,1
2426,"Cohen & Steers , Inc. : 5 534 626 shares repre...",1
2661,"SAN FRANCISCO ( MarketWatch ) -- Nokia Corp , ...",1
1483,Raute said it has won an order worth around 15...,2
2860,"The power supplies , DC power systems and inve...",1
...,...,...
3444,To see a slide show of all the newest product ...,1
466,"Under the rental agreement , Stockmann was com...",1
3092,"Eero Katajavuori , currently Group Vice Presid...",1
3772,The floor area of the Yliopistonrinne project ...,1


In [70]:
from sklearn.model_selection import train_test_split

df_train, df_test = train_test_split(df_fpb, test_size=0.2)

## BERT

In [50]:
from simpletransformers.classification import ClassificationModel

In [54]:
# Create a ClassificationModel
model = ClassificationModel(
    "bert", "bert-base-cased", num_labels=3, args={"reprocess_input_data": True, "overwrite_output_dir": True}, use_cuda=False
)

INFO:filelock:Lock 140438700248992 acquired on /Users/Roshan/.cache/huggingface/transformers/092cc582560fc3833e556b3f833695c26343cb54b7e88cd02d40821462a74999.1f48cab6c959fc6c360d22bea39d06959e90f5b002e77e836d2da45464875cda.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…

INFO:filelock:Lock 140438700248992 released on /Users/Roshan/.cache/huggingface/transformers/092cc582560fc3833e556b3f833695c26343cb54b7e88cd02d40821462a74999.1f48cab6c959fc6c360d22bea39d06959e90f5b002e77e836d2da45464875cda.lock





Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…

INFO:filelock:Lock 140438562834176 released on /Users/Roshan/.cache/huggingface/transformers/6508e60ab3c1200bffa26c95f4b58ac6b6d95fba4db1f195f632fa3cd7bc64cc.437aa611e89f6fc6675a049d2b5545390adbc617e7d655286421c191d2be2791.lock





INFO:filelock:Lock 140438700381952 acquired on /Users/Roshan/.cache/huggingface/transformers/ec84e86ee39bfe112543192cf981deebf7e6cbe8c91b8f7f8f63c9be44366158.ec5c189f89475aac7d8cbd243960a0655cfadc3d0474da8ff2ed0bf1699c2a5f.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29.0, style=ProgressStyle(description_w…

INFO:filelock:Lock 140438700381952 released on /Users/Roshan/.cache/huggingface/transformers/ec84e86ee39bfe112543192cf981deebf7e6cbe8c91b8f7f8f63c9be44366158.ec5c189f89475aac7d8cbd243960a0655cfadc3d0474da8ff2ed0bf1699c2a5f.lock





INFO:filelock:Lock 140438700381952 acquired on /Users/Roshan/.cache/huggingface/transformers/226a307193a9f4344264cdc76a12988448a25345ba172f2c7421f3b6810fddad.3dab63143af66769bbb35e3811f75f7e16b2320e12b7935e216bd6159ce6d9a6.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435797.0, style=ProgressStyle(descripti…

INFO:filelock:Lock 140438700381952 released on /Users/Roshan/.cache/huggingface/transformers/226a307193a9f4344264cdc76a12988448a25345ba172f2c7421f3b6810fddad.3dab63143af66769bbb35e3811f75f7e16b2320e12b7935e216bd6159ce6d9a6.lock





In [71]:
# Train the model
model.train_model(df_train)

INFO:simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

INFO:simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_train_bert_128_3_3373





HBox(children=(FloatProgress(value=0.0, description='Epoch', max=1.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 1', max=422.0, style=ProgressStyle(des…





KeyboardInterrupt: 