# FinBERT Example Notebook

This notebooks shows how to train and use the FinBERT pre-trained language model for financial sentiment analysis.

## Modules 

In [1]:
from pathlib import Path
import shutil
import os
import logging
import sys
sys.path.append('..')

from textblob import TextBlob
from pprint import pprint
from sklearn.metrics import classification_report

from transformers import AutoModelForSequenceClassification

from finbert.finbert import *
import finbert.utils as tools

%load_ext autoreload
%autoreload 2

project_dir = Path.cwd().parent
pd.set_option('max_colwidth', -1)

  pd.set_option('max_colwidth', -1)


In [3]:
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.ERROR)

## Prepare the model

### Setting path variables:
1. `lm_path`: the path for the pre-trained language model (If vanilla Bert is used then no need to set this one).
2. `cl_path`: the path where the classification model is saved.
3. `cl_data_path`: the path of the directory that contains the data files of `train.csv`, `validation.csv`, `test.csv`.
---

In the initialization of `bertmodel`, we can either use the original pre-trained weights from Google by giving `bm = 'bert-base-uncased`, or our further pre-trained language model by `bm = lm_path`


---
All of the configurations with the model is controlled with the `config` variable. 

In [4]:
lm_path = project_dir/'models'/'language_model'/'finbertTRC2'
cl_path = project_dir/'models'/'classifier_model'/'finbert-sentiment'
cl_data_path = project_dir/'data'/'sentiment_data'
print('lm_path: {}'.format(lm_path))
print('cl_path: {}'.format(cl_path))
print('cl_data_path: {}'.format(cl_data_path))

lm_path: /home/rrmorris/project/ucsd-mle/finBERT/models/language_model/finbertTRC2
cl_path: /home/rrmorris/project/ucsd-mle/finBERT/models/classifier_model/finbert-sentiment
cl_data_path: /home/rrmorris/project/ucsd-mle/finBERT/data/sentiment_data


###  Configuring training parameters

You can find the explanations of the training parameters in the class docsctrings. 

In [5]:
# Clean the cl_path
#try:
#    shutil.rmtree(cl_path) 
#except:
#    pass

#bertmodel = AutoModelForSequenceClassification.from_pretrained(lm_path,cache_dir=None, num_labels=3)
bertmodel = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased',cache_dir=None, num_labels=3)


