In [None]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

import sys
sys.path.append("..")
import solax as sx

import os
os.environ['XLA_FLAGS']='--xla_gpu_deterministic_ops=true'

In [None]:
def build_bath(N_bath):
    ii = np.arange(N_bath) + 1
    xx = ii * np.pi / (N_bath + 1)
    e_bath = -2 * np.cos(xx)
    
    V0 = np.sqrt(20 / (N_bath + 1))
    V_bath = V0 * np.sqrt(1 - (e_bath / 2)**2)
    
    return e_bath, V_bath

In [None]:
def build_start_dets(N_bath):
    det1 = "01"  + "1" * (N_bath - 1) + "10" + "0" * (N_bath - 1)
    det2 = "10"  + "1" * (N_bath - 1) + "01" + "0" * (N_bath - 1)
    return det1, det2

In [None]:
U = 10
N_bath = 21
e_bath, V_bath = build_bath(N_bath)
start_dets = build_start_dets(N_bath)

basis_start = sx.Basis(build_start_dets(N_bath))

H_imp2 = sx.Operator(
    (1, 0, 1, 0),
    np.array([
        [0, 0, 1, 1]
    ]),
    np.array([U])
)

H_imp1 = sx.Operator(
    (1, 0),
    np.array([
        [0, 0],
        [1, 1]
    ]),
    np.array([-U / 2, -U / 2])
)

H_imp = H_imp2 + H_imp1 + U / 4

H_bath = sx.Operator(
    (1, 0),
    np.arange(2, 2 * N_bath + 2).repeat(2).reshape(-1, 2),
    e_bath.repeat(2)
)

H_hyb_posits = np.vstack([
    np.array([0, 1] * N_bath),
    np.arange(2, 2 * N_bath + 2)
]).T

H_hyb_nohc = sx.Operator(
    (1, 0),
    H_hyb_posits,
    V_bath.repeat(2)
)

H = H_imp + H_bath + H_hyb_nohc + H_hyb_nohc.hconj

In [None]:
num_iterations = 4

basis = basis_start

for i in range(num_iterations):
    matrix = H.build_matrix(basis)
    energy = sp.sparse.linalg.eigsh(
        matrix.to_scipy(), k=1, which="SA"
    )[0][0]
    
    basis_size = len(basis)
    print(
        f"Iteration: {i+1:<8d}"
        f"Basis size = {basis_size:<12d}"
        f"Energy = {energy}"
    )
    
    if i < num_iterations - 1:
        basis = H(basis)

## BasisClassifier

In [None]:
from flax import linen as nn
import optax

