In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

os.environ['NETKET_EXPERIMENTAL_SHARDING'] = '1'
os.environ['NETKET_EXPERIMENTAL_FFT_AUTOCORRELATION'] = '1'

# Here we run the different optimizations with sign structure

In [2]:
import netket as nk
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.append('/scratch/samiz/GPU_ViT_Calcs/models')
sys.path.append('/scratch/samiz/GPU_ViT_Calcs/Logger_Pickle')

from Afm_Model_functions import H_afm_1d as H_xyz_1d
from ViT_1d_translation import *
import ViT_1d_translation_Xavier as xavier
from vmc_2spins_sampler import VMC_SR, grad_norms_callback

from json_log import PickledJsonLog

from scipy.sparse.linalg import eigsh

from optax.schedules import linear_schedule
from convergence_stopping import LateConvergenceStopping
from netket.callbacks import InvalidLossStopping
import pickle




In [3]:
Stopper1 = InvalidLossStopping(monitor = 'mean', patience = 20)
Stopper2 = LateConvergenceStopping(target = 0.001, monitor = 'variance', patience = 20, start_from_step=100)

log_curr = nk.logging.RuntimeLog()
DataDir = '/scratch/samiz/GPU_ViT_Calcs/ViT_1d_Calcs/'

## 16 Spins with Sign Structure:

In [4]:
p_Ha16 = {
    'L' : 16,
    'J' : 1.0,
    'Dxy' : 0.75,
    'd' : 0.1,
    'parity': 0.,
    'make_rot' : True,
    'exchange_XY' : True,
    'return_hi' : True
}

Ha16, hi16 = H_xyz_1d(L = p_Ha16['L'], J1 = p_Ha16['J'], Dxy = p_Ha16['Dxy'], d = p_Ha16['d'], parity= p_Ha16['parity'], 
                            make_rotation = p_Ha16['make_rot'], exchange_XY = p_Ha16['exchange_XY'], return_space= p_Ha16['return_hi'])

sampler_16 = nk.sampler.MetropolisHamiltonian(hilbert=hi16, hamiltonian=Ha16.to_jax_operator(), n_chains=32)

In [5]:
p_opt_16 = {
    # 'learning_rate' : linear_schedule(init_value=1e-3, end_value=1e-4, transition_steps=150, transition_begin=400),
    'learning_rate': 1e-3,

    # 'dshift' : 1e-4,
    'dshift' : linear_schedule(init_value=1e-4, end_value=1e-5, transition_steps=100, transition_begin=300),

    'n_iter' : 800,
    'n_samples' : 2**12,
    'chunk_size' : 2**10,
    'holom' : True,

}


pvit_16 = {
    'p' : 4,
    'd' : 16,
    'h' : 4,
    'nl' : 1, 
}

transl_arr = get_translations(number_nodes=p_Ha16['L'], patch_size=pvit_16['p'])
pvit_16['translations'] = transl_arr

# m_vit_64 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_64['h'], nl=pvit_64['nl'],
                                                #  translations=pvit_64['translations'])

m_vit_xavier_16 = xavier.Simplified_ViT_TranslationSymmetric(patch_size=pvit_16['p'], embed_dim=pvit_16['d'], heads=pvit_16['h'], nl=pvit_16['nl'],
                                                 translations=pvit_16['translations'])

In [None]:
nls = [1]

