In [304]:
#####確率的潜在意味解析(トピックモデル)#####
import numpy as np
import pandas as pd
import matplotlib.pyplot  as plt
import numpy.matlib
import itertools
from scipy.special import gammaln
from scipy.misc import factorial
from pandas.tools.plotting import scatter_matrix
from numpy.random import *
from scipy import optimize
from scipy import sparse
import seaborn as sns

In [305]:
##多項分布の乱数を生成する関数
def rmnom(pr, n, k, no):
    z_id = np.argmax((np.cumsum(pr, axis=1) > np.random.rand(n).reshape(n, 1)), axis=1)
    Z = sparse.coo_matrix((np.repeat(1, n), (no, np.array(z_id))), shape=(n, k))   #スパース行列の設定
    return Z

In [306]:
####データを発生####
#データの設定
k = 15   #トピック数
d = 5000   #文書数
v = 1250   #語彙数
w = np.random.poisson(np.random.gamma(100, 1/0.5, d), d)   #1文書あたりの単語数
f = sum(w)   #総単語数
k_vec = np.arange(k)
v_vec = np.arange(v)

In [307]:
##文書のIDとインデックスを設定
#IDを設定
d_id = np.repeat(np.arange(d), w)
no = np.array(list(itertools.chain(*[np.array(range(w[i]), dtype="int") for i in range(d)])))

#インデックスの設定
index = np.arange(f)
d_index = [i for i in range(d)]
d_vec = [i for i in range(d)]
for i in range(d):
    d_index[i] = index[d_id==i]
    d_vec[i] = np.repeat(1, w[i])

In [308]:
##応答変数を生成
#ディレクリ分布の事前分布の設定
alpha1 = np.repeat(0.1, k)   #文書のディレリ事前分布のパラメータ
alpha2 = np.repeat(0.05, v)   #単語のディレクリ事前分布のパラメータ

In [309]:
##すべての単語が出現するまでデータの生成を続ける
rp = 0
while True:
    rp = rp + 1
    print(rp)
    
    #ディリクレ分布からパラメータを生成
    theta = np.random.dirichlet(alpha1, d)
    phi = np.random.dirichlet(alpha2, k)

    #出現確率が低い単語を入れ替える
    index = np.array(range(v))[np.max(phi, axis=0) <= (np.power(k, 2))/f]
    for j in range(index.shape[0]):
        phi[np.argmax(np.random.multinomial(1, np.repeat(1/k, k), 1)), index[j]] = (np.power(k, 2))/f
    thetat = theta.copy(); phit = phi.copy()

    #トピックを生成
    Z = np.array(rmnom(theta[d_id, ], f, k, np.arange(f)).todense())
    z_vec = np.dot(Z, np.arange(k))

    #トピックから単語を生成
    WX = np.zeros((d, v))
    wd = np.repeat(0, f)
    for i in range(d):
        word = np.array(rmnom(phi[z_vec[d_index[i]], ], w[i], v, np.arange(w[i])).todense())
        WX[i, ] = np.sum(word, axis=0)
        wd[d_index[i]] = np.dot(word, v_vec)

    #break条件
    if np.min(np.sum(WX, axis=0)) > 0:
        break

1
2
3


In [310]:
#インデックスを作成
index = np.array(range(f))
w_index = [i for i in range(v)]
w_vec = [i for i in range(v)]
word_n = np.repeat(0, v)
for i in range(v):
    w_index[i] = index[wd==i]
    w_vec[i] = np.repeat(1, w_index[i].shape[0])
    word_n[i] = np.sum(w_vec[i])

In [311]:
####EMアルゴリズムでPLSAを推定####
#トピック尤度と負担率を計算する関数
def loglike(theta, phi, d_id, wd, f, k):
    Lho = theta[d_id, ] * (phi.T)[wd, ]
    topic_rate = Lho / np.sum(Lho, axis=1).reshape(f, 1)
    return Lho, topic_rate

In [312]:
#初期値をランダム化
theta = np.random.dirichlet(np.repeat(1.0, k), d)
phi = np.random.dirichlet(np.repeat(1.0, v), k)

In [313]:
##対数尤度の基準値
#初期値での対数尤度
LLi = loglike(theta, phi, d_id, wd, f, k)[0]
LL = np.sum(np.log(np.dot(LLi, np.repeat(1, k))))   

#ユニグラムモデルの対数尤度
par = np.sum(WX, axis=0) / f
LLst = np.sum(np.dot(WX, np.log(par)))

#真値での対数尤度
LLi = loglike(thetat, phit, d_id, wd, f, k)[0]
LLbest = np.sum(np.log(np.dot(LLi, np.repeat(1, k))))
print(np.round(np.array([LL, LLst, LLbest]), 1))

[-7208471.5 -6565466.3 -5503211.6]


In [314]:
#更新ステータス
iter = 1
dl = 100   #EMステップでの対数尤度の差の初期値
tol = 1.0   #EMアルゴリズムの収束判定
LL = -1000000000   #対数尤度の初期値

