# SmartSifter

## ライブラリのインポート

In [1]:
import numpy as np
import numpy.random as rd
import os
import scipy.stats as st
import itertools

In [2]:
os.chdir("..")
home_dir = os.getcwd()
os.chdir("src")
from sdle import SDLE
from sdem import SDEM
from smartsifter import SmartSifter
os.chdir(home_dir)

## データの生成

In [3]:
seed = 0
#データ生成(離散値ベクトル)
rd.seed(seed)
T = 500 #データ数(観測数)
x_1 = np.array([np.argmax(i) for i in rd.multinomial(1, [1/6]*6,T)])
x_1 = np.array([np.argmax(i) for i in rd.multinomial(1, [1]*1,T)])
x = x_1.copy()
#x_2 = rd.binomial(1, 0.5, T)
#x = np.c_[x_1, x_2]

#データ生成(連続変数ベクトル)
n = [200, 150, 150] #各データ数
K = 3 #潜在変数の数
D = 2 #次元

#mu:D次元
mu_true = np.array(
    [[0.2, 0.5],
     [1.2, 0.5],
     [2.0, 0.5]])

#sigma: D×D次元
sigma_true = np.array(
    [[[0.1,  0.085], [0.085, 0.1]],
     [[0.1, -0.085], [-0.085, 0.1]],
     [[0.1,  0.085], [0.085, 0.1]]
    ])

rd.seed(seed)
org_data = None
for i in range(K):
    #k_0 に属するデータを生成
    if org_data is None:
        org_data = np.c_[st.multivariate_normal.rvs(mean=mu_true[i], cov=sigma_true[i], size=n[i]), np.ones(n[i])*i]
        
    #k_1, k_2に属するデータを生成し、結合する
    else:
        tmp_data = np.c_[st.multivariate_normal.rvs(mean=mu_true[i], cov=sigma_true[i], size=n[i]), np.ones(n[i])*i]
        org_data = np.r_[org_data, tmp_data]

#print(org_data)
y = org_data[:, :2]

## パラメータの設定

In [4]:
#パラメータの設定
#SDLE
A = list(itertools.product(set(x)))
#A = list(itertools.product(set(x_1), set(x_2)))
beta = 1 #正の定数

#SDEM
alpha = 1.0 #(1.0~2.0)
r = 1 / T * 5
k = K #潜在変数の数
d = D #次元
#SmartSifter
r_h = 0.1

In [5]:
ss = SmartSifter(r, beta, A, alpha, k, r_h)
ss.train(x, y) 

  s_l = -np.log(p_prev_params) #シャノン情報量を計算


In [6]:
for ix, iy, S_L in zip(x, y, ss.S_L):
    print(ix, iy, S_L)

0 [-0.37117021 -0.00186094] inf
0 [-0.29173858  0.39639551] 2.565752096387263
0 [-0.28336084 -0.15263034] 3.6299518848897225
0 [-0.07585019  0.19793397] 3.5227376843421006
0 [0.19583391 0.56695166] 3.436152096071763
0 [0.03024708 0.58213464] 3.75290662290217
0 [-0.04199795  0.27907678] 4.099310705493703
0 [0.03610722 0.39390131] 4.123708728566701
0 [-0.23663918  0.02782636] 4.697679072183912
0 [0.17875104 0.33081731] 4.287401160977097
0 [0.91985651 1.33306657] 4.255336791661482
0 [0.00136537 0.17281862] 4.677500252317472
0 [-0.36436716 -0.31627068] 5.5146175849841
0 [0.20229369 0.46987249] 4.547501463496073
0 [-0.3934268  0.1610736] 5.237090799959432
0 [0.12012475 0.48562442] 4.708391585424886
0 [0.6415515  0.59846749] 4.570966179837828
0 [0.29227313 0.61935357] 4.697698795010845
0 [-0.27830745  0.22995085] 5.280738234538637
0 [0.34398104 0.59162067] 4.774922347047759
0 [0.6418821  0.69592777] 4.744918362937192
0 [0.54999971 1.18788392] 4.885831440461769
0 [0.39294301 0.61706631] 4.873