TO DO

* refactor validation and test steps
* probably need the finbert model in training mode too when training?
* extend for multi-class classification - change loss function etc.
* keeping LSTM for classification?
* needs some hyperparameter tuning

# Imports

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
!pip install datasets
!pip install transformers

In [None]:
import pandas as pd
import yfinance as yf
from concurrent.futures import ThreadPoolExecutor
import datetime
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from datasets import load_dataset
import numpy as np
from statistics import mean
import pickle
from sklearn.preprocessing import LabelEncoder


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from transformers import AutoTokenizer, AutoModel
from torch.optim import SGD

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
finbert = AutoModel.from_pretrained("ProsusAI/finbert").to(device)
bert = AutoModel.from_pretrained("bert-base-uncased").to(device)

# Data loading and cleaning

In [None]:
dataset = load_dataset("edarchimbaud/news-stocks")
dataset.set_format(type='pandas')

df = dataset['train'][:]
df

Unnamed: 0,symbol,body,publisher,publish_time,title,url,uuid
0,A,Vipshop Holdings Limited VIPS is set to report...,Zacks,2023-05-18 16:14:04+00:00,Vipshop (VIPS) to Post Q1 Earnings: What's in ...,https://finance.yahoo.com/news/vipshop-vips-po...,27293957-38d0-3710-8e36-c33a88428202
1,A,"SANTA CLARA, Calif., May 19, 2023--(BUSINESS W...",Business Wire,2023-05-19 01:25:00+00:00,Agilent to Appeal Patent Office Decision on CR...,https://finance.yahoo.com/news/agilent-appeal-...,5477d2fd-f0e0-3721-aeb5-9e4184af5fc1
2,A,Agilent Technologies A is set to report its se...,Zacks,2023-05-19 15:08:03+00:00,Agilent Technologies (A) to Post Q2 Earnings: ...,https://finance.yahoo.com/news/agilent-technol...,619f9fb9-c102-37fc-8b8c-50c11237b9a7
3,A,"Earnings reports from Zoom Video, Lowe’s, Snow...",Barrons.com,2023-05-21 19:00:00+00:00,"Costco, JPMorgan, Snowflake, Ford, Zoom, and M...",https://finance.yahoo.com/m/5f7a781e-1e0c-30b0...,5f7a781e-1e0c-30b0-a052-1d88fc0ce184
4,A,Agilent Technologies A reported second-quarter...,Zacks,2023-05-24 17:17:00+00:00,"Agilent (A) Q2 Earnings Match Estimates, Reven...",https://finance.yahoo.com/news/agilent-q2-earn...,e9307f15-4308-33d8-a649-57f4df3530b1
...,...,...,...,...,...,...,...
22020,ZTS,When considering what names to put on your wat...,Investor's Business Daily,2023-08-09 18:18:00+00:00,Drugmaker Zoetis Stock Shows Rising Relative S...,https://finance.yahoo.com/m/de58f2c0-d31a-310c...,de58f2c0-d31a-310c-a293-02e6c4206e8c
22021,ZTS,"LINCOLN, Neb., August 09, 2023--(BUSINESS WIRE...",Business Wire,2023-08-09 22:33:00+00:00,Zoetis Welcomes Officials to Open New State-of...,https://finance.yahoo.com/news/zoetis-welcomes...,51662184-e4c7-3e7f-a6a2-7b0666de92cd
22022,ZTS,Investors interested in stocks from the Medica...,Zacks,2023-08-10 15:40:11+00:00,USNA or ZTS: Which Is the Better Value Stock R...,https://finance.yahoo.com/news/usna-zts-better...,14a5e82f-6510-30a5-bbf1-cc4982348c42
22023,ZTS,"In this article, we will be taking a look at t...",Insider Monkey,2023-11-09 08:49:40+00:00,Top 20 Drug Companies in the US by Revenue,https://finance.yahoo.com/news/top-20-drug-com...,636f3ccc-872c-3532-ad4b-675b269c4602


In [None]:
df = df.drop(['publisher', 'url', 'uuid'], axis=1)
df

