In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.stats.multivariate_normal import pdf as gaussian_pdf
from jax import jacfwd, jacrev
import optax
import numpy as np
import matplotlib.pyplot as plt

import pickle

from scipy.optimize import minimize

# 1. Problem Description

In [2]:
sig_mean = 0.0
d_obs = 1
mu_1 = 1.
mu_2 = 2.
sig_1 = 2.
sig_2 = 2.

def M(x):
    x1, x2 = x[0], x[1]
    return 0.5*jnp.power(x1, 2) + 1.5*jnp.power(x2, 2)

@jit
def updf(x):
    def likelihood(x):
        d = M(x)
        value = gaussian_pdf(d_obs-d, jnp.array([sig_mean]), np.diag([1]))
        return value

    def prior(x):
        return gaussian_pdf(x, jnp.array([mu_1, mu_2]), np.diag([sig_1**2, sig_2**2]))

    x = jnp.array(x)
    value = likelihood(x) * prior(x)
    value = jnp.where(value==0, 1e-16, value)
    return value

vupdf = vmap(updf, in_axes=(0,))

def sampler(n):
    return np.random.multivariate_normal(np.zeros((2, )), np.identity(2), size=(n, ))

def reference_pdf(r):
    return gaussian_pdf(r, np.zeros((2, )), np.identity(2))

# 2. MCMC

In [3]:
def RWMCMC(iter, x_set=None):
    if x_set is None:
        x_curr = np.array([0, 0])
        x_set = []
    else:
        x_set = np.array(x_set).tolist()
        x_curr = x_set[-1]

    try:
        for i in range(1, iter):
            x_next = np.random.multivariate_normal(x_curr, np.identity(2))

            alpha1 = updf(x_next)
            alpha2 = updf(x_curr)
            alpha = alpha1/alpha2
            alpha = np.min([1, alpha])
            x_curr = x_next if np.random.uniform(0, 1) < alpha else x_curr

            x_set.append(x_curr)
            print(len(x_set), alpha, x_curr)

        return x_set
    except KeyboardInterrupt:
        return x_set


def AMCMC(iter):
    x_curr = np.array([0, 0])
    x_set = []

    try:
        for i in range(0, iter):
            if i<=4:
                x_next = np.random.multivariate_normal(x_curr, 0.01*(np.identity(2)/2))
            else:
                COV = np.array(x_set)
                COV = np.cov(COV.T)
                x_next = (1-0.05)*np.random.multivariate_normal(x_curr, 2.38**2*COV/2) + 0.05*np.random.multivariate_normal(x_curr, 0.01*(np.identity(2)/2))

            alpha1 = updf(x_next)
            alpha2 = updf(x_curr)
            alpha = alpha1/alpha2
            alpha = np.min([1, alpha])
            x_curr = x_next if np.random.uniform(0, 1) < alpha else x_curr

            x_set.append(x_curr)
            print(len(x_set), alpha, x_curr)

        return x_set
    except KeyboardInterrupt:
        return x_set


def EMCEEMCMC(iter):
    x_curr = np.array([0.0, 0.0])
    x_set = []
    FS = np.random.uniform(-1, 1, (iter, 2))

    try:
        for i in range(0, iter):
            full_range = np.arange(iter)
            left_range = full_range[full_range != i]
            x_next_index = np.random.choice(left_range)
            x_next = FS[x_next_index]

            r = np.random.rand()
            z = ((1 + r * (2 - 1))**2) / 2
            x_next = x_next + z * (x_curr - x_next)

            alpha = z * (updf(x_next)/ updf(x_curr))
            x_curr = x_next if np.random.uniform(0, 1) < alpha else x_curr

            x_set.append(x_curr)
            FS[i] = x_curr
            print(len(x_set), alpha, x_curr)
        return x_set
    except KeyboardInterrupt:
        return x_set

In [4]:
# RMCMC = RWMCMC(100000)
# AMCMC = AMCMC(100000)
# EMCMC = EMCEEMCMC(100000)

