In [16]:
import pandas as pd
import numpy as np
from scipy.linalg import norm, eigh
from sklearn.decomposition import PCA
import plotly_express as px

In [2]:
OHCO = ['doc_source', 'doc_id', 'sent_num', 'token_num']

In [3]:
LIB = pd.read_csv('LIB.csv').set_index('doc_id')
LIB['doc_date'] = pd.to_datetime(LIB['doc_date'])
CORPUS = pd.read_csv('CORPUS.csv').set_index(OHCO)
VOCAB = pd.read_csv('VOCAB.csv').set_index('term_str')

In [4]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation as LDA
import plotly_express as px

class TopicExplorer:
    
    n_features = 4000
    stopwords = 'english'
    lda_num_topics = 20
    lda_max_iter = 5
    lda_n_top_terms = 7
    
    def __init__(self, tokens_df, lib_df, bag, labels=[]):
        self.TOKENS = tokens_df
        self.LIB = lib_df
        self.bag = bag
        self.labels = labels
        
    def generate_tables(self):
        print("BAG:", self.bag[-1])
        print("LABELS:", self.labels)
        print("Getting DOCS")
        self._get_docs()
        print("Getting TERMS")
        self._get_count_model()
        print("Getting THETA, PHI")
        self._get_topic_model()
        print("Getting TOPICS")
        self._get_topics()
        print('Binding LIB labels to THETA')
        self._bind_labels()
        print("Done.")
        return self
        
    def _get_docs(self, pos_remove_pat=r'^NNS?$'):
        self.DOCS = self.TOKENS[self.TOKENS.pos.str.match(pos_remove_pat)]\
            .groupby(self.bag).term_str\
            .apply(lambda x: ' '.join(x))\
            .to_frame()\
            .rename(columns={'term_str':'doc_str'})
        
    def _get_count_model(self):
        self.count_engine = CountVectorizer(max_features=self.n_features, 
                                            stop_words=self.stopwords)
        self.count_model = self.count_engine.fit_transform(self.DOCS.doc_str)
        self.TERMS = self.count_engine.get_feature_names_out()
        
    def _get_topic_model(self):
        self.lda_engine = LDA(n_components=self.lda_num_topics, 
                              max_iter=self.lda_max_iter, 
                              learning_offset=50., 
                              random_state=0)
        self.THETA = pd.DataFrame(self.lda_engine.fit_transform(self.count_model), 
                                  index=self.DOCS.index)
        self.THETA.columns.name = 'topic_id'
        self.PHI = pd.DataFrame(self.lda_engine.components_, columns=self.TERMS)
        self.PHI.index.name = 'topic_id'
        self.PHI.columns.name = 'term_str'
        
    def _get_topics(self, n_terms=10):
        self.TOPICS = self.PHI.stack().to_frame('weight')\
            .groupby('topic_id')\
            .apply(lambda x: x.weight.sort_values(ascending=False)\
               .head(self.lda_n_top_terms)\
               .reset_index()\
               .drop('topic_id', axis=1)\
               .term_str)
        self.TOPICS['label'] = self.TOPICS[[t for t in range(self.lda_n_top_terms)]]\
            .apply(lambda x: str(x.name)\
                   .zfill(len(str(self.lda_num_topics))) + ' ' + ' '.join(x), axis=1)
        self.TOPICS['doc_weight_sum'] = self.THETA.sum()
        self.topic_cols = [t for t in range(self.lda_num_topics)]
        
    def _bind_labels(self):
        self.LABELS = {}
        self.LABEL_VALUES = {}
        for label in self.labels:
            self.THETA[label] = self.THETA\
                .apply(lambda x: self.LIB.loc[x.name[0], label], axis=1)
            self.LABELS[label] = self.THETA.groupby(label)[self.topic_cols].mean().T  
            self.THETA = self.THETA.drop(label, axis=1) # Don't keep the column
            self.LABELS[label].index.name = 'topic_id'
            self.LABELS[label]['label'] = self.TOPICS['label']
            self.LABEL_VALUES[label] = sorted(list(set(self.LIB[label])))
            
    def show_dominant_label_topic(self, label):
        X = self.LABELS[label][self.LABEL_VALUES[label]].idxmax()
        return X.to_frame('topic_id').topic_id.map(self.TOPICS.label)
            
    def show_label_values(self):
        for label in self.LABEL_VALUES:
            print(label, ": ", self.LABEL_VALUES[label])
        
    def show_topic_bar(self):
        fig_height = self.lda_num_topics / 3
        self.TOPICS.sort_values('doc_weight_sum', ascending=True)\
            .plot.barh(y='doc_weight_sum', x='label', figsize=(5, fig_height));
        
    def show_topic_label_heatmap(self, label):
        return MP.LABELS[label][MP.LABEL_VALUES[label]].style.background_gradient()
        
    def show_label_comparison_plot(self, label, label_value_x, label_value_y):
        px.scatter(self.LABELS[label].reset_index(), label_value_x, label_value_y, 
                   hover_name='label', text='topic_id', width=800, height=600)\
            .update_traces(mode='text').show()        