Unnamed: 0,symbol,body,publish_time,title
0,A,Vipshop Holdings Limited VIPS is set to report...,2023-05-18 16:14:04+00:00,Vipshop (VIPS) to Post Q1 Earnings: What's in ...
1,A,"SANTA CLARA, Calif., May 19, 2023--(BUSINESS W...",2023-05-19 01:25:00+00:00,Agilent to Appeal Patent Office Decision on CR...
2,A,Agilent Technologies A is set to report its se...,2023-05-19 15:08:03+00:00,Agilent Technologies (A) to Post Q2 Earnings: ...
3,A,"Earnings reports from Zoom Video, Lowe’s, Snow...",2023-05-21 19:00:00+00:00,"Costco, JPMorgan, Snowflake, Ford, Zoom, and M..."
4,A,Agilent Technologies A reported second-quarter...,2023-05-24 17:17:00+00:00,"Agilent (A) Q2 Earnings Match Estimates, Reven..."
...,...,...,...,...
22020,ZTS,When considering what names to put on your wat...,2023-08-09 18:18:00+00:00,Drugmaker Zoetis Stock Shows Rising Relative S...
22021,ZTS,"LINCOLN, Neb., August 09, 2023--(BUSINESS WIRE...",2023-08-09 22:33:00+00:00,Zoetis Welcomes Officials to Open New State-of...
22022,ZTS,Investors interested in stocks from the Medica...,2023-08-10 15:40:11+00:00,USNA or ZTS: Which Is the Better Value Stock R...
22023,ZTS,"In this article, we will be taking a look at t...",2023-11-09 08:49:40+00:00,Top 20 Drug Companies in the US by Revenue


In [None]:
# select stocks
ticker = 'AAPL'

df = df[df['symbol'] == ticker]

df = df.reset_index(drop=True)
df

Unnamed: 0,symbol,body,publish_time,title
0,AAPL,Apple has stopped some of its employees from u...,2023-05-19 12:08:00+00:00,Apple Bans Some Staff From Using ChatGPT. But ...
1,AAPL,Samsung shelved a review that could have seen ...,2023-05-19 13:19:00+00:00,Google Parent Alphabet Stock Rises. It Got Som...
2,AAPL,Nvidia (NVDA) stock has become a Wall Street d...,2023-05-22 12:26:04+00:00,Nvidia stock is trading on 'heroic' valuations...
3,AAPL,Yahoo Finance markets contributor Remy Blaire ...,2023-05-22 20:15:42+00:00,Apple nears $3 trillion market cap amid Loop C...
4,AAPL,Yahoo Finance Senior Reporter Alexandra Canal ...,2023-05-24 16:12:43+00:00,"Streaming wars evolving between Netflix, Disne..."
...,...,...,...,...
91,AAPL,Epic Games is facing a setback in its legal ba...,2023-08-09 19:55:42+00:00,Supreme Court rules in Apple's favor in Epic G...
92,AAPL,As data has taken a backseat in driving market...,2023-08-10 13:33:34+00:00,Tech: Nvidia earnings 'big catalyst' to watch ...
93,AAPL,Apple (AAPL) shares sink in August amid growth...,2023-08-10 16:23:26+00:00,Apple and Nvidia: How the tech stocks are perf...
94,AAPL,"ChatGPT, OpenAI’s text-generating AI chatbot, ...",2023-11-20 07:00:39+00:00,ChatGPT: Everything you need to know about the...


# Load stock prices

In [None]:
# start date is 7 days before the earliest day from the df
# to make sure at least one trading day before is included
# (there might be weekend days/holidays etc.)

start_date = df['publish_time'].min() - datetime.timedelta(7)
end_date = df['publish_time'].max() + datetime.timedelta(7)

prices = yf.download('AAPL', start_date, end_date)

prices.index = pd.to_datetime(prices.index, format='%Y-%m-%d', utc=True)
prices

[*********************100%%**********************]  1 of 1 completed