# np.save('/content/drive/MyDrive/No. 21 MCMC+TM/E1/Data_RMCMC.npy', RMCMC)
# np.save('/content/drive/MyDrive/No. 21 MCMC+TM/E1/Data_AMCMC.npy', AMCMC)
# np.save('/content/drive/MyDrive/No. 21 MCMC+TM/E1/Data_EMCMC.npy', EMCMC)

x_RMCMC = np.load('/content/drive/MyDrive/No. 21 MCMC+TM/E1/Data_RMCMC.npy')
x_AMCMC = np.load('/content/drive/MyDrive/No. 21 MCMC+TM/E1/Data_AMCMC.npy')
x_EMCMC = np.load('/content/drive/MyDrive/No. 21 MCMC+TM/E1/Data_EMCMC.npy')

# 3. TM_polynomial

In [5]:
from itertools import product

def generate_J(n, i, p):
    combinations = product(range(p+1), repeat=i)
    valid_combinations = [comb for comb in combinations if sum(comb) <= p]
    j_vectors = [comb + (0,)*(n-i) for comb in valid_combinations]
    return jnp.array(j_vectors)

@jit
def T1(param1, J1, x1, x2):
    def uni_polynomial(p, ri):
        return jnp.where(p == 0, 1,
            jnp.where(p == 1, ri,
            jnp.where(p == 2, 0.5 * (3 * jnp.power(ri, 2) - 1),
            jnp.where(p == 3, 0.5 * (5 * jnp.power(ri, 3) - 3 * ri),
            jnp.where(p == 4, (1/8) * (35 * jnp.power(ri, 4) - 30 * jnp.power(ri, 2) + 3),
                                (1/8) * (63 * jnp.power(ri, 5) - 70 * jnp.power(ri, 3) + 15*ri)
            )))))
    vuni_polynomial = vmap(uni_polynomial, in_axes=(0, 0))

    def psi(Ji, r):
        result = vuni_polynomial(Ji, r)
        result = jnp.prod(result)
        return result
    vpsi = vmap(psi, in_axes=(0, None))

    x = jnp.array([x1, x2]).reshape((-1, ))
    result = jnp.dot(param1, vpsi(J1, x))
    return result

@jit
def T2(param2, J2, x1, x2):
    def uni_polynomial(p, ri):
        return jnp.where(p == 0, 1,
            jnp.where(p == 1, ri,
            jnp.where(p == 2, 0.5 * (3 * jnp.power(ri, 2) - 1),
            jnp.where(p == 3, 0.5 * (5 * jnp.power(ri, 3) - 3 * ri),
            jnp.where(p == 4, (1/8) * (35 * jnp.power(ri, 4) - 30 * jnp.power(ri, 2) + 3),
                                (1/8) * (63 * jnp.power(ri, 5) - 70 * jnp.power(ri, 3) + 15*ri)
            )))))
    vuni_polynomial = vmap(uni_polynomial, in_axes=(0, 0))

    def psi(Ji, r):
        result = vuni_polynomial(Ji, r)
        result = jnp.prod(result)
        return result
    vpsi = vmap(psi, in_axes=(0, None))

    x = jnp.array([x1, x2]).reshape((-1, ))
    result = jnp.dot(param2, vpsi(J2, x))
    return result

gT1 = grad(T1, argnums=2)
gT2 = grad(T2, argnums=3)

def TM(param1, param2, J1, J2, x):
    return jnp.array([T1(param1, J1, x[0], x[1]), T2(param2, J2, x[0], x[1])])
vTM = vmap(TM, in_axes=(None, None, None, None, 0))
gTM = jacfwd(TM, argnums=4)

In [6]:
def obj1(param1, J1, x_batch):
    def obj_fun(param1, J1, x1, x2):
        result = 0.5*T1(param1, J1, x1, x2)**2 - jnp.log(gT1(param1, J1, x1, x2))
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, None, 0, 0))

    result = vobj_fun(param1, J1, x_batch[:, 0], x_batch[:, 1]).mean()
    return result

def obj_test1(param1, J1, x_batch):
    def obj_fun(param1, J1, x1, x2):
        result = 0.5*T1(param1, J1, x1, x2)**2 - jnp.log(gT1(param1, J1, x1, x2))
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, None, 0, 0))

    result = vobj_fun(param1, J1, x_batch[:, 0], x_batch[:, 1])
    return result