for j, nl in enumerate(nls):
    m_vit_16 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_16['p'], embed_dim=pvit_16['d'], heads=pvit_16['h'], nl=nl,
                                                     translations=pvit_16['translations'])


    gs, vs = VMC_SR(hamiltonian=Ha16.to_jax_operator(), sampler=sampler_16, model = m_vit_16, learning_rate=p_opt_16['learning_rate'], diag_shift=p_opt_16['dshift'],
                n_samples=p_opt_16['n_samples'], chunk_size=p_opt_16['chunk_size'], holomorph=p_opt_16['holom'], discards=8)

    StateLogger64 = PickledJsonLog(output_prefix=DataDir + 'log_vit_16S_Sign_nl_{}_transl'.format(nl), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger64), n_iter=p_opt_16['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    log_curr.serialize(DataDir + 'log_vit_16S_Sign_nl_{}_transl'.format(nl))

Defaulting to a slow, possibly infinitely-looping method to generate random state of
the current Hilbert space with a custom constraint. Consider implementing a
custom `random_state` method for your constraint if this method takes a long time to
generate a random state.

in your code.

To generate a custom random_state dispatched method, you should use multiple dispatch
following the following syntax:

>>> import netket as nk
>>> from netket.utils import dispatch
>>>
>>> @dispatch.dispatch
>>> def random_state(hilb: netket.hilbert.spin.Spin,
                    constraint: vmc_2spins_sampler.Mtot_Parity_Constraint,
                    key,
                    batches: int,
                    *,
                    dtype=None):
>>>    # your custom implementation here
>>>    # You should return a batch of `batches` random states, with the given dtype.
>>>    # return jax.Array with shape (batches, hilb.size) and dtype dtype.



-------------------------------------------------------
Fo

number of parameters:  352
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


In [8]:
for j, nl in enumerate(nls):
    m_vit_16 = Simplified_ViT(patch_size=pvit_16['p'], embed_dim=pvit_16['d'], heads=pvit_16['h'], nl=nl)


    gs, vs = VMC_SR(hamiltonian=Ha16.to_jax_operator(), sampler=sampler_16, model = m_vit_16, learning_rate=p_opt_16['learning_rate'], diag_shift=p_opt_16['dshift'],
                n_samples=p_opt_16['n_samples'], chunk_size=p_opt_16['chunk_size'], holomorph=p_opt_16['holom'], discards=8)

    StateLogger64 = PickledJsonLog(output_prefix=DataDir + 'log_vit_16S_Sign_nl_{}'.format(nl), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger64), n_iter=p_opt_16['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    # log_curr.serialize(DataDir + 'log_vit_16S_Sign_nl_{}'.format(nl))

number of parameters:  352
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

##  Now for 64 Spins

In [9]:
p_Ha64 = {
    'L' : 64,
    'J' : 1.0,
    'Dxy' : 0.75,
    'd' : 0.1,
    'parity': 0.,
    'make_rot' : True,
    'exchange_XY' : True,
    'return_hi' : True
}



Ha64, hi64 = H_xyz_1d(L = p_Ha64['L'], J1 = p_Ha64['J'], Dxy = p_Ha64['Dxy'], d = p_Ha64['d'], parity= p_Ha64['parity'], 
                            make_rotation = p_Ha64['make_rot'], exchange_XY = p_Ha64['exchange_XY'], return_space= p_Ha64['return_hi'])

sampler_64 = nk.sampler.MetropolisHamiltonian(hilbert=hi64, hamiltonian=Ha64.to_jax_operator(), n_chains=32)

In [11]:
p_opt_64 = {
    # 'learning_rate' : linear_schedule(init_value=1e-3, end_value=1e-4, transition_steps=200, transition_begin=500),
    'learning_rate': 1e-3,
    # 'dshift' : 1e-4,
    'dshift': linear_schedule(init_value=1e-4, end_value=1e-5, transition_steps=100, transition_begin=300),
    'n_iter' : 800,
    'n_samples' : 2**12,
    'chunk_size' : 2**10,
    'holom' : True,

}

pvit_64 = {
    'p' : 4,
    'd' : 32,
    'h' : 8,
    'nl' : 1, 
}

transl_arr = get_translations(number_nodes=p_Ha64['L'], patch_size=pvit_64['p'])
pvit_64['translations'] = transl_arr

# m_vit_64 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_64['h'], nl=pvit_64['nl'],
                                                #  translations=pvit_64['translations'])

m_vit_xavier_64 = xavier.Simplified_ViT_TranslationSymmetric(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_16['h'], nl=pvit_64['nl'],
                                                 translations=pvit_64['translations'])


In [12]:
nls = [1]

for j, nl in enumerate(nls):
    m_vit_64 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_64['h'], nl=nl,
                                                     translations=pvit_64['translations'])


    gs, vs = VMC_SR(hamiltonian=Ha64.to_jax_operator(), sampler=sampler_64, model = m_vit_64, learning_rate=p_opt_64['learning_rate'], diag_shift=p_opt_64['dshift'],
                n_samples=p_opt_64['n_samples'], chunk_size=p_opt_64['chunk_size'], holomorph=p_opt_64['holom'], discards=8)

    StateLogger64 = PickledJsonLog(output_prefix=DataDir + 'log_vit_64S_Sign_nl_{}transl'.format(nl), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger64), n_iter=p_opt_64['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    # log_curr.serialize(DataDir + 'log_vit_64S_Sign_nl_{}_transl'.format(nl))

Defaulting to a slow, possibly infinitely-looping method to generate random state of
the current Hilbert space with a custom constraint. Consider implementing a
custom `random_state` method for your constraint if this method takes a long time to
generate a random state.

in your code.

To generate a custom random_state dispatched method, you should use multiple dispatch
following the following syntax:

>>> import netket as nk
>>> from netket.utils import dispatch
>>>
>>> @dispatch.dispatch
>>> def random_state(hilb: netket.hilbert.spin.Spin,
                    constraint: vmc_2spins_sampler.Mtot_Parity_Constraint,
                    key,
                    batches: int,
                    *,
                    dtype=None):
>>>    # your custom implementation here
>>>    # You should return a batch of `batches` random states, with the given dtype.
>>>    # return jax.Array with shape (batches, hilb.size) and dtype dtype.



-------------------------------------------------------
Fo

number of parameters:  1312
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


In [18]:
nls = [1]

for j, nl in enumerate(nls):
    m_vit_64 = Simplified_ViT(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_64['h'], nl=nl)
                                                     


    gs, vs = VMC_SR(hamiltonian=Ha64.to_jax_operator(), sampler=sampler_64, model = m_vit_64, learning_rate=p_opt_64['learning_rate'], diag_shift=p_opt_64['dshift'],
                n_samples=p_opt_64['n_samples'], chunk_size=p_opt_64['chunk_size'], holomorph=p_opt_64['holom'], discards=8)

    StateLogger64 = PickledJsonLog(output_prefix=DataDir + 'log_vit_64S_Sign_nl_{}'.format(nl), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger64), n_iter=500, callback=[grad_norms_callback, Stopper1, Stopper2])

    # log_curr.serialize(DataDir + 'log_vit_64S_Sign_nl_{}'.format(nl))

number of parameters:  1312
using regular SR


  0%|          | 0/500 [00:00<?, ?it/s]

## Now for 100 Spins

In [4]:
p_Ha = {
    'L' : 100,
    'J' : 1.0,
    'Dxy' : 0.75,
    'd' : 0.1,
    'parity': 0.,
    'make_rot' : True,
    'exchange_XY' : True,
    'return_hi' : True
}

Ha100, hi100 = H_xyz_1d(L = p_Ha['L'], J1 = p_Ha['J'], Dxy = p_Ha['Dxy'], d = p_Ha['d'], parity= p_Ha['parity'], 
                            make_rotation = p_Ha['make_rot'], exchange_XY = p_Ha['exchange_XY'], return_space= p_Ha['return_hi'])

sampler_100 = nk.sampler.MetropolisHamiltonian(hilbert=hi100, hamiltonian=Ha100.to_jax_operator(), n_chains=32)

In [21]:
pvit_100_Vers1 = {
    'p' : 4,
    'd' : 32,
    'h' : 8,
    'nl' : 1, 
}

transl_arr = get_translations(number_nodes=100, patch_size=pvit_100_Vers1['p'])
pvit_100_Vers1['translations'] = transl_arr

m_vit_100 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_100_Vers1['p'], embed_dim=pvit_100_Vers1['d'], heads=pvit_100_Vers1['h'], nl=pvit_100_Vers1['nl'],
                                                 translations=pvit_100_Vers1['translations'])

