In [4]:
import numpy as np
from scipy.stats import norm
import pickle
from copy import deepcopy
import matplotlib.pyplot as plt

plt.rcParams['font.family'] = 'STIXGeneral'
plt.rcParams['mathtext.fontset'] = 'stix'

from dphbmu.conditionals import *
from dphbmu.splitmerge_collapsed import *
from dphbmu.relabeling import *

from frame2d.smf_3story2bay import smf_3story2bay

# project code
proj_code = 'num_dp'

# define simulator
simulator = lambda x: smf_3story2bay(norm.cdf(x))[4:]


# ------------------------------
# function for create_obs 
# ------------------------------

def create_obs(ne, sg, rng):
    uo_1 = np.array([0.90, 0.90, 0.90, 0.90, 0.90, 0.90, 0.90, 0.90, 0.90])
    uo_2 = np.array([0.40, 0.70, 0.60, 0.70, 0.90, 0.80, 0.90, 0.90, 0.90])
    uo_3 = np.array([0.20, 0.30, 0.30, 0.50, 0.60, 0.60, 0.70, 0.80, 0.80])
    uo = np.concatenate([
        np.tile(uo_1, (ne, 1)),
        np.tile(uo_2, (ne, 1)),
        np.tile(uo_3, (ne, 1)),
    ])
    uo = uo + norm(0, 0.02).rvs(uo.shape, random_state = rng)
    xo = norm.ppf(uo)
    yo = np.array([simulator(x) for x in xo])
    yo += norm(0, sg).rvs(yo.shape, random_state = rng)
    return xo, yo


# ------------------------------
# util. functions
# ------------------------------

def sigmoid(u):
    return 1.0 / (1.0 + np.exp(-u))

def logit(p):
    EPS = 1e-8
    p = np.clip(p, EPS, 1.0 - EPS)
    return np.log(p) - np.log(1.0 - p)


# ------------------------------
# experiment settings
# ------------------------------

# num. of experiments
n_sim = 10

# hyperparameters for synthetic data
sg = 0.10
ne = 5

# set seeds
rng_for_seed = np.random.default_rng(1001)
seeds_obs = rng_for_seed.choice(9999, n_sim)[1:]
seeds_rng = rng_for_seed.choice(9999, n_sim)[1:]

# ------------------------------
# hyperparameters settings
# ------------------------------

# hyperparameters for base distribution
D = 9
mu0 = norm.ppf(np.array([0.5] * D))
rh0 = 0.05

# hyperparameters for gamma distributions
a0g = 2.0
b0g = a0g / 0.10**-2
a0t = 2.0
b0t = a0t / 0.10**-2
a0l = 1.0
b0l = a0l / 1.00

# hyperparameters for pCN stepsize adaptation
arate_tar = 0.8
n_win = 50
v = 0.6

# num. of iterations
n_iter = 20000
n_burn = 5000


# ------------------------------
# EXPERIMENTS!
# ------------------------------