def constrant1(param1, J1, x_batch):
    vgT1 = vmap(gT1, in_axes=(None, None, 0, 0))
    value = vgT1(param1, J1, x_batch[:, 0], x_batch[:, 1])
    value = jnp.maximum(0, -value)
    value = jnp.sum(value)
    return value

def obj2(param2, J2, x_batch):
    def obj_fun(param2, J2, x1, x2):
        result = 0.5*T2(param2, J2, x1, x2)**2 - jnp.log(gT2(param2, J2, x1, x2))
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, None, 0, 0))

    result = vobj_fun(param2, J2, x_batch[:, 0], x_batch[:, 1]).mean()
    return result

def obj_test2(param2, J2, x_batch):
    def obj_fun(param2, J2, x1, x2):
        result = 0.5*T2(param2, J2, x1, x2)**2 - jnp.log(gT2(param2, J2, x1, x2))
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, None, 0, 0))

    result = vobj_fun(param2, J2, x_batch[:, 0], x_batch[:, 1])
    return result

def constrant2(param2, J2, x_batch):
    vgT1 = vmap(gT2, in_axes=(None, None, 0, 0))
    value = vgT1(param2, J2, x_batch[:, 0], x_batch[:, 1])
    value = jnp.maximum(0, -value)
    value = jnp.sum(value)
    return value

In [None]:
def TM_train(param1, param2, J1, J2, x_set):
    x_set = np.array(x_set)

    index1 = np.where(np.isnan(obj_test1(param1, J1, x_set)))[0]
    x_set1 = np.delete(x_set, index1, axis=0)
    cons = {'type': 'ineq', 'fun': constrant1, 'args': (J1, x_set1)}
    solution1 = minimize(obj1, param1, args=(J1, x_set1), method='SLSQP', jac=grad(obj1), constraints=cons, options={'disp': False, 'maxiter': 300})
    if solution1.success:
        param1 = solution1.x

    index2 = np.where(np.isnan(obj_test2(param2, J2, x_set)))[0]
    x_set2 = np.delete(x_set, index2, axis=0)
    cons = {'type': 'ineq', 'fun': constrant2, 'args': (J2, x_set2)}
    solution2 = minimize(obj2, param2, args=(J2, x_set2), method='SLSQP', jac=grad(obj2), constraints=cons, options={'disp': False, 'maxiter': 300})
    if solution2.success:
        param2 = solution2.x

    return param1, param2


def TM_algo(order, x_batch):
    x_batch = np.array(x_batch)

    J1 = generate_J(n=2, i=1, p=order)
    J2 = generate_J(n=2, i=2, p=order)

    param1, param2 = np.ones(len(J1)), np.ones(len(J2))
    param1, param2 = TM_train(param1, param2, J1, J2, x_batch[:10000])
    param1, param2 = TM_train(param1, param2, J1, J2, x_batch)

    return param1, param2

# for order in [1, 2, 3, 4, 5]:
#     param1, param2 = TM_algo(order, x_EMCEEMCMC)
    # pickle.dump(param1, open(f'/content/drive/MyDrive/No. 21 MCMC+TM/E1/EMCMC_TM1_o{order}.pkl', 'wb'))
    # pickle.dump(param2, open(f'/content/drive/MyDrive/No. 21 MCMC+TM/E1/EMCMC_TM2_o{order}.pkl', 'wb'))

In [None]:
%timeit TM_algo(1, x_RMCMC)
%timeit TM_algo(1, x_AMCMC)
%timeit TM_algo(1, x_EMCMC)

1.22 s ± 31.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.2 s ± 26.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.48 s ± 36.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit TM_algo(2, x_RMCMC)
%timeit TM_algo(2, x_AMCMC)
%timeit TM_algo(2, x_EMCMC)

4.36 s ± 97 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.26 s ± 62.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.79 s ± 131 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit TM_algo(3, x_RMCMC)
%timeit TM_algo(3, x_AMCMC)
%timeit TM_algo(3, x_EMCMC)

