In [None]:
import os
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from scipy.sparse.linalg import eigsh
import h5py
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding

from qiskit_aer import AerSimulator
from qiskit_addon_sqd.qubit import project_operator_to_subspace
from heavyhex_qft.triangular_z2 import TriangularZ2Lattice
from skqd_z2lgt.circuits import make_plaquette_circuits, compose_trotter_circuits
from skqd_z2lgt.sqd import keys_to_intset, qiskit_sqd
from skqd_z2lgt.ising_hamiltonian import ising_hamiltonian, make_matvec, parse_hamiltonian
from skqd_z2lgt.jax_experimental_sparse_linalg import lobpcg_standard
from skqd_z2lgt.ising_dmrg import ising_dmrg

os.environ['CUDA_VISIBLE_DEVICES'] = '1,3,4,5,6,2,7'
#os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
#os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,3,4'
jax.config.update('jax_enable_x64', True)
data_dir = '/data/iiyama/2dz2'

julia_bin = ['julia', '--sysimage', '/opt/julia/iiyama/sysimages/sys_itensors.so']

In [None]:
configs = []
configs.append('''
* * *
 * * *
* * *
''')
configs.append('''
* * * * *
 * * * * *
* * * * *
''')
configs.append('''
* * * * *
 * * * * *
* * * * *
 * * * * *
''')
configs.append('''
* * * * *
 * * * * *
* * * * *
 * * * * *
* * * * *
''')

In [None]:
max_steps = 12
shots = 100_000
num_experiments = 5
couplings = np.array([0.5, 1., 1.5])
dtvals = np.linspace(0.1, 0.7, 7)
#couplings = np.array([0.5])
#dtvals = np.array([0.1])

### Theoretical optimal $\Delta t$ for $\lambda = 1$

In [None]:
energy_gaps = []
for config in configs:
    dual_lattice = TriangularZ2Lattice(config).plaquette_dual()
    hamiltonian = dual_lattice.make_hamiltonian(1.)
    emax = -ising_dmrg(-hamiltonian, julia_bin=julia_bin)
    emin = ising_dmrg(hamiltonian, julia_bin=julia_bin)
    energy_gaps.append(emax - emin)


In [None]:
plt.plot(np.pi / np.array(energy_gaps))
plt.plot(0.8 / np.array([8, 16, 24, 32]))

In [None]:
def extract_samples(exp_samples):
    subspace = set([])
    bitstrings = []
    for step_samples in exp_samples:
        subspace |= set(step_samples.tolist())
        bitstrings.append(
            (
                (np.array(sorted(subspace))[:, None] >> np.arange(npl)[None, ::-1]) % 2
            ).astype(np.uint8)
        )
    return bitstrings


In [None]:
executor = ThreadPoolExecutor()
rng = np.random.default_rng()
simulator = AerSimulator(method='statevector', device='GPU', max_parallel_experiments=7,
                         num_threads_per_device=1)

for config in configs:
    lattice = TriangularZ2Lattice(config)
    npl = lattice.num_plaquettes
    print(npl, 'plaquettes')
    dual_lattice = lattice.plaquette_dual()

    all_circuits = []
    for icp, coupling in enumerate(couplings):
        print('cp', coupling)
        hamiltonian = dual_lattice.make_hamiltonian(coupling)

        for idt, delta_t in enumerate(dtvals):
            print('dt', delta_t)
            step_circuits = make_plaquette_circuits(dual_lattice, coupling, delta_t)
            trotter_step = step_circuits[0]
            measure = step_circuits[-1]
            circuits = compose_trotter_circuits(trotter_step, measure, max_steps)
            all_circuits += circuits
    print(len(all_circuits))
    start = time.time()
    result = simulator.run(all_circuits, shots=shots * num_experiments).result()
    print('sim completed', time.time() - start)

    samples = np.empty((len(couplings) * len(dtvals) * max_steps, num_experiments * shots),
                       dtype=np.uint32)
    for ires, res in enumerate(result.results):
        pos = 0
        for key, value in res.data.counts.items():
            samples[ires, pos:pos + value] = int(key, 16)
            pos += value

    rng.shuffle(samples, axis=1)

    samples = samples.reshape((len(couplings), len(dtvals), max_steps, num_experiments, shots))
    samples = samples.transpose((0, 1, 3, 2, 4))

    bitstrings = {}
    futures = {}
    iresult = 0
    for icp, coupling in enumerate(couplings):
        print('cp', coupling)
        for idt, delta_t in enumerate(dtvals):
            for iexp in range(num_experiments):
                futures[(icp, idt, iexp)] = executor.submit(extract_samples, samples[icp, idt, iexp])
                iresult += max_steps

    for (icp, idt, iexp), future in futures.items():
        for istep, subspace in enumerate(future.result()):
            bitstrings[(icp, idt, istep, iexp)] = subspace
    print('samples extracted')

    filename = f'{data_dir}/plaqsim_data/'
    filename += f'{npl}plaqs_{max_steps}steps_{shots}shots_samples.h5'

    with h5py.File(filename, 'w') as out:
        out.create_dataset('lattice_config', data=config)
        out.create_dataset('max_steps', data=max_steps)
        out.create_dataset('shots', data=shots)
        out.create_dataset('num_experiments', data=num_experiments)
        out.create_dataset('couplings', data=couplings)
        out.create_dataset('dtvals', data=dtvals)
        for icp in range(len(couplings)):
            cpgroup = out.create_group(f'c{icp}')
            for idt in range(len(dtvals)):
                dtgroup = cpgroup.create_group(f'dt{idt}')
                for istep in range(max_steps):
                    stgroup = dtgroup.create_group(f'step{istep + 1}')
                    for iexp in range(num_experiments):
                        stgroup.create_dataset(f'exp{iexp}', data=bitstrings[(icp, idt, istep, iexp)])

### Attempts at Krylov diagonalization

In [None]:
nq, zzops, zops, xops = parse_hamiltonian(hamiltonian)
matvec = make_matvec(nq, zzops, zops, xops, jnp)
mesh = jax.make_mesh((jax.device_count(), 1), ('dev', 'dum'))
xmat = jax.device_put(np.full((2 ** nq, 1), np.power(2., -nq / 2), dtype=np.complex128),
                      NamedSharding(mesh, PartitionSpec('dev', 'dum')))
print('compiling')
matvec(xmat)
print('done')

In [None]:
isingh = ising_hamiltonian(hamiltonian, npmod=jnp)
print('compiling')
isingh.matvec(np.zeros(2 ** hamiltonian.num_qubits, dtype=np.complex128))
print('eigsh')
evals, evecs = eigsh(isingh, k=1, which='SA')

In [None]:
nq, zzops, zops, xops = parse_hamiltonian(hamiltonian)
matvec = make_matvec(nq, zzops, zops, xops, jnp)
mesh = jax.make_mesh((jax.device_count(), 1), ('dev', 'dum'))
xmat = jax.device_put(np.full((2 ** nq, 1), np.power(2., -nq / 2), dtype=np.complex128),
                      NamedSharding(mesh, PartitionSpec('dev', 'dum')))
evalsj, evecsj, _ = lobpcg_standard(matvec, xmat)