In [130]:
####混合ガウス過程回帰モデル
import numpy as np
import pandas as pd
import matplotlib.pyplot  as plt
import numpy.matlib
import scipy.linalg
import itertools
import seaborn as sns
import gc
from scipy import sparse
from scipy.stats import norm
from pandas.tools.plotting import scatter_matrix
from numpy.random import *
from scipy import optimize

#np.random.seed(98537)

In [131]:
##多項分布の乱数を生成する関数
def rmnom(pr, n, k, no, pattern):
    z_id = np.argmax((np.cumsum(pr, axis=1) >= np.random.uniform(0, 1, n)[:, np.newaxis]), axis=1)
    return z_id
    if pattern==1:
        Z = sparse.coo_matrix((np.repeat(1, n), (no, np.array(z_id))), shape=(n, k))   #スパース行列の設定
        return z_id, Z

In [164]:
####データの生成####
##データとセグメントの設定
#データの設定
seg = 8
T = 3000   
k = 100
Lambda = np.random.gamma(15.0, 1/0.5, T)
w = np.random.poisson(Lambda, T)
w[w <= 3] = 3

#セグメントの生成
pi = np.random.dirichlet(np.repeat(5.0, seg), 1).reshape(-1)
Z = np.random.multinomial(1, pi, T)
z = np.dot(Z, np.arange(seg))
index_z = [np.array(np.where(z==j)[0], dtype="int") for j in range(seg)] 
pit = pi.copy()

In [165]:
##ガウス過程からデータの生成
#多項分布から入力変数を生成
alpha = np.repeat(0.15, k)
theta = np.random.dirichlet(alpha, seg)
data = np.zeros((T, k))
for i in range(T):
    data[i, ] = np.random.multinomial(w[i], theta[z[i], ], 1) / w[i]
thetat = theta.copy()

#カーネル関数の生成
n = np.repeat(0, seg)
K = [j for j in range(seg)]
for j in range(seg):
    n[j] = index_z[j].shape[0]
    K[j] = np.dot(data[index_z[j], ], data[index_z[j], ].T)

In [166]:
#モデルパラメータと応答変数を生成
y = np.repeat(0.0, T)
Sigma = np.power(0.1, 2)
beta = np.random.normal(0, 0.75, seg)
for j in range(seg):
    Sigma_diag = np.diag(np.repeat(Sigma, n[j]))
    y[index_z[j]] = beta[j] + np.random.multivariate_normal(np.repeat(0, n[j]), K[j] + Sigma_diag, 1).reshape(-1)
betat = beta.copy()

In [167]:
####MCMC-EMアルゴリズムで混合ガウス過程回帰モデルを推定####
##パラメータ推定のための関数を定義
#多変量正規分布の条件付き期待値と分散を計算する関数
def cdMVN(mu, Cov, department, U):
    #分散共分散行列のブロック行列を定義
    department = np.array([department])
    index = np.delete(np.arange(Cov.shape[0]), department)
    Cov11 = Cov[department, ][:, department]
    Cov12 = Cov[department, ][:, index]
    Cov21 = Cov[:, department][index, ]
    Cov22 = Cov[index, ][:, index]

    #条件付き分散と条件付き平均を計算
    inv_Cov22 = np.linalg.inv(Cov22)
    CDinv = np.dot(Cov12, inv_Cov22)
    CDmu = mu[:, department] + np.dot(CDinv, (U[:, index] - mu[:, index]).T).T   #条件付き平均
    CDvar = Cov11 - np.dot(np.dot(Cov12, inv_Cov22), Cov21)   #条件付き分散
    return CDmu, CDvar

#多変量正規分布の密度関数
def mvdnorm(u, mu, Cov, k):
    er = U - mu 
    inv_Cov = np.linalg.inv(Cov)
    det_Cov = np.linalg.det(Cov)
    Lho = 1 / (np.power(np.sqrt(2*np.pi), k)*det_Cov) * np.exp(-1/2 * np.sum(np.dot(er, inv_Cov) * er, axis=1))
    return Lho 