for k, seed_obs in enumerate(seeds_obs):

    # init. random number generator
    rng_obs = np.random.default_rng(seed_obs)
    seed_rng = seeds_rng[k]
    rng = np.random.default_rng(seed_rng)

    # observation
    xo, yo = create_obs(ne, sg, rng_obs)

    # data slots
    xxs, zzs, als, mus, tus, gms = [], [], [], [], [], []
    acs, scs = [], []  # for acceptance rate & scales

    # initialize
    xx = norm(0, 1).rvs(
        size = (len(yo), D),
        random_state = rng
    )
    yy = np.array([simulator(x) for x in xx])
    zz = np.zeros(len(yo)).astype(int)
    tu = a0t / b0t
    gm = a0g / b0g
    al = a0l / b0l
    k_ini = 1
    mu = np.array([sample_mu(tu, mu0, rh0, rng = rng) for _ in range(k_ini)])


    # --------------------
    # gibbs sampling!
    # --------------------

    # current scale
    lam = logit(0.5)
    scale = np.full(len(xx), sigmoid(lam))

    for i in range(n_iter):
            
        # metropolis-within-gibbs rule
        zz = update_zz_split_merge_collapsed(zz, xx, tu, mu0, rh0, al, rng = rng)
        al = update_al(al, len(np.unique(zz)), len(zz), a0l, b0l, rng = rng)
        mu = update_mu(zz, xx, tu, mu0, rh0, rng = rng)
        tu = update_tu(zz, xx, mu0, rh0, a0t, b0t, rng = rng)
        xx, yy, accepts = update_xx_pcn(simulator, zz, xx, yy, yo, gm, mu, tu, scale, rng = rng)
        gm = update_gm(yy, yo, a0g, b0g, rng = rng)

        # append
        als.append(deepcopy(al))
        zzs.append(deepcopy(zz))
        xxs.append(deepcopy(xx))
        mus.append(deepcopy(mu))
        tus.append(deepcopy(tu))
        gms.append(deepcopy(gm))
        acs.append(deepcopy(accepts))
        scs.append(deepcopy(scale))

        # tuning stepsize in pCN
        ar = np.array(acs[-np.min((len(acs), n_win)):]).mean()
        if i >= n_win and i < n_burn:
            zet = (i + 1)**-v
            lam = lam + zet * (ar - arate_tar)
            scale = np.full(len(xx), sigmoid(lam))
        
        print(
            f'Iter {i:05}:',
            zz,
            f'tu**-0.5 = {tu**-0.5:.3f}',
            f'gm**-0.5 = {gm**-0.5:.3f}',
            f'acc.rate = {ar:.3f}',
        )


    # --------------------
    # relabeling!
    # --------------------

    # number of clusters
    classnums = np.array([len(m) for m in mus])

    # posterior draws conditional on K_hat
    K_hat = 3
    zzr = np.array(zzs[n_burn:])[classnums[n_burn:] == K_hat]
    mur = np.array([mus[i] for i in range(n_burn, len(mus)) if classnums[i] == K_hat])

    # relabeling
    zzr, mur = relabelling(zzr, mur)


    # --------------------
    # save
    # --------------------

    dic = {
        'seeds': {
            'obs': seed_obs,
            'rng': seed_rng
        },
        'observation': {
            'sg': sg,
            'ne': ne,
            'xo': xo,
            'yo': yo,
        },
        'posterior': {
            'zzs': np.array(zzs),
            'xxs': np.array(xxs),
            'mus': mus,
            'als': np.array(als),
            'gms': np.array(gms),
            'tus': np.array(tus),
            'classnums': classnums
        },
        'relabeling': {
            'zzs': np.array(zzr),
            'mus': np.array(mur)
        },
        'hyper_step': {
            'arate_tar': arate_tar, 'n_win': n_win, 'v': v,
        },
        'hyper_mcmc': {
            'mu0': mu0,
            'rh0': rh0,
            'a0g': a0g,
            'b0g': b0g,
            'a0t': a0t,
            'b0t': b0t,
            'a0l': a0l,
            'b0l': b0l,
        },
        'n_burn': n_burn,
        'acceptance': np.array(acs),
        'scales': np.array(scs)
    }

    with open(f'out/result_{proj_code}_seed_{seed_obs:04}.pickle', mode = 'wb') as f:
        pickle.dump(dic, f)

Iter 00000: [1 0 1 0 0 0 0 0 0 0 1 0 1 0 0] tu**-0.5 = 0.933 gm**-0.5 = 0.933 acc.rate = 0.667
Iter 00001: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 0.998 gm**-0.5 = 1.041 acc.rate = 0.767
Iter 00002: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 1.012 gm**-0.5 = 0.900 acc.rate = 0.733
Iter 00003: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 0.924 gm**-0.5 = 0.773 acc.rate = 0.767
Iter 00004: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 0.959 gm**-0.5 = 0.836 acc.rate = 0.720
Iter 00005: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 0.994 gm**-0.5 = 0.743 acc.rate = 0.722
Iter 00006: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 0.913 gm**-0.5 = 0.782 acc.rate = 0.705
Iter 00007: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 0.908 gm**-0.5 = 0.771 acc.rate = 0.708
Iter 00008: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 0.753 gm**-0.5 = 0.662 acc.rate = 0.719
Iter 00009: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0.5 = 0.852 gm**-0.5 = 0.641 acc.rate = 0.693
Iter 00010: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] tu**-0

KeyboardInterrupt: 