# ML Classifiers

**Goal:** Given a sentence as input, classify it as either a prediction or non-prediction.

In [1]:
import os
import sys
import warnings

import pandas as pd

from tqdm import tqdm

# Get the current working directory of the notebook
notebook_dir = os.getcwd()
# Add the parent directory to the system path
sys.path.append(os.path.join(notebook_dir, '../'))

# import log_files
from data_processing import DataProcessing
from feature_extraction import SpacyFeatureExtraction
# from classification_models import SkLearnPerceptronModel, SkLearnSGDClassifier, EvaluationMetric
from classification_models import SkLearnModelFactory
from metrics import EvaluationMetric

In [2]:
pd.set_option('max_colwidth', 800)
# pd.set_option('display.max_columns', None)
# pd.set_option('display.max_columns', 40)
pd.set_option('display.max_rows', None)

warnings.filterwarnings('ignore')

## Load Data

In [3]:
print("======= LOAD DATA =======")



In [4]:
base_data_path = os.path.join(notebook_dir, '../data')
combine_data_path = os.path.join(base_data_path, 'financial_phrase_bank/combined_generated_fin_phrase_bank')
data_path = os.path.join(combine_data_path, 'combined_generated_fin_phrase_bank-v3.csv')

In [5]:
df = DataProcessing.load_from_file(data_path, 'csv', sep=',')
df

Unnamed: 0,Base Sentence,Sentence Label,Author Type
0,JPMorgan Chase forecasts that the net profit at Amazon potentially decrease in Q3 of 2027.,1,0
1,"On August 21, 2024, Bank of America speculates the revenue at Microsoft will likely increase.",1,0
2,"Citigroup predicts on 2024-08-21, the operating income at Alphabet may rise.",1,0
3,"According to Goldman Sachs, the research and development expenses at Facebook would fall in 2025.",1,0
4,"In 21 August 2024, Morgan Stanley envisions that the gross profit at Johnson & Johnson has some probability to remain stable.",1,0
5,"The stock price at Visa should stay same in Q2 of 2026, according to Wells Fargo.",1,0
6,JPMorgan forecasts that the revenue at Microsoft potentially decrease in Q3 of 2027.,1,0
7,"On August 25, 2024, to September 25, 2025, Citigroup speculates the net profit at Johnson & Johnson will likely increase.",1,0
8,"Bank of America predicts on 2024-08-21, the operating income at Visa may rise.",1,0
9,"According to Goldman Sachs, the research and development expenses at Alphabet would fall in 2029 Q2.",1,0


In [6]:
print(len(df))
# df.drop(columns=['Unnamed: 0'], inplace=True)
print(f"\tShape: {df.shape}, \nSubset of Data:{df.head(7)}")
df.shape, df.tail(3)

2825
	Shape: (2825, 3), 
Subset of Data:                                                                                                                   Base Sentence  \
0                                     JPMorgan Chase forecasts that the net profit at Amazon potentially decrease in Q3 of 2027.   
1                                  On August 21, 2024, Bank of America speculates the revenue at Microsoft will likely increase.   
2                                                   Citigroup predicts on 2024-08-21, the operating income at Alphabet may rise.   
3                              According to Goldman Sachs, the research and development expenses at Facebook would fall in 2025.   
4  In 21 August 2024, Morgan Stanley envisions that the gross profit at Johnson & Johnson has some probability to remain stable.   
5                                              The stock price at Visa should stay same in Q2 of 2026, according to Wells Fargo.   
6                                   

((2825, 3),
                                                                                                        Base Sentence  \
 2822  These moderate but significant changes resulted in a significant 24-32 % reduction in the estimated CVD risk .   
 2823                Uponor improved its performance in spite of the decrease in residential construction in the US .   
 2824                                                                       The inventor was issued U.S. Patent No. .   
 
       Sentence Label  Author Type  
 2822               0            1  
 2823               0            1  
 2824               0            1  )

## Shuffle Data

In [7]:
df.head(3)

Unnamed: 0,Base Sentence,Sentence Label,Author Type
0,JPMorgan Chase forecasts that the net profit at Amazon potentially decrease in Q3 of 2027.,1,0
1,"On August 21, 2024, Bank of America speculates the revenue at Microsoft will likely increase.",1,0
2,"Citigroup predicts on 2024-08-21, the operating income at Alphabet may rise.",1,0


