In [2]:
import numpy as np
import nltk
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score
from gensim.utils import simple_preprocess
from gensim.models import TfidfModel, LdaModel
from gensim import corpora
nltk.download('stopwords')
stop_words = nltk.corpus.stopwords.words('english')

categories = ['rec.autos', 'comp.graphics', 'sci.space']
newsgroup = fetch_20newsgroups(subset='all', categories=categories, shuffle=True, remove=('headers', 'footers', 'quotes'))
print(newsgroup.filenames.shape)

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\zhest\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


(2950,)


In [3]:
tokenized_documents = [simple_preprocess(text) for text in newsgroup.data]
dictionary = corpora.Dictionary(tokenized_documents)
bow_corpus = [dictionary.doc2bow(doc) for doc in tokenized_documents]
modelTf = TfidfModel(bow_corpus)
tf_corpus = [modelTf[corpus_item] for corpus_item in bow_corpus]

In [4]:
def train(X_train, X_test, y_train, y_test):
    clf = RandomForestClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    return f1_score(y_test, y_pred, average='macro')

In [5]:
for n in [200,210,230,250,270,280,300,350,400]:
    lda = LdaModel(tf_corpus, num_topics=n, id2word=dictionary,passes=15, minimum_probability = 0)
    vectorized_corpus = lda[tf_corpus]
    vectorized_corpus_new = [0 for i in range(2950)]
    for i in range(len(vectorized_corpus)):
        curr = []
        a = vectorized_corpus[i]
        length = len(a)
        curr = [a[j][1] for j in range(n)]
        vectorized_corpus_new[i] = curr
    X_train, X_test, y_train, y_test = train_test_split(vectorized_corpus_new, newsgroup.target, test_size=0.33)
    print(f"N:{n},result is {train(X_train, X_test, y_train, y_test)}")

N:200,result is 0.6198551859430407
N:210,result is 0.5700157261701455
N:230,result is 0.5568983127820979
N:250,result is 0.5969885420884736
N:270,result is 0.4938585303644336
N:280,result is 0.44891257356875336
N:300,result is 0.43422628442844186
N:350,result is 0.38608868534419366
N:400,result is 0.16408668730650156
