In [38]:
#####Mixed membership block model#####
import numpy as np
import pandas as pd
import matplotlib.pyplot  as plt
import numpy.matlib
import gensim
import itertools
from numpy.random import *
from scipy import optimize
from scipy import sparse
import seaborn as sns

In [39]:
##多項分布の乱数を生成する関数
def rmnom(pr, n, k, no):
    z_id = np.argmax((np.cumsum(pr, axis=1) > np.random.rand(n).reshape(n, 1)), axis=1)
    Z = sparse.coo_matrix((np.repeat(1, n), (no, np.array(z_id))), shape=(n, k))   #スパース行列の設定
    return Z

In [40]:
####データの生成####
##データの設定
k = 10   #混合数
d = 3000   #ノード数
hhpt = int(d*(d-1)/2)
vec_k = np.repeat(1, k)

In [41]:
##ノード有無を生成
#すべてのノードの組み合わせを作成
d1 = np.repeat(range(d), d).reshape(d, d)
d2 = np.repeat(range(d), d).reshape(d, d, order="F")
id_list1 = [i for i in range(d-1)]
id_list2 = [i for i in range(d-1)]
for i in range(d-1):
    id_list1[i] = d1[i, np.arange(i+1, d)]
    id_list2[i] = d2[i, np.arange(i+1, d)]

#リストをベクトルに変換
d_id01 = np.array(list(itertools.chain(*id_list1)))
d_id02 = np.array(list(itertools.chain(*id_list2)))

#インデックスの設定
d_list = [i for i in range(d)]
for i in range(d):
    d_list[i] = np.array(np.where((d_id01==i) | (d_id02==i))[0], "int")

In [42]:
#ベルヌーイ分布からノードを作成
S = np.ones((hhpt), dtype="int")
for i in range(d):
    prob = np.random.beta(15.0, 35.0, 1)
    S[d_list[i]] = S[d_list[i]] * np.random.binomial(1, prob, d-1)
    
#ノードありの部分のみ抽出
index_s = np.array(np.where(S==1)[0], dtype="int")
d_id1 = d_id01[index_s]
d_id2 = d_id02[index_s]
hhpt = d_id1.shape[0]   #総ノード数

#インデックスの設定
n = np.repeat(0, d)
d_list = [i for i in range(d)]
for i in range(d):
    d_list[i] = np.array(np.where((d_id1==i) | (d_id2==i))[0], "int")
    n[i] = d_list[i].shape[0]
    
#ノードごとのスパース行列を作成
d_dt1 = sparse.coo_matrix((np.repeat(1, hhpt), (d_id1, range(hhpt))), shape=(d, hhpt)).tocsr()
d_dt2 = sparse.coo_matrix((np.repeat(1, hhpt), (d_id2, range(hhpt))), shape=(d, hhpt)).tocsr() 

In [43]:
##パラメータを生成
#事前分布を設定
alpha1 = np.repeat(0.15, k)
alpha21 = 0.5 * np.ones((k, k))
alpha22 = 0.8 * np.ones((k, k))
for i in range(k):
    alpha21[i, i] = 1.25
    alpha22[i, i] = 1.75
    
#モデルパラメータを生成
theta = np.random.dirichlet(alpha1, d)
phi = np.random.beta(alpha21, alpha22)
thetat = theta; phit = phi

In [44]:
##応答変数を生成
#トピックを生成
Z1 = np.array(rmnom(theta[d_id1, ], hhpt, k, np.arange(hhpt)).todense())
Z2 = np.array(rmnom(theta[d_id2, ], hhpt, k, np.arange(hhpt)).todense())
z1_vec = np.dot(Z1, np.arange(k))
z2_vec = np.dot(Z2, np.arange(k))

#ノードを生成
y1 = np.random.binomial(1, np.dot(phi[z1_vec, ] * Z2, np.repeat(1, k)), hhpt)
y2 = np.random.binomial(1, np.dot(phi[z2_vec, ] * Z1, np.repeat(1, k)), hhpt)
y_vec1 = y1.reshape(hhpt, 1)
y_vec2 = y2.reshape(hhpt, 1)
print([np.mean(y1), np.mean(y2)] )

[0.30905378822500845, 0.30873089991455877]


In [45]:
####ギブスサンプリングでMixed membership block modelを推定####
##アルゴリズムの設定
R = 3000
keep = 2
burnin = int(1500/keep)
iter = 0
disp = 10