In [8]:
print("======= SHUFFLE DATA =======")



In [9]:
shuffled_df = DataProcessing.shuffle_df(df)
print(f"\tShape: {shuffled_df.shape}, \nSubset of Data:{shuffled_df.head(7)}")

	Shape: (2825, 3), 
Subset of Data:                                                                                                                                            Base Sentence  \
0  Net income from life insurance rose to EUR 16.5 mn from EUR 14.0 mn , and net income from non-life insurance to EUR 22.6 mn from EUR 15.2 mn in 2009 .   
1                                      According to the financial senior level person at JPMorgan Chase, the revenue at Alphabet had risen in Q1 of 2029.   
2                                       On 11th of October 2025, the World Health Organization monitored the obesity rates at rural high schools changed.   
3                                                                      To be number one means creating added value for stakeholders in everything we do .   
4                                                                                    Dr. Smith noted on 09/15/2024, the prevalence of heart disease fell.   
5                      

## Extract Sentence Embeddings

In [10]:
print("======= EMBED SENTENCES: Spacy =======")



In [11]:
spacy_fe = SpacyFeatureExtraction(shuffled_df, 'Base Sentence')
spacy_fe

<feature_extraction.SpacyFeatureExtraction at 0x346f4d710>

In [12]:
spacy_sentence_embeddings_df = spacy_fe.sentence_feature_extraction(attach_to_df=True)
# print(f"{spacy_sentence_embeddings_df.head(3)}")

100%|██████████| 2825/2825 [00:09<00:00, 293.51it/s]


## Normalize Embeddings

- Why: Getting the below warnings
    1. sklearn/utils/extmath.py:203: RuntimeWarning: divide by zero encountered in matmul ret = a @ b
    2. sklearn/utils/extmath.py:203: RuntimeWarning: overflow encountered in matmul ret = a @ b
    3. sklearn/utils/extmath.py:203: RuntimeWarning: invalid value encountered in matmul ret = a @ b

- Normalize will place data within "boundaries" to be all on one scale

In [13]:
print("======= NORMALIZE EMBEDDINGS =======")



In [14]:
from sklearn.preprocessing import StandardScaler

# Convert embeddings to matrix if not already
embeddings_matrix = pd.DataFrame(spacy_sentence_embeddings_df["Embedding"].tolist())

# Scale the embeddings
scaler = StandardScaler()
scaled_embeddings = scaler.fit_transform(embeddings_matrix)

spacy_sentence_embeddings_df['Normalized Embeddings'] = list(scaled_embeddings)

In [15]:
# print(f"{spacy_sentence_embeddings_df.head(3)}")
# spacy_sentence_embeddings_df
# print(f"{spacy_sentence_embeddings_df.to_dict()}")

for idx, row in spacy_sentence_embeddings_df.iterrows():
    text = row['Base Sentence']
    label = row['Sentence Label']
    embedding = row['Embedding']
    norm_embedding = row['Normalized Embeddings']
    if idx < 7:
        print(f"{idx}\n Sentence: {text}\n Label: {label}\n Embeddings Shape: {embedding.shape}\n\t Embeddings Subset [:6]: {embedding[:6]} \n Norm Embeddings: {norm_embedding.shape}, \n\tNorm Embeddings Subset [:6]: {norm_embedding[:6]}")

0
 Sentence: Net income from life insurance rose to EUR 16.5 mn from EUR 14.0 mn , and net income from non-life insurance to EUR 22.6 mn from EUR 15.2 mn in 2009 .
 Label: 0
 Embeddings Shape: (300,)
	 Embeddings Subset [:6]: [-0.22492191  0.33076787  0.00159167  0.02455215  0.24869722 -0.12270349] 
 Norm Embeddings: (300,), 
	Norm Embeddings Subset [:6]: [-1.6847831   1.3524358  -0.06767292  0.96886927  2.4268584  -1.2660933 ]