m_vit_xavier = xavier.Simplified_ViT_TranslationSymmetric(patch_size=pvit_100_Vers1['p'], embed_dim=pvit_100_Vers1['d'], heads=pvit_100_Vers1['h'], nl=pvit_100_Vers1['nl'],
                                                 translations=pvit_100_Vers1['translations'])

# vs_vit100_trasl = nk.vqs.MCState(sampler=sampler_100, model=m_vit_100, n_samples=2**10)
vs_vit100_trasl = nk.vqs.MCState(sampler=sampler_100, model=m_vit_xavier, n_samples=2**10)

In [7]:
p_opt_100 = {
    # 'learning_rate' : linear_schedule(init_value=1e-3, end_value=1e-4, transition_steps=200, transition_begin=500),
    'learning_rate': linear_schedule(init_value=1e-3, end_value=1e-4, transition_steps=100, transition_begin=400),
    # 'dshift' : 1e-4,
    'dshift': linear_schedule(init_value=1e-4, end_value=1e-5, transition_steps=100, transition_begin=300),
    'n_iter' : 800,
    'n_samples' : 2**12,
    'chunk_size' : 2**10,
    'holom' : True,

}

In [25]:
nls = [1]

for j, nl in enumerate(nls):
    m_vit_100 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_100_Vers1['p'], embed_dim=pvit_100_Vers1['d'], heads=pvit_100_Vers1['h'], nl=nl,
                                                     translations=pvit_100_Vers1['translations'])


    gs, vs = VMC_SR(hamiltonian=Ha100.to_jax_operator(), sampler=sampler_100, model = m_vit_100, learning_rate=p_opt_100['learning_rate'], diag_shift=p_opt_100['dshift'],
                n_samples=p_opt_100['n_samples'], chunk_size=p_opt_100['chunk_size'], holomorph=p_opt_100['holom'], discards=8)

    StateLogger = PickledJsonLog(output_prefix=DataDir + 'log_vit_100S_Sign_nl_{}_p{}_d{}_transl'.format(nl, pvit_100_Vers1['p'], pvit_100_Vers1['d']), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger), n_iter=p_opt_100['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    log_curr.serialize(DataDir + 'log_vit_100S_Sign_nl_{}_p{}_d{}_transl'.format(nl, pvit_100_Vers1['p'], pvit_100_Vers1['d']))

    

number of parameters:  1384
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

In [5]:
pvit_100_Vers2 = {
    'p' : 10,
    'd' : 32,
    'h' : 8,
    'nl' : 1, 
}

transl_arr = get_translations(number_nodes=100, patch_size=pvit_100_Vers2['p'])
pvit_100_Vers2['translations'] = transl_arr

m_vit_100 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_100_Vers2['p'], embed_dim=pvit_100_Vers2['d'], heads=pvit_100_Vers2['h'], nl=pvit_100_Vers2['nl'],
                                                 translations=pvit_100_Vers2['translations'])