In [5]:
bag=['doc_id', 'doc_source']
LABELS = ['year']

In [6]:
#CORPUS.set_index(['doc_id', 'doc_source', 'sent_num', 'token_num'])

In [7]:
M = TopicExplorer(CORPUS, LIB, bag, LABELS).generate_tables()

BAG: doc_source
LABELS: ['year']
Getting DOCS
Getting TERMS
Getting THETA, PHI
Getting TOPICS
Binding LIB labels to THETA
Done.


In [13]:
TOPICS=M.TOPICS
TOPICS

term_str,0,1,2,3,4,5,6,label,doc_weight_sum
topic_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,deal,game,trial,points,coverage,law,stock,00 deal game trial points coverage law stock,2021.935317
1,man,police,authorities,death,woman,prison,officer,01 man police authorities death woman prison o...,4718.925643
2,government,border,state,budget,wall,emergency,night,02 government border state budget wall emergen...,2212.921888
3,news,media,investigation,articles,security,people,press,03 news media investigation articles security ...,2098.992193
4,story,link,column,advertise,second,headline,support,04 story link column advertise second headline...,2776.797631
5,years,program,climate,change,day,companies,immigrants,05 years program climate change day companies ...,2099.95219
6,campaign,president,tax,plan,candidate,reading,race,06 campaign president tax plan candidate readi...,2548.417608
7,school,students,student,college,earnings,schools,people,07 school students student college earnings sc...,2250.310906
8,case,sex,abuse,assault,years,allegations,traffic,08 case sex abuse assault years allegations tr...,2256.564663
9,officials,authorities,woman,residents,people,cases,girl,09 officials authorities woman residents peopl...,2275.60988


In [14]:
THETA=M.THETA
THETA

Unnamed: 0_level_0,topic_id,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
doc_id,doc_source,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
23,Google News,0.116042,0.003846,0.380782,0.215584,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.003846,0.226053
42,Google News,0.005000,0.005000,0.219991,0.579468,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000,0.115541,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000
52,PowerLine,0.103175,0.045346,0.000633,0.000633,0.000633,0.000633,0.000633,0.000633,0.000633,0.000633,0.462802,0.000633,0.000633,0.000633,0.335747,0.043436,0.000633,0.000633,0.000633,0.000633
81,Guardian,0.000321,0.000321,0.000321,0.000321,0.000321,0.680926,0.000321,0.000321,0.000321,0.000321,0.000321,0.078916,0.000321,0.022891,0.198078,0.000321,0.000321,0.014381,0.000321,0.000321
87,Guardian,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.000385,0.869871,0.108991,0.000385,0.000385,0.000385,0.014599
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1026234,US News,0.597850,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.252150,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333
1026260,US News,0.005000,0.005000,0.005000,0.005000,0.005000,0.375058,0.005000,0.005000,0.005000,0.534942,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000,0.005000
1026320,Fox,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143,0.864286,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143,0.007143
1026322,Fox,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.669326,0.008333,0.180674,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333,0.008333


In [15]:
PHI=M.PHI
PHI