1
 Sentence: According to the financial senior level person at JPMorgan Chase, the revenue at Alphabet had risen in Q1 of 2029.
 Label: 0
 Embeddings Shape: (300,)
	 Embeddings Subset [:6]: [-0.14797592  0.18903823  0.02756777  0.00196567  0.06092546  0.05799879] 
 Norm Embeddings: (300,), 
	Norm Embeddings Subset [:6]: [-0.753743   -0.5607003   0.32978797  0.60701054  0.11148097  1.2488207 ]
2
 Sentence: On 11th of October 2025, the World Health Organization monitored the obesity rates at rural high schools changed.
 Label: 0
 Embeddings Shape: (300,)
	 Embeddi

In [16]:
embeddings_col_name = 'Normalized Embeddings'

## Split Data


> **Stratification preserves the original dataset ratio when splitting into train-test splits.**  
> **Example**: If we have 1,000 samples: 920 non-predictions, 80 predictions  
> - **Train size = 0.8 → 800 samples in train**  
>     - **With stratify**: train ≈ 736 non-predictions, 64 predictions  
>     - **Without stratify**: train could randomly have 797 non-predictions and 3 predictions (or worse)  
>     - **Without stratify**: train could randomly have 100 non-predictions and 700 predictions (or worse)  
> - **Test size = 0.2 → 200 samples in test**  
>     - **With stratify**: test ≈ 184 non-predictions, 16 predictions  
>     - **Without stratify**: test could randomly have 198 non-predictions and 2 predictions (or worse)  


In [17]:
print("======= SPLIT DATA =======")



In [18]:
spacy_sentence_embeddings_df.head(3)