In [None]:
def nn_call_on_bits(x):
    x = x.reshape(-1, 2)
    x = nn.Conv(features=64, kernel_size=(2,), padding="valid")(x)
    x = nn.relu(x)
    x = nn.Conv(features=4, kernel_size=(1,), padding="valid")(x)
    x = nn.relu(x)
    x = x.reshape(-1)
    
    x = nn.Dense(features=dense_size)(x)
    x = nn.relu(x)
    x = nn.Dense(features=dense_size//2)(x)
    x = nn.relu(x)
    x = nn.Dense(features=dense_size//4)(x)
    x = nn.relu(x)
    x = nn.Dense(features=2)(x)
    return x

In [None]:
dense_size = int(7 * np.sqrt(2 * N_bath + 2))
print(dense_size)

In [None]:
classifier = sx.BasisClassifier(nn_call_on_bits)

In [None]:
rand_keys = sx.RandomKeys(seed=1234)
key_for_nn = next(rand_keys)

optimizer = optax.adam(learning_rate=0.005)

classifier.initialize(key_for_nn, basis_start, optimizer)
classifier.print_summary()

In [None]:
basis_small = basis
del basis

basis_big = H(basis_small)
print(len(basis_big))

In [None]:
candidates = basis_big % basis_small
print(len(candidates))

## BigBasisManager

In [None]:
bbm = sx.BigBasisManager(candidates, classifier)

In [None]:
target_num = int(np.sqrt(len(basis_big)) * 50)
print(target_num)

In [None]:
random_num = int(target_num / 1.5)
print(random_num)

In [None]:
random_sel = bbm.sample_subbasis(next(rand_keys), random_num)

In [None]:
print(isinstance(random_sel, sx.Basis))
print(len(random_sel))

In [None]:
basis_diag = basis_small + random_sel
print(len(basis_diag))

In [None]:
matrix = H.build_matrix(basis_diag)
result = sp.sparse.linalg.eigsh(matrix.to_scipy(), k=1, which="SA")

energy = result[0][0]
print(f"Intermediate energy:\t{energy}")

In [None]:
eigenvec = result[1][:, 0]
state_diag = sx.State(basis_diag, eigenvec)

In [None]:
state_train = state_diag % basis_small
len(state_train)

In [None]:
abs_coeff_cut = bbm.derive_abs_coeff_cut(target_num, state_train)
print(f"Cutoff:\t{abs_coeff_cut}")

In [None]:
state_train_impt = state_train.chop(abs_coeff_cut)
print(len(state_train_impt))

In [None]:
print(len(state_train_impt) / len(state_train))

In [None]:
print(target_num / len(candidates))

In [None]:
early_stopped = bbm.train_classifier(
    next(rand_keys),
    state_train,
    abs_coeff_cut,
    batch_size=256,
    epochs=200,
    early_stop=True,
    early_stop_params={"patience": 3}
)

print(early_stopped)

In [None]:
nn_selected = bbm.predict_impt_subbasis(batch_size=256)
nn_selected = nn_selected % state_train.basis
print(len(nn_selected))

In [None]:
basis_impt = nn_selected + state_train_impt.basis
print(len(basis_impt))
print(abs(len(basis_impt) - target_num) / target_num)

In [None]:
basis = basis_small + basis_impt

matrix = H.build_matrix(basis)
result = sp.sparse.linalg.eigsh(matrix.to_scipy(), k=1, which="SA")

energy = result[0][0]
print(f"Basis:\t{len(basis)}")
print(f"Energy:\t{energy}")

In [None]:
eigenvec = result[1][:, 0]
state = sx.State(basis, eigenvec)

nn_selected_state = state % basis_small % state_train.basis
print(nn_selected_state.basis == nn_selected)

In [None]:
nn_selected_right = nn_selected_state.chop(abs_coeff_cut).basis
print(len(nn_selected_right))
print(len(nn_selected_right) / len(nn_selected))

In [None]:
nn_selected_wrong = nn_selected % nn_selected_right
basis_final = basis % nn_selected_wrong
print(len(basis_final))

## Save / load

In [None]:
sx.save(basis_final, "solax_basis_")

In [None]:
basis_loaded = sx.load("solax_basis_")
print(basis_loaded == basis_final)

In [None]:
dict_to_save = dict(
    basis_from_nn=basis_final,
    hamiltonian=H
)

sx.save(dict_to_save, "solax_basis_ham_")

In [None]:
loaded_dict = sx.load("solax_basis_ham_")

for key, value in loaded_dict.items():
    print(f"{key} has type {type(value).__name__}")

In [None]:
dict_to_save = dict(
    info="This computation is a demonstration of SOLAX",
    params=dict(
        N_bath=N_bath,
        U_impurity=U
    ),
    basis_from_nn=basis_final,
    last_epochs=dict(
        epochs=np.array([22, 23, 24, 25, 26]),
        accuracies=np.array([9.613544e-01, 9.513097e-01, 9.507517e-01,
                             9.456353e-01, 9.568901e-01])
    ),
    random_keys_after=rand_keys
)

sx.save(dict_to_save, "solax_big_save_")

In [None]:
sx.load("solax_big_save_")

In [None]:
classifier.save_state("solax_nn_")

In [None]:
loaded_nn = sx.BasisClassifier(nn_call_on_bits)

fake_key = sx.RandomKeys.fake_key()
loaded_nn.initialize(fake_key, basis_start, optimizer)

In [None]:
loaded_nn.load_state("solax_nn_")

In [None]:
## clean up
#!rm -r solax_basis_
#!rm -r solax_basis_ham_
#!rm -r solax_big_save_
#!rm -r solax_nn_