In [1]:
import sys
import os
sys.path.append('z2vqe/src')
from dataclasses import dataclass
import numpy as np
import h5py
import jax
jax.config.update('jax_enable_x64', True)
from z2_lgt import calculate_num_params, z2_ansatz_layer, create_hamiltonian, initial_state
from z2_vqe import make_cost_fn, vqe_jaxopt

In [2]:
@dataclass
class Options:
    sites: int
    layers: int
    maxiter: int
    stepsize: float
    instances: int
    seed: int
    gpus: list[str | int]
    out: str

options = Options(
    sites=4,
    layers=1,
    maxiter=2000,
    stepsize=0.002,
    instances=512,
    seed=12345,
    gpus=[f'{i}' for i in [0, 1, 4, 5, 6, 7, 8, 9]],
    out='vqe_2000iter.h5'
)

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, options.gpus))
num_devices = len(jax.devices())
if options.instances % num_devices != 0:
    raise ValueError(f'Number of instances {options.instances} must be a multiple of the number'
                        f' of GPUs {num_devices}')

In [4]:
j_hopping = 1.  # coupling constant J
f_gauge = 0.5  # coupling constant f
mass = 2.  # value of mass

trans_inv = True
overall_coeff = 10
boundary_cond = 'closed'
reference_values = {}

init_state, _ = initial_state(options.sites, boundary_cond)
ansatz_layer = z2_ansatz_layer(options.sites, boundary_cond)
hamiltonian = create_hamiltonian(options.sites, j_hopping, f_gauge, mass, overall_coeff,
                                    boundary_cond, overall_coeff_cond=False)

cost_fn = jax.jit(make_cost_fn(init_state, ansatz_layer, options.layers, hamiltonian))

num_parameters = calculate_num_params(options.sites, options.layers, trans_inv)
instances_per_device = max(1, options.instances // num_devices)
x0 = 2 * np.pi * np.random.random((num_devices, instances_per_device, num_parameters))

In [None]:
energies, parameters = vqe_jaxopt(cost_fn, x0, options.maxiter, stepsize=options.stepsize)

In [None]:
if not options.out:
    options.out = (f'vqe_{options.sites}sites_{options.layers}layers_'
                    f'{options.maxiter}iter_jaxopt.h5')

with h5py.File(options.out, 'a') as out:
    group = out.create_group(f'vqe_{len(out.keys())}')
    group.create_dataset('num_sites', data=options.sites)
    group.create_dataset('num_layers', data=options.layers)
    group.create_dataset('j_hopping', data=options.j_hopping)
    group.create_dataset('f_gauge', data=options.f_gauge)
    group.create_dataset('mass', data=options.mass)
    group.create_dataset('maxiter', data=options.maxiter)
    group.create_dataset('stepsize', data=options.stepsize)
    group.create_dataset('x0', data=x0.reshape(-1, num_parameters))
    group.create_dataset('energies', data=energies)
    group.create_dataset('parameters', data=parameters)