Unnamed: 0_level_0,Open,High,Low,Close,Adj Close,Volume
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2023-05-12 00:00:00+00:00,173.619995,174.059998,171.000000,172.570007,171.891205,45497800
2023-05-15 00:00:00+00:00,173.160004,173.210007,171.470001,172.070007,171.393173,37266700
2023-05-16 00:00:00+00:00,171.990005,173.139999,171.800003,172.070007,171.393173,42110300
2023-05-17 00:00:00+00:00,171.710007,172.929993,170.419998,172.690002,172.010727,57951600
2023-05-18 00:00:00+00:00,173.000000,175.240005,172.580002,175.050003,174.361450,65496700
...,...,...,...,...,...,...
2023-11-20 00:00:00+00:00,189.889999,191.910004,189.880005,191.449997,191.206009,46505100
2023-11-21 00:00:00+00:00,191.410004,191.520004,189.740005,190.639999,190.397049,38134500
2023-11-22 00:00:00+00:00,191.490005,192.929993,190.830002,191.309998,191.066193,39617700
2023-11-24 00:00:00+00:00,190.869995,190.899994,189.250000,189.970001,189.727905,24048300


In [None]:
def trend(date):

  prev_date = prices[prices.index < date].index.max()
  prev_date_index = prices.index.get_loc(prev_date.strftime('%Y-%m-%d'))

  if date.strftime('%Y-%m-%d') == prev_date.strftime('%Y-%m-%d'):
    prev_date_index -= 1

  next_date = prices[prices.index > date].index.min()
  next_date_index = prices.index.get_loc(next_date.strftime('%Y-%m-%d'))
  ret = ((prices['Open'][next_date_index] - prices['Close'][prev_date_index]) / prices['Close'][prev_date_index]) * 100

  return_threshold = 1.0 # (1%)
  if ret >= return_threshold:
    return 'increase'
  elif ret <= -return_threshold:
    return 'decrease'
  else:
    return 'stable'


df['trend'] = df['publish_time'].apply(trend)

# Tokenize data

In [None]:
def wrap_tokenizer(tokenizer, padding=True, truncation=True, return_tensors='pt', max_length=None):
    def tokenize(text):
        text = list(text)
        tokens = tokenizer(
            text,
            padding=padding,
            return_attention_mask=False,
            truncation=truncation,
            max_length=max_length,
            return_tensors=return_tensors
            )['input_ids']
        return tokens
    return tokenize

In [None]:
tokenize = wrap_tokenizer(tokenizer, padding=False, truncation = False, return_tensors=None)
tokens = tokenize(df['body'])
num_tokens = [len(x) for x in tokens]
df['length'] = pd.Series(num_tokens)

Token indices sequence length is longer than the specified maximum sequence length for this model (1035 > 512). Running this sequence through the model will result in indexing errors


In [None]:
max_length = df['length'].unique().max()
max_length
max_tokens = 512
for i in range(1,20):
    num = sum(df['length']>i*max_tokens)
    print(f"Number of text that have more than {i}*max_tokens is {num}")

Number of text that have more than 1*max_tokens is 49
Number of text that have more than 2*max_tokens is 26
Number of text that have more than 3*max_tokens is 15
Number of text that have more than 4*max_tokens is 9
Number of text that have more than 5*max_tokens is 6
Number of text that have more than 6*max_tokens is 3
Number of text that have more than 7*max_tokens is 3
Number of text that have more than 8*max_tokens is 3
Number of text that have more than 9*max_tokens is 2
Number of text that have more than 10*max_tokens is 1
Number of text that have more than 11*max_tokens is 1
Number of text that have more than 12*max_tokens is 1
Number of text that have more than 13*max_tokens is 1
Number of text that have more than 14*max_tokens is 1
Number of text that have more than 15*max_tokens is 1
Number of text that have more than 16*max_tokens is 1
Number of text that have more than 17*max_tokens is 1
Number of text that have more than 18*max_tokens is 0
Number of text that have more than

