In [1]:
import jax.numpy as jnp
from jax import grad, vmap, hessian

import jax.ops as jop
from jax.config import config; 
config.update("jax_enable_x64", True)

# numpy
import numpy as onp
from numpy import random 

# Plot
import matplotlib.pyplot as plt

In [2]:
# up to 2nd derivatives

def kappa(x,y,d,sigma):
    dist2 = jnp.sum((x-y)**2)
    return jnp.exp(-dist2/(2*sigma**2))

def D_wy_kappa(x,y,d, sigma,w):
    dist2 = jnp.sum((x-y)**2)
    val = -jnp.sum(w*(y-x))/(sigma**2)*jnp.exp(-dist2/(2*sigma**2))
    return val

def Delta_y_kappa(x,y,d,sigma):
    dist2 = jnp.sum((x-y)**2)
    val = (-d*(sigma**2)+dist2)/(sigma**4)*jnp.exp(-dist2/(2*sigma**2))
    return val

def D_wx_kappa(x,y,d, sigma,w):
    dist2 = jnp.sum((x-y)**2)
    val = -jnp.sum(w*(x-y))/(sigma**2)*jnp.exp(-dist2/(2*sigma**2))
    return val

# Dx vector
def D_x_kappa(x,y,d, sigma):
    dist2 = jnp.sum((x-y)**2)
    val = -(x-y)/(sigma**2)*jnp.exp(-dist2/(2*sigma**2))
    return val

def D_wx_D_wy_kappa(x,y,d,sigma,wx,wy):
    dist2 = jnp.sum((x-y)**2)
    val = (jnp.sum(wx*wy)/(sigma**2)+jnp.sum(wx*(x-y))*jnp.sum(wy*(y-x))/(sigma**4))*jnp.exp(-dist2/(2*sigma**2))
    return val

# DxDwy vector
def D_x_D_wy_kappa(x,y,d,sigma,wy):
    dist2 = jnp.sum((x-y)**2)
    val = (wy/(sigma**2)+(x-y)*jnp.sum(wy*(y-x))/(sigma**4))*jnp.exp(-dist2/(2*sigma**2))
    return val

def D_wx_Delta_y_kappa(x,y, d,sigma,w):
    dist2 = jnp.sum((x-y)**2)
    val = jnp.sum(w*(x-y))*((sigma**2)*(2+d)-dist2)/(sigma**6)*jnp.exp(-dist2/(2*sigma**2))
    return val


def Delta_x_kappa(x,y,d,sigma):
    dist2 = jnp.sum((x-y)**2)
    val = (-d*(sigma**2)+dist2)/(sigma**4)*jnp.exp(-dist2/(2*sigma**2))
    return val

def Delta_x_D_wy_kappa(x,y, d,sigma,w):
    dist2 = jnp.sum((x-y)**2)
    val = jnp.sum(w*(y-x))*((sigma**2)*(2+d)-dist2)/(sigma**6)*jnp.exp(-dist2/(2*sigma**2))
    return val

def Delta_x_Delta_y_kappa(x,y,d,sigma):
    dist2 = jnp.sum((x-y)**2)
    val = ((sigma**4)*d*(2+d)-2*(sigma**2)*(2+d)*dist2+dist2**2)/(sigma**8)*jnp.exp(-dist2/(2*sigma**2))
    return val

def get_GNkernel_train(x,y,wx0,wx1,wy0,wy1,d,sigma):
    return wx0*wy0*kappa(x,y,d,sigma) + wx0*D_wy_kappa(x,y,d, sigma,wy1) + wy0* D_wx_kappa(x,y,d, sigma,wx1) + D_wx_D_wy_kappa(x,y,d,sigma,wx1,wy1)

def get_GNkernel_val_predict(x,y,wy0,wy1,d,sigma):
    return wy0*kappa(x,y,d,sigma) + D_wy_kappa(x,y,d, sigma,wy1)

