# Benchmark engines

## Overview

In this notebook, we will test the run time of all engines in TenCirChem, along with their interplay with different backends.

The hydrogen chain system is used as the benchmark platform. The benchmarked system size is from 2 atoms to 6 atoms.

## Setup

In [1]:
import time

import numpy as np
import pandas as pd

from tencirchem import UCCSD, set_backend
from tencirchem.molecule import h_chain

In [2]:
n_h_list = list(range(2, 8, 2))
uccsd_list = [UCCSD(h_chain(n_h)) for n_h in n_h_list]
params_list = [np.random.rand(uccsd.n_params) for uccsd in uccsd_list]

In [3]:
# tensornetwork and statevector engine are only compatible with JAX backend
jax_engines = ["tensornetwork", "statevector", "civector", "civector-large", "pyscf"]
numpy_engines = ["civector", "civector-large", "pyscf"]
cupy_engines = numpy_engines
tested_engines_list = [jax_engines, numpy_engines, cupy_engines]

## Benchmark

In [5]:
time_data = []
for backend, tested_engines in zip(["jax", "numpy", "cupy"], tested_engines_list):
    set_backend(backend)
    for uccsd, params in zip(uccsd_list, params_list):
        for engine in tested_engines:
            # dry run first. Let it compile or build caches
            time1 = time.time()
            uccsd.energy_and_grad(params, engine=engine)
            time2 = time.time()
            staging_time = time2 - time1
            # several real runs. Assuming `n_run` evaluations during the optimization
            n_run = 20
            for i in range(n_run):
                uccsd.energy_and_grad(params, engine=engine)
            run_time = (time.time() - time2) / n_run
            item = (backend, uccsd.n_qubits, engine, staging_time, run_time, staging_time + n_run * run_time)
            print(item)
            time_data.append(item)

('jax', 4, 'tensornetwork', 0.6734888553619385, 0.0013010025024414063, 0.6995089054107666)
('jax', 4, 'statevector', 0.5532455444335938, 0.0008256316184997558, 0.5697581768035889)
('jax', 4, 'civector', 0.6092183589935303, 0.0061431884765625, 0.7320821285247803)
('jax', 4, 'civector-large', 0.6665019989013672, 0.0064354419708251955, 0.7952108383178711)
('jax', 4, 'pyscf', 0.0294039249420166, 0.016276955604553223, 0.35494303703308105)
('jax', 8, 'tensornetwork', 4.530415773391724, 0.003479158878326416, 4.599998950958252)
('jax', 8, 'statevector', 2.3404414653778076, 0.001799321174621582, 2.3764278888702393)
('jax', 8, 'civector', 0.6547484397888184, 0.007375049591064453, 0.8022494316101074)
('jax', 8, 'civector-large', 1.63521146774292, 0.006232154369354248, 1.7598545551300049)
('jax', 8, 'pyscf', 0.1072854995727539, 0.09403668642044068, 1.9880192279815674)
('jax', 12, 'tensornetwork', 24.58689785003662, 0.05817370414733887, 25.7503719329834)
('jax', 12, 'statevector', 7.56268572807312,

## Results and Discussion

In [6]:
df = pd.DataFrame(
    time_data, columns=["backend", "qubits", "engine", "staging time", "run time", "total time"]
).set_index(["backend", "qubits", "engine"])
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,staging time,run time,total time
backend,qubits,engine,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
jax,4,tensornetwork,0.673489,0.001301,0.699509
jax,4,statevector,0.553246,0.000826,0.569758
jax,4,civector,0.609218,0.006143,0.732082
jax,4,civector-large,0.666502,0.006435,0.795211
jax,4,pyscf,0.029404,0.016277,0.354943
jax,8,tensornetwork,4.530416,0.003479,4.599999
jax,8,statevector,2.340441,0.001799,2.376428
jax,8,civector,0.654748,0.007375,0.802249
jax,8,civector-large,1.635211,0.006232,1.759855
jax,8,pyscf,0.107285,0.094037,1.988019


The table contains rich information, but conclusion is not easily drawn. 

Next, we find out the best option for each system size.

In [7]:
interesting_indices = []
for i, ddf in df.groupby("qubits"):
    run_time_idx = ddf["run time"].idxmin()
    total_time_idx = ddf["total time"].idxmin()
    print(run_time_idx, total_time_idx)
    interesting_indices.extend([run_time_idx, total_time_idx])

('jax', 4, 'statevector') ('numpy', 4, 'civector')
('jax', 8, 'statevector') ('numpy', 8, 'civector')
('jax', 12, 'statevector') ('numpy', 12, 'civector')


In [8]:
df.loc[interesting_indices]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,staging time,run time,total time
backend,qubits,engine,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
jax,4,statevector,0.553246,0.000826,0.569758
numpy,4,civector,0.003418,0.003229,0.068
jax,8,statevector,2.340441,0.001799,2.376428
numpy,8,civector,0.005615,0.003976,0.085136
jax,12,statevector,7.562686,0.005127,7.665227
numpy,12,civector,0.017621,0.008579,0.189198


For every system size tested, JAX + statevector is the fastest in terms of run time.

However, if the staging time is taken into account, then  NumPy + civector is most efficient.

We note that the conclusion here is only valid for system size <= 16 qubits. 
For larger system CuPy + civector-large is the most scalable choice.