In [17]:
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import joblib
import pandas as pd
import os

def train_and_save_model():
    # Load preprocessed data
    X_train = pd.read_csv('data/X_train.csv')
    X_test = pd.read_csv('data/X_test.csv')
    y_train = pd.read_csv('data/y_train.csv')
    y_test = pd.read_csv('data/y_test.csv')

    # Initialize models
    models = {
        'Logistic Regression': LogisticRegression(),
        'Naive Bayes': MultinomialNB(),
        'SVM': SVC(),
        'Random Forest': RandomForestClassifier()
    }

    results = {}

    # Train and evaluate each model
    for name, model in models.items():
        model.fit(X_train, y_train.values.ravel())
        y_pred = model.predict(X_test)
        results[name] = {
            'accuracy': accuracy_score(y_test, y_pred),
            'precision': precision_score(y_test, y_pred),
            'recall': recall_score(y_test, y_pred),
            'f1': f1_score(y_test, y_pred)
        }

    # Print results
    print(results)

    # Save the best model based on accuracy
    best_model_name = max(results, key=lambda k: results[k]['accuracy'])
    best_model = models[best_model_name]

    # Ensure models folder exists
    os.makedirs('models', exist_ok=True)

    joblib.dump(best_model, 'models/best_model.pkl')

if __name__ == "__main__":
    train_and_save_model()


{'Logistic Regression': {'accuracy': 0.97847533632287, 'precision': 0.8881987577639752, 'recall': 0.959731543624161, 'f1': 0.9225806451612903}, 'Naive Bayes': {'accuracy': 0.9668161434977578, 'precision': 0.8255813953488372, 'recall': 0.9530201342281879, 'f1': 0.8847352024922118}, 'SVM': {'accuracy': 0.957847533632287, 'precision': 0.7965116279069767, 'recall': 0.9194630872483222, 'f1': 0.8535825545171339}, 'Random Forest': {'accuracy': 0.9901345291479821, 'precision': 0.9928571428571429, 'recall': 0.9328859060402684, 'f1': 0.9619377162629759}}
