In [1]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

import sys
sys.path.append("..")
import solax as sx

import os
os.environ['XLA_FLAGS']='--xla_gpu_deterministic_ops=true'

In [2]:
def build_bath(N_bath):
    ii = np.arange(N_bath) + 1
    xx = ii * np.pi / (N_bath + 1)
    e_bath = -2 * np.cos(xx)
    
    V0 = np.sqrt(20 / (N_bath + 1))
    V_bath = V0 * np.sqrt(1 - (e_bath / 2)**2)
    
    return e_bath, V_bath

In [3]:
def build_start_dets(N_bath):
    det1 = "01"  + "1" * (N_bath - 1) + "10" + "0" * (N_bath - 1)
    det2 = "10"  + "1" * (N_bath - 1) + "01" + "0" * (N_bath - 1)
    return det1, det2

In [4]:
U = 10
N_bath = 21
e_bath, V_bath = build_bath(N_bath)
start_dets = build_start_dets(N_bath)

basis_start = sx.Basis(build_start_dets(N_bath))

H_imp2 = sx.Operator(
    (1, 0, 1, 0),
    np.array([
        [0, 0, 1, 1]
    ]),
    np.array([U])
)

H_imp1 = sx.Operator(
    (1, 0),
    np.array([
        [0, 0],
        [1, 1]
    ]),
    np.array([-U / 2, -U / 2])
)

H_imp = H_imp2 + H_imp1 + U / 4

H_bath = sx.Operator(
    (1, 0),
    np.arange(2, 2 * N_bath + 2).repeat(2).reshape(-1, 2),
    e_bath.repeat(2)
)

H_hyb_posits = np.vstack([
    np.array([0, 1] * N_bath),
    np.arange(2, 2 * N_bath + 2)
]).T

H_hyb_nohc = sx.Operator(
    (1, 0),
    H_hyb_posits,
    V_bath.repeat(2)
)

H = H_imp + H_bath + H_hyb_nohc + H_hyb_nohc.hconj

In [5]:
num_iterations = 4

basis = basis_start

for i in range(num_iterations):
    matrix = H.build_matrix(basis)
    energy = sp.sparse.linalg.eigsh(
        matrix.to_scipy(), k=1, which="SA"
    )[0][0]
    
    basis_size = len(basis)
    print(
        f"Iteration: {i+1:<8d}"
        f"Basis size = {basis_size:<12d}"
        f"Energy = {energy}"
    )
    
    if i < num_iterations - 1:
        basis = H(basis)

2024-08-20 15:54:11.119129: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Iteration: 1       Basis size = 2           Energy = -28.463653910211487
Iteration: 2       Basis size = 44          Energy = -30.19530217404953
Iteration: 3       Basis size = 684         Energy = -31.242891311317663
Iteration: 4       Basis size = 7084        Energy = -31.70729257122757


## BasisClassifier

In [6]:
from flax import linen as nn
import optax

In [7]:
def nn_call_on_bits(x):
    x = x.reshape(-1, 2)
    x = nn.Conv(features=64, kernel_size=(2,), padding="valid")(x)
    x = nn.relu(x)
    x = nn.Conv(features=4, kernel_size=(1,), padding="valid")(x)
    x = nn.relu(x)
    x = x.reshape(-1)
    
    x = nn.Dense(features=dense_size)(x)
    x = nn.relu(x)
    x = nn.Dense(features=dense_size//2)(x)
    x = nn.relu(x)
    x = nn.Dense(features=dense_size//4)(x)
    x = nn.relu(x)
    x = nn.Dense(features=2)(x)
    return x

In [8]:
dense_size = int(7 * np.sqrt(2 * N_bath + 2))
print(dense_size)

46


In [9]:
classifier = sx.BasisClassifier(nn_call_on_bits)

In [10]:
rand_keys = sx.RandomKeys(seed=1234)
key_for_nn = next(rand_keys)

optimizer = optax.adam(learning_rate=0.005)

classifier.initialize(key_for_nn, basis_start, optimizer)
classifier.print_summary()


[3m                                 Module Summary                                 [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs        [0m[1m [0m┃[1m [0m[1moutputs       [0m[1m [0m┃[1m [0m[1mparams                 [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ Module │ [2muint8[0m[2,6]     │ [2mfloat32[0m[2,2]   │                         │
├─────────┼────────┼────────────────┼────────────────┼─────────────────────────┤
│ Conv_0  │ Conv   │ [2muint8[0m[22,2]    │ [2mfloat32[0m[21,64] │ bias: [2mfloat32[0m[64]       │
│         │        │                │                │ kernel: [2mfloat32[0m[2,2,64] │
│         │        │                │                │                         │
│         │        │                │                │ [1m320 [0m[1;2m(1.3 KB)[0m         

In [11]:
basis_small = basis
del basis

basis_big = H(basis_small)
print(len(basis_big))

58984


In [12]:
candidates = basis_big % basis_small
print(len(candidates))

51900


## BigBasisManager

In [13]:
bbm = sx.BigBasisManager(candidates, classifier)

In [14]:
target_num = int(np.sqrt(len(basis_big)) * 50)
print(target_num)

12143


In [15]:
random_num = int(target_num / 1.5)
print(random_num)

8095


In [16]:
random_sel = bbm.sample_subbasis(next(rand_keys), random_num)

In [17]:
print(isinstance(random_sel, sx.Basis))
print(len(random_sel))

True
8095


In [18]:
basis_diag = basis_small + random_sel
print(len(basis_diag))

15179


In [19]:
matrix = H.build_matrix(basis_diag)
result = sp.sparse.linalg.eigsh(matrix.to_scipy(), k=1, which="SA")

energy = result[0][0]
print(f"Intermediate energy:\t{energy}")

Intermediate energy:	-31.720920015599123


In [20]:
eigenvec = result[1][:, 0]
state_diag = sx.State(basis_diag, eigenvec)

In [21]:
state_train = state_diag % basis_small
len(state_train)

8095

In [22]:
abs_coeff_cut = bbm.derive_abs_coeff_cut(target_num, state_train)
print(f"Cutoff:\t{abs_coeff_cut}")

Cutoff:	0.00020762079204639022


In [23]:
state_train_impt = state_train.chop(abs_coeff_cut)
print(len(state_train_impt))

1893


In [24]:
print(len(state_train_impt) / len(state_train))

0.23384805435453984


In [25]:
print(target_num / len(candidates))

0.23396917148362234


In [26]:
early_stopped = bbm.train_classifier(
    next(rand_keys),
    state_train,
    abs_coeff_cut,
    batch_size=256,
    epochs=200,
    early_stop=True,
    early_stop_params={"patience": 3}
)

print(early_stopped)

Started:	accuracy=2.472703e-01
Epoch 0:	accuracy=7.993290e-01
Epoch 1:	accuracy=8.385327e-01
Epoch 2:	accuracy=8.559730e-01
Epoch 3:	accuracy=8.589044e-01
Epoch 4:	accuracy=8.705290e-01
Epoch 5:	accuracy=8.602961e-01
Epoch 6:	accuracy=8.913646e-01
Epoch 7:	accuracy=8.637385e-01
Epoch 8:	accuracy=9.182444e-01
Epoch 9:	accuracy=9.271260e-01
Epoch 10:	accuracy=9.412651e-01
Epoch 11:	accuracy=9.384749e-01
Epoch 12:	accuracy=9.446133e-01
Epoch 13:	accuracy=9.393086e-01
Epoch 14:	accuracy=9.540528e-01
Epoch 15:	accuracy=9.496827e-01
Epoch 16:	accuracy=9.491246e-01
Epoch 17:	accuracy=9.478675e-01
Epoch 18:	accuracy=9.518678e-01
True


In [27]:
nn_selected = bbm.predict_impt_subbasis(batch_size=256)
nn_selected = nn_selected % state_train.basis
print(len(nn_selected))

10711


In [28]:
basis_impt = nn_selected + state_train_impt.basis
print(len(basis_impt))
print(abs(len(basis_impt) - target_num) / target_num)

12604
0.0379642592440089


In [29]:
basis = basis_small + basis_impt

matrix = H.build_matrix(basis)
result = sp.sparse.linalg.eigsh(matrix.to_scipy(), k=1, which="SA")

energy = result[0][0]
print(f"Basis:\t{len(basis)}")
print(f"Energy:\t{energy}")

Basis:	19688
Energy:	-31.81742496812377


In [30]:
eigenvec = result[1][:, 0]
state = sx.State(basis, eigenvec)

nn_selected_state = state % basis_small % state_train.basis
print(nn_selected_state.basis == nn_selected)

True


In [31]:
nn_selected_right = nn_selected_state.chop(abs_coeff_cut).basis
print(len(nn_selected_right))
print(len(nn_selected_right) / len(nn_selected))

9039
0.8438987956306601


In [32]:
nn_selected_wrong = nn_selected % nn_selected_right
basis_final = basis % nn_selected_wrong
print(len(basis_final))

18016


## Save / load

In [33]:
sx.save(basis_final, "solax_basis_")

In [34]:
basis_loaded = sx.load("solax_basis_")
print(basis_loaded == basis_final)

True


In [35]:
dict_to_save = dict(
    basis_from_nn=basis_final,
    hamiltonian=H
)

sx.save(dict_to_save, "solax_basis_ham_")

In [36]:
loaded_dict = sx.load("solax_basis_ham_")

for key, value in loaded_dict.items():
    print(f"{key} has type {type(value).__name__}")

basis_from_nn has type Basis
hamiltonian has type Operator


In [37]:
dict_to_save = dict(
    info="This computation is a demonstration of SOLAX",
    params=dict(
        N_bath=N_bath,
        U_impurity=U
    ),
    basis_from_nn=basis_final,
    last_epochs=dict(
        epochs=np.array([22, 23, 24, 25, 26]),
        accuracies=np.array([9.613544e-01, 9.513097e-01, 9.507517e-01,
                             9.456353e-01, 9.568901e-01])
    ),
    random_keys_after=rand_keys
)

sx.save(dict_to_save, "solax_big_save_")

In [38]:
sx.load("solax_big_save_")

{'info': 'This computation is a demonstration of SOLAX',
 'params': {'N_bath': 21, 'U_impurity': 10},
 'basis_from_nn': Basis(_encoding=array([[127, 255, 254,   0,   0,   0],
        [223, 255, 254,   0,   0,   0],
        [247, 255, 254,   0,   0,   0],
        ...,
        [127, 239, 127,   2,   0,   0],
        [123, 247, 254,  64, 128,   0],
        [127, 247, 190,  18,   0,   0]], dtype=uint8), _bitlen=44),
 'last_epochs': {'epochs': array([22, 23, 24, 25, 26]),
  'accuracies': array([0.9613544, 0.9513097, 0.9507517, 0.9456353, 0.9568901])},
 'random_keys_after': RandomKeys(_key=Array([2236037635, 1430511502], dtype=uint32))}

In [39]:
classifier.save_state("solax_nn_")

In [40]:
loaded_nn = sx.BasisClassifier(nn_call_on_bits)

fake_key = sx.RandomKeys.fake_key()
loaded_nn.initialize(fake_key, basis_start, optimizer)

In [41]:
loaded_nn.load_state("solax_nn_")

In [42]:
## clean up
#!rm -r solax_basis_
#!rm -r solax_basis_ham_
#!rm -r solax_big_save_
#!rm -r solax_nn_

  pid, fd = os.forkpty()
