In [13]:
%matplotlib inline
import numpy as np
from scipy.stats import multivariate_normal, norm
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc
rc('text', usetex=True)

import os
import pickle

from copy import deepcopy

import tqdm

from sklearn.datasets import make_spd_matrix

In [2]:
outdir = './output/synthetic/univariate/abrupt_mixture/'
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [3]:
class Theta:
    def __init__(self, pi, mu, sigma):
        self.pi = pi
        self.mu = mu
        self.sigma = sigma

In [4]:
def compute_tau(X, theta):
    N = X.shape[0]
    K = len(theta.pi)

    tau = np.zeros((N, K))
    for k in range(K):
        d_k = norm(theta.mu[k], theta.sigma[k])
        for i in range(N):
            tau[i, k] = theta.pi[k] * d_k.pdf(X[i])
    
    tau /= np.sum(tau, axis=1, keepdims=True)
    
    return tau

In [5]:
def compute_stat(X, Z):
    K = Z.shape[0]
    D = X.shape[1]

    s1 = np.sum(Z, axis=0)
    s2 = X.T.dot(Z).ravel()
    s3 = (X**2).T.dot(Z).ravel()
    
    return s1, s2, s3

In [6]:
def step_M(s, x, eps=1e-8):
    pi = (s.s1 + eps) / (1.0 + eps * len(s.s1))
    mu = s.s2 / (s.s1 + eps)
    sigma = np.sqrt(s.s3 / (s.s1 + eps) - mu**2)
    return pi, mu, sigma

In [7]:
class SuffStat:
    def __init__(self):
        pass

In [8]:
class SuffStatGMM(SuffStat):
    def __init__(self, K, D, seed=0):
        super().__init__()
        np.random.seed(seed)
        self.s1 = np.random.random(K)
        self.s2 = np.random.random(K)
        self.s3 = np.random.random(K)

In [9]:
def sra(X, theta0, rho, gamma):
    K = len(theta0.mu)
    N = X.shape[0]
    
    theta_est = np.zeros((N, 3*K+1))
    theta = deepcopy(theta0)
    
    s = SuffStatGMM(len(theta0.pi), 1)
    
    # initialization
    tau = compute_tau(X[:10, :], theta)
    s.s1[:], s.s2[:], s.s3[:] = compute_stat(X[:10, :], tau)
    
    inds = np.arange(N)
    
    for n in range(N):
        ind_n = inds[n]
        
        tau_indiv_new = compute_tau(X[ind_n, :], theta).ravel()
        s2_indiv_new = X[ind_n, :] * tau_indiv_new
        s3_indiv_new = X[ind_n, :]**2 * tau_indiv_new

        if ( np.any(np.isnan(tau_indiv_new)) | (np.linalg.norm(s2_indiv_new - s.s2) >= gamma) ):
            theta_est[n, :] = np.hstack((n, theta.pi, theta.mu, theta.sigma))
            continue
            
        # update the statistics
        s.s1 += rho[n] * (tau_indiv_new - s.s1)
        s.s2 += rho[n] * (s2_indiv_new - s.s2)
        s.s3 += rho[n] * (s3_indiv_new - s.s3)
                
        # M-step
        theta.pi[:], theta.mu[:], theta.sigma[:] = step_M(s, X[ind_n, :])
        theta_est[n, :] = np.hstack((n, theta.pi, theta.mu, theta.sigma))

    return theta_est

In [10]:
def generate_data(N, pi, mu, sigma):
    
    z_list = np.random.choice(np.arange(len(pi)), N, replace=True)
    
    X = np.zeros((N, 1))
    for i, z in enumerate(z_list):
        X_i = np.random.normal(mu[z], sigma[z])
        X[i, :] = X_i
    
    return X

In [11]:
pi = np.array([0.2, 0.8])
pi0 = pi.copy()

sigma = np.array([0.1, 0.1])
sigma0 = np.array([0.2, 0.2])

mu0 = [0.1, -0.1]

theta0 = Theta(pi0, mu0, sigma0)

seed0 = 1
seed = seed0

In [12]:
gamma = 3.0
rho = 0.0116
ratio = 0.1
alpha_orig = 0.01
M = 5
alpha_list = [0.01, 0.05, 0.1]
u = 20
n_trial = 10
N = 10000

pi = np.array([0.5, 0.5])
mu1 = np.array([0.5, -0.5])
mu2 = np.array([1.0, -1.0])
sigma = np.array([0.1, 0.1])

mse_eval_alpha =  np.zeros((n_trial, len(alpha_list)))
mse_former_alpha = np.zeros((n_trial, len(alpha_list)))
mse_latter_alpha = np.zeros((n_trial, len(alpha_list)))
mse_overall_alpha =  np.zeros((n_trial, len(alpha_list)))

