# Notebook used to generate Figure 3.3

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps
import jax
import jax.numpy as jnp
import jax.random as random
import scipy.sparse
from jax import grad, hessian, jit, vmap
from GEK import GEKRunner

jax.config.update("jax_enable_x64", True)

plt.style.use('thesis.mplstyle')

seed=43
np.random.seed(seed)
d = 12
h = 1.5
alpha = 5
iterations = 50
sigma_f = 1
sigma_g = 1
length_scale = 1.2
sigma = 1
var_threshold = 1

key = random.PRNGKey(seed)
key, key1, key2 = random.split(key, num=3)
weights_save = 0.001 * random.normal(key1, shape=(alpha * (d + 1),)) + \
               0.001j * random.normal(key2, shape=(alpha * (d + 1),))

configs = jnp.arange(2 ** d)[:, None] >> jnp.arange(d)[::-1] & 1
configs = configs.astype(jnp.bool_)
diffs = configs ^ configs[:, (jnp.arange(d) + 1) % d]
spin_z = 2 * jnp.sum(diffs, axis=1) - d

spin_x = jnp.zeros(configs.shape, dtype=jnp.int_)
for i in range(d):
    new = configs.at[:, i].set(~configs[:, i])
    new = jnp.dot(new, 2 ** jnp.arange(d)[::-1])
    spin_x = spin_x.at[:, i].set(new)
i_vals = jnp.repeat(jnp.arange(2 ** d), d)
j_vals = jnp.ravel(spin_x)
data = jnp.append(spin_z, jnp.repeat(-h, i_vals.size))
i_vals = jnp.append(jnp.arange(2 ** d), i_vals)
j_vals = jnp.append(jnp.arange(2 ** d), j_vals)
matrix = scipy.sparse.coo_matrix((data, (i_vals, j_vals))).toarray()

subset1 = (2 * jnp.sum(configs, axis=1) + configs[:, 0] <= d)
subset1 = jnp.arange(2 ** d)[subset1]
subset2 = jnp.dot(configs[subset1, :], 2 ** jnp.arange(d)[::-1])
subset2 = 2 ** d - 1 - subset2
matrix = matrix[subset1[:, None], subset1] + matrix[subset1[:, None], subset2]
configs = configs[subset1, :]

q = np.arange(1 / d, 1, 2 / d)
soln = -np.mean(np.sqrt(1 + h ** 2 + 2 * h * np.cos(np.pi * q)))

# =====================
# Ansatz and gradient
# =====================
@jit
def ansatz(state, features2, bias):
    state2 = jnp.fft.fft(state)
    angles = jnp.fft.ifft(features2 * jnp.conj(state2)) + bias
    return jnp.sum(jnp.log(jnp.cosh(angles)))

ansatz1 = vmap(ansatz, (0, None, None), 0)
jansatz = jit(ansatz1)

@jit
def gradient(state, features2, bias):
    state2 = jnp.fft.fft(state)
    angles = jnp.fft.ifft(features2 * jnp.conj(state2)) + bias
    y = jnp.tanh(angles)
    grad_bias = jnp.sum(y, axis=-1)
    y2 = jnp.fft.fft(y)
    grad_features = jnp.fft.ifft(y2 * state2)
    return grad_features, grad_bias

gradient1 = vmap(gradient, (0, None, None), (0, 0))
jgradient = jit(gradient1)

@jit
def rayleigh(weights):
    bias = jnp.reshape(weights[-alpha:], (alpha, 1))
    features = jnp.reshape(weights[:-alpha], (alpha, d))
    features2 = jnp.fft.fft(features)
    y = jnp.exp(jansatz(configs, features2, bias))
    return jnp.real(jnp.vdot(y, jnp.dot(matrix, y)) / jnp.vdot(y, y))

@jit
def objective(weights):
    weights = weights[:alpha * (d + 1)] + 1j * weights[alpha * (d + 1):]
    bias = jnp.reshape(weights[-alpha:], (alpha, 1))
    features = jnp.reshape(weights[:-alpha], (alpha, d))
    features2 = jnp.fft.fft(features)
    configs2 = jnp.fft.fft(configs, axis=1)
    angles = features2[jnp.newaxis, :, :] * jnp.conj(configs2)[:, jnp.newaxis, :]
    angles = jnp.fft.ifft(angles, axis=-1) + bias[jnp.newaxis, :, :]
    y = jnp.exp(jnp.sum(jnp.log(jnp.cosh(angles)), axis=(1, 2)))
    return jnp.real(jnp.vdot(y, jnp.dot(matrix, y)) / jnp.vdot(y, y))