In [8]:
nls = [1]

for j, nl in enumerate(nls):
    m_vit_100 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_100_Vers2['p'], embed_dim=pvit_100_Vers2['d'], heads=pvit_100_Vers2['h'], nl=nl,
                                                     translations=pvit_100_Vers2['translations'])


    gs, vs = VMC_SR(hamiltonian=Ha100.to_jax_operator(), sampler=sampler_100, model = m_vit_100, learning_rate=p_opt_100['learning_rate'], diag_shift=p_opt_100['dshift'],
                n_samples=p_opt_100['n_samples'], chunk_size=p_opt_100['chunk_size'], holomorph=p_opt_100['holom'], discards=8)

    StateLogger = PickledJsonLog(output_prefix=DataDir + 'log_vit_100S_Sign_nl_{}_p{}_d{}_transl'.format(nl, pvit_100_Vers2['p'], pvit_100_Vers2['d']), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger), n_iter=p_opt_100['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    # log_curr.serialize(DataDir + 'log_vit_100S_Sign_nl_{}_p{}_d{}_transl'.format(nl, pvit_100_Vers1['p'], pvit_100_Vers1['d']))


Defaulting to a slow, possibly infinitely-looping method to generate random state of
the current Hilbert space with a custom constraint. Consider implementing a
custom `random_state` method for your constraint if this method takes a long time to
generate a random state.

in your code.

To generate a custom random_state dispatched method, you should use multiple dispatch
following the following syntax:

>>> import netket as nk
>>> from netket.utils import dispatch
>>>
>>> @dispatch.dispatch
>>> def random_state(hilb: netket.hilbert.spin.Spin,
                    constraint: vmc_2spins_sampler.Mtot_Parity_Constraint,
                    key,
                    batches: int,
                    *,
                    dtype=None):
>>>    # your custom implementation here
>>>    # You should return a batch of `batches` random states, with the given dtype.
>>>    # return jax.Array with shape (batches, hilb.size) and dtype dtype.



-------------------------------------------------------
Fo

number of parameters:  1456
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


In [None]:
# nls = [1]

# for j, nl in enumerate(nls):
#     m_vit_100 = Simplified_ViT(patch_size=pvit_100_Vers2['p'], embed_dim=pvit_100_Vers2['d'], heads=pvit_100_Vers2['h'], nl=nl)
                                                     


#     gs, vs = VMC_SR(hamiltonian=Ha100.to_jax_operator(), sampler=sampler_100, model = m_vit_100, learning_rate=p_opt_100['learning_rate'], diag_shift=p_opt_100['dshift'],
#                 n_samples=p_opt_100['n_samples'], chunk_size=p_opt_100['chunk_size'], holomorph=p_opt_100['holom'], discards=8)

#     StateLogger = PickledJsonLog(output_prefix=DataDir + 'log_vit_100S_Sign_nl_{}_p{}_d{}'.format(nl, pvit_100_Vers2['p'], pvit_100_Vers2['d']), save_params_every=10, save_params=True)

#     gs.run(out=(log_curr, StateLogger), n_iter=p_opt_100['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])