In [2]:
import nltk
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import string

nltk.download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/pepijnschouten/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [3]:
categories = ['alt.atheism', 'comp.graphics',
'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale',
'rec.autos','rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey',
'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space',
'soc.religion.christian', 'talk.politics.guns',
'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc']

newsgroup_data = fetch_20newsgroups(
    subset='all', 
    remove=('headers', 'footers', 'quotes'),
    shuffle=True,
    random_state=87)

X_train, X_test, y_train, y_test = train_test_split(
    newsgroup_data.data,
    newsgroup_data.target,
    test_size=0.2,
    random_state=87)

In [None]:
def tokenize(text):
    tokens = nltk.word_tokenize(text)
    stop_words = set(nltk.corpus.stopwords.words('english'))
    punctuations = set(string.punctuation)
    filtered_tokens = [token for token in tokens
                       if token.lower() not in stop_words
                       and token not in punctuations]
    return filtered_tokens

vectorizer = TfidfVectorizer(
    stop_words='english',
    tokenizer=tokenize,)

X_train_vectors = vectorizer.fit_transform(X_train)
X_test_vectors = vectorizer.transform(X_test)



In [8]:
clf = MultinomialNB()
clf.fit(X_train_vectors, y_train)
y_pred = clf.predict(X_test_vectors)

print("-" * 35)
for i in range(5):
    actual_label = newsgroup_data.target_names[y_test[i]]
    predicted_label = newsgroup_data.target_names[y_pred[i]]
    print("Actual:", actual_label)
    print("Predicted:", predicted_label)
    print("-" * 35)
    
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
print("-" * 35)

-----------------------------------
Actual: comp.sys.mac.hardware
Predicted: comp.sys.mac.hardware
-----------------------------------
Actual: rec.motorcycles
Predicted: rec.motorcycles
-----------------------------------
Actual: soc.religion.christian
Predicted: sci.med
-----------------------------------
Actual: alt.atheism
Predicted: soc.religion.christian
-----------------------------------
Actual: comp.sys.mac.hardware
Predicted: comp.sys.mac.hardware
-----------------------------------
Accuracy: 0.7124668435013263
-----------------------------------