hess = jit(hessian(objective))
grad_energy = jit(grad(objective))

w0_real = np.concatenate([np.real(weights_save), np.imag(weights_save)])

def energy_and_grad(w):
    return objective(w), grad_energy(w)

def noisy_energy_and_grad(w):
    return objective(w) + np.random.normal(scale=sigma_f), grad_energy(w) + np.random.normal(scale=sigma_g, size=w.shape)

In [None]:
# ==============
# Run optimizer
# ==============
NLC_paths = []
GD_paths = []

np.random.seed(seed)
for _ in range(10):
    runner = GEKRunner(length_scale=length_scale, sigma=sigma, sigma_f=sigma_f, sigma_g=sigma_g)
    x_opt, path, surrogate, *rest = runner.GEK_optimize(
        noisy_energy_and_grad, w0_real,
        var_threshold=var_threshold,
        outer_tol=1e-7,
        inner_tol=1e-2,
        alpha=0.1,
        max_iter=50,
        internal_max_iter=100,
        method="NLC",
        return_path=True,
        return_surrogate=True
    )
    NLC_paths.append(path)

np.random.seed(seed)
for _ in range(10):
    iterations = 50
    gd_path = [w0_real]
    for i in range(iterations):
        _, grad = noisy_energy_and_grad(gd_path[i])
        gd_path.append(gd_path[i] - grad * 0.03)
    GD_paths.append(gd_path)

energy_paths_NLC = [[objective(point) for point in path] for path in NLC_paths]
energy_paths_GD = [[objective(point) for point in path] for path in GD_paths]

In [None]:
# rgn
rotate = jnp.eye(alpha * (d + 1))
rotate1 = jnp.column_stack((rotate, 1j * rotate))
rotate2 = jnp.column_stack((rotate, -1j * rotate))
rotate = jnp.vstack((rotate1, rotate2)) / 2

# starting condition
fixed_epsilon = 0.0055
fixed_reg     = 0.25
rgn_paths = []
np.random.seed(seed)
for run in range(10):
    epsilons = np.full(iterations, fixed_epsilon)
    regs     = np.full(iterations, fixed_reg)
    regs_reset = np.copy(regs)
    weights = jnp.array(weights_save)
    
    weight_log = np.zeros((iterations, weights.size)) + 0j
    rgn_exact = np.zeros(iterations)
    hessian_error = np.zeros(iterations)
    
    energy_log = np.zeros(iterations)
    
    guidance = np.inf
    for iteration in range(iterations):
        print(iteration)
    
        # process weights
        bias = jnp.reshape(weights[-alpha:], (alpha, 1))
        features = jnp.reshape(weights[:-alpha], (alpha, d))
        features2 = jnp.fft.fft(features)
    
        # normalized wavefunction and derivatives
        vals = jnp.exp(jansatz(configs, features2, bias))
        vals = vals / jnp.sum(jnp.abs(vals) ** 2) ** .5
        g1, g2 = jgradient(configs, features2, bias)
        grads = jnp.column_stack((jnp.reshape(g1, (-1, alpha * d)), g2))
        grads = grads - jnp.dot(jnp.abs(vals) ** 2, grads)[jnp.newaxis, :]
        grads = grads * vals[:, jnp.newaxis]
        forces = jnp.dot(grads.conj().T, jnp.dot(matrix, vals))
    
        noise_real = np.random.normal(0, 1, size=forces.shape)
        noise_imag = np.random.normal(0, 1, size=forces.shape)
        forces = forces + noise_real + 1j * noise_imag
        
        cov = jnp.matmul(grads.conj().T, grads)
        linear = jnp.matmul(grads.conj().T, jnp.matmul(matrix, grads))
        linear = linear - jnp.vdot(vals, jnp.dot(matrix, vals)) * cov
        regular = linear + (cov + regs[iteration] * jnp.eye(weights.size)) \
                  / epsilons[iteration]
    
        reset = False
        move = -jnp.linalg.solve(regular, forces)
        while np.sum(np.abs(move) ** 2) ** .5 > 2 * guidance:
            reset = True
            epsilons[iteration] = epsilons[iteration] / 2
            regular = linear \
                      + (cov + regs[iteration] * jnp.eye(weights.size)) \
                      / epsilons[iteration]
            move = -jnp.linalg.solve(regular, forces)
        weights = weights + move
        if reset:
            epsilons[iteration:] = epsilons_reset[:-iteration]
            regs[iteration:] = regs_reset[:-iteration]
        if iteration > 0:
            guidance = np.sum(np.abs((weights
                                      - weight_log[iteration - 1])) ** 2) ** .5
        weight_log[iteration, :] = weights
    
        # error
        rgn_exact[iteration] = rayleigh(weights)
        energy_log[iteration] = rgn_exact[iteration]
        weights2 = jnp.concatenate((weights.real, weights.imag))
        newton = hess(weights2)
        newton = jnp.matmul(jnp.matmul(rotate, newton), rotate.conj().T)
        error = jnp.sum(jnp.abs(newton[:weights.size, weights.size:]) ** 2)
        hessian_error[iteration] = 2 * error / jnp.sum(jnp.abs(newton) ** 2)
    rgn_paths.append(rgn_exact)

