In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
import haiku as hk
import optax
import matplotlib.pyplot as plt 
from IPython import display

In [None]:
print(jax.devices())
!nvidia-smi

In [None]:
key = jax.random.PRNGKey(42)

n = 20
dim = 2
batch = 8192
beta = 10.0
lr = 0.001

nvp_depth = 8
mlp_width = 32
mlp_depth = 2

epochs = 10000

In [None]:
#################################################################################### 
def energy_fn(x, n, dim):
    i, j = np.triu_indices(n, k=1)
    r_ee = jnp.linalg.norm((jnp.reshape(x, (n, 1, dim)) - jnp.reshape(x, (1, n, dim)))[i,j], axis=-1)
    return jnp.sum(1/r_ee) + jnp.sum(x**2)

batch_energy = jax.vmap(energy_fn, (0, None, None), 0)

In [None]:
class RealNVP(hk.Module):
    """
        Real-valued non-volume preserving (real NVP) transform. 
        The implementation follows the paper "arXiv:1605.08803."
    """
    def __init__(self, maskflow, nvp_depth, mlp_width, mlp_depth, event_size):
        super().__init__()
        self.maskflow = maskflow
        self.nvp_depth = nvp_depth
        self.event_size = event_size
        
        self.fc_mlp = [hk.nets.MLP([mlp_width]*mlp_depth, 
                        activation=jax.nn.tanh,
                        activate_final=True)
                        for _ in range(nvp_depth)]
        
        self.fc_lin = [hk.Linear(event_size * 2,
                        w_init=hk.initializers.TruncatedNormal(stddev=0.1), 
                        b_init=hk.initializers.TruncatedNormal(stddev=1.0, mean=1.0, lower=0.1, upper=4.0))
                        for _ in range(nvp_depth)]
        
        self.zoom = hk.get_parameter("zoom", [event_size, ], 
                        init=hk.initializers.Constant(-1.5), dtype=jnp.float64)

    ####################################################################################
    def coupling_forward(self, x1, x2, l):
        ## get shift and log(scale) from x1
        shift_and_logscale = self.fc_lin[l](self.fc_mlp[l](x1))
        shift, logscale = jnp.split(shift_and_logscale, 2, axis=-1)

        logscale = jnp.where(self.maskflow[l], 0, jax.nn.tanh(logscale)*self.zoom)
        
        ## transform: y2 = x2 * scale + shift
        y2 = x2 * jnp.exp(logscale) + shift
        ## calculate: logjacdet for each layer
        sum_logscale = jnp.sum(logscale)
        
        return y2, sum_logscale
    
    ####################################################################################   
    def __call__(self, x):
        #========== Real NVP (forward) ==========
        n, dim = x.shape  
        
        ## initial x and logjacdet
        x_flatten = jnp.reshape(x, (n*dim, ))
        logjacdet = 0
        
        for l in range(self.nvp_depth):
            ## split x into two parts: x1, x2
            x1 = jnp.where(self.maskflow[l], x_flatten, 0)
            x2 = jnp.where(self.maskflow[l], 0, x_flatten)
            
            ## get y2 from fc(x1), and calculate logjacdet = sum_l log(scale_l)
            y2, sum_logscale = self.coupling_forward(x1, x2, l)
            logjacdet += sum_logscale

            ## update: [x1, x2] -> [x1, y2]
            x_flatten = jnp.where(self.maskflow[l], x_flatten, y2)
            
        x = jnp.reshape(x_flatten, (n, dim))
        return x , logjacdet

####################################################################################
def get_maskflow(key, nvp_depth, event_size):
    
    mask1 = jnp.arange(0, jnp.prod(event_size)) % 2 == 0
    mask1 = (jnp.reshape(mask1, event_size)).astype(bool)
    mask2 = jnp.arange(0, jnp.prod(event_size)) % 2 == 1
    mask2 = (jnp.reshape(mask2, event_size)).astype(bool)
    
    maskflow = []
    for ii in range(nvp_depth):
        if   ii % 2 == 1: mask = mask1
        elif ii % 2 == 0: mask = mask2
        maskflow += [mask]
    return maskflow

# def get_maskflow(key, nvp_depth, event_size):
    
#     mask1 = jnp.concatenate([jnp.ones(event_size//2), jnp.zeros(event_size//2)])
#     mask1 = (jnp.reshape(mask1, event_size)).astype(bool)
#     mask2 = jnp.concatenate([jnp.zeros(event_size//2), jnp.ones(event_size//2)])
#     mask2 = (jnp.reshape(mask2, event_size)).astype(bool)
    
#     maskflow = []
#     for ii in range(nvp_depth):
#         if   ii % 2 == 1: mask = mask1
#         elif ii % 2 == 0: mask = mask2
#         maskflow += [mask]
#     return maskflow

#################################################################################### 
def make_flow(nvp_depth, mlp_width, mlp_depth, n, dim):

    key = jax.random.PRNGKey(42)
    maskflow = get_maskflow(key, nvp_depth, n*dim)
    print(jnp.array(maskflow).astype(jnp.int32))
    
    def forward_fn(z):
        model = RealNVP(maskflow, nvp_depth, mlp_width, mlp_depth, n*dim)
        return model(z)
    
    flow = hk.transform(forward_fn)
    return flow

In [None]:
flow = make_flow(nvp_depth, mlp_width, mlp_depth, n, dim)
params = flow.init(key, jnp.zeros((n, dim)))

