# 03_train_tfidf

Train TF-IDF + Logistic Regression / Naive Bayes baselines.

In [1]:
import os, joblib, pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

DATA_DIR = 'data/imdb'
art = 'artifacts/tfidf'
os.makedirs(art, exist_ok=True)
train_df = pd.read_csv(os.path.join(DATA_DIR,'train.csv'))
test_df = pd.read_csv(os.path.join(DATA_DIR,'test.csv'))
X_train, X_val, y_train, y_val = train_test_split(train_df['text'], train_df['label'], test_size=0.1, random_state=42, stratify=train_df['label'])
vect = TfidfVectorizer(max_features=20000, ngram_range=(1,2))
X_tr = vect.fit_transform(X_train)
X_val_t = vect.transform(X_val)
X_test_t = vect.transform(test_df['text'])
lr = LogisticRegression(max_iter=1000)
lr.fit(X_tr, y_train)
print('LR val report:\n', classification_report(y_val, lr.predict(X_val_t)))
joblib.dump({'model': lr, 'vectorizer': vect}, os.path.join(art, 'lr_tfidf.joblib'))
nb = MultinomialNB()
nb.fit(X_tr, y_train)
print('NB val report:\n', classification_report(y_val, nb.predict(X_val_t)))
joblib.dump({'model': nb, 'vectorizer': vect}, os.path.join(art, 'nb_tfidf.joblib'))


LR val report:
               precision    recall  f1-score   support

           0       0.90      0.89      0.90      1250
           1       0.89      0.90      0.90      1250

    accuracy                           0.90      2500
   macro avg       0.90      0.90      0.90      2500
weighted avg       0.90      0.90      0.90      2500

NB val report:
               precision    recall  f1-score   support

           0       0.87      0.88      0.87      1250
           1       0.88      0.87      0.87      1250

    accuracy                           0.87      2500
   macro avg       0.87      0.87      0.87      2500
weighted avg       0.87      0.87      0.87      2500



['artifacts/tfidf/nb_tfidf.joblib']