In [315]:
##EMアルゴリズムでパラメータを更新
while abs(dl) >= tol:
    
    ##Eステップで負担率と尤度の更新
    #負担率と尤度を更新
    Topic_par = loglike(theta, phi, d_id, wd, f, k)
    Zi = Topic_par[1]
    Zi_T = Zi.T
    
    ##Mステップでパラメータを最適化
    #トピック分布を更新
    dsums = np.zeros((d, k))
    for i in range(d):
        dsums[i, ] = np.dot(Zi_T[:, d_index[i]], d_vec[i])
    theta = dsums / w.reshape(d, 1)

    #単語分布を更新
    vsums = np.zeros((k, v))
    for j in range(v):
        vsums[:, j] = np.dot(Zi_T[:, w_index[j]], w_vec[j])
    phi = vsums / np.sum(vsums, axis=1).reshape(k, 1)

    
    ##アルゴリズムの更新
    #対数尤度の更新
    LLi = Topic_par[0]
    LL1 = np.sum(np.log(np.dot(LLi, np.repeat(1, k))))  

    #収束判定
    iter = iter + 1
    dl = LL1 - LL
    LL = LL1
    print(np.round(np.array([LL, LLst, LLbest]), 3))

[-7208471.457 -6565466.327 -5503211.649]
[-6574985.765 -6565466.327 -5503211.649]
[-6561916.748 -6565466.327 -5503211.649]
[-6546644.727 -6565466.327 -5503211.649]
[-6523875.437 -6565466.327 -5503211.649]
[-6487212.167 -6565466.327 -5503211.649]
[-6429943.332 -6565466.327 -5503211.649]
[-6349585.915 -6565466.327 -5503211.649]
[-6253221.972 -6565466.327 -5503211.649]
[-6155265.724 -6565466.327 -5503211.649]
[-6067387.304 -6565466.327 -5503211.649]
[-5992823.45  -6565466.327 -5503211.649]
[-5930647.069 -6565466.327 -5503211.649]
[-5879152.451 -6565466.327 -5503211.649]
[-5835985.756 -6565466.327 -5503211.649]
[-5799107.692 -6565466.327 -5503211.649]
[-5766087.989 -6565466.327 -5503211.649]
[-5734097.515 -6565466.327 -5503211.649]
[-5702167.992 -6565466.327 -5503211.649]
[-5672322.648 -6565466.327 -5503211.649]
[-5647127.175 -6565466.327 -5503211.649]
[-5626615.612 -6565466.327 -5503211.649]
[-5608876.065 -6565466.327 -5503211.649]
[-5592639.802 -6565466.327 -5503211.649]
[-5578311.402 -6

In [316]:
####推定結果と統計量####
np.round(pd.DataFrame(np.concatenate((phi.T, phit.T), axis=1)), 3)
np.round(pd.DataFrame(np.concatenate((theta, thetat), axis=1)), 3)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,0.006,0.676,0.007,0.000,0.000,0.000,0.000,0.005,0.000,0.000,...,0.000,0.653,0.329,0.000,0.004,0.000,0.000,0.000,0.000,0.000
1,0.000,0.000,0.000,0.197,0.009,0.000,0.000,0.000,0.000,0.404,...,0.000,0.000,0.000,0.204,0.004,0.059,0.000,0.394,0.012,0.024
2,0.000,0.000,0.021,0.009,0.000,0.242,0.000,0.000,0.557,0.087,...,0.013,0.000,0.000,0.000,0.533,0.000,0.000,0.088,0.000,0.000
3,0.006,0.211,0.119,0.006,0.000,0.055,0.012,0.024,0.008,0.000,...,0.014,0.170,0.004,0.000,0.005,0.000,0.021,0.000,0.545,0.000
4,0.005,0.013,0.000,0.738,0.108,0.000,0.000,0.000,0.014,0.000,...,0.000,0.001,0.154,0.706,0.010,0.000,0.005,0.000,0.000,0.108
5,0.118,0.007,0.120,0.000,0.185,0.000,0.170,0.147,0.000,0.000,...,0.175,0.000,0.003,0.000,0.000,0.028,0.152,0.000,0.000,0.186
6,0.425,0.000,0.000,0.000,0.000,0.000,0.000,0.007,0.000,0.062,...,0.000,0.000,0.000,0.000,0.000,0.222,0.000,0.100,0.247,0.000
7,0.000,0.100,0.000,0.000,0.137,0.000,0.000,0.000,0.001,0.530,...,0.000,0.077,0.000,0.000,0.001,0.000,0.006,0.512,0.006,0.151
8,0.000,0.000,0.000,0.211,0.016,0.000,0.000,0.000,0.000,0.003,...,0.000,0.028,0.040,0.199,0.000,0.000,0.000,0.000,0.730,0.000
9,0.016,0.000,0.048,0.080,0.239,0.001,0.145,0.016,0.000,0.322,...,0.047,0.001,0.000,0.087,0.000,0.001,0.141,0.328,0.001,0.237