term_str,100000,3pointers,49ers,500000,76ers,ab,abduction,ability,abortion,abortions,...,youll,young,youre,youth,youths,youve,zero,zone,zones,zoo
topic_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.05,0.05,2.322354,5.044,0.05,0.05,0.05,0.05,16.472382,12.270463,...,0.05,0.05,0.05,0.050001,0.05,0.423996,2.282459,0.05,0.05,0.05
1,8.671345,0.05,0.05,0.05,0.05,0.05,10.960599,0.0516,0.05,0.059887,...,0.05,0.05,0.05,0.690028,0.05,0.05,0.05,0.05,0.05,0.05
2,0.05,13.05,0.05,0.05,0.05,0.05,0.05,3.648752,0.05,0.440942,...,0.05,0.05,0.05,0.05,0.05,0.05,0.05,3.599531,0.651163,0.056195
3,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,...,0.05,0.05,2.706064,0.05,1.049346,0.05,0.456825,0.05,10.000257,0.05
4,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.214512,0.05,0.05,...,0.05,5.399635,0.05,0.593451,0.05,0.05,0.05,0.05,0.05,2.696106
5,0.05,0.05,9.948503,0.05,0.05,0.05,0.05,1.962257,0.051543,0.05,...,0.05,0.05,7.6477,4.854318,0.05,0.05,0.05,0.050001,0.05,0.05
6,0.05,0.05,0.05,0.05,0.05,0.05,0.05,8.11835,4.630611,0.05,...,0.05,4.501143,0.05,1.552383,0.057033,0.05,0.05,2.945503,0.05,0.05
7,5.40212,0.05,0.05,0.05,0.05,45.049984,0.05,0.05,23.225842,0.678571,...,0.502025,0.05,0.05,0.05,4.041828,0.05,2.077375,5.321335,0.05,0.05
8,0.05,0.05,0.05,0.05,0.05,0.050016,0.05,0.05,0.292887,0.05,...,0.05,0.05,2.329056,30.333053,0.05,2.124573,0.05,7.346649,0.05,0.05
9,0.05,0.05,0.05,0.05,0.05,0.05,1.139401,1.835544,0.05,0.05,...,0.05,0.05,0.05,7.268132,8.047336,0.053525,0.05,2.545459,0.05,0.05


In [17]:
def get_pca(TFIDF, 
            k=10, 
            norm_docs=True,
            norm_level=2,
            center_by_mean=True, 
            center_by_variance=False):
    
    # if TFIDF.isna().sum().sum():
    #     print("Filled NA")
    #     TFIDF = TFIDF.fillna(0)
    
    if norm_docs:
        # TFIDF = TFIDF.apply(lambda x: x / norm(x), 1).fillna(0)
        TFIDF = (TFIDF.T / norm(TFIDF, 2, axis=1)).T
    
    if center_by_mean:
        TFIDF = TFIDF - TFIDF.mean()
        
    if center_by_variance:
        TFIDF = TFIDF / TFIDF.std()        

    COV = TFIDF.cov()

    eig_vals, eig_vecs = eigh(COV)
    
    EIG_VEC = pd.DataFrame(eig_vecs, index=COV.index, columns=COV.index)
    EIG_VAL = pd.DataFrame(eig_vals, index=COV.index, columns=['eig_val'])
    EIG_VAL.index.name = 'term_str'
        
    EIG_IDX = EIG_VAL.eig_val.sort_values(ascending=False).head(k)
    
    COMPS = EIG_VEC[EIG_IDX.index].T
    COMPS.index = [i for i in range(COMPS.shape[0])]
    COMPS.index.name = 'pc_id'
    

    LOADINGS = COMPS.T

    DCM = TFIDF.dot(LOADINGS)
    
    COMPINF = pd.DataFrame(index=COMPS.index)

    for i in range(k):
        for j in [0, 1]:
            top_terms = ' '.join(LOADINGS.sort_values(i, ascending=bool(j)).head(5).index.to_list())
            COMPINF.loc[i, j] = top_terms
    COMPINF = COMPINF.rename(columns={0:'pos', 1:'neg'})
    
    COMPINF['eig_val'] = EIG_IDX.reset_index(drop=True).to_frame()
    COMPINF['exp_var'] = COMPINF.eig_val / COMPINF.eig_val.sum()
    
    return LOADINGS, DCM, COMPINF

In [18]:
center_by_mean=False
center_by_variance=False

In [19]:
LOADINGS, DCM, COMPINF = get_pca(PHI, 
                                 norm_docs=True, 
                                 norm_level=2, 
                                 center_by_mean=center_by_mean, 
                                 center_by_variance=center_by_variance)

In [25]:
vis1=px.scatter(DCM, 0, 1, 
           color=TOPICS.index.to_series(), 
           hover_name=TOPICS.index.to_series(),
           size=TOPICS.doc_weight_sum,
           marginal_x='box', marginal_y='box', height=1000)

In [26]:
vis1

In [27]:
TOPICS.to_csv('TOPICS.csv')
THETA.to_csv('THETA.csv')
PHI.to_csv('PHI.csv')

In [28]:
import kaleido

In [29]:
vis1.write_image('lda.png')