In [None]:
def get_text_split(text, length=200, overlap=50, max_chunks=4):
    chunks = []

    words = text.split()
    n_words = len(words)

    n = max(1, min(max_chunks, (n_words - length) // (length - overlap) + 1))

    for i in range(n):
        start_idx = i * (length - overlap)
        end_idx = min(start_idx + length, n_words)

        chunk_words = words[start_idx:end_idx]

        chunk_text = " ".join(chunk_words)

        # If it's the last chunk and its length is less than 75% of the desired length, skip
        if i == n - 1 and len(chunk_words) < 0.75 * length and n > 1:
            continue

        chunks.append(chunk_text)

    return chunks

**ENCODE LABELS**

In [None]:
df = df[df['trend'] != 'decrease']

In [None]:
labels = ["increase", "stable"]

label_encoder = LabelEncoder()

df['trend'] = label_encoder.fit_transform(df['trend'])

#Train-val-test split

In [None]:
def split_df(df):
    n_rows = len(df)

    df_train = df.iloc[:int(0.8*n_rows),:]
    df_val = df.iloc[int(0.8*n_rows):int(0.9*n_rows),:]
    df_test = df.iloc[int(0.9*n_rows):,:]

    return df_train, df_val, df_test

In [None]:
n_rows = len(df)
dfs_train, dfs_val, dfs_test = [],[],[]
gb = df.groupby('symbol')
for x in gb.groups:
    group = gb.get_group(x)
    df_train, df_val, df_test = split_df(group)
    dfs_train.append(df_train)
    dfs_val.append(df_val)
    dfs_test.append(df_test)

df_train = pd.concat(dfs_train, ignore_index=True)

df_val = pd.concat(dfs_val, ignore_index=True)

df_test = pd.concat(dfs_test, ignore_index=True)

print(f'Number of training examples: {len(df_train)}')
print(f'Number of validation examples: {len(df_val)}')
print(f'Number of test examples: {len(df_test)}')

Number of training examples: 75
Number of validation examples: 9
Number of test examples: 10


In [None]:
df_train.body = df_train.body.apply(lambda x: get_text_split(x))
df_val.body = df_val.body.apply(lambda x: get_text_split(x))
df_test.body = df_test.body.apply(lambda x: get_text_split(x))

In [None]:
df_train['n_chunks'] = df_train.body.apply(lambda x: len(x))
df_val['n_chunks'] = df_val.body.apply(lambda x: len(x))
df_test['n_chunks'] = df_test.body.apply(lambda x: len(x))

In [None]:
df_train.head()

Unnamed: 0,symbol,body,publish_time,title,trend,length,n_chunks
0,AAPL,[Apple has stopped some of its employees from ...,2023-05-19 12:08:00+00:00,Apple Bans Some Staff From Using ChatGPT. But ...,1,105,1
1,AAPL,[Samsung shelved a review that could have seen...,2023-05-19 13:19:00+00:00,Google Parent Alphabet Stock Rises. It Got Som...,1,32,1
2,AAPL,[Yahoo Finance Senior Reporter Alexandra Canal...,2023-05-24 16:12:43+00:00,"Streaming wars evolving between Netflix, Disne...",1,622,3
3,AAPL,[The best Dow Jones stocks to buy and watch in...,2023-05-24 16:50:33+00:00,Best Dow Jones Stocks To Buy And Watch In May ...,1,31,1
4,AAPL,[By Blake Brittain(Reuters) - The U.S. solicit...,2023-05-24 20:04:04+00:00,Biden administration urges Supreme Court not t...,1,407,2


# Model

In [None]:
class MyDataset(Dataset):
    def __init__(self,df):
        self.n_chunks = df['n_chunks'].to_list()
        self.X = df['body'].to_list()
        self.Y = df['trend']

    def __len__(self):
        return len(self.X)

    def __getitem__(self,index):
        return self.X[index], self.Y.iloc[index], self.n_chunks[index]

def collate_func(batch):
    X = [x[0] for x in batch]
    Y = torch.Tensor([x[1] for x in batch])
    c = [x[2] for x in batch]
    return [X,Y,c]

In [None]:
class Classifier(nn.Module):
    def __init__(self, lstm_size, emb_dim, out_dim_lin, lstm_do):
        super().__init__()

        self.lstm = nn.LSTM(input_size=emb_dim, hidden_size=lstm_size, batch_first=True)
        self.dropout = nn.Dropout(lstm_do)
        self.linear = nn.Linear(in_features=lstm_size, out_features=out_dim_lin)

    def forward(self, x, n_chunks):
        x = pad_sequence(x, batch_first=True, padding_value=0)
        x = pack_padded_sequence(input=x, lengths=n_chunks, batch_first=True, enforce_sorted=False)
        x, _ = self.lstm(x)
        x, _ = pad_packed_sequence(x, batch_first=True)
        x = x[:,-1,:]
        x = self.dropout(x)
        logit = self.linear(x)
        return logit

In [None]:
# class Classifier(nn.Module):
#     def __init__(self, bert_model, out_dim_lin, dropout):
#         super().__init__()

#         self.bert = bert_model
#         self.dropout = nn.Dropout(dropout)
#         self.linear = nn.Linear(in_features=self.bert.config.hidden_size, out_features=out_dim_lin)

#     def forward(self, x, n_chunks):

#         x = pad_sequence(x, batch_first=True, padding_value=0)

#         with torch.no_grad():
#             bert_outputs = self.bert(x)

#         bert_last_hidden_state = bert_outputs.last_hidden_state

#         x = self.dropout(bert_last_hidden_state)

#         logit = self.linear(x[:, 0, :])

#         return logit

In [1]:
def save_to_disk(txt_path, values):
    if os.path.isfile(txt_path):
        os.remove(txt_path)
    with open(txt_path, "wb") as fp:
        pickle.dump(values, fp)
    return


def load_from_disk(txt_path):
    with open(txt_path, "rb") as f:
        values =  pickle.load(f)
    return values


def save_checkpoint(xlmr, classifier, optimizer, logs, checkpoint_dir, epoch):
    print('')
    print('Saving checkpoint...')
    state_dict = {
        'classifier':classifier.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(state_dict, os.path.join(checkpoint_dir, 'checkpoint_{}.pt'.format(epoch)))
    save_to_disk(os.path.join(checkpoint_dir, 'logs.txt'),logs)
    print(f'Checkpoint saved!')

checkpoint_dir = '/content/google_drive/MyDrive/checkpoints'

def load_checkpoint(checkpoint_dir, epoch, xlmr, classifier, device, optimizer=None):
    pretrained_dict = torch.load(os.path.join(checkpoint_dir,'checkpoint_{}.pt'.format(epoch)),map_location=torch.device(device))
    classifier.load_state_dict(pretrained_dict['classifier'])
    if optimizer is not None:
        optimizer.load_state_dict(pretrained_dict['optimizer'])
        return classifier, optimizer
    return classifier

In [None]:
# def get_lr(optimizer):
#     for param_group in optimizer.param_groups:
#         return param_group['lr']

# def set_lr(optimizer, lr):
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr

In [None]:
def train(train_loader, tokenize, device, model, optimizer, classifier, dropout=0.0):

    classifier.train()
    #model.train()

    accuracy = []
    losses = []

    for text, target, n_chunks in train_loader:
        target = target.to(device)

        flat_text = [item for sublist in text for item in sublist]
        tokens = tokenize(flat_text)
        tokens = tokens.to(device)

        # Based on the results - I think we should finetune the model
        with torch.no_grad():
          outputs = model(tokens)
        #outputs = model(tokens)

        embeddings = outputs.last_hidden_state

        pooled_emb = torch.mean(embeddings, axis=1)
        pooled_emb = nn.Dropout(dropout)(pooled_emb)
        x = [s for s in torch.split(pooled_emb, n_chunks, dim=0)]

        logit = classifier(x, n_chunks)
        prob = torch.sigmoid(logit)

        target = torch.reshape(target, shape=(-1, 1))
        loss = nn.BCELoss()(input=prob, target=target.float())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        prediction = (prob >= 0.5).float()
        correct_predictions = torch.sum(prediction == target).item()
        acc = correct_predictions / len(target)
        accuracy.append(acc)
        losses.append(loss.item())

    return accuracy, losses


In [None]:
# def val_step(val_loader, tokenize, device, xlmr, classifier, is_test=False):
#     print('')
#     print('Validating...')
#     xlmr.eval()
#     classifier.eval()
#     iteration = 0
#     accuracy = []
#     losses = []
#     cms = []

#     if is_test:
#         predictions,targets = [],[]

#     for text, target, n_chunks in val_loader:
#         target = target.to(device)
#         flat_text = [item for sublist in text for item in sublist]
#         tokens = tokenize(flat_text)
#         tokens = tokens.to(device)
#         with torch.no_grad():
#             embeddings = xlmr.extract_features(tokens)
#             pooled_emb = torch.mean(embeddings, axis=1)
#             # perform sum pooling
#             x = [s for s in torch.split(pooled_emb, n_chunks,dim=0)]
#             logit = classifier(x, n_chunks)
#             prob = torch.sigmoid(logit)
#         target = torch.reshape(target,shape=(-1,1))
#         loss = nn.CrossEntropyLoss()(input=prob, target=target.long())

#         # calculate accuracy
#         prob = torch.squeeze(prob)
#         prediction = torch.clone(prob)
#         prediction[prediction >= 0.5] = 1
#         prediction[prediction < 0.5] = 0
#         target = torch.squeeze(target)
#         acc = torch.sum(target==prediction)/float(len(target))
#         accuracy.append(float(acc))
#         losses.append(float(loss.cpu().numpy()))

#         # calculate confusion matrix
#         cm = confusion_matrix(target.cpu().numpy(),
#                               prediction.cpu().numpy(),
#                               labels = np.array([0,1])
#                               )
#         cms.append(cm)
#         iteration+=1
#         print(f"\r iter: {iteration}/{len(val_loader)}",end='')
#         if is_test:
#             predictions.append(prediction.cpu().numpy())
#             targets.append(target.cpu().numpy())
#     if is_test:
#         return accuracy, losses, cms, predictions, targets
#     return accuracy, losses, cms

Hyperparameters

In [None]:
BATCH_SIZE = 32
EPOCHS = 5
LR = 1e-7
EMBEDDING_DIM = 768 # 768 for base and 1024 for large
LSTM_SIZE = 128
LSTM_DO = 0.2
POOLED_EMB_DO = 0.3
OUT_DIM_LIN = 1
WEIGHT_DECAY = 1e-3
LR_FREEZE = 6e-5



In [None]:
train_dataset = MyDataset(df_train)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_func)

val_dataset = MyDataset(df_val)
val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_func)

test_dataset = MyDataset(df_test)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_func)

# Sentiment Analysis

In [None]:
import os

checkpoint_dir = '/content/google_drive/MyDrive/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Selected device is {}'.format(device))

classifier = Classifier(lstm_size=LSTM_SIZE,
                        emb_dim=EMBEDDING_DIM,
                        out_dim_lin=OUT_DIM_LIN,
                        lstm_do = LSTM_DO).to(device)


params = list(classifier.parameters())
# optimizer = SGD(params, lr=LR_FREEZE, weight_decay=WEIGHT_DECAY)
optimizer = SGD(params)

tokenize = wrap_tokenizer(tokenizer)

Selected device is cpu


In [None]:
logs = {'train_acc':[],'train_loss':[],
        'val_acc':[],'val_loss':[]}

for epoch in range(EPOCHS):

    train_acc, train_loss = train(train_loader = train_loader,
                                  tokenize = tokenize,
                                  device = device,
                                  optimizer = optimizer,
                                  model = finbert,
                                  classifier = classifier,
                                  dropout=POOLED_EMB_DO)

    # val_acc, val_loss, val_cm = val_step(
    #     val_loader=val_loader,
    #     tokenize=tokenize,
    #     device=device,
    #     xlmr=xlmr,
    #     classifier=classifier
    # )

    logs['train_acc'] += train_acc
    logs['train_loss'] += train_loss
    # logs['val_acc'] += val_acc
    # logs['val_loss'] += val_loss

    # if epoch % 2 == 0 and epoch != 0 :
    #     save_checkpoint(
    #         xlmr=xlmr,
    #         classifier=classifier,
    #         optimizer=optimizer,
    #         logs=logs,
    #         checkpoint_dir=checkpoint_dir,
    #         epoch=epoch
    #         )


    print(f"Epoch {epoch} --> loss:{mean(train_loss):.4f},\
                               acc:{mean(train_acc): .4f}% ")

FinBert (training)

* Epoch 0 --> loss:0.6814,                               acc: 0.6089%
* Epoch 1 --> loss:0.6841,                               acc: 0.5881%
* Epoch 2 --> loss:0.6828,                               acc: 0.6600%
* Epoch 3 --> loss:0.6844,                               acc: 0.5881%


Distil Bert (training)

* Epoch 0 --> loss:0.6897,                               acc: 0.5028%
* Epoch 1 --> loss:0.6923,                               acc: 0.4934%
* Epoch 2 --> loss:0.6905,                               acc: 0.5038%
* Epoch 3 --> loss:0.6934,                               acc: 0.4839%
* Epoch 4 --> loss:0.6882,                               acc: 0.5350%