def get_GNkernel_grad_predict(x,y,wy0,wy1,d,sigma):
    return wy0*D_x_kappa(x,y,d, sigma) + D_x_D_wy_kappa(x,y,d,sigma,wy1)


In [3]:
def g(x):
    return jnp.log(1/2+1/2*sum(x**2))


In [4]:
def assembly_Theta(X_domain, w0, w1, sigma):
    # X_domain, dim: N_domain*d; 
    # w0 col vec: coefs of Diracs, dim: N_domain; 
    # w1 coefs of gradients, dim: N_domain*d
    
    N_domain,d = onp.shape(X_domain)
    Theta = onp.zeros((N_domain,N_domain))
    
    XdXd0 = onp.reshape(onp.tile(X_domain,(1,N_domain)),(-1,d))
    XdXd1 = onp.tile(X_domain,(N_domain,1))
    
    arr_wx0 = onp.reshape(onp.tile(w0,(1,N_domain)),(-1,1))
    arr_wx1 = onp.reshape(onp.tile(w1,(1,N_domain)),(-1,d))
    arr_wy0 = onp.tile(w0,(N_domain,1))
    arr_wy1 = onp.tile(w1,(N_domain,1))
    
    val = vmap(lambda x,y,wx0,wx1,wy0,wy1: get_GNkernel_train(x,y,wx0,wx1,wy0,wy1,d,sigma))(XdXd0,XdXd1,arr_wx0,arr_wx1,arr_wy0,arr_wy1)
    Theta[:N_domain,:N_domain] = onp.reshape(val, (N_domain,N_domain))
    return Theta
    
def assembly_Theta_value_and_grad_predict(X_infer, X_domain, w0, w1, sigma):
    N_infer, d = onp.shape(X_infer)
    N_domain, _ = onp.shape(X_domain)
    Theta = onp.zeros((N_infer*(d+1),N_domain))
    
    XiXd0 = onp.reshape(onp.tile(X_infer,(1,N_domain)),(-1,d))
    XiXd1 = onp.tile(X_domain,(N_infer,1))
    
    arr_wy0 = onp.tile(w0,(N_infer,1))
    arr_wy1 = onp.tile(w1,(N_infer,1))
    
    val = vmap(lambda x,y,wy0,wy1: get_GNkernel_val_predict(x,y,wy0,wy1,d,sigma))(XiXd0,XiXd1,arr_wy0,arr_wy1)
    Theta[:N_infer,:N_domain] = onp.reshape(val, (N_infer,N_domain))

    val = vmap(lambda x,y,wy0,wy1: get_GNkernel_grad_predict(x,y,wy0,wy1,d,sigma))(XiXd0,XiXd1,arr_wy0,arr_wy1)
    Theta[N_infer:,:N_domain] = onp.reshape(val,(N_infer*d,N_domain))
    return Theta
    
def assembly_Theta_stanGP(X_domain,sigma):
    N_domain,d = onp.shape(X_domain)
    Theta = onp.zeros((N_domain,N_domain))
    
    XdXd0 = onp.reshape(onp.tile(X_domain,(1,N_domain)),(-1,d))
    XdXd1 = onp.tile(X_domain,(N_domain,1))
    
    val = vmap(lambda x,y: kappa(x,y,d,sigma))(XdXd0,XdXd1)
    Theta[:N_domain,:N_domain] = onp.reshape(val, (N_domain,N_domain))
    return Theta
    