from jax.flatten_util import ravel_pytree
raveled_params, _ = ravel_pytree(params)

print("parameters in the flow model: %d" % raveled_params.size, flush=True)

In [None]:
def make_logp(flow):
    
    def logp(z, params):
        n, dim = z.shape
        x, logjacdet = flow.apply(params, None, z)
        logqz = jnp.sum(jax.scipy.stats.norm.logpdf(z))
        
        # z_flatten = z.reshape(-1)
        # flow_flatten = lambda z: flow.apply(params, None, z.reshape(n, dim))[0].reshape(-1)
        # jac = jax.jacfwd(flow_flatten)(z_flatten)
        # _, logjacdet = jnp.linalg.slogdet(jac)
        
        logpx = logqz - logjacdet
        return logpx, logqz, logjacdet, x
    
    return logp

logp_novmap = make_logp(flow)
logp = jax.vmap(logp_novmap, (0, None), (0, 0, 0, 0))

In [None]:
z = jax.random.normal(key, (batch, n, dim))
print("z.shape:", z.shape)

logpx, logqz, logjacdet, x = logp(z, params)
print("batch logp(x):", logpx)
print("batch logq(z):", logqz)
print("batch logjacdet(dz/dx):", logjacdet)
print(jnp.mean(-logpx))
print(jnp.mean(-logqz))

In [None]:

plot_x = x.reshape(-1, 2)
fig = plt.figure(figsize=(3, 3), dpi = 300)
plt.scatter(plot_x[:, 0], plot_x[:, 1], alpha=0.1, s=1.0)


In [None]:
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)

In [None]:
def make_loss(logp, batch, n, dim, beta):
    
    def observable_and_lossfn(params, z):
        logpx, logqz, logjacdet, x = logp(z, params)
        
        Eloc = batch_energy(x, n, dim)
        Floc = logpx / beta + Eloc
        
        F_mean, F_std = Floc.mean(), jnp.std(Floc)/jnp.sqrt(batch)
        E_mean, E_std = Eloc.mean(), jnp.std(Eloc)/jnp.sqrt(batch)
        S_mean, S_std = -logpx.mean(), jnp.std(-logpx)/jnp.sqrt(batch)
        observable = (F_mean, F_std, E_mean, E_std, S_mean, S_std, x)

        gradf_theta = jnp.mean( -logjacdet / beta + Eloc )   
        #gradf_theta = jnp.mean( logpx / beta + Eloc )
        return gradf_theta, observable
    
    return observable_and_lossfn

observable_and_lossfn = make_loss(logp, batch, n, dim, beta)
value_and_gradfn = jax.value_and_grad(observable_and_lossfn, argnums=0, has_aux=True)

In [None]:
@jax.jit
def update(params, opt_state, z):
    datas, grad_params = value_and_gradfn(params, z)
    updates, opt_state = optimizer.update(grad_params, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, datas

In [None]:
loss_history = []
for i in range(epochs):
    key, subkey = jax.random.split(key)
    
    z = jax.random.normal(key, (batch, n, dim)) 
    params, opt_state, datas = update(params, opt_state, z)
    
    F_mean, F_std, E_mean, E_std, S_mean, S_std, x = datas[1]
    loss_history.append(jnp.array([F_mean, F_std, E_mean, E_std, S_mean, S_std]))
    print("epoch: %04d    F: %.6f (%.6f)    E: %.6f (%.6f)    S: %.6f (%.6f)"
          %(i, F_mean, F_std, E_mean, E_std, S_mean, S_std))

    plot_x = jnp.reshape(x, (batch*n, dim)) 
    plot_z = jnp.reshape(z, (batch*n, dim))
    
    display.clear_output(wait=True)


    fig = plt.figure(figsize=(18, 6), dpi=300)
    plt.title("epoch: %04d    F: %.6f (%.6f)    E: %.6f (%.6f)    S: %.6f (%.6f)"
          %(i, F_mean, F_std, E_mean, E_std, S_mean, S_std), fontsize=16)
    plt.axis('off')
    #====== plot x ======
    plt.subplot(1, 3, 1)
    H, xedges, yedges = np.histogram2d(plot_x[:, 0], plot_x[:, 1], bins=100, 
                                       range=((-4, 4), (-4, 4)), density=True)
    plt.imshow(H, interpolation="nearest", 
               extent=(xedges[0], xedges[-1], yedges[0], yedges[-1]), cmap="inferno")
    plt.xlim([-4, 4])
    plt.ylim([-4, 4])

    #====== plot z ======
    plt.subplot(1, 3, 2)
    H, xedges, yedges = np.histogram2d(plot_z[:, 0], plot_z[:, 1], bins=100, 
                                       range=((-4, 4), (-4, 4)), density=True)
    plt.imshow(H, interpolation="nearest", 
               extent=(xedges[0], xedges[-1], yedges[0], yedges[-1]), cmap="inferno")
    plt.xlim([-4, 4])
    plt.ylim([-4, 4])

    #====== plot loss ======
    plt.subplot(1, 3, 3)
    y = np.reshape(np.array(loss_history), (-1, 6))
    plt.errorbar(np.arange(i+1), y[:, 0], yerr=y[:, 1], marker='o', capsize=8)
    plt.xlabel('epochs')
    plt.ylabel('variational free energy')
    plt.pause(0.0001)
    
print(loss_history[-1])