# Hartree-Fock VQE Demo

In [None]:
import tqdm

import numpy as np
import scipy as sp
import torch

from QuICT.algorithm.quantum_machine_learning.model.chemistry.utils.hamiltonian.moleculardata import MolecularData
from QuICT.algorithm.quantum_machine_learning.model.chemistry.utils.hamiltonian import obi_basis_rotation, tbi_basis_rotation, generate_hamiltonian
from QuICT.algorithm.quantum_machine_learning.model.chemistry.utils.operators.encoder import JordanWigner
from QuICT.algorithm.quantum_machine_learning.model.chemistry import HartreeFockVQE, HartreeFockVQENet
from QuICT.algorithm.quantum_machine_learning.utils import Hamiltonian

## Load the data

In [None]:
moldir = "./utils/hamiltonian/molecular_data/hydrogen_chains/h_6_sto-3g/bond_distance_1.3"
molfile = moldir + "/H6_sto-3g_singlet_linear_r-1.3.hdf5"
moldata = MolecularData(molfile)

overlap = np.load(moldir + "/overlap.npy")
Hcore = np.load(moldir + "/h_core.npy")
two_electron_integral = np.einsum("psqr", np.load(moldir + "/tei.npy"))  # (1, 1, 0, 0)

_, X = sp.linalg.eigh(Hcore, overlap)
obi = obi_basis_rotation(Hcore, X)
tbi = tbi_basis_rotation(two_electron_integral, X)
molecular_hamiltonian = generate_hamiltonian(moldata.nuclear_repulsion, obi, tbi)

## Convert the Hamiltonian

In [None]:
fermi_op = molecular_hamiltonian.get_fermion_operator()
orbitals = 2 * moldata.n_orbitals
electrons = moldata.n_electrons
qubit_op = JordanWigner(orbitals).encode(fermi_op)
hamiltonian = Hamiltonian(qubit_op.to_hamiltonian())

## Calculate the ground energy

In [None]:
MAX_ITERS = 10000
LR = 0.1

hfvqe_net = HartreeFockVQENet(orbitals, electrons, hamiltonian)
optim = torch.optim.Adam([dict(params=hfvqe_net.parameters(), lr=LR)])
energy = []

hfvqe_net.train()
loader = tqdm.trange(MAX_ITERS, desc="Training", leave=True)
for it in loader:
    optim.zero_grad()
    state = hfvqe_net()
    loss = hfvqe_net.loss_func(state)
    energy.append(float(loss))
    loss.backward()
    optim.step()
    loader.set_postfix(loss=loss.item())

print(energy)
print(hfvqe_net.params)
print(float(loss))