def assembly_Theta_predict_value_and_grad_stanGP(X_infer, X_domain,sigma):
    N_infer,d = onp.shape(X_infer)
    N_domain = onp.shape(X_domain)[0]
    Theta = onp.zeros((N_infer*(d+1),N_domain))
    
    XdXd0 = onp.reshape(onp.tile(X_infer,(1,N_domain)),(-1,d))
    XdXd1 = onp.tile(X_domain,(N_infer,1))
    
    val = vmap(lambda x,y: kappa(x,y,d,sigma))(XdXd0,XdXd1)
    Theta[:N_infer,:N_domain] = onp.reshape(val, (N_infer,N_domain))
    val = vmap(lambda x,y: D_x_kappa(x,y,d,sigma))(XdXd0,XdXd1)
    Theta[N_infer:,:N_domain] = onp.reshape(val,(N_infer*d,N_domain))
    return Theta
    

In [5]:
def generate_path(X_init, N_domain, dt, T):
    if onp.ndim(X_init)==1: X_init = X_init[onp.newaxis,:]
    _,d = onp.shape(X_init)
    Nt = int(T/dt)+1
    arr_X = onp.zeros((Nt,N_domain,d))
    arr_X[0,:,:] = X_init
    rdn = onp.random.normal(0, 1, (Nt-1, N_domain,d))
    for i in range(Nt-1):
        arr_X[i+1,:,:] = arr_X[i,:,:] + onp.sqrt(2*dt)*rdn[i,:,:]
    return arr_X

def one_step_iteration(V_future, X_future, X_now, dt, sigma, nugget, GN_step):
    N_domain = onp.shape(X_now)[0]
    Theta_train = assembly_Theta_stanGP(X_future,sigma)
    Theta_infer = assembly_Theta_predict_value_and_grad_stanGP(X_now, X_future,sigma)
    
    V_val_n_grad = Theta_infer @ (onp.linalg.solve(Theta_train + nugget*onp.eye(onp.shape(Theta_train)[0]),V_future))
    w0 = onp.ones((N_domain,1))
    for i in range(GN_step):
        # get grad V_{old}
        V_old = V_val_n_grad[:N_domain]
        print('GN step', i, ' and sol val at 1st point', V_old[0])
        V_old_grad = onp.reshape(V_val_n_grad[N_domain:],(N_domain,d))
        
        w1 = 2*V_old_grad+(X_future-X_now)
        Theta_train = assembly_Theta(X_now, w0, w1, sigma)
        Theta_infer = assembly_Theta_value_and_grad_predict(X_now, X_now, w0, w1, sigma)
        rhs = V_future + onp.sum(V_old_grad**2,axis=1)*dt
        V_val_n_grad = Theta_infer @ (onp.linalg.solve(Theta_train + nugget*onp.eye(onp.shape(Theta_train)[0]),rhs))
    
    return V_val_n_grad[:N_domain]


def GPsolver(X_init, N_domain, dt, T, sigma, nugget, GN_step = 4):
    if onp.ndim(X_init)==1: X_init = X_init[onp.newaxis,:]
    _,d = onp.shape(X_init)
    Nt = int(T/dt)+1
    arr_X = generate_path(X_init, N_domain, dt, T)
    V = onp.zeros((Nt,N_domain))
    V[-1,:] = vmap(g)(arr_X[-1,:,:])
    
    # solve V[-i-1,:] from V[-i,:]
    for i in onp.arange(1,Nt):
        V[-i-1,:] = one_step_iteration(V[-i,:], arr_X[-i,:,:], arr_X[-i-1,:,:], dt, sigma, nugget, GN_step)
        if i % 10 ==0:
            print(f'time t = {(Nt-i)*dt} solved')
    return V

In [8]:
d = 10
X_init = onp.zeros((1,d))
N_domain = 500
dt = 1e-2
T = 1
sigma = onp.sqrt(d)*50
nugget = 1e-2
GN_step = 10
V = GPsolver(X_init, N_domain, dt, T, sigma, nugget, GN_step)