In [46]:
##事前分布の設定
alpha = 0.1
s0 = 0.5
v0 = 0.5

In [47]:
##パラメータの真値
#モデルパラメータの真値
theta = thetat
phi = phit

#トピックの初期値
Zi1 = Z1
Zi2 = Z2
z1_vec = np.dot(Zi1, np.arange(k))
z2_vec = np.dot(Zi2, np.arange(k))

In [48]:
##初期値の設定
#モデルパラメータの初期値
theta = np.random.dirichlet(np.repeat(5.0, k), d)
phi = np.random.beta(2.5, 5.0, k*k).reshape(k, k)

#トピックの初期値
Zi1 = np.random.multinomial(1, np.repeat(1/k, k), hhpt)
Zi2 = np.random.multinomial(1, np.repeat(1/k, k), hhpt)
z1_vec = np.dot(Zi1, np.arange(k))
z2_vec = np.dot(Zi2, np.arange(k))

In [49]:
##パラメータの格納用配列
THETA = np.zeros((d, k, int(R/keep)))
PHI = np.zeros((k, k, int(R/keep)))
SEG1 = np.zeros((hhpt, k))
SEG2 = np.zeros((hhpt, k))

In [50]:
##対数尤度の基準値
#1パラメータモデルの対数尤度
LLst = np.sum(y1*np.log(np.mean(y1)) + (1-y1)*np.log(1-np.mean(y1)) + y2*np.log(np.mean(y2)) + (1-y2)*np.log(1-np.mean(y2)))

#真値での対数尤度
phi1 = np.sum(phit[np.dot(Z1, np.arange(k)), ] * Z2, axis=1)
phi2 = np.sum(phit[np.dot(Z2, np.arange(k)), ] * Z1, axis=1)
LLbest = np.sum(y1*np.log(phi1) + (1-y1)*np.log(1-phi1) + y2*np.log(phi2) + (1-y2)*np.log(1-phi2))

In [51]:
####ギブスサンプリングでパラメータをサンプリング
for rp in range(R):

    ##1つ目の潜在変数をサンプリング
    #潜在変数の割当確率
    phi11 = phi.T[z2_vec, ]
    phi12 =  phi[z2_vec, ]
    LLho = y_vec1*np.log(phi11) + (1-y_vec1)*np.log(1-phi11) + y_vec2*np.log(phi12) + (1-y_vec2)*np.log(1-phi12)   #対数尤度
    Lho = theta[d_id1,] * np.exp(LLho - np.max(LLho, axis=1).reshape(hhpt, 1))
    z_rate = Lho / np.dot(Lho, vec_k).reshape(hhpt, 1)

    #多項分布から潜在変数をサンプリング
    sparse_z1 = rmnom(z_rate, hhpt, k, np.arange(hhpt)).tocsr()
    Zi1 = np.array(sparse_z1.todense())
    z1_vec = np.dot(Zi1, np.arange(k))


    ##2つ目の潜在変数をサンプリング
    #潜在変数の割当確率
    phi21 = phi[z1_vec, ]
    phi22 =  phi.T[z1_vec, ]
    LLho = y_vec1*np.log(phi21) + (1-y_vec1)*np.log(1-phi21) + y_vec2*np.log(phi22) + (1-y_vec2)*np.log(1-phi22)   #対数尤度
    Lho = theta[d_id2,] * np.exp(LLho - np.max(LLho, axis=1).reshape(hhpt, 1))
    z_rate = Lho / np.dot(Lho, vec_k).reshape(hhpt, 1)

    #多項分布から潜在変数をサンプリング
    sparse_z2 = rmnom(z_rate, hhpt, k, np.arange(hhpt)).tocsr()
    Zi2 = np.array(sparse_z2.todense())
    z2_vec = np.dot(Zi2, np.arange(k))


    ##モデルパラメータをサンプリング
    #トピック分布をサンプリング
    wsum = np.array(np.dot(d_dt1, sparse_z1).todense()) + np.array(np.dot(d_dt2, sparse_z2).todense()) + alpha
    theta = np.zeros((d, k))
    for i in range(d):
        theta[i, ] = np.random.dirichlet(wsum[i, ], 1)

    #リンク確率をサンプリング
    phi = np.zeros((k, k))
    for j in range(k):
        Zi12 = Zi1[:, j].reshape(hhpt, 1) * Zi2; Zi21 = Zi2[:, j].reshape(hhpt, 1) * Zi1
        s = np.sum(y_vec1 * Zi12, axis=0) + np.sum(y_vec2 * Zi21, axis=0) + s0
        v = np.sum((1-y_vec1) * Zi12, axis=0) + np.sum((1-y_vec2) * Zi21, axis=0) + v0
        phi[j, ] = np.random.beta(s, v, k)


    ##パラメータの格納とサンプリング結果の表示
    #サンプリング結果の格納
    if rp%keep==0:
        mkeep = rp//keep
        THETA[:, :, mkeep] = theta
        PHI[:, :, mkeep] = phi
 
    #トピック割当はバーンイン期間を超えたら格納
    if rp%keep==0 & rp >= burnin:
        SEG1 = SEG1 + Zi1
        SEG2 = SEG2 + Zi2

    if rp%disp==0:
        #対数尤度の更新
        phi1 = np.sum(phi[z1_vec, ] * Zi2, axis=1)
        phi2 = np.sum(phi[z2_vec, ] * Zi1, axis=1)
        LL = np.sum(y1*np.log(phi1) + (1-y1)*np.log(1-phi1) + y2*np.log(phi2) + (1-y2)*np.log(1-phi2))

        #サンプリング結果を確認
        print(rp)
        print(np.round(np.array([LL, LLst, LLbest]), 2))

