In [None]:
# pylint: disable=redefined-outer-name, wrong-import-position, invalid-name
"""Run the VQE for Z2 LGT."""
import sys
import os
sys.path.append('z2vqe/src')
import time
from dataclasses import dataclass
from numbers import Number
from argparse import ArgumentParser
import numpy as np
import h5py
from scipy.optimize import minimize
import jax
import jax.numpy as jnp
import jaxopt
from qiskit.circuit.parametervector import ParameterVectorElement
from z2_lgt import calculate_num_params, z2_ansatz_layer, create_hamiltonian, initial_state
from z2_vqe import make_cost_fn, vqe_jaxopt

In [2]:
@dataclass
class Options:
    sites: int
    layers: int
    maxiter: int
    instances: int
    gpus: list[str]
    out: str

options = Options(
    sites=2,
    layers=1,
    maxiter=2000,
    instances=4,
    gpus=[f'{i}' for i in range(9, 10)],
    out='/data/iiyama/vqe_2000iter.h5'
)

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(options.gpus)
num_devices = len(jax.devices())
if options.instances % num_devices != 0:
    raise ValueError(f'Number of instances {options.instances} must be a multiple of the number'
                        f' of GPUs {num_devices}')

In [6]:
j_hopping = 1.  # coupling constant J
f_gauge = 0.5  # coupling constant f
mass = 2.  # value of mass

trans_inv = True
overall_coeff = 10
boundary_cond = 'closed'
reference_values = {}

init_state, _ = initial_state(options.sites, boundary_cond)
ansatz_layer = z2_ansatz_layer(options.sites, boundary_cond)
hamiltonian = create_hamiltonian(options.sites, j_hopping, f_gauge, mass, overall_coeff,
                                    boundary_cond, overall_coeff_cond=False)

cost_fn = jax.jit(make_cost_fn(init_state, ansatz_layer, options.layers, hamiltonian))

num_parameters = calculate_num_params(options.sites, options.layers, trans_inv)
instances_per_device = max(1, options.instances // num_devices)
x0 = 2 * np.pi * np.random.random((num_devices, instances_per_device, num_parameters))

In [5]:
energies = vqe_jaxopt(cost_fn, x0, options.maxiter)
energies = energies.reshape(-1, num_devices * instances_per_device)

Compiling the cost function..
Compilation of the cost function took 3.6475534439086914 seconds
Iteration: 1, elapsed time: 0.8401436805725098 seconds
Iteration: 11, elapsed time: 0.8960425853729248 seconds
Iteration: 21, elapsed time: 0.9411425590515137 seconds
Iteration: 31, elapsed time: 0.9873313903808594 seconds
Iteration: 41, elapsed time: 1.0296621322631836 seconds
Iteration: 51, elapsed time: 1.0752172470092773 seconds
Iteration: 61, elapsed time: 1.1162433624267578 seconds
Iteration: 71, elapsed time: 1.159902811050415 seconds
Iteration: 81, elapsed time: 1.2020177841186523 seconds
Iteration: 91, elapsed time: 1.2470216751098633 seconds
Iteration: 101, elapsed time: 1.289790153503418 seconds
Iteration: 111, elapsed time: 1.3320093154907227 seconds
Iteration: 121, elapsed time: 1.3738667964935303 seconds
Iteration: 131, elapsed time: 1.4161343574523926 seconds
Iteration: 141, elapsed time: 1.460127830505371 seconds
Iteration: 151, elapsed time: 1.5043866634368896 seconds
Iterati

In [None]:
if not options.out:
    options.out = (f'vqe_{options.sites}sites_{options.layers}layers_'
                    f'{options.maxiter}iter_jaxopt.h5')

with h5py.File(options.out, 'a') as out:
    group = out.create_group(f'vqe_{len(out.keys())}')
    group.create_dataset('num_sites', data=options.sites)
    group.create_dataset('num_layers', data=options.layers)
    group.create_dataset('j_hopping', data=j_hopping)
    group.create_dataset('f_gauge', data=f_gauge)
    group.create_dataset('mass', data=mass)
    group.create_dataset('maxiter', data=options.maxiter)
    group.create_dataset('energies', data=energies)