Unnamed: 0,Base Sentence,Sentence Label,Author Type,Embedding,Normalized Embeddings
0,"Net income from life insurance rose to EUR 16.5 mn from EUR 14.0 mn , and net income from non-life insurance to EUR 22.6 mn from EUR 15.2 mn in 2009 .",0,1,"[-0.22492191, 0.33076787, 0.0015916728, 0.024552153, 0.24869722, -0.12270349, 0.077695, 0.036100093, 0.17108163, 1.3147347, -0.44673762, 0.13499488, -0.09395125, 0.022185089, 0.10875409, 0.061217517, 0.03683422, 1.1383543, -0.05511938, 0.033903327, 0.035543617, -0.053965002, -0.092551306, -0.051658418, 0.11317858, 0.07838255, -0.10281078, -0.0828436, -0.08695139, 0.16029881, 0.03223964, -0.046120018, -0.12969352, 0.04341615, 0.021033173, -0.009246599, -0.16340375, -0.09167824, -0.018575398, 0.13226306, 0.017988915, 0.051376373, 0.15791693, -0.10549688, -0.07, 0.111039765, 0.15711744, 0.0023394853, 0.12900947, 0.01061755, -0.12041667, -0.042009674, -0.09062783, -0.08716391, 0.14827925, -0.10065144, -0.048354708, 0.044614907, -0.06126854, -0.093334354, -0.11243491, -0.107763015, -0.26356...","[-1.6847831, 1.3524358, -0.06767292, 0.96886927, 2.4268584, -1.2660933, 1.5015978, 0.6130948, 1.875508, -1.563784, -1.9667782, 1.3606664, -2.5874863, 0.6566864, 1.5288833, 1.6460268, 0.77185035, 0.25685945, 0.9508728, 1.0484687, 0.20961311, -1.4369309, -1.2569767, 0.38518983, 1.2346611, 0.22325297, 0.3019607, -1.615992, -1.5019004, 1.2467152, 0.5224793, -0.69948846, -2.0667944, -0.527591, -0.31377298, 0.21632737, -2.021523, -2.089863, -0.021239078, 2.0683947, 0.167936, 0.08498458, 1.0196956, -1.0416915, -1.3515515, 1.7234211, 3.0996788, 0.5156986, 1.2091638, 0.01454015, -1.8745446, -1.008702, -0.65537107, -0.3519377, 1.4470301, -0.9149368, -0.85097164, 1.0063423, -0.95718586, 0.3945605, -1.3345459, -0.8587651, -2.2379487, -1.3992167, -0.76099265, -0.51154155, 1.9765085, -1.163064, 0.48..."
1,"According to the financial senior level person at JPMorgan Chase, the revenue at Alphabet had risen in Q1 of 2029.",0,0,"[-0.14797592, 0.18903823, 0.027567772, 0.0019656746, 0.060925458, 0.05799879, -0.024907697, 0.13308369, 0.11494873, 2.010199, -0.25961187, 0.0109938225, 0.1105645, 0.07865309, -0.029267984, -0.02273451, -0.024475409, 0.9206356, -0.111091964, 0.0010093455, 0.04381073, 0.10480573, -0.018486956, -0.061253823, 0.07465677, 0.12606715, -0.060426887, 0.029329566, -0.025426142, 0.06758911, 0.057725355, 0.08223564, 0.025920369, 0.18043904, 0.06054687, -0.044482682, -0.026366549, 0.08514258, 0.044186402, 0.010201949, 0.070893876, -0.0047824113, -0.0060838233, -0.093484946, 0.10798281, 0.02062661, -0.0039412733, -0.035238508, -0.013841703, 0.023276232, -0.04481264, -0.020503543, -0.043965098, 0.080525406, -0.011630089, -0.025144821, 0.035508182, -0.07530105, -0.046570312, -0.12124377, 0.03996019,...","[-0.753743, -0.5607003, 0.32978797, 0.60701054, 0.111480966, 1.2488207, -0.25328374, 1.7680104, 1.1024882, 0.7523284, 0.2574567, -0.22799437, 0.5773869, 1.4955589, 0.15318674, 0.24765712, -0.24251279, -1.034374, 0.10102975, 0.5221763, 0.34579566, 0.78445023, -0.17936698, 0.2528741, 0.6491704, 0.89950097, 0.87393886, 0.07694304, -0.5634217, -0.051457874, 0.9559546, 1.1519892, 0.42239207, 1.3638468, 0.29068092, -0.3707338, -0.13043657, 0.6421685, 0.8772152, 0.45633778, 1.0627156, -0.7821341, -1.0406907, -0.85462475, 1.1389904, 0.23680475, 0.83151436, 0.074242115, -0.71656907, 0.18575948, -0.75267917, -0.68441117, 0.13332179, 1.667282, -0.59339595, 0.29462445, 0.4232794, -0.65243214, -0.7144797, -0.05391862, 1.2617244, -0.47301805, 0.69809985, -0.36569864, -0.68550366, 0.779293, 0.7943997..."
2,"On 11th of October 2025, the World Health Organization monitored the obesity rates at rural high schools changed.",0,0,"[-0.076041564, 0.3122237, 0.026510298, -0.0717942, -0.086790755, -0.010789577, -0.020441111, 0.030986553, 0.026818525, 2.277497, -0.40642172, -0.021653151, 0.09790225, 0.04019418, -0.040514797, -0.040429495, 0.0054392503, 1.2545536, -0.13120344, -0.09417785, -0.0004992186, -0.010959702, 0.094102845, -0.064638995, 0.043657802, -0.00772529, -0.1283721, -0.024550527, -0.054511957, 0.15769775, 0.03518964, 0.011812801, 0.0805304, -0.041667435, 0.051297754, -0.048048753, -0.076382495, 0.114326835, -0.03039441, 0.19640622, -0.010173494, 0.088735, -0.0689217, -0.015006678, 0.11604055, -0.023721, -0.108259395, 0.0055881417, 0.024729151, -0.030221347, -0.08574061, 0.06973127, -0.050878346, -0.078708425, 0.035788704, -0.048028205, 0.15204714, 0.0071707987, 0.12800325, -0.1753889, -0.05890914, -0....","[0.11665679, 1.1021175, 0.31360754, -0.5746989, -1.709979, 0.29146248, -0.17688876, 0.5522009, -0.11117445, 1.6425138, -1.4875709, -0.6462566, 0.38143912, 0.92422426, 0.041087262, -0.04708421, 0.25242296, 0.9460072, -0.2043269, -1.000786, -0.38411364, -0.83523834, 1.4587744, 0.20619436, 0.17801897, -0.9979014, -0.042993914, -0.73622376, -1.0070838, 1.2102938, 0.57265455, 0.1361687, 1.2959294, -1.7020686, 0.14919417, -0.43014738, -0.8206468, 1.0930899, -0.19043183, 2.915531, -0.30837375, 0.6618198, -1.83014, 0.36754954, 1.2517437, -0.49237984, -0.6375812, 0.55386305, -0.1966072, -0.53784037, -1.3599969, 0.67623925, 0.016474152, -0.25012156, 0.011662799, -0.07195073, 2.1940258, 0.4883851, 2.1681864, -0.9239821, -0.4226573, -0.62790334, 0.5726152, -1.0644513, 0.01750101, 1.6099285, 0.0101..."