GN step 0  and sol val at 1st point 2.2306380987141097
GN step 1  and sol val at 1st point 2.230264104587036
GN step 2  and sol val at 1st point 2.2302369972009046
GN step 3  and sol val at 1st point 2.230243722916981
GN step 4  and sol val at 1st point 2.2302437359298892
GN step 5  and sol val at 1st point 2.2302437359001317
GN step 6  and sol val at 1st point 2.2302437359141436
GN step 7  and sol val at 1st point 2.2302437359138594
GN step 8  and sol val at 1st point 2.2302437359136604
GN step 9  and sol val at 1st point 2.23024373591511
GN step 0  and sol val at 1st point 2.2396135771697097
GN step 1  and sol val at 1st point 2.2396497870570613
GN step 2  and sol val at 1st point 2.2396497163850304
GN step 3  and sol val at 1st point 2.2396497162713267
GN step 4  and sol val at 1st point 2.2396497162714657
GN step 5  and sol val at 1st point 2.2396497162714617
GN step 6  and sol val at 1st point 2.2396497162714795
GN step 7  and sol val at 1st point 2.2396497162714724
GN step 8  and

GN step 0  and sol val at 1st point 2.2733951381093163
GN step 1  and sol val at 1st point 2.2733863593963997
GN step 2  and sol val at 1st point 2.273386362650854
GN step 3  and sol val at 1st point 2.2733863626506157
GN step 4  and sol val at 1st point 2.2733863626506166
GN step 5  and sol val at 1st point 2.2733863626506152
GN step 6  and sol val at 1st point 2.2733863626506152
GN step 7  and sol val at 1st point 2.2733863626506152
GN step 8  and sol val at 1st point 2.2733863626506152
GN step 9  and sol val at 1st point 2.2733863626506152
GN step 0  and sol val at 1st point 2.273716474662583
GN step 1  and sol val at 1st point 2.2737072432820127
GN step 2  and sol val at 1st point 2.2737072449533198
GN step 3  and sol val at 1st point 2.2737072449533926
GN step 4  and sol val at 1st point 2.27370724495339
GN step 5  and sol val at 1st point 2.2737072449533895
GN step 6  and sol val at 1st point 2.27370724495339
GN step 7  and sol val at 1st point 2.27370724495339
GN step 8  and sol

time t = 0.71 solved
GN step 0  and sol val at 1st point 2.273863469519427
GN step 1  and sol val at 1st point 2.2738540932285956
GN step 2  and sol val at 1st point 2.2738540935045064
GN step 3  and sol val at 1st point 2.273854093504493
GN step 4  and sol val at 1st point 2.2738540935044957
GN step 5  and sol val at 1st point 2.273854093504495
GN step 6  and sol val at 1st point 2.273854093504495
GN step 7  and sol val at 1st point 2.273854093504495
GN step 8  and sol val at 1st point 2.273854093504495
GN step 9  and sol val at 1st point 2.273854093504495
GN step 0  and sol val at 1st point 2.2739386229141276
GN step 1  and sol val at 1st point 2.2739290437910755
GN step 2  and sol val at 1st point 2.273929044493436
GN step 3  and sol val at 1st point 2.2739290444934723
GN step 4  and sol val at 1st point 2.2739290444934683
GN step 5  and sol val at 1st point 2.2739290444934683
GN step 6  and sol val at 1st point 2.2739290444934683
GN step 7  and sol val at 1st point 2.27392904449346

GN step 0  and sol val at 1st point 2.27320209765025
GN step 1  and sol val at 1st point 2.2731936260705528
GN step 2  and sol val at 1st point 2.273193626314333
GN step 3  and sol val at 1st point 2.2731936263143595
GN step 4  and sol val at 1st point 2.27319362631436
GN step 5  and sol val at 1st point 2.27319362631436
GN step 6  and sol val at 1st point 2.27319362631436
GN step 7  and sol val at 1st point 2.27319362631436
GN step 8  and sol val at 1st point 2.27319362631436
GN step 9  and sol val at 1st point 2.27319362631436
GN step 0  and sol val at 1st point 2.273202452844444
GN step 1  and sol val at 1st point 2.2731933952538372
GN step 2  and sol val at 1st point 2.2731933957417936
GN step 3  and sol val at 1st point 2.27319339574178
GN step 4  and sol val at 1st point 2.27319339574178
GN step 5  and sol val at 1st point 2.27319339574178
GN step 6  and sol val at 1st point 2.27319339574178
GN step 7  and sol val at 1st point 2.27319339574178
GN step 8  and sol val at 1st point 