In [168]:
##アルゴリズムの設定
R = 2000
keep = 2
burnin = int(500/keep)
iter = 0
disp = 10

In [222]:
##事前分布の設定
#潜在変数の事前分布
alpha = 0.1

#モデルパラメータの事前分布
gamma = 0.0
tau = 100
eta = np.power(0.1, 2)
Cov_diag = np.sum(data * data, axis=1) + eta

In [234]:
##パラメータの真値
#潜在変数の真値
pi = pit.copy()
Zi = Z.copy()
z = np.dot(Zi, np.arange(seg))
n = np.sum(Zi, axis=0)
index_z = [j for j in range(seg)]
for j in range(seg):
    index_z[j] = np.hstack((np.array(np.where(z==j)[0], dtype="int")[:, np.newaxis], np.arange(np.sum(Zi[:, j]))[:, np.newaxis]))
    
#モデルパラメータの真値
theta = thetat.copy()
beta = betat.copy()
K = [j for j in range(seg)]
for j in range(seg):
    K[j] = np.dot(data[index_z[j][:, 0], ], data[index_z[j][:, 0], ].T)

In [235]:
####MCMC-EMアルゴリズムでパラメータをサンプリング####
##セグメントパターン別の条件付き分布を定義
#パラメータの格納用配列
CDmu = np.zeros((T, seg))
CDvar = np.zeros((T, seg))

#条件付き分布を計算
for j in range(seg):
    index = index_z[j][:, 0]
    x = np.dot(data, data[index, ].T)
    u = y[index, ]
    inv_K = np.linalg.inv(K[j] + np.diag(np.repeat(eta, n[j])))
    CDmu[:, j] = np.dot(np.dot(x, inv_K), u)
    CDvar[:, j] = Cov_diag - np.sum(np.dot(x, inv_K) * x, axis=1)

##自身のレコードを除去した条件付き分布を定義
for i in range(seg):
    #インデックスを抽出
    index_n = np.arange(n[i])
    index_z1 = index_z[i][:, 0]; index_z2 = index_z[i][:, 1]
    
    for j in range(index_n.shape[0]):
        #データの定義
        allocation = np.delete(index_n, index_z2[j])
        eta_diag = np.diag(np.repeat(eta, allocation.shape[0]))
        
        #条件付き分布と条件付き分布を計算
        x = np.dot(data[index_z1[j], ], data[index_z1[allocation], ].T)
        u = y[index_z1[allocation]]
        inv_K = np.linalg.inv(K[i][allocation, ][:, allocation] + eta_diag)
        CDmu[index_z1[j], i] = np.dot(np.dot(x, inv_K), u)
        CDvar[index_z1[j], i] = K[i][j, j] - np.dot(np.dot(x, inv_K), x)

In [171]:
##潜在変数zをサンプリング
#セグメントパターン別の尤度
Lho1 = scipy.stats.norm.pdf(y[:, np.newaxis], CDmu, np.sqrt(CDvar))


-0.19579003968492964

array([[6.54356256e-05, 3.32602545e-01, 1.13643091e+00, ...,
        1.00854310e+00, 4.49099993e-02, 1.03912071e+00],
       [7.96804353e-01, 3.39607502e-04, 2.75780880e-06, ...,
        4.12287132e-05, 1.23083886e+00, 1.12654646e-04],
       [1.58957377e-04, 6.68108513e-09, 4.23840109e-07, ...,
        3.54560728e-17, 1.63384713e-06, 3.99124158e-12],
       ...,
       [5.89839531e-03, 7.29562995e-04, 4.02990126e-05, ...,
        2.70412820e-09, 1.98329043e-02, 3.39205366e-05],
       [9.87154520e-01, 7.10451102e-01, 5.35291758e-04, ...,
        1.54259323e-02, 1.34681525e+00, 3.30759994e-01],
       [1.32752918e-07, 5.53336615e-07, 2.41753035e-06, ...,
        1.66395962e-06, 6.92529828e-18, 2.97537822e+00]])

In [241]:
scipy.stats.norm.pdf(0, loc=0, scale=1)
y

