### 20 topics 중에 8개 topic을 뽑은 후, 8개 topic으로 재구성했을때 각 topic을 표현하고 있는지 체크하는 예제

In [4]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation

cats = ['rec.motorcycles', 'rec.sport.baseball', 'comp.graphics', 'comp.windows.x',
        'talk.politics.mideast', 'soc.religion.christian', 'sci.electronics', 'sci.med']

# categories: 8개만 지정
news_df = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'), categories=cats, random_state=0)

# LDA applies count based vectorizer
count_vect = CountVectorizer(max_df=0.95, max_features=1000, min_df=2, stop_words='english', ngram_range=(1,2))
feat_vect = count_vect.fit_transform(news_df.data)
print('CountVectorizer Shape:', feat_vect.shape)

CountVectorizer Shape: (7862, 1000)


In [5]:
lda = LatentDirichletAllocation(n_components=8, random_state=0)
lda.fit(feat_vect)

LatentDirichletAllocation(n_components=8, random_state=0)

In [6]:
print(lda.components_.shape)
lda.components_

(8, 1000)


array([[3.60992018e+01, 1.35626798e+02, 2.15751867e+01, ...,
        3.02911688e+01, 8.66830093e+01, 6.79285199e+01],
       [1.25199920e-01, 1.44401815e+01, 1.25045596e-01, ...,
        1.81506995e+02, 1.25097844e-01, 9.39593286e+01],
       [3.34762663e+02, 1.25176265e-01, 1.46743299e+02, ...,
        1.25105772e-01, 3.63689741e+01, 1.25025218e-01],
       ...,
       [3.60204965e+01, 2.08640688e+01, 4.29606813e+00, ...,
        1.45056650e+01, 8.33854413e+00, 1.55690009e+01],
       [1.25128711e-01, 1.25247756e-01, 1.25005143e-01, ...,
        9.17278769e+01, 1.25177668e-01, 3.74575887e+01],
       [5.49258690e+01, 4.47009532e+00, 9.88524814e+00, ...,
        4.87048440e+01, 1.25034678e-01, 1.25074632e-01]])

In [23]:
# i.e. check to see if topic 0 is related to topics listed below
# topic 0: year+game+medical+health+team+desease+cancer+patients -> sci.med 로 추정됨 
def display_topic_words(model, feature_names, no_top_words):
    for topic_index, topic in enumerate(model.components_):
        print('\nTopic #', topic_index)
        print(topic[0])
        
        # argsort 로 일단 index 순서 확보하고, 역순으로(큰 순서대로)
        topic_word_indexes = topic.argsort()[::-1]
        top_indexes = topic_word_indexes[:no_top_words]
        
        feature_concat = '+'.join([str(feature_names[i]) + '*' + str(round(topic[i],1)) for i in top_indexes])
        print(feature_concat)


feature_names = count_vect.get_feature_names()

display_topic_words(lda, feature_names, 15)

# cats = ['rec.motorcycles', 'rec.sport.baseball', 'comp.graphics', 'comp.windows.x',
#         'talk.politics.mideast', 'soc.religion.christian', 'sci.electronics', 'sci.med']


Topic # 0
36.09920175842141
year*703.2+10*563.6+game*476.3+medical*413.2+health*377.4+team*346.8+12*343.9+20*340.9+disease*332.1+cancer*319.9+1993*318.3+games*317.0+years*306.5+patients*299.8+good*286.3

Topic # 1
0.1251999196700124
don*1454.3+just*1392.8+like*1190.8+know*1178.1+people*836.9+said*802.5+think*799.7+time*754.2+ve*676.3+didn*675.9+right*636.3+going*625.4+say*620.7+ll*583.9+way*570.3

Topic # 2
334.76266340911604
image*1047.7+file*999.1+jpeg*799.1+program*495.6+gif*466.0+images*443.7+output*442.3+format*442.3+files*438.5+color*406.3+entry*387.6+00*334.8+use*308.5+bit*308.4+03*258.7

Topic # 3
0.12510914916233595
like*620.7+know*591.7+don*543.7+think*528.4+use*514.3+does*510.2+just*509.1+good*425.8+time*417.4+book*410.7+read*402.9+information*395.2+people*393.5+used*388.2+post*368.4

Topic # 4
12.8163316007951
armenian*960.6+israel*815.9+armenians*699.7+jews*690.9+turkish*686.1+people*653.0+israeli*476.1+jewish*467.0+government*464.4+war*417.8+dos dos*401.1+turkey*393.5+ar

In [28]:
# transform makes 1000 feature -> 8 feature (SVD similar)
doc_topics = lda.transform(feat_vect)
print(doc_topics.shape)
print(doc_topics[:3])

(7862, 8)
[[0.01389701 0.01394362 0.01389104 0.48221844 0.01397882 0.01389205
  0.01393501 0.43424401]
 [0.27750436 0.18151826 0.0021208  0.53037189 0.00212129 0.00212102
  0.00212113 0.00212125]
 [0.00544459 0.22166575 0.00544539 0.00544528 0.00544039 0.00544168
  0.00544182 0.74567512]]


In [43]:
# this is actually only for this dataset so not geneally acceptable
def get_filename_list(newsdata):
    filename_list = []
    
    for file in newsdata.filenames:
        filename_temp = file.split('/')[-2:]
        # combine two values by adding .
        filename = '.'.join(filename_temp)
        filename_list.append(filename)
        
    return filename_list

filename_list = get_filename_list(news_df)
print(filename_list[:10])


['soc.religion.christian.20630', 'sci.med.59422', 'comp.graphics.38765', 'comp.graphics.38810', 'sci.med.59449', 'comp.graphics.38461', 'comp.windows.x.66959', 'rec.motorcycles.104487', 'sci.electronics.53875', 'sci.electronics.53617']


In [46]:
# Visualization w dataframe 
import pandas as pd

topic_names = ['Topic #'+ str(i) for i in range(8)]
doc_topic_df = pd.DataFrame(data=doc_topics, columns=topic_names, index=filename_list)
doc_topic_df.head()

Unnamed: 0,Topic #0,Topic #1,Topic #2,Topic #3,Topic #4,Topic #5,Topic #6,Topic #7
soc.religion.christian.20630,0.013897,0.013944,0.013891,0.482218,0.013979,0.013892,0.013935,0.434244
sci.med.59422,0.277504,0.181518,0.002121,0.530372,0.002121,0.002121,0.002121,0.002121
comp.graphics.38765,0.005445,0.221666,0.005445,0.005445,0.00544,0.005442,0.005442,0.745675
comp.graphics.38810,0.005439,0.005441,0.005449,0.578959,0.00544,0.388387,0.005442,0.005442
sci.med.59449,0.006584,0.552,0.006587,0.408485,0.006585,0.006585,0.006588,0.006585