for trial in tqdm.tqdm(range(n_trial)):
    np.random.seed(trial)
    X = np.zeros((2*N, 1))
    Xj_former = generate_data(N, pi, mu1, sigma)
    X[:N, 0] = Xj_former.ravel()
    Xj_latter = generate_data(N, pi, mu2, sigma)
    X[N:, 0] = Xj_latter.ravel()
    
    for j, alpha in enumerate(alpha_list):
        idxes_ol = np.random.choice(np.arange(2*N), int(alpha*2*N), replace=False)
        X[idxes_ol, :] = np.random.uniform(-u, u, len(idxes_ol)).reshape(-1, 1)

        rho = ratio * (1-alpha_orig)/(1-alpha) *np.exp(-gamma**2/M**2)
        rho_const = np.repeat([rho], X.shape[0])
        theta_est_sra = sra(X, theta0, rho_const, gamma)

        mse_eval = np.sum((theta_est_sra[500:1000, 3:5] - mu1)**2, axis=1)
        mse_eval_alpha[trial, j] = np.mean(mse_eval)
        
        mse_former = np.sum((theta_est_sra[1000:10000, 3:5] - mu1)**2, axis=1)
        mse_former_alpha[trial, j] = np.mean(mse_former)
        
        mse_latter = np.sum((theta_est_sra[10000:, 3:5] - mu2)**2, axis=1)
        mse_latter_alpha[trial, j] = np.mean(mse_latter)

        mse_overall = np.hstack((mse_former, mse_latter))
        mse_overall_alpha[trial, j] = np.mean(mse_overall)

        print(np.mean(mse_eval), np.mean(mse_former), np.mean(mse_latter), np.mean(mse_overall))

  # This is added back by InteractiveShellApp.init_path()


0.0021014678704488984 0.0024049441945100955 0.001977626655594734 0.0021800402266599055
0.0037351928475942266 0.007617017596990774 0.1474191127166607 0.08119706765997493


 10%|█         | 1/10 [01:43<15:33, 103.72s/it]

0.015944313751274013 0.014559992578456556 0.0073454431443337655 0.0107628612973393
0.0027828433603861945 0.0018379621222346211 0.0021207383218861 0.0019867917009985573
0.006326138154849464 0.0060009474489323795 0.004369323892663926 0.005142198208791088


 20%|██        | 2/10 [03:28<13:52, 104.11s/it]

0.012195422031230236 0.012033868348255504 0.015955423537026572 0.014097844763398173
0.0014123332761470627 0.001671735690172463 0.0020298678961288025 0.001860226324886326
0.0029853958819001593 0.007123537298849829 0.0044354254519208165 0.0057087415899398225


 30%|███       | 3/10 [05:11<12:07, 103.91s/it]

0.0128441895129767 0.011654651564770798 0.007877752291665378 0.009666809842083736
0.0020706076811293894 0.00208497607553327 0.021873114904892677 0.01249978598572243
0.002322682069984328 0.007467447861879079 0.046088828913180606 0.02779449052045883


 40%|████      | 4/10 [06:57<10:26, 104.45s/it]

0.007866869406835892 0.013058438260500075 0.03441129472013395 0.024296783765570533
0.002330023609156684 0.0017571541488169275 0.026632256380925725 0.014849313218347877
0.003517924375631876 0.00446327498382807 0.029041382816959897 0.01739912121179219


 50%|█████     | 5/10 [08:46<08:50, 106.18s/it]

0.014768283992623381 0.011727761848507068 0.717299408337969 0.38308126000085546
0.0020292456352065465 0.0027479503640732597 0.0021517168362636552 0.0024341432441734676
0.0038861648976974265 0.008605015157426897 0.0044275691888244284 0.006406359384478228


 60%|██████    | 6/10 [10:32<07:04, 106.15s/it]

0.01371200859256011 0.017774243104683493 0.007183124406605695 0.012199970105695178
0.001135221319158567 0.0022398797097421246 0.0024200034643247405 0.0023346816858382377
0.0021734610840270716 0.00867252819320332 0.21647742380021265 0.11804352588110298


 70%|███████   | 7/10 [12:18<05:18, 106.24s/it]

0.008154894862514042 0.015945454960452812 0.013103052035710962 0.014449453421114996
0.001589899391728482 0.001944687929468963 0.0022371398609707656 0.0020986099986804382
0.003768008628888501 0.004227680710075858 0.004992971734886595 0.004630465459976246


 80%|████████  | 8/10 [14:02<03:31, 105.57s/it]

0.011143417228330761 0.012397823463835183 0.25025461728976306 0.1375856096880077
nan nan nan nan
nan nan nan nan


 90%|█████████ | 9/10 [15:32<01:40, 100.65s/it]

nan nan nan nan
0.002306593082290716 0.002096051863714809 0.002213868933919168 0.002158060848032893
0.002426232092928858 0.008961644709108306 0.0047389187784530155 0.0067391573771844685


100%|██████████| 10/10 [17:01<00:00, 102.15s/it]

0.010913768698788878 0.014197858345640597 0.008192886074239042 0.011037346623850303





In [14]:
with open(os.path.join(outdir, 'mse_eval.pkl'), 'wb') as f:
    pickle.dump(mse_eval_alpha, f)
    
with open(os.path.join(outdir, 'mse_former.pkl'), 'wb') as f:
    pickle.dump(mse_former_alpha, f)
    
with open(os.path.join(outdir, 'mse_latter.pkl'), 'wb') as f:
    pickle.dump(mse_latter_alpha, f)
    
with open(os.path.join(outdir, 'mse_overall.pkl'), 'wb') as f:
    pickle.dump(mse_overall_alpha, f)