In [None]:
fixed_epsilon = 0.011
fixed_reg     = 0.08

lm_paths = []
np.random.seed(seed)
for run in range(10):
    epsilons = np.full(iterations, fixed_epsilon)
    regs     = np.full(iterations, fixed_reg)
    regs_reset = np.copy(regs)
    weights = jnp.array(weights_save)
    
    weight_log = np.zeros((iterations, weights.size)) + 0j
    lm_exact = np.zeros(iterations)
    hessian_error = np.zeros(iterations)
    
    guidance = np.inf
    for iteration in range(iterations):
    
        bias = jnp.reshape(weights[-alpha:], (alpha, 1))
        features = jnp.reshape(weights[:-alpha], (alpha, d))
        features2 = jnp.fft.fft(features)
    
        vals = jnp.exp(jansatz(configs, features2, bias))
        vals = vals / jnp.sum(jnp.abs(vals) ** 2) ** .5
        g1, g2 = jgradient(configs, features2, bias)
        grads = jnp.column_stack((jnp.reshape(g1, (-1, alpha * d)), g2))
        grads = grads - jnp.dot(jnp.abs(vals) ** 2, grads)[jnp.newaxis, :]
        grads = grads * vals[:, jnp.newaxis]
        forces = jnp.dot(grads.conj().T, jnp.dot(matrix, vals))
    
        noise_real = np.random.normal(0, 1, size=forces.shape)
        noise_imag = np.random.normal(0, 1, size=forces.shape)
        forces = forces + noise_real + 1j * noise_imag
    
        cov = jnp.matmul(grads.conj().T, grads)
        linear = jnp.matmul(grads.conj().T, jnp.matmul(matrix, grads))
        linear = linear - jnp.vdot(vals, jnp.dot(matrix, vals)) * cov
        cov = jnp.block([[1, jnp.zeros((alpha * (d + 1),))],
                         [jnp.zeros((alpha * (d + 1), 1)),
                          cov + regs[iteration] * jnp.eye(weights.size)]])
    
        reset = False
        regular = jnp.block([[0, forces.conj()],
                             [jnp.expand_dims(forces, -1),
                              linear + jnp.eye(weights.size)
                              / epsilons[iteration]]])
        _, vecs = scipy.linalg.eigh(regular, cov)
        move = vecs[1:, 0]
        while np.sum(np.abs(move) ** 2) ** .5 > 2 * guidance:
            reset = True
            epsilons[iteration] = epsilons[iteration] / 2
            regular = jnp.block([[0, forces.conj()],
                                 [jnp.expand_dims(forces, -1),
                                  linear + jnp.eye(weights.size)
                                  / epsilons[iteration]]])
            _, vecs = scipy.linalg.eigh(regular, cov)
            move = vecs[1:, 0]
        weights = weights + move
        if reset:
            # np.save('lm_error_' + str(iteration) + '.npy', weights)
            epsilons[iteration:] = epsilons_reset[:-iteration]
            regs[iteration:] = regs_reset[:-iteration]
        if iteration > 0:
            guidance = np.sum(np.abs((weights
                                      - weight_log[iteration - 1])) ** 2) ** .5
        weight_log[iteration, :] = weights
    
        lm_exact[iteration] = rayleigh(weights)
    lm_paths.append(lm_exact)

