In [None]:
import sys
import os
sys.path.append('z2vqe/src')
from dataclasses import dataclass
import numpy as np
import h5py
import jax
jax.config.update('jax_enable_x64', True)
from z2_lgt import calculate_num_params, z2_ansatz_layer, initial_state
from z2_vqe import make_qfim_fn

In [None]:
@dataclass
class Options:
    sites: int | list[int]
    layers: int | list[int]
    points: int
    gpus: list[str | int]
    out: str

options = Options(
    sites=list(range(2, 10, 2)),
    layers=list(range(1, 27)),
    points=100,
    gpus=[1],
    out='z2_qfim.h5'
)

if isinstance(options.sites, int):
    options.sites = [options.sites]
if isinstance(options.layers, int):
    options.layers = [options.layers]

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, options.gpus))
num_devices = len(jax.devices())
if options.points % num_devices != 0:
    raise ValueError(f'Number of instances {options.instances} must be a multiple of the number'
                        f' of GPUs {num_devices}')

In [None]:
trans_inv = True
boundary_cond = 'closed'
reference_values = {}

points_per_device = max(1, options.points // num_devices)

if not options.out:
    options.out = 'qfim.h5'

out = h5py.File(options.out, 'w')

for num_sites in options.sites:
    init_state, _ = initial_state(num_sites, boundary_cond)
    ansatz_layer = z2_ansatz_layer(num_sites, boundary_cond)
    mean_ranks = []
    max_ranks = []
    for num_layers in options.layers:
        print(num_sites, 'sites', num_layers, 'layers')
        qfim_fn = jax.pmap(jax.vmap(make_qfim_fn(init_state, ansatz_layer, num_layers)))
        num_parameters = calculate_num_params(num_sites, num_layers, trans_inv)
        params = 2 * np.pi * np.random.random((num_devices, points_per_device, num_parameters))
        matrices = qfim_fn(params).reshape(-1, num_parameters, num_parameters)
        ranks = np.linalg.matrix_rank(matrices, tol=1.e-12, hermitian=True)
        group = out.create_group(f'qfim_{num_sites}sites_{num_layers}layers')
        group.create_dataset('params', data=params.reshape(-1, num_parameters))
        group.create_dataset('qfim', data=matrices)
        group.create_dataset('rank', data=ranks)

        mean_ranks.append(np.mean(ranks))
        max_ranks.append(np.amax(ranks))
        print('  rank mean', mean_ranks[-1])
        print('  rank max', max_ranks[-1])

        if (len(mean_ranks) > 4 and np.allclose(mean_ranks[-1], mean_ranks[-3:-1])
                and np.allclose(max_ranks[-1], max_ranks[-3:-1])):
            break

out.close()