In [19]:
# spacy_embeds = spacy_sentence_embeddings_df['Embedding'].to_list()
cols_with_labels = spacy_sentence_embeddings_df.loc[:, ['Sentence Label', 'Author Type']]
cols_with_labels.head(3)

Unnamed: 0,Sentence Label,Author Type
0,0,1
1,0,0
2,0,0


In [21]:
splits = DataProcessing.split_data(spacy_sentence_embeddings_df, cols_with_labels, stratify=True, stratify_by='Sentence Label')
splits
# print(f"{X_train_df.head(3)}")

[                                                                                                                                                                                                                                                                                                             Base Sentence  \
 54                                                                                                                                                                              According to Atria 's President and CEO Matti Tikkakoski , the company 's Swedish operations significantly improved in the first quarter .   
 2090                                                                                                            The tower it chose to add , due to go into operation in the summer , will increase total capacity to 80 broadsheet or 160 tabloid pages , with the ability to run 32 - and 48-page products side by side .   
 1677                                      

In [23]:
X_train_df, X_test_df, y_train_sentence_df, y_test_sentence_df, y_train_author_df, y_test_author_df = splits
X_train_df.head(3)

Unnamed: 0,Base Sentence,Sentence Label,Author Type,Embedding,Normalized Embeddings
54,"According to Atria 's President and CEO Matti Tikkakoski , the company 's Swedish operations significantly improved in the first quarter .",0,1,"[0.041027006, 0.1820981, -0.020990776, -0.017873267, 0.02760573, -0.05120162, 0.037013326, -0.092839114, -0.009454541, 1.8566082, -0.21250282, -0.073992275, 0.13637276, -0.026959503, 0.0036987744, -0.12699701, 0.054408632, 0.86367, -0.11449091, -0.087293826, 0.051253065, 0.060785774, 0.0076474054, -0.2108174, 0.016746499, 0.09260433, -0.15951684, 0.030492347, -0.013935759, -0.07429332, 0.0011344169, -0.001241451, -0.010624001, 0.1094713, 0.0713425, 0.037630994, 0.025294207, -0.009313684, -0.021804791, -0.123197086, 0.0461751, 0.06474435, 0.01552536, -0.073185295, -0.015361861, -0.10975688, 0.039395545, 0.003920167, 0.109445505, 0.115822986, 0.11245847, -0.008725389, -0.047884904, 0.032616522, 0.0598614, -0.0945988, 0.013164907, 0.022863908, -0.015512865, -0.15775046, -0.017384807, 0.01...","[1.5331769, -0.6543815, -0.41320747, 0.28917015, -0.2993782, -0.27096987, 0.80579233, -0.9223599, -0.6106997, 0.24082291, 0.8174096, -1.3168082, 0.97676873, -0.07339154, 0.48177373, -1.4890194, 1.0626177, -1.3722222, 0.049422853, -0.89064395, 0.46839184, 0.16856156, 0.20087865, -1.8095307, -0.2310043, 0.4249418, -0.4632977, 0.0944919, -0.3881526, -2.038175, -0.006575713, -0.05213353, -0.16216853, 0.38422164, 0.45582512, 0.9973453, 0.58247167, -0.8172608, -0.067468815, -1.3054585, 0.6446479, 0.29139262, -0.76920855, -0.53848994, -0.5869913, -1.9070227, 1.4418199, 0.5342681, 0.945428, 1.4375323, 1.5810192, -0.50680846, 0.06706928, 1.0903904, 0.3188278, -0.81797785, 0.08378548, 0.7054649, -0.20164002, -0.64054817, 0.28476956, 0.8700376, -0.15381446, -0.16897859, 0.4016923, -0.8785516, 0.3..."
2090,"The tower it chose to add , due to go into operation in the summer , will increase total capacity to 80 broadsheet or 160 tabloid pages , with the ability to run 32 - and 48-page products side by side .",0,1,"[0.022676554, 0.27905893, -0.005440206, 0.04101572, 0.028533207, -0.022087641, -0.10377527, 0.017977975, -0.037837636, 1.8846042, -0.15306629, 0.07733075, 0.11042544, 0.0016491526, -0.12795655, -0.022667203, -0.06467701, 1.4026767, -0.13643815, -0.061641958, -0.047967315, 0.048904773, -0.040466562, -0.009019414, 0.086929806, 0.008409476, -0.054103896, -0.03744836, -0.026812427, -0.009389524, -0.04112014, 0.1139884, -0.11598987, -0.0002831363, 0.07660422, -0.1346757, -0.09740868, 0.0975301, 0.061730504, -0.041590095, -0.070093796, 0.08762044, 0.018985458, -0.14367935, -0.062360097, 0.041349366, -0.07251717, 0.036065783, -0.019729113, 0.016377065, -0.01018686, 0.038048115, 0.089512795, -0.009121782, 0.019252788, 0.009468689, 0.045341544, -0.07612066, -0.029661197, -0.036368743, 0.0429762...","[1.3111378, 0.65444326, -0.17526786, 1.2326326, -0.2879416, 0.13422228, -1.6022079, 0.39729008, -1.0015702, 0.3340585, 1.5238906, 0.6218925, 0.57523495, 0.35161042, -0.8304641, 0.24877822, -0.9076452, 1.8244872, -0.28380644, -0.48022288, -1.1660464, 0.002332985, -0.4991624, 0.9731598, 0.83570755, -0.7690832, 0.9592687, -0.9308799, -0.5845675, -1.1293558, -0.72526485, 1.6100098, -1.8475919, -1.1308076, 0.5363154, -1.8734211, -1.1108041, 0.8335661, 1.1283643, -0.22767793, -1.3218036, 0.6446104, -0.7257384, -1.6363227, -1.2446452, 0.5775382, -0.13422915, 0.9119065, -0.7959354, 0.092442445, -0.23887774, 0.19848904, 2.3893635, 0.5878014, -0.19933371, 0.84910774, 0.57269204, -0.6637696, -0.435266, 1.3099465, 1.3131078, 1.3458992, -0.08271862, 0.75795156, 0.83473533, 0.14551637, -1.310539, -0..."
1677,"In Q3 2028, Professor Chen envisioned that her patients' body mass index at the local clinic decreased.",0,0,"[-0.061861046, 0.23680219, -0.0379443, 0.058220647, 0.078254, -0.025313035, 0.031442486, -0.0184335, 0.03042985, 1.8410904, -0.20648515, 0.08007239, 0.18419315, -0.11648165, -0.033785056, 0.017133156, 0.029407453, 1.1232946, -0.17044091, -0.11203464, 0.003896701, 0.088495836, -0.12435456, -0.123019256, 0.08988704, 0.08324601, -0.004029304, 0.051926084, 0.09357063, 0.012874378, -0.036120635, 0.010841602, -0.001974003, 0.04227495, 0.1447469, -0.061285317, 0.041888863, 0.09170372, -0.042803027, -0.018771304, 0.03643465, 0.079883896, 0.0161318, -0.08886372, 0.009178597, 0.05170872, -0.023176154, -0.039468963, 0.09726059, -0.10507055, -0.08307594, 0.016027248, -0.033854246, -0.10610999, 0.050839346, -0.014455145, 0.0532704, -0.12634611, -0.07461981, -0.16359253, -0.112019554, -0.07795816, 0...","[0.2882399, 0.08404104, -0.67261374, 1.5082734, 0.32515588, 0.08933305, 0.7105106, -0.03631092, -0.06144202, 0.1891441, 0.8889375, 0.6570174, 1.7167873, -1.4033067, 0.10816407, 0.91172373, 0.648975, 0.16754422, -0.8000774, -1.2864884, -0.31170046, 0.5562565, -1.7197028, -0.5988394, 0.88065434, 0.29222497, 1.6350342, 0.41797322, 1.2517016, -0.8176042, -0.6402305, 0.12215956, -0.023803912, -0.5433439, 1.5787162, -0.65067905, 0.8114747, 0.7435437, -0.36806497, 0.07368905, 0.47990805, 0.5251545, -0.7615897, -0.78265655, -0.24359277, 0.74787164, 0.5606326, 0.02454383, 0.7811668, -1.550239, -1.3204567, -0.13356352, 0.30421528, -0.5800757, 0.20370738, 0.46586528, 0.6931667, -1.3585286, -1.1776502, -0.73442507, -1.3274697, -0.4216261, 0.7265246, -0.51856166, 0.1758846, -0.014314739, 1.0441087,..."


In [None]:
sentence_labels_col

In [None]:
save_df = True

if save_df == True:
    print("Save test set so we can pass these into LLMs")
    # save_path = os.path.join(base_data_path, 'combined_generated_fin_phrase_bank')
    DataProcessing.save_to_file(X_test_df, combine_data_path, 'x_test_set', 'csv')
    DataProcessing.save_to_file(y_test_df, combine_data_path, 'y_test_set', 'csv')

In [None]:
len(y_train_df)

In [None]:
X_train_df[embeddings_col_name].to_list()

## Models

In [None]:
print("======= TRAIN x TEST MODELS =======")

> Track loss: try BCE (Binary Cross Entropy)

In [None]:
sklmf = SkLearnModelFactory
perception_model = sklmf.select_model('perceptron')
sgd_classifier_model = sklmf.select_model('sgd_classifier')
logistic_regression_model = sklmf.select_model('logistic_regression')
ridge_classifier_model = sklmf.select_model('ridge_classifier')
# linear_regression_model = sklmf.select_model('linear_regression')
# elastic_net_model = sklmf.select_model('elastic_net')
decision_tree_classifier_model = sklmf.select_model('decision_tree_classifier')
random_forest_classifier_model = sklmf.select_model('random_forest_classifier')
gradient_boosting_classifier_model = sklmf.select_model('gradient_boosting_classifier')

ml_models = [perception_model, sgd_classifier_model, logistic_regression_model, ridge_classifier_model, decision_tree_classifier_model, random_forest_classifier_model, gradient_boosting_classifier_model]

In [None]:
models_with_predictions = {}
for ml_model in ml_models:
    print(f"Train -> Predict for {ml_model.get_model_name()}")
    ml_model.train_model(X_train_df[embeddings_col_name].to_list(), y_train_df)
    ml_model_predictions = ml_model.predict(X_test_df[embeddings_col_name].to_list())
    models_with_predictions[ml_model.get_model_name()] = ml_model_predictions

models_with_predictions

In [None]:
# models_predictions_df = pd.DataFrame(models_to_predictions)
# models_predictions_df

In [None]:
y_test_df.rename(index='Actual Label', inplace=True)

In [None]:
test_and_models_df = pd.concat([X_test_df.loc[:, :], y_test_df], axis=1)
# test_and_models_df = pd.concat([test_df, models_predictions_df])

for key, value in models_with_predictions.items():
    test_and_models_df[key] = value.to_numpy().ravel()

test_and_models_df.head(3)

In [None]:
test_and_models_df[(test_and_models_df['Sentence Label'] == 0)].shape

## Save Output

In [None]:
DataProcessing.save_to_file(test_and_models_df, combine_data_path, 'ml_classifiers', '.csv')

## Evaluation

In [None]:
print("======= EVALUATION/RESULTS =======")

In [None]:
get_metrics = EvaluationMetric()
get_metrics

> - Results may differ (from previous runs and even terminal runs) because we shuffle the data.

![image.png](attachment:image.png)

In [None]:
eval_reports = {}
actual_label = test_and_models_df['Actual Label'].values
for ml_model in ml_models:
    ml_model_name = ml_model.get_model_name()
    print(f"Actual Label:\t\t{actual_label}")
    ml_model_predictions = test_and_models_df[ml_model_name].values
    print(f"{ml_model_name}:\t\t{ml_model_predictions}")
    print()
    eval_report = get_metrics.eval_classification_report(y_test_df, ml_model_predictions)
    eval_reports[ml_model_name] = eval_report

In [None]:
eval_reports

In [None]:
eval_reports_df = pd.DataFrame(eval_reports)
eval_reports_df.to_latex()