config = Config(   data_dir=cl_data_path,
                   bert_model=bertmodel,
                   num_train_epochs=4,
                   model_dir=cl_path,
                   max_seq_length = 48,
                   train_batch_size = 32,
                   learning_rate = 2e-5,
                   output_mode='classification',
                   warm_up_proportion=0.2,
                   local_rank=-1,
                   discriminate=True,
                   gradual_unfreeze=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

`finbert` is our main class that encapsulates all the functionality. The list of class labels should be given in the prepare_model method call with label_list parameter.

In [6]:
finbert = FinBert(config)
finbert.base_model = 'bert-base-uncased'
finbert.config.discriminate=True
finbert.config.gradual_unfreeze=True

In [7]:
finbert.prepare_model(label_list=['positive','negative','neutral'])

06/08/2021 22:09:17 - INFO - finbert.finbert -   device: cuda n_gpu: 1, distributed training: False, 16-bits training: False


## Fine-tune the model

In [8]:
# Get the training examples
train_data = finbert.get_data('train')

In [9]:
model = finbert.create_the_model()

### [Optional] Fine-tune only a subset of the model
The variable `freeze` determines the last layer (out of 12) to be freezed. You can skip this part if you want to fine-tune the whole model.

<span style="color:red">Important: </span>
Execute this step if you want a shorter training time in the expense of accuracy.

In [20]:
# This is for fine-tuning a subset of the model.

freeze = 6

for param in model.bert.embeddings.parameters():
    param.requires_grad = False
    
for i in range(freeze):
    for param in model.bert.encoder.layer[i].parameters():
        param.requires_grad = False

### Training

In [21]:
trained_model = finbert.train(train_examples = train_data, model = model)

06/08/2021 22:16:08 - INFO - finbert.utils -   *** Example ***
06/08/2021 22:16:08 - INFO - finbert.utils -   guid: train-1
06/08/2021 22:16:08 - INFO - finbert.utils -   tokens: [CLS] damn , i would also really rather bt ##c over airline miles [SEP]
06/08/2021 22:16:08 - INFO - finbert.utils -   input_ids: 101 4365 1010 1045 2052 2036 2428 2738 18411 2278 2058 8582 2661 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:08 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:08 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:08 - INFO - finbert.utils -   label: positive (id = 0)
06/08/2021 22:16:08 - INFO - finbert.finbert -   ***** Loading data *****
06/08/2021 22:16:08 - INFO - finbert.finbert -     Num examples = 719
06/08/2021 22:16:08 

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

06/08/2021 22:16:11 - INFO - finbert.utils -   *** Example ***
06/08/2021 22:16:11 - INFO - finbert.utils -   guid: validation-1
06/08/2021 22:16:11 - INFO - finbert.utils -   tokens: [CLS] bt ##c is the root of this market . [SEP]
06/08/2021 22:16:11 - INFO - finbert.utils -   input_ids: 101 18411 2278 2003 1996 7117 1997 2023 3006 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:11 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:11 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:11 - INFO - finbert.utils -   label: positive (id = 0)
06/08/2021 22:16:11 - INFO - finbert.finbert -   ***** Loading data *****
06/08/2021 22:16:11 - INFO - finbert.finbert -     Num examples = 80
06/08/2021 22:16:11 - INFO - finbert.finbert -   

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation losses: [0.9105482896169027]
No best model found


Epoch:  25%|██▌       | 1/4 [00:03<00:11,  3.79s/it]

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

06/08/2021 22:16:16 - INFO - finbert.utils -   *** Example ***
06/08/2021 22:16:16 - INFO - finbert.utils -   guid: validation-1
06/08/2021 22:16:16 - INFO - finbert.utils -   tokens: [CLS] bt ##c is the root of this market . [SEP]
06/08/2021 22:16:16 - INFO - finbert.utils -   input_ids: 101 18411 2278 2003 1996 7117 1997 2023 3006 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:16 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:16 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:16 - INFO - finbert.utils -   label: positive (id = 0)
06/08/2021 22:16:16 - INFO - finbert.finbert -   ***** Loading data *****
06/08/2021 22:16:16 - INFO - finbert.finbert -     Num examples = 80
06/08/2021 22:16:16 - INFO - finbert.finbert -   

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation losses: [0.9105482896169027, 0.9105482896169027]


Epoch:  50%|█████     | 2/4 [00:08<00:08,  4.42s/it]

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

06/08/2021 22:16:22 - INFO - finbert.utils -   *** Example ***
06/08/2021 22:16:22 - INFO - finbert.utils -   guid: validation-1
06/08/2021 22:16:22 - INFO - finbert.utils -   tokens: [CLS] bt ##c is the root of this market . [SEP]
06/08/2021 22:16:22 - INFO - finbert.utils -   input_ids: 101 18411 2278 2003 1996 7117 1997 2023 3006 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:22 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:22 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:22 - INFO - finbert.utils -   label: positive (id = 0)
06/08/2021 22:16:22 - INFO - finbert.finbert -   ***** Loading data *****
06/08/2021 22:16:22 - INFO - finbert.finbert -     Num examples = 80
06/08/2021 22:16:22 - INFO - finbert.finbert -   

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation losses: [0.9105482896169027, 0.9105482896169027, 0.9105482896169027]


Epoch:  75%|███████▌  | 3/4 [00:14<00:05,  5.10s/it]

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

06/08/2021 22:16:28 - INFO - finbert.utils -   *** Example ***
06/08/2021 22:16:28 - INFO - finbert.utils -   guid: validation-1
06/08/2021 22:16:28 - INFO - finbert.utils -   tokens: [CLS] bt ##c is the root of this market . [SEP]
06/08/2021 22:16:28 - INFO - finbert.utils -   input_ids: 101 18411 2278 2003 1996 7117 1997 2023 3006 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:28 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:28 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:28 - INFO - finbert.utils -   label: positive (id = 0)
06/08/2021 22:16:28 - INFO - finbert.finbert -   ***** Loading data *****
06/08/2021 22:16:28 - INFO - finbert.finbert -     Num examples = 80
06/08/2021 22:16:28 - INFO - finbert.finbert -   

Validating:   0%|          | 0/3 [00:00<?, ?it/s]

Validation losses: [0.9105482896169027, 0.9105482896169027, 0.9105482896169027, 0.9105482896169027]


Epoch: 100%|██████████| 4/4 [00:20<00:00,  5.24s/it]


## Test the model

`bert.evaluate` outputs the DataFrame, where true labels and logit values for each example is given

In [22]:
test_data = finbert.get_data('test')

In [23]:
results = finbert.evaluate(examples=test_data, model=trained_model)

06/08/2021 22:16:53 - INFO - finbert.utils -   *** Example ***
06/08/2021 22:16:53 - INFO - finbert.utils -   guid: test-1
06/08/2021 22:16:53 - INFO - finbert.utils -   tokens: [CLS] second wave will be taking over ether ##eum chain . [SEP]
06/08/2021 22:16:53 - INFO - finbert.utils -   input_ids: 101 2117 4400 2097 2022 2635 2058 28855 14820 4677 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:53 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:53 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:16:53 - INFO - finbert.utils -   label: neutral (id = 2)
06/08/2021 22:16:53 - INFO - finbert.finbert -   ***** Loading data *****
06/08/2021 22:16:53 - INFO - finbert.finbert -     Num examples = 200
06/08/2021 22:16:53 - INFO - finber

Testing:   0%|          | 0/7 [00:00<?, ?it/s]

### Prepare the classification report

In [24]:
def report(df, cols=['label','prediction','logits']):
    #print('Validation loss:{0:.2f}'.format(metrics['best_validation_loss']))
    cs = CrossEntropyLoss(weight=finbert.class_weights)
    loss = cs(torch.tensor(list(df[cols[2]])),torch.tensor(list(df[cols[0]])))
    print("Loss:{0:.2f}".format(loss))
    print("Accuracy:{0:.2f}".format((df[cols[0]] == df[cols[1]]).sum() / df.shape[0]) )
    print("\nClassification Report:")
    print(classification_report(df[cols[0]], df[cols[1]]))

In [25]:
results['prediction'] = results.predictions.apply(lambda x: np.argmax(x,axis=0))

In [26]:
report(results,cols=['labels','prediction','predictions'])

Loss:0.89
Accuracy:0.56

Classification Report:
              precision    recall  f1-score   support

           0       0.38      0.56      0.45        45
           1       0.31      0.64      0.42        22
           2       0.82      0.56      0.66       133

    accuracy                           0.56       200
   macro avg       0.51      0.58      0.51       200
weighted avg       0.67      0.56      0.59       200



### Get predictions

With the `predict` function, given a piece of text, we split it into a list of sentences and then predict sentiment for each sentence. The output is written into a dataframe. Predictions are represented in three different columns: 

1) `logit`: probabilities for each class