array([ 0.93711621, -0.79888588, -1.41815118, ..., -0.97232973,
       -0.11087972,  1.54516553])

In [232]:
j = 1
np.hstack((y[index_z[j][:, 0]][:, np.newaxis], CDmu[index_z[j][:, 0], j][:, np.newaxis]))

array([[ 1.34325472,  1.32182881],
       [ 0.49283635,  0.43861455],
       [ 0.1859264 ,  0.28605758],
       [ 0.82556473,  0.83349161],
       [ 0.73062403,  0.81371724],
       [ 1.42386141,  1.32136164],
       [ 1.01097189,  0.90302731],
       [ 0.63039932,  0.77203827],
       [ 0.70218037,  0.87769371],
       [ 0.39763304,  0.50195265],
       [ 0.5104942 ,  0.62313279],
       [ 1.21193378,  1.11739648],
       [ 1.2203697 ,  1.18464079],
       [ 0.56488297,  0.60663932],
       [ 0.71925162,  0.648289  ],
       [ 0.40265107,  0.43644956],
       [ 0.7342942 ,  0.66341066],
       [ 0.66145904,  0.71561101],
       [ 0.60937762,  0.53714922],
       [ 0.58906413,  0.51274889],
       [ 0.33361267,  0.29030196],
       [ 0.98269863,  1.12030051],
       [ 0.46562682,  0.38768257],
       [ 1.01926629,  0.81528923],
       [ 0.64933548,  0.81634367],
       [ 0.56359558,  0.46823292],
       [ 0.98271355,  0.94047009],
       [ 1.07076942,  0.95444811],
       [ 0.66518596,

In [229]:
j = 0


array([-0.46969735, -0.52691992, -0.2325083 , -0.67771555, -0.44026041,
       -0.57990909, -0.50312682, -0.38914443, -0.24415096, -0.45875874,
       -0.43833882, -0.35335872, -0.47326256, -0.81806726, -0.51826173,
       -0.34841057, -0.56466711, -0.44183995, -0.28761193, -0.34442343,
       -0.20014448, -0.5410723 , -0.53269233, -0.54465324, -0.47465017,
       -0.17677848, -0.38289298, -0.41256506, -0.61491328, -0.40376421,
       -0.52514863, -0.21506853, -0.45988499, -0.40577184, -0.31394998,
       -0.02981596, -0.48440656, -0.17278468, -0.43143887, -0.24038073,
       -0.56671128, -0.3922602 , -0.41395013, -0.26707641, -0.67117818,
       -0.37478391, -0.43651074, -0.50483399, -0.40170089, -0.6158564 ,
       -0.1833619 , -0.12234696, -0.3223337 , -0.07783484, -0.39973833,
       -0.51816653, -0.10938504, -0.48006523, -0.22244704, -0.39472933,
       -0.36221712, -0.32813527, -0.45037302, -0.40372626, -0.53695605,
       -0.53489298, -0.649021  , -0.44782266, -0.23858261, -0.23

[array([[  21,    0],
        [  27,    1],
        [  34,    2],
        [  46,    3],
        [  66,    4],
        [  80,    5],
        [  98,    6],
        [ 105,    7],
        [ 107,    8],
        [ 128,    9],
        [ 162,   10],
        [ 176,   11],
        [ 187,   12],
        [ 188,   13],
        [ 229,   14],
        [ 231,   15],
        [ 241,   16],
        [ 257,   17],
        [ 273,   18],
        [ 280,   19],
        [ 301,   20],
        [ 304,   21],
        [ 323,   22],
        [ 357,   23],
        [ 365,   24],
        [ 367,   25],
        [ 370,   26],
        [ 375,   27],
        [ 387,   28],
        [ 405,   29],
        [ 409,   30],
        [ 424,   31],
        [ 426,   32],
        [ 435,   33],
        [ 446,   34],
        [ 459,   35],
        [ 472,   36],
        [ 517,   37],
        [ 518,   38],
        [ 533,   39],
        [ 543,   40],
        [ 552,   41],
        [ 555,   42],
        [ 576,   43],
        [ 595,   44],
        [ 