In [48]:
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer


In [None]:
categories = ['alt.atheism', 'soc.religion.christian',
...               'comp.graphics', 'sci.med']
from sklearn.datasets import fetch_20newsgroups
twenty_train = fetch_20newsgroups(subset='train',
    categories=categories, shuffle=True, random_state=42)
twenty_test = fetch_20newsgroups(subset='test',
    categories=categories, shuffle=True, random_state=42)
print(len(twenty_train.data))
print(len(twenty_test.data))
print(twenty_train.target_names)

2257
1502
['alt.atheism', 'comp.graphics', 'sci.med', 'soc.religion.christian']


In [62]:
from sklearn.model_selection import GridSearchCV
text_clf = Pipeline([
    ('vect', CountVectorizer()),
    ('tfidf', TfidfTransformer()),
    ('clf', MultinomialNB()),
])
parameters = [{  'vect__stop_words': ('english', None)},{'clf__alpha': (0.1, 1.0)}]
grridsearch = GridSearchCV(text_clf, parameters, n_jobs=-1)
grridsearch.fit(twenty_train.data, twenty_train.target)


In [63]:
best_parameters = grridsearch.best_params_
print(f"Best parameters found: {best_parameters}")

Best parameters found: {'clf__alpha': 0.1}


In [64]:
from sklearn.metrics import classification_report

In [65]:
from sklearn.metrics import accuracy_score, classification_report

# Predict on the test set
y_pred = grridsearch.predict(twenty_test.data)

# Calculate accuracy
accuracy = accuracy_score(twenty_test.target, y_pred)
print("Accuracy on test set:", accuracy)

# Print a detailed classification report
print(classification_report(twenty_test.target, y_pred, target_names=twenty_test.target_names))

Accuracy on test set: 0.9254327563249002
                        precision    recall  f1-score   support

           alt.atheism       0.97      0.84      0.90       319
         comp.graphics       0.95      0.96      0.95       389
               sci.med       0.97      0.91      0.94       396
soc.religion.christian       0.85      0.98      0.91       398

              accuracy                           0.93      1502
             macro avg       0.93      0.92      0.92      1502
          weighted avg       0.93      0.93      0.93      1502