13.7 s ± 214 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
11.6 s ± 224 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
15.1 s ± 251 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit TM_algo(4, x_RMCMC)
%timeit TM_algo(4, x_AMCMC)
%timeit TM_algo(4, x_EMCMC)

1min ± 2.38 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
54.8 s ± 879 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
50.5 s ± 997 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit TM_algo(5, x_RMCMC)
%timeit TM_algo(5, x_AMCMC)
%timeit TM_algo(5, x_EMCMC)

1min 35s ± 1.68 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
1min 9s ± 1.37 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
1min 22s ± 828 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# 4. TM_NN

In [7]:
from sklearn.model_selection import train_test_split

def init_network_params(sizes, key=random.PRNGKey(4)):
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.uniform(w_key, (n, m), minval=1, maxval=3), scale * random.uniform(b_key, (n,), minval=1, maxval=3)

    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

def TM_NN1(NN1, x1):
    activations = x1.reshape((-1, ))
    for w, b in NN1[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = jax.nn.sigmoid(outputs)
    final_w, final_b = NN1[-1]
    logit = jnp.dot(final_w, activations) + final_b
    return logit.squeeze()
gTM_NN1 = grad(TM_NN1, argnums=1)

NN1 = init_network_params([1, 100, 100, 100, 1])

@jit
def obj1(NN1, x_batch):
    def obj_fun(NN1, x1):
        result = 0.5*TM_NN1(NN1, x1)**2 - jnp.log(gTM_NN1(NN1, x1)+1e-12)
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, 0))

    result = vobj_fun(NN1, x_batch[:, 0]).mean()
    return result

@jit
def constraint1(NN1, x_batch):
    vgT1 = vmap(gTM_NN1, in_axes=(None, 0))
    value = vgT1(NN1, x_batch[:, 0])
    value = jnp.maximum(0, -value)
    value = jnp.sum(value)
    return value

In [8]:
def init_network_params(sizes, key=random.PRNGKey(4)):
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.uniform(w_key, (n, m), minval=1, maxval=3), scale * random.uniform(b_key, (n,), minval=1, maxval=3)

    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