0
[-444064.77 -497803.66 -307932.8 ]
10
[-442574.39 -497803.66 -307932.8 ]
20
[-428170.88 -497803.66 -307932.8 ]
30
[-418519.4  -497803.66 -307932.8 ]
40
[-416234.43 -497803.66 -307932.8 ]
50
[-418483.02 -497803.66 -307932.8 ]
60
[-421223.5  -497803.66 -307932.8 ]
70
[-422103.16 -497803.66 -307932.8 ]
80
[-422894.97 -497803.66 -307932.8 ]
90
[-424255.21 -497803.66 -307932.8 ]
100
[-424325.62 -497803.66 -307932.8 ]
110
[-422962.52 -497803.66 -307932.8 ]
120
[-422791.28 -497803.66 -307932.8 ]
130
[-423322.33 -497803.66 -307932.8 ]
140
[-421828.04 -497803.66 -307932.8 ]
150
[-421320.39 -497803.66 -307932.8 ]
160
[-422332.89 -497803.66 -307932.8 ]
170
[-421519.42 -497803.66 -307932.8 ]
180
[-421115.3  -497803.66 -307932.8 ]
190
[-420396.58 -497803.66 -307932.8 ]
200
[-418981.49 -497803.66 -307932.8 ]
210
[-420138.25 -497803.66 -307932.8 ]
220
[-419531.64 -497803.66 -307932.8 ]
230
[-418209.61 -497803.66 -307932.8 ]
240
[-418243.27 -497803.66 -307932.8 ]
250
[-417395.47 -497803.66 -307932.8

2080
[-350736.23 -497803.66 -307932.8 ]
2090
[-352710.62 -497803.66 -307932.8 ]
2100
[-350918.13 -497803.66 -307932.8 ]
2110
[-350672.78 -497803.66 -307932.8 ]
2120
[-350927.13 -497803.66 -307932.8 ]
2130
[-351662.23 -497803.66 -307932.8 ]
2140
[-352438.74 -497803.66 -307932.8 ]
2150
[-350748.51 -497803.66 -307932.8 ]
2160
[-352470.09 -497803.66 -307932.8 ]
2170
[-351706.75 -497803.66 -307932.8 ]
2180
[-352425.18 -497803.66 -307932.8 ]
2190
[-351084.67 -497803.66 -307932.8 ]
2200
[-352263.82 -497803.66 -307932.8 ]
2210
[-351550.95 -497803.66 -307932.8 ]
2220
[-350737.36 -497803.66 -307932.8 ]
2230
[-349990.   -497803.66 -307932.8 ]
2240
[-349922.14 -497803.66 -307932.8 ]
2250
[-350526.45 -497803.66 -307932.8 ]
2260
[-349024.28 -497803.66 -307932.8 ]
2270
[-349624.28 -497803.66 -307932.8 ]
2280
[-349105.27 -497803.66 -307932.8 ]
2290
[-351197.46 -497803.66 -307932.8 ]
2300
[-351413.52 -497803.66 -307932.8 ]
2310
[-350936.12 -497803.66 -307932.8 ]
2320
[-352353.45 -497803.66 -307932.8 ]