In [None]:
fixed_epsilon = 0.0066
sr_reg = 0.2
sr_paths = []
np.random.seed(seed)
for run in range(10):
    epsilons = np.full(iterations, fixed_epsilon)
    epsilons_reset = np.copy(epsilons)
    weights = jnp.array(weights_save)
    
    weight_log = np.zeros((iterations, weights.size)) + 0j
    sr_exact = np.zeros(iterations)
    
    guidance = np.inf
    for iteration in range(iterations):
        #print(iteration)
    
        bias = jnp.reshape(weights[-alpha:], (alpha, 1))
        features = jnp.reshape(weights[:-alpha], (alpha, d))
        features2 = jnp.fft.fft(features)
    
        vals = jnp.exp(jansatz(configs, features2, bias))
        vals = vals / jnp.sum(jnp.abs(vals) ** 2) ** .5
        g1, g2 = jgradient(configs, features2, bias)
        grads = jnp.column_stack((jnp.reshape(g1, (-1, alpha * d)), g2))
        grads = grads - jnp.dot(jnp.abs(vals) ** 2, grads)[jnp.newaxis, :]
        grads = grads * vals[:, jnp.newaxis]
        forces = jnp.dot(grads.conj().T, jnp.dot(matrix, vals))
    
        noise_real = np.random.normal(0, 1, size=forces.shape)
        noise_imag = np.random.normal(0, 1, size=forces.shape)
        forces = forces + noise_real + 1j * noise_imag
    
        cov = jnp.matmul(grads.conj().T, grads)
        regular = (cov + sr_reg * jnp.eye(weights.size)) / epsilons[iteration]
    
        reset = False
        move = -jnp.linalg.solve(regular, forces)
        while np.sum(np.abs(move) ** 2) ** .5 > 2 * guidance:
            reset = True
            epsilons[iteration] = epsilons[iteration] / 2
            regular = (cov + sr_reg * jnp.eye(weights.size)) \
                      / epsilons[iteration]
            move = -jnp.linalg.solve(regular, forces)
        weights = weights + move
        if reset:
            # np.save('sr_error_' + str(iteration) + '.npy', weights)
            epsilons[iteration:] = epsilons_reset[:-iteration]
        if iteration > 0:
            guidance = np.sum(np.abs((weights
                                      - weight_log[iteration - 1])) ** 2) ** .5
        weight_log[iteration, :] = weights
    
        sr_exact[iteration] = rayleigh(weights)
    sr_paths.append(sr_exact)

In [None]:
def stack_paths(paths):
    min_len = min(len(p) for p in paths)
    return np.array([np.array(p[:min_len]) / d  for p in paths])

method_paths = {
    'RVO-GEK': energy_paths_NLC,
    'GD': energy_paths_GD,
    'SR': sr_paths,
    'LM': lm_paths,
    'RGN': rgn_paths,
}

colors = colormaps['tab10'].colors
fig, ax = plt.subplots(figsize=(4, 3))

for i, (label, paths) in enumerate(method_paths.items()):
    arr = stack_paths(paths)
    mu, sigma = arr.mean(0), arr.std(0)
    ax.plot(mu, label=f'{label}', color=colors[i])
    ax.fill_between(range(len(mu)), mu - sigma, mu + sigma, color=colors[i], alpha=0.2)

ax.axhline(soln, linestyle='--', color='red', label='Exact')
ax.set_ylim(-1.7, -1.45)
ax.set_xlabel("Iteration")
ax.set_ylabel("Energy")
ax.legend(loc="upper right")
fig.tight_layout()
plt.savefig("Comparison_d12h1.5_zoom.pdf", dpi=600)