def TM_NN2(NN2, x1, x2):
    activations = jnp.array([x1, x2]).reshape((-1, ))
    for w, b in NN2[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = jax.nn.sigmoid(outputs)
    final_w, final_b = NN2[-1]
    logit = jnp.dot(final_w, activations) + final_b
    return logit.squeeze()
gTM_NN2 = grad(TM_NN2, argnums=2)

NN2 = init_network_params([2, 100+50, 100+50, 100+50, 1])

@jit
def obj2(NN2, x_batch):
    def obj_fun(NN2, x1, x2):
        result = 0.5*TM_NN2(NN2, x1, x2)**2 - jnp.log(gTM_NN2(NN2, x1, x2)+1e-12)
        return result
    vobj_fun = vmap(obj_fun, in_axes=(None, 0, 0))

    result = vobj_fun(NN2, x_batch[:, 0], x_batch[:, 1]).mean()
    return result

@jit
def constraint2(NN2, x_batch):
    vgT2 = vmap(gTM_NN2, in_axes=(None, 0, 0))
    value = vgT2(NN2, x_batch[:, 0], x_batch[:, 1])
    value = jnp.maximum(0, -value)
    value = jnp.sum(value)
    return value

In [None]:
x_batch = np.array(x_RMCMC)
x_batch_train, x_batch_test = train_test_split(x_batch, test_size=0.15, random_state=42)

tobj1 = jit(grad(obj1))
optimizer= optax.adam(0.0005)
opt_state = optimizer.init(NN1)
lowest_value = 10000
for i in range(10000):
    grads = tobj1(NN1, x_batch_train)
    updates, opt_state = optimizer.update(grads, opt_state)
    NN1 = optax.apply_updates(NN1, updates)
    print(f'Train loss: {i}, {obj1(NN1, x_batch_train):.4f}, {constraint1(NN1, x_batch_train)}   \
            Test loss: {obj1(NN1, x_batch_test):.4f}, {constraint1(NN1, x_batch_test)}     {lowest_value:.4f}')
    if obj1(NN1, x_batch_test)<lowest_value:
        lowest_value = obj1(NN1, x_batch_test)
        NN1_best = NN1


tobj2 = jit(grad(obj2))
optimizer= optax.adam(0.0005)
opt_state = optimizer.init(NN2)
lowest_value = 100
for i in range(10000):
    grads = tobj2(NN2, x_batch_train)
    updates, opt_state = optimizer.update(grads, opt_state)
    NN2 = optax.apply_updates(NN2, updates)
    print(f'Train loss: {i}, {obj2(NN2, x_batch_train):.4f}, {constraint2(NN2, x_batch_train)}      Test loss: {obj2(NN2, x_batch_test):.4f}, {constraint2(NN2, x_batch_test)}     {lowest_value:.4f}')
    if obj2(NN2, x_batch_test)<lowest_value:
        lowest_value = obj2(NN2, x_batch_test)
        NN2_best = NN2

# pickle.dump(NN1_best, open('/content/drive/MyDrive/No. 21 MCMC+TM/RMCMC_NN1_best.pkl', 'wb'))
# pickle.dump(NN2_best, open('/content/drive/MyDrive/No. 21 MCMC+TM/RMCMC_NN2_best.pkl', 'wb'))

In [14]:
import time

def NNNN(NN1):
    x_batch = np.array(x_RMCMC)
    x_batch_train, x_batch_test = train_test_split(x_batch, test_size=0.15, random_state=42)

    tobj1 = jit(grad(obj1))
    optimizer= optax.adam(0.0005)
    opt_state = optimizer.init(NN1)
    lowest_value = 10000
    for i in range(10000):
        grads = tobj1(NN1, x_batch_train)
        updates, opt_state = optimizer.update(grads, opt_state)
        NN1 = optax.apply_updates(NN1, updates)
        # print(f'Train loss: {i}, {obj1(NN1, x_batch_train):.4f}, {constraint1(NN1, x_batch_train)}   \
        #         Test loss: {obj1(NN1, x_batch_test):.4f}, {constraint1(NN1, x_batch_test)}     {lowest_value:.4f}')
        # if obj1(NN1, x_batch_test)<lowest_value:
        #     lowest_value = obj1(NN1, x_batch_test)
        #     NN1_best = NN1

%timeit NNNN(NN1)

2min 17s ± 1.45 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
CPUTIME

240.78335857391357

In [None]:
# x_init = jnp.array([0.0, 0.0])
# def MCMC(x_init, param1, param2):
#     x_curr = x_init
#     x_set = []

#     try:
#         for i in range(1, 10001):
#             r_curr = TM(param1, param2, x_curr)
#             r_next = np.random.multivariate_normal(r_curr, np.identity(2))
#             x_next = inverse(param1, param2, r_next)

#             det_inv_jac1 = gTM(param1, param2, x_next)
#             det_inv_jac1 = jnp.linalg.inv(det_inv_jac1)
#             det_inv_jac1 = jnp.linalg.det(det_inv_jac1)
#             det_inv_jac1 = jnp.abs(det_inv_jac1)
#             alpha1 = updf(x_next)*gaussian_pdf(r_curr, r_next, np.identity(2))*det_inv_jac1

#             det_inv_jac2 = gTM(param1, param2, x_curr)
#             det_inv_jac2 = jnp.linalg.inv(det_inv_jac2)
#             det_inv_jac2 = jnp.linalg.det(det_inv_jac2)
#             det_inv_jac2 = jnp.abs(det_inv_jac2)
#             alpha2 = updf(x_curr)*gaussian_pdf(r_next, r_curr, np.identity(2))*det_inv_jac2

#             alpha = alpha1/alpha2
#             alpha = np.min([1, alpha])
#             x_curr = x_next if np.random.uniform(0, 1) < alpha else x_curr

#             x_set.append(x_curr)
#             print(i, alpha, x_curr)

#             if i % 10 == 0:
#                 param1, param2 = train(x_set, param1, param2)
#                 # plot(param1, param2)

#         return x_set, param1, param2
#     except KeyboardInterrupt:
#         return x_set, param1, param2

# x_set, param1, param2 = MCMC(x_init, param1, param2)