2) `prediction`: predicted label

3) `sentiment_score`: sentiment score calculated as: probability of positive - probability of negative

Below we analyze a paragraph taken out of [this](https://www.economist.com/finance-and-economics/2019/01/03/a-profit-warning-from-apple-jolts-markets) article from The Economist. For comparison purposes, we also put the sentiments predicted with TextBlob.
> Later that day Apple said it was revising down its earnings expectations in the fourth quarter of 2018, largely because of lower sales and signs of economic weakness in China. The news rapidly infected financial markets. Apple’s share price fell by around 7% in after-hours trading and the decline was extended to more than 10% when the market opened. The dollar fell by 3.7% against the yen in a matter of minutes after the announcement, before rapidly recovering some ground. Asian stockmarkets closed down on January 3rd and European ones opened lower. Yields on government bonds fell as investors fled to the traditional haven in a market storm.

In [27]:
text = "Bitcoin’s price has fallen to its lowest point in over a week as traders \
stare down prospects of shifting U.S. monetary policy and continued tightening of \
regulation of cryptocurrencies in China. Other notable cryptos were also trading in \
the red, with the top 10 by market capitalization having fallen between 7.3% and \
12.9% over the previous 24 hours. Polkadot and XRP were the hardest hit, down 12.93% \
and 11.39%, respectively."

In [28]:
cl_path = project_dir/'models'/'classifier_model'/'finbert-sentiment'
model = AutoModelForSequenceClassification.from_pretrained(cl_path, cache_dir=None, num_labels=3)

In [29]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/rrmorris/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [30]:
result = predict(text,model)

06/08/2021 22:20:22 - INFO - finbert.utils -   *** Example ***
06/08/2021 22:20:22 - INFO - finbert.utils -   guid: 0
06/08/2021 22:20:22 - INFO - finbert.utils -   tokens: [CLS] bit ##co ##in ’ s price has fallen to its lowest point in over a week as traders stare down prospects of shifting u . s . monetary policy and continued tightening of regulation of crypt ##oc ##ur ##ren ##cies in china . [SEP]
06/08/2021 22:20:22 - INFO - finbert.utils -   input_ids: 101 2978 3597 2378 1521 1055 3976 2038 5357 2000 2049 7290 2391 1999 2058 1037 2733 2004 13066 6237 2091 16746 1997 9564 1057 1012 1055 1012 12194 3343 1998 2506 18711 1997 7816 1997 19888 10085 3126 7389 9243 1999 2859 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:20:22 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:20:22 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 

In [31]:
blob = TextBlob(text)
result['textblob_prediction'] = [sentence.sentiment.polarity for sentence in blob.sentences]
result

Unnamed: 0,sentence,logit,prediction,sentiment_score,textblob_prediction
0,Bitcoin’s price has fallen to its lowest point in over a week as traders stare down prospects of shifting U.S. monetary policy and continued tightening of regulation of cryptocurrencies in China.,"[0.41733512, 0.46646473, 0.11620015]",negative,-0.04913,-0.155556
1,"Other notable cryptos were also trading in the red, with the top 10 by market capitalization having fallen between 7.3% and 12.9% over the previous 24 hours.","[0.5750858, 0.27490568, 0.15000847]",positive,0.30018,0.141667
2,"Polkadot and XRP were the hardest hit, down 12.93% and 11.39%, respectively.","[0.5511973, 0.31143925, 0.13736351]",positive,0.239758,-0.077778


In [32]:
print(f'Average sentiment is %.2f.' % (result.sentiment_score.mean()))

Average sentiment is 0.16.


Here is another example

In [33]:
text2 = "Ether (ETH) eclipsed $4,000 for the first time on Monday, passing the \
psychologically significant barrier on multiple exchanges, including Coinbase. \
The new milestone comes just a week after breaking $3,000. The remarkable run has \
even prompted renewed speculation that there could be a “flippening” on the horizon \
— a long-anticipated event among the Ethereum community in which ETH overtakes \
Bitcoin (BTC) in market capitalization."

In [34]:
result2 = predict(text2,model)
blob = TextBlob(text2)
result2['textblob_prediction'] = [sentence.sentiment.polarity for sentence in blob.sentences]

06/08/2021 22:31:35 - INFO - finbert.utils -   *** Example ***
06/08/2021 22:31:35 - INFO - finbert.utils -   guid: 0
06/08/2021 22:31:35 - INFO - finbert.utils -   tokens: [CLS] ether ( et ##h ) eclipse ##d $ 4 , 000 for the first time on monday , passing the psychological ##ly significant barrier on multiple exchanges , including coin ##base . [SEP]
06/08/2021 22:31:35 - INFO - finbert.utils -   input_ids: 101 28855 1006 3802 2232 1007 13232 2094 1002 1018 1010 2199 2005 1996 2034 2051 2006 6928 1010 4458 1996 8317 2135 3278 8803 2006 3674 15800 1010 2164 9226 15058 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:31:35 - INFO - finbert.utils -   attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
06/08/2021 22:31:35 - INFO - finbert.utils -   token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 

In [35]:
result2

Unnamed: 0,sentence,logit,prediction,sentiment_score,textblob_prediction
0,"Ether (ETH) eclipsed $4,000 for the first time on Monday, passing the psychologically significant barrier on multiple exchanges, including Coinbase.","[0.6277018, 0.24478924, 0.12750891]",positive,0.382913,0.208333
1,"The new milestone comes just a week after breaking $3,000.","[0.65775746, 0.21923997, 0.12300254]",positive,0.438518,0.136364
2,The remarkable run has even prompted renewed speculation that there could be a “flippening” on the horizon — a long-anticipated event among the Ethereum community in which ETH overtakes Bitcoin (BTC) in market capitalization.,"[0.5580745, 0.29772595, 0.1441996]",positive,0.260349,0.75


In [36]:
print(f'Average sentiment is %.2f.' % (result2.sentiment_score.mean()))

Average sentiment is 0.36.
