# NHTSA Complaints
https://www.nhtsa.gov/nhtsa-datasets-and-apis
Complaint information entered into NHTSA’s Office of Defects Investigation vehicle owner's complaint database is used with other data sources to identify safety issues that warrant investigation and to determine if a safety-related defect trend exists. Complaint information is also analyzed to monitor existing recalls for proper scope and adequacy.

Overview: Based on customer comments, predict primary failure part.

Technical: This is a text classification task.

Steps:
- Produce embedding (vector representation) for customer comment
- Use vector elements as features for classification model (logistic, etc.)

Embedding approaches:
- GloVe
- Sentence Transformer

Classification approach:
- Logistic

## Environment

**NOTE**: Use Instance GPU instance or transformer sentence embeddings computations will take forever.

In [1]:
# mount google drive
from google.colab import drive
import os
drive.mount('/content/drive')

# change dir
cur_path = '/content/drive/MyDrive/transformers/transformer_presentation'
os.chdir(cur_path)
!pwd

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1vfLs8lJDQqBTM3iH8G_76AWSt_g_WCC7/Transformers/transformer_presentation


In [2]:
! pip install sentence_transformers tqdm -qqq

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/86.0 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 KB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m74.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m62.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m99.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for sentence_transformers (setup.py) ... [?25l[?25hdone


In [3]:
import pandas as pd
import pyarrow as pa
from sentence_transformers import SentenceTransformer
import torch

## Data Prep

In [None]:
! pwd

/content/drive/.shortcut-targets-by-id/1vfLs8lJDQqBTM3iH8G_76AWSt_g_WCC7/Transformers/transformer_presentation


In [None]:
# download data
! wget https://static.nhtsa.gov/odi/ffdd/cmpl/COMPLAINTS_RECEIVED_2020-2023.zip

--2023-03-20 00:02:22--  https://static.nhtsa.gov/odi/ffdd/cmpl/COMPLAINTS_RECEIVED_2020-2023.zip
Resolving static.nhtsa.gov (static.nhtsa.gov)... 104.69.162.16, 2600:1413:1:59c::27ea, 2600:1413:1:592::27ea
Connecting to static.nhtsa.gov (static.nhtsa.gov)|104.69.162.16|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/octet-stream]
Saving to: ‘COMPLAINTS_RECEIVED_2020-2023.zip’

COMPLAINTS_RECEIVED     [                <=> ]  42.43M  13.8MB/s    in 3.7s    

2023-03-20 00:02:27 (11.5 MB/s) - ‘COMPLAINTS_RECEIVED_2020-2023.zip’ saved [44494113]



In [19]:
df = pd.read_csv('COMPLAINTS_RECEIVED_2020-2023.zip', sep='\t', low_memory=False)
print(df.shape)

(246876, 49)


In [20]:
# columns in appendix a: https://static.nhtsa.gov/odi/ffdd/cmpl/Import_Instructions_Excel_5-year.pdf
df.columns = ['CMPLID','ODINO','MFR_NAME','MAKETXT','MODELTXT','YEARTXT','CRASH'
              ,'FAILDATE','FIRE','INJURED','DEATHS','COMPDESC','CITY','STATE','VIN'
              ,'DATEA','LDATE','MILES','OCCURENCES','CDESCR','CMPL_TYPE'
              ,'POLICE_RPT_YN','PURCH_DT','ORIG_OWNER_YN','ANTI_BRAKES_YN'
              ,'CRUISE_CONT_YN','NUM_CYLS','DRIVE_TRAIN','FUEL_SYS','FUEL_TYPE'
              ,'TRANS_TYPE','VEH_SPEED','DOT','TIRE_SIZE','LOC_OF_TIRE','TIRE_FAIL_TYPE'
              ,'ORIG_EQUIP_YN','MANUF_DT','SEAT_TYPE','RESTRAINT_TYPE','DEALER_NAME'
              ,'DEALER_TEL','DEALER_CITY','DEALER_STATE','DEALER_ZIP','PROD_TYPE'
              ,'REPAIRED_YN','MEDICAL_ATTN','VEHICLES_TOWED_YN']

In [21]:
selected_cols = ['CMPLID','MAKETXT','MODELTXT','YEARTXT','FAILDATE','CMPL_TYPE','COMPDESC','CDESCR']
df = df[selected_cols]
df.columns = ['id','make','model','year','event_date','complaint_type','vehicle_component','complaint']

In [22]:
# convert the column of string dates to datetime format
df['event_date'] = pd.to_datetime(df['event_date'], format='%Y%m%d')

Selecting complaints for components with more than 10,000 complaints

In [23]:
pd.set_option("display.max_rows", 25)
df['vehicle_component'].value_counts().head(25)

ENGINE                                                      34361
ELECTRICAL SYSTEM                                           29209
UNKNOWN OR OTHER                                            24780
POWER TRAIN                                                 22347
STEERING                                                    17760
SERVICE BRAKES                                              14681
AIR BAGS                                                    12180
FUEL/PROPULSION SYSTEM                                      11321
STRUCTURE:BODY                                               8302
SUSPENSION                                                   8093
VEHICLE SPEED CONTROL                                        6631
VISIBILITY/WIPER                                             6538
EXTERIOR LIGHTING                                            6134
FORWARD COLLISION AVOIDANCE: AUTOMATIC EMERGENCY BRAKING     4210
WHEELS                                                       2862
SEATS     

In [24]:
top_components = list(df['vehicle_component'].value_counts().head(8).index)
top_components.remove("UNKNOWN OR OTHER")

In [25]:
df = df[df['vehicle_component'].isin(top_components)]
df.shape

(141859, 8)

In [26]:
balanced_df = df.groupby('vehicle_component').apply(lambda x: x.sample(n=10000, replace=False)).reset_index(drop=True)

In [27]:
balanced_df['vehicle_component'].value_counts()

AIR BAGS                  10000
ELECTRICAL SYSTEM         10000
ENGINE                    10000
FUEL/PROPULSION SYSTEM    10000
POWER TRAIN               10000
SERVICE BRAKES            10000
STEERING                  10000
Name: vehicle_component, dtype: int64

In [28]:
df = balanced_df

In [29]:
# quick look sample of complaints
pd.set_option('max_colwidth', 500)
df.sample(5).T

Unnamed: 0,12716,18843,67367,20128,64417
id,1666084,1832859,1742139,1817255,1795724
make,HONDA,HONDA,FORD,CHEVROLET,FORD
model,CR-V,ODYSSEY,FUSION,SUBURBAN,F-150
year,2019.0,2019.0,2012.0,2017.0,2016.0
event_date,2020-05-30 00:00:00,2022-02-14 00:00:00,2021-04-20 00:00:00,2022-02-08 00:00:00,2022-02-04 00:00:00
complaint_type,IVOQ,IVOQ,IVOQ,IVOQ,EVOQ
vehicle_component,ELECTRICAL SYSTEM,ELECTRICAL SYSTEM,STEERING,ENGINE,STEERING
complaint,"WHILE DRIVING ON A HIGHWAY THE INFOTAINMENT CENTER SUDDENLY BEGAN BEEPING AND TURNING ON AND OFF REPEATEDLY. THE BRIGHTNESS CONTROL BEGAN TO BEEP AND TURN TO A HIGHER SETTING AND THEN A LOWER ONE. NO ONE WAS TOUCHING THE SCREENS/CONTROLS. I VIDEOED THIS. EARLIER IN THE MONTH WHILE I WAS DRIVING AND THE TIRE PRESSURE MONITOR SAID TIRES WERE LOW BUT WHEN I MEASURED THEM, THEY WERE NOT SO WHILE THE CAR WAS STATIONARY, I CHECKED AND THE MESSAGE WAS STILL THERE.. BEFORE THAT I HAD ERROR ME...","The passenger side, rear sliding door does not engage in the locking mechanism. It intermittently lacks securely latching. In looking into the door, it appears that the actuator is not engaging properly. Upon researching this for a few minutes, I do see that there were recalls on 2018-2019 Honda Odysseys (as well as other models of Hondas). My car is a 2019 Odyssey. I have called Honda about this issue and they do not acknowledge this recall in regards to my VIN. Service Bulletin 18-128...","VEHICLE STEERING HAS BEEN ACTING UP FOR A FEW DAYS LIKE IT IS CATCHING AND PULLING IN THE DIRECTION I JUST SLIGHTLY VEER TO. 2 DAYS LATER, WHILE DRIVING ON MAIN HIGHWAY, 60MPH, SERVICE POWER STEERING AND SERVICE ADVANCETRAC DISPLAY ON DASH. COMPLETE LOSS OF POWER STEERING ALMOST IMPOSSIBLE TO TURN STEERING WHEEL.","Lifter and cam shaft failure in 2017 Suburban at 87,000 miles. Vehicle was purchased new and well maintained when engine knocking or clicking started. What component or system failed or malfunctioned, and is it available for inspection upon request? Camshaft, lifters, pistons How was your safety or the safety of others put at risk? Had failure not been detected I could have been stranded or lost engine power in situation that created a collision and injury. Has the problem been reproduced ...","The contact owns a 2016 Ford F-150. The contact stated that after stopping and making a right turn, the vehicle lost power steering assist and locked up, while driving 5 MPH. The contact stopped, turned off, and then restarted the vehicle however, the power steering assist did not return. The vehicle was driven to the residence in manual steering mode. The vehicle was taken to the local dealer to be diagnosed. The contact was informed that the power steering unit needed to be replaced. The v..."


### Normalization

Note: this can be done more efficiently, but just getting something working here.

In [30]:
import nltk
import re
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize

In [31]:
nltk.download('stopwords')
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')


[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [32]:
from nltk.corpus import wordnet

def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN


In [33]:
def preprocess_text(text):
    # lowercase
    text = text.lower()

    # remove special characters and punctuation
    text = re.sub(r'[^a-zA-Z\s]', ' ', text)

    # tokenize
    tokens = word_tokenize(text)

    # remove stopwords
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if word not in stop_words]

    # lemmatize
    lemmatizer = WordNetLemmatizer()
    tagged = nltk.pos_tag(tokens)
    tokens = [lemmatizer.lemmatize(word, get_wordnet_pos(tag)) for word, tag in tagged]

    # rejoin string
    preprocessed_text = ' '.join(tokens)

    return preprocessed_text


In [34]:
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()

# preprocess text
df['complaint'] = df['complaint'].progress_apply(preprocess_text)

  0%|          | 0/70000 [00:00<?, ?it/s]

In [35]:
df['complaint_length'] = df['complaint'].str.len()
df['complaint_length'].describe()

count    70000.000000
mean       342.898629
std        255.326318
min          0.000000
25%        154.000000
50%        286.000000
75%        455.000000
max       1552.000000
Name: complaint_length, dtype: float64

In [36]:
cutoff_length = 100
df = df[df['complaint_length'] > cutoff_length].reset_index(drop=True)

## Compute Comment Embeddings

Starts with static word embeddings and uses a deep neural network to learn contextual meanings of the words. Pools word embeddings for full sentence/comment embedding.

### Transformer
Pretrained Sentence Transformer Models: https://www.sbert.net/docs/pretrained_models.html

all-MiniLM-L6-v2: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2

Nils Reimers, Iryna Gurevych. 2019. [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084)

In [37]:
# initial download model to local
model_name = 'all-MiniLM-L6-v2'
#model_path = 'sentence-transformers/'
#model = SentenceTransformer(model_path + model_name)
#model.save('./' + model_name)

In [38]:
transformer_model = SentenceTransformer('./' + model_name)
transformer_model = transformer_model.to(torch.device('cuda')) # use GPU

In [39]:
complaints = df['complaint']
print(len(complaints))

60219


In [40]:
%%time

# compute vector representations of the comments
transformer_embeddings = transformer_model.encode(complaints)

CPU times: user 1min 2s, sys: 796 ms, total: 1min 3s
Wall time: 58.5 s


In [41]:
transformer_embeddings.shape

(60219, 384)

In [42]:
df['transformer_embedding'] = list(transformer_embeddings)

### GloVe
https://nlp.stanford.edu/projects/glove/

Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. [GloVe: Global Vectors for Word Representation](https://nlp.stanford.edu/pubs/glove.pdf).

In [43]:
# download model
model_name = 'average_word_embeddings_glove.840B.300d'
#model_path = 'sentence-transformers/'
#model = SentenceTransformer(model_path + model_name)
#model.save('./' + model_name)

In [44]:
glove_model = SentenceTransformer('./' + model_name)

In [45]:
%%time

# compute vector representations of the comments
glove_comment_embeddings = glove_model.encode(complaints)

CPU times: user 6.49 s, sys: 222 ms, total: 6.71 s
Wall time: 6.87 s


In [46]:
df['glove_embedding'] = list(glove_comment_embeddings)

### Results

In [47]:
df[['id','vehicle_component','complaint','transformer_embedding','glove_embedding']].sample(5)

Unnamed: 0,id,vehicle_component,complaint,transformer_embedding,glove_embedding
12418,1749693,ELECTRICAL SYSTEM,think unfair manufacture problem unnoticed left people pay already pay enough car also worry manufacture problem leave unattended company certainly dangerous might car battery die way get immediate help right car acknowledge park mode one day car start use wrong shift unknowingly end car accident wont able anything leave unnoticed recall make even danger present driver car information find https mazda oemdtc com ignition turn press startstop button mazda,"[-0.030233692, 0.032133438, 0.022760015, 0.021018686, 0.05256883, 0.13174456, 0.089021936, 0.055530675, -0.07448465, 0.037720595, 0.12936272, 0.018751448, -0.005752621, 0.0077214474, -0.014856036, -0.04564667, 0.0426663, -0.011937177, -0.0778882, 0.018217236, -0.060160227, -0.018547496, -0.045764185, 0.07866061, -0.081829645, -0.0068552443, 0.0015151186, 0.034842126, 0.07637082, -0.0868866, -0.0038601076, 0.07929665, 0.040965002, -0.07605302, 0.0041219736, -0.067651115, -0.04374159, -0.01342...","[0.00060105877, 0.22678807, -0.13958123, 0.015119894, 0.024455028, -0.08711454, -0.07454371, -0.14263138, -0.04353049, 1.9338084, -0.13678229, 0.041464515, -0.05595109, -0.0056651365, -0.3084877, -0.12516601, -0.21803729, 1.1765188, -0.038645644, 0.05742539, 0.019516852, -0.08444442, 0.09006737, -0.07368893, 0.04290676, -0.07467095, -0.046565898, -0.16555168, 0.09879612, -0.15640289, 0.036210693, -0.11712872, -0.107714295, 0.17498374, 0.064284556, 0.021991171, 0.05353815, 0.07843531, 0.00479..."
50833,1664128,SERVICE BRAKES,happen three time three year suddenly error message pop say electrical park brake problem first two time go away fiddled electrical parking error message go away third time start car brake pad move car stationary happen park lot tr,"[-0.015238783, -0.051335715, 0.07837318, -0.033296652, 0.01381467, 0.0027175515, 0.051877927, 0.034465775, 0.09381481, 0.03440375, 0.107392564, -0.04904155, -0.0009952801, 0.027641423, -0.038501553, 0.006690969, -0.018557493, -0.019476881, -0.060030974, 0.043450166, 0.045401365, -0.014163181, -0.08441091, 0.08502019, -0.09116967, 0.055372763, -0.06364056, 0.026626162, 0.01674105, 0.023404744, -0.032732133, 0.033560693, 0.011472247, -0.013425471, 0.035449237, 0.056625273, -0.082529604, -0.036...","[0.25179362, 0.25886455, -0.06877896, -0.16750361, 0.032142807, -0.028235191, -0.08752924, -0.23182622, -0.11843743, 1.8754237, -4.5582172e-05, 0.08300912, -0.014669971, 0.0051850597, -0.35737783, -0.14584093, -0.009774891, 1.3699056, -0.08814926, 0.11366577, 0.09847581, -0.05505011, 0.069685616, -0.12687835, 0.12357699, -0.03263268, -0.03939826, -0.14043011, 0.0953412, -0.062985994, -0.07978461, -0.022129044, -0.057104617, 0.016791968, 0.11957993, 0.12575503, 0.07800268, 0.101803035, 0.0370..."
6744,1637923,AIR BAGS,one day air bag failiure light come om along fuel cutoff unavailble light light temporarily fix manually lock unlock door recently fix worked error happen cruise control turn,"[-0.033651054, 0.043016322, -0.017984105, 0.12798707, 0.03197828, -0.014826645, 0.058550425, -0.027831689, -0.012732637, 0.04005575, 0.15731415, 0.016853487, -0.036108844, 0.026284905, -0.0190675, 0.059456788, -0.04531917, -0.09567009, -0.057926938, 0.015499791, 0.03423588, -0.025321709, -0.00068101, 0.01597516, -0.082938485, -0.02117651, -0.023821779, 0.018956577, -0.01648202, -0.02052895, 0.032284204, 0.062101085, -0.029241113, 0.005826448, 0.099739715, 0.01028861, -0.022457954, -0.0194876...","[0.078259125, 0.14325035, 0.08262812, -0.014705657, -0.08233604, -0.0027936536, -0.10907388, -0.10794346, -0.05134443, 1.3804443, 0.03177326, 0.077025846, -0.046471033, -0.049769457, -0.176685, -0.17192644, -0.025476553, 1.3686185, -0.062250655, 0.010685643, 0.10074313, -0.035462413, -0.0048686126, -0.069500536, -0.08423969, -0.05831504, -0.11176141, -0.21926877, 0.10325934, -0.05728754, -0.10215102, -0.040776614, -0.051137574, 0.025116425, -0.0038626115, -0.031530745, 0.0008640335, -0.01012..."
34095,1654939,POWER TRAIN,drive ford escape turbo interstate start feel like transmission slip day later check engine light come code p b dealer code p reilly auto part reilly say generic code cam shaft sensor ford say code p b coolant bypass valve c need replace cost noticed ton complaint escape even recall problem ford escape turbo obviously need recall well also ford dealer mike murphy ford morton il say transmission fine though feel like,"[-0.06923738, -0.08915639, 0.06522927, 0.071054265, 0.022880306, 0.06815305, 0.060104504, 0.001432855, -0.01150523, -0.0442396, 0.06568161, -0.04851977, -0.01130552, -0.051788237, 0.017203078, -0.018692557, 0.07753968, -0.044262946, -0.027133225, -0.023756236, -0.09774977, 0.0155427335, -0.06688957, 0.08468998, -0.06959488, 0.059664965, -0.050436925, 0.0834055, 0.013586959, -0.09034566, -0.07832587, 0.03893884, -0.061033264, 0.018631365, 0.065544486, -0.06120817, -0.057029065, 0.019242944, 0...","[0.027058586, 0.28849366, -0.09220874, -0.078520805, 0.13816594, 0.016369203, 0.021603566, -0.18672574, 0.015348776, 1.2862654, -0.08398353, -0.011511967, 0.06992565, -0.20077991, -0.29659903, -0.13963878, -0.17402352, 1.431752, -0.12479244, -0.020641379, 0.09096382, -0.1506973, 0.07126474, -0.14716166, -0.039823968, -0.024986615, -0.14328073, -0.19125618, 0.18559957, 0.06650853, -0.036680404, -0.005337121, -0.014197332, 0.014320342, 0.017113538, -0.045151167, -0.008949817, 0.086723484, -0.1..."
12164,1730610,ELECTRICAL SYSTEM,tl contact own ram contact state vehicle park contact notice oil leak underneath vehicle vehicle take independent mechanic diagnose crack engine filter housing coolant unit vehicle repair contact receive notification nhtsa campaign number v electrical system vehicle speed control v latch lock linkages however part recall repair unavailable dealer central chrysler jeep boston providence turnpike norwood contact confirm part yet available contact state manufacturer exceed reasonable amount tim...,"[-0.13712965, -0.037152294, 0.1050495, 0.0064556003, 0.0145426355, 0.032075662, 0.0034053654, 0.013819418, 0.00999706, -0.07504839, 0.024878073, -0.08774059, -0.039908983, -0.08653034, -0.021644251, -0.015854424, 0.054424774, 0.0070585436, -0.034778275, -0.13067068, 0.023058238, 0.061378364, -0.01867224, 0.07331347, -0.14929198, 0.04418584, 0.043427736, 0.024671527, -0.024701195, -0.035148557, -0.034046054, 0.024392517, -0.047043495, 0.0032199903, 0.05119901, -0.021034362, -0.068591714, 0.00...","[0.14200008, 0.22929566, -0.03855455, 0.014933971, 0.1069296, -0.10290361, -0.13335665, -0.23999237, -0.04076034, 1.528655, -0.117373385, -0.003327625, -0.052089456, -0.1881613, -0.27471337, -0.11826014, -0.10564815, 1.4854552, -0.145039, 0.0039143637, 0.13205452, -0.16630968, 0.020787312, -0.13542758, 0.1051707, -0.003425244, -0.10748004, -0.14007246, 0.13954604, -0.015087329, 0.009022529, -0.05506965, -0.0333714, 0.16917354, -0.046575915, 0.01118408, -0.058261096, 0.059432324, -0.003268461..."


In [48]:
# save embeddings
df.to_parquet('df_embeddings.parquet')

## Classification

In [49]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

def assess_performance(y_pred,y_val):
    
    print("Accuracy:", round(accuracy_score(y_val.tolist(), y_pred),3))
    print("Precision:", round(precision_score(y_val.tolist(), y_pred, average='weighted'),3))
    print("Recall:", round(recall_score(y_val.tolist(), y_pred, average='weighted'),3))
    print("F1 score:", round(f1_score(y_val.tolist(), y_pred, average='weighted'),3))
    

In [50]:
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split

X = df[['transformer_embedding','glove_embedding']]
y = df['vehicle_component']

# Split the data into train, validation, and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=42)

# Print the number of rows in each set
print("Train set size: ", len(X_train))
print("Test set size: ", len(X_test))


Train set size:  36131
Test set size:  24088


### Transformer

In [51]:
X_train_t, X_test_t = X_train['transformer_embedding'], X_test['transformer_embedding']

In [52]:
# train
transformer_log_model = LogisticRegression(random_state=42, max_iter = 1000)
transformer_log_model.fit(X_train_t.tolist(), y_train.tolist())

# assess
y_pred_tl = transformer_log_model.predict(X_test_t.tolist())
assess_performance(y_pred_tl,y_test)

Accuracy: 0.694
Precision: 0.695
Recall: 0.694
F1 score: 0.694


### GloVe

In [53]:
X_train_g, X_test_g = X_train['glove_embedding'], X_test['glove_embedding']

In [54]:
# train
glove_log_model = LogisticRegression(random_state=42, max_iter = 1000)
glove_log_model.fit(X_train_g.tolist(), y_train.tolist())

# assess
y_pred_gl = glove_log_model.predict(X_test_g.tolist())
assess_performance(y_pred_gl,y_test)

Accuracy: 0.683
Precision: 0.683
Recall: 0.683
F1 score: 0.682