GN step 1  and sol val at 1st point 2.272762265109355
GN step 2  and sol val at 1st point 2.272762264735303
GN step 3  and sol val at 1st point 2.2727622647352854
GN step 4  and sol val at 1st point 2.272762264735284
GN step 5  and sol val at 1st point 2.272762264735285
GN step 6  and sol val at 1st point 2.272762264735285
GN step 7  and sol val at 1st point 2.272762264735285
GN step 8  and sol val at 1st point 2.272762264735285
GN step 9  and sol val at 1st point 2.272762264735285
GN step 0  and sol val at 1st point 2.272746752069847
GN step 1  and sol val at 1st point 2.272737952716346
GN step 2  and sol val at 1st point 2.272737952343892
GN step 3  and sol val at 1st point 2.27273795234383
GN step 4  and sol val at 1st point 2.2727379523438316
GN step 5  and sol val at 1st point 2.2727379523438316
GN step 6  and sol val at 1st point 2.2727379523438316
GN step 7  and sol val at 1st point 2.2727379523438316
GN step 8  and sol val at 1st point 2.2727379523438316
GN step 9  and sol val 

GN step 2  and sol val at 1st point 2.2722643618550156
GN step 3  and sol val at 1st point 2.2722643618550107
GN step 4  and sol val at 1st point 2.2722643618550116
GN step 5  and sol val at 1st point 2.2722643618550116
GN step 6  and sol val at 1st point 2.2722643618550116
GN step 7  and sol val at 1st point 2.2722643618550116
GN step 8  and sol val at 1st point 2.2722643618550116
GN step 9  and sol val at 1st point 2.2722643618550116
GN step 0  and sol val at 1st point 2.2722500341503906
GN step 1  and sol val at 1st point 2.2722415400590954
GN step 2  and sol val at 1st point 2.272241540004097
GN step 3  and sol val at 1st point 2.2722415400041154
GN step 4  and sol val at 1st point 2.272241540004118
GN step 5  and sol val at 1st point 2.2722415400041185
GN step 6  and sol val at 1st point 2.2722415400041185
GN step 7  and sol val at 1st point 2.2722415400041185
GN step 8  and sol val at 1st point 2.2722415400041185
GN step 9  and sol val at 1st point 2.2722415400041185
GN step 0  a

GN step 3  and sol val at 1st point 2.2716276612101454
GN step 4  and sol val at 1st point 2.271627661210146
GN step 5  and sol val at 1st point 2.271627661210146
GN step 6  and sol val at 1st point 2.271627661210146
GN step 7  and sol val at 1st point 2.271627661210146
GN step 8  and sol val at 1st point 2.271627661210146
GN step 9  and sol val at 1st point 2.271627661210146
GN step 0  and sol val at 1st point 2.2715797719110764
GN step 1  and sol val at 1st point 2.271570295357461
GN step 2  and sol val at 1st point 2.271570295377986
GN step 3  and sol val at 1st point 2.2715702953779906
GN step 4  and sol val at 1st point 2.2715702953779906
GN step 5  and sol val at 1st point 2.2715702953779906
GN step 6  and sol val at 1st point 2.2715702953779906
GN step 7  and sol val at 1st point 2.2715702953779906
GN step 8  and sol val at 1st point 2.2715702953779906
GN step 9  and sol val at 1st point 2.2715702953779906
GN step 0  and sol val at 1st point 2.271534248931692
GN step 1  and sol 

In [None]:
print(V[-1,:])

In [None]:
print()