In [3]:
import sys, os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax
print(jax.devices())

import netket as nk
import netket.experimental as nkx
from netket.experimental.operator.fermion import destroy as c
from netket.experimental.operator.fermion import create as cdag
from netket.experimental.operator.fermion import number as nc
from netket.models.slater import Slater2nd, MultiSlater2nd
from scipy.sparse.linalg import eigsh
import jax.numpy as jnp
import matplotlib.pyplot as plt

import sys, os
sys.path.append('/global/homes/w/wttai/machine_learning/common_lib')
#from models import get_qwz_graph, get_qwz_Ham, get_qwz_exchange_graph

from networks import *
from datetime import datetime

[CpuDevice(id=0)]


In [4]:
from netket_system import NetketQWZSystem

L = 2
Lx = 2 * L
Ly = L
N_fill = Lx * Ly // 2
pbc = True
U = 8.0
t = 1.0
m = 5.0
bias = 1e-5
complex = True
args = {'U': U, 't': t, 'm': m, 'bias': bias, 'complex': complex}

system = NetketQWZSystem(Lx, L2 = Ly, N = N_fill, pbc = pbc, args = args)

Ham = system.get_hamiltonian()

hi = system.hi
exchange_graph = system.get_exchange_graph()

learning_rate = 0.01
diag_shift = 0.01

# Exact diagonalization

evals, evecs = system.get_ed_data(k = 6)
E_gs = evals[0]
print(evals)

Exact ground state energy: -28.462538109212584
[-28.46253811 -28.46123746 -26.92752973 -26.92752731 -26.2436958
 -26.24369332]


In [5]:
s, p = 1, -1
#dummy_array = [0 for i in range(L**2)]
corrs = {'ss': [0 for i in range(L**2)], 'sp': [0 for i in range(L**2)], 'pp': [0 for i in range(L**2)]}
for i in range(L**2):
    corrs['pp'][i] = system.corr_func(i, p, p)
    corrs['sp'][i] = system.corr_func(i, s, p)
    corrs['ss'][i] = system.corr_func(i, s, s)
    #corrs[f"nc{i}nc0"] = corr_func(i)

# Slater determinant

In [6]:
n_iter = 50
max_restarts = -1  # Maximum number of restart attempts
restart_count = 0  # Counter to track restarts
converged = False  # Flag to check if the run converged
print(f"Running on {jax.devices()}")
start_time = datetime.now()

# Create the Slater determinant model
model = LogSlaterDeterminant(hi, complex=complex)

# Define the Metropolis-Hastings sampler
#sa = nk.sampler.ExactSampler(hi)
sa = nk.sampler.MetropolisExchange(hi, graph = exchange_graph)

# Define the optimizer
op = nk.optimizer.Sgd(learning_rate=learning_rate)

# Define a preconditioner
#preconditioner = nk.optimizer.SR(diag_shift=diag_shift)
preconditioner = nk.optimizer.SR(diag_shift=diag_shift, holomorphic=complex)

# Function to run the VMC simulation
def run_simulation(n_iter = 50):
    # Create the VMC (Variational Monte Carlo) driver
    vstate = nk.vqs.MCState(sa, model, n_samples=2**12, n_discard_per_chain=32)
    gs = nk.VMC(Ham, op, variational_state=vstate, preconditioner=preconditioner)
    
    # Construct the logger to visualize the data later on
    slater_log = nk.logging.RuntimeLog()
    
    # Run the optimization for a short number of iterations (e.g., 50)
    gs.run(n_iter=n_iter, out=slater_log, obs=corrs)
    
    return gs, slater_log

# Main loop for checking convergence and restarting if needed
while restart_count < max_restarts and not converged:
    slater_log = run_simulation()
    
    print(slater_log['Energy']['Variance'])
    # Check if the standard deviation of the energy at the last iteration is too high
    if slater_log['Energy']['Variance'][-1] > 1:
        print(f"Bad convergence detected. Restarting attempt {restart_count + 1} of {max_restarts}...")
        restart_count += 1
        if restart_count >= max_restarts  :
            raise Exception("Failed to converge after 3 attempts. Aborting the run.")
    else:
        converged = True
        print("Good convergence. Continuing with the full run...")



# If converged, run the full simulation
print("Starting full simulation...")
# You can extend this part to run the full simulation for more iterations
gs, slater_log = run_simulation(n_iter = n_iter)  # Re-run with the full iteration count
print("Full simulation completed.")

end_time = datetime.now()
elapsed_time = end_time - start_time
elapsed_seconds = elapsed_time.total_seconds()

print(f"Elapsed time in seconds: {elapsed_time.total_seconds()} seconds", flush = True)

Running on [CpuDevice(id=0)]
Starting full simulation...


  0%|          | 0/50 [00:00<?, ?it/s]

Full simulation completed.
Elapsed time in seconds: 63.25752 seconds


In [11]:
n_iter = 50
max_restarts = -1  # Maximum number of restart attempts
restart_count = 0  # Counter to track restarts
converged = False  # Flag to check if the run converged
print(f"Running on {jax.devices()}")
start_time = datetime.now()

# Create the Slater determinant model
model = LogSlaterDeterminant(hi, complex=complex)

# Define the Metropolis-Hastings sampler
#sa = nk.sampler.ExactSampler(hi)
sa = nk.sampler.MetropolisExchange(hi, graph = exchange_graph)

# Define the optimizer
op = nk.optimizer.Sgd(learning_rate=learning_rate)

# Define a preconditioner
#preconditioner = nk.optimizer.SR(diag_shift=diag_shift)
preconditioner = nk.optimizer.SR(diag_shift=diag_shift, holomorphic=complex)

# Function to run the VMC simulation
def run_simulation(n_iter = 50):
    # Create the VMC (Variational Monte Carlo) driver
    vstate = nk.vqs.MCState(sa, model, n_samples=2**12, n_discard_per_chain=32)
    gs = nk.VMC(Ham, op, variational_state=vstate, preconditioner=preconditioner)
    
    # Construct the logger to visualize the data later on
    slater_log = nk.logging.RuntimeLog()
    
    # Run the optimization for a short number of iterations (e.g., 50)
    gs.run(n_iter=n_iter, out=slater_log, obs=corrs)
    
    return gs, slater_log

# Main loop for checking convergence and restarting if needed
while restart_count < max_restarts and not converged:
    slater_log = run_simulation()
    
    print(slater_log['Energy']['Variance'])
    # Check if the standard deviation of the energy at the last iteration is too high
    if slater_log['Energy']['Variance'][-1] > 1:
        print(f"Bad convergence detected. Restarting attempt {restart_count + 1} of {max_restarts}...")
        restart_count += 1
        if restart_count >= max_restarts  :
            raise Exception("Failed to converge after 3 attempts. Aborting the run.")
    else:
        converged = True
        print("Good convergence. Continuing with the full run...")



# If converged, run the full simulation
print("Starting full simulation...")
# You can extend this part to run the full simulation for more iterations
gs, slater_log = run_simulation(n_iter = n_iter)  # Re-run with the full iteration count
print("Full simulation completed.")

end_time = datetime.now()
elapsed_time = end_time - start_time
elapsed_seconds = elapsed_time.total_seconds()

print(f"Elapsed time in seconds: {elapsed_time.total_seconds()} seconds", flush = True)

Starting full simulation...


  0%|          | 0/50 [00:00<?, ?it/s]

Full simulation completed.
Elapsed time in seconds: 44.029835 seconds
