In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

os.environ['NETKET_EXPERIMENTAL_SHARDING'] = '1'
os.environ['NETKET_EXPERIMENTAL_FFT_AUTOCORRELATION'] = '1'


In [2]:
import netket as nk
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.append('/scratch/samiz/GPU_ViT_Calcs/models')
sys.path.append('/scratch/samiz/GPU_ViT_Calcs/Logger_Pickle')

from Afm_Model_functions import H_afm_1d as H_xyz_1d
from ViT_1d_translation import *
import ViT_1d_translation_Xavier as xavier
from vmc_2spins_sampler import VMC_SR, grad_norms_callback

from json_log import PickledJsonLog

from scipy.sparse.linalg import eigsh

from optax.schedules import linear_schedule
from convergence_stopping import LateConvergenceStopping
from netket.callbacks import InvalidLossStopping
import pickle



In [40]:
p_Ha = {
    'L' : 100,
    'J' : 1.0,
    'Dxy' : 0.75,
    'd' : 0.1,
    'parity': 0.,
    'make_rot' : False,
    'exchange_XY' : False,
    'return_hi' : True
}

Ha100, hi100 = H_xyz_1d(L = p_Ha['L'], J1 = p_Ha['J'], Dxy = p_Ha['Dxy'], d = p_Ha['d'], parity= p_Ha['parity'], 
                            make_rotation = p_Ha['make_rot'], exchange_XY = p_Ha['exchange_XY'], return_space= p_Ha['return_hi'])

sampler_100 = nk.sampler.MetropolisHamiltonian(hilbert=hi100, hamiltonian=Ha100.to_jax_operator(), n_chains=32)
# sampler_100 = nk.sampler.MetropolisLocal(hilbert=hi100, n_chains=8)

In [41]:
pvit_100 = {
    'p' : 4,
    'd' : 32,
    'h' : 8,
    'nl' : 1, 
}

transl_arr = get_translations(number_nodes=100, patch_size=pvit_100['p'])
pvit_100['translations'] = transl_arr

m_vit_100 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_100['p'], embed_dim=pvit_100['d'], heads=pvit_100['h'], nl=pvit_100['nl'],
                                                 translations=pvit_100['translations'])

m_vit_xavier = xavier.Simplified_ViT_TranslationSymmetric(patch_size=pvit_100['p'], embed_dim=pvit_100['d'], heads=pvit_100['h'], nl=pvit_100['nl'],
                                                 translations=pvit_100['translations'])

# vs_vit100_trasl = nk.vqs.MCState(sampler=sampler_100, model=m_vit_100, n_samples=2**10)
vs_vit100_trasl = nk.vqs.MCState(sampler=sampler_100, model=m_vit_xavier, n_samples=2**10)

In [6]:
vs_vit100_trasl.parameters['Simplified_ViT_0'].keys()

dict_keys(['Simplified_SelfAttention_0'])

In [7]:
vs_vit100_trasl.n_parameters

1384

In [None]:
p_opt = {
    # 'learning_rate' : linear_schedule(init_value=1e-3, end_value=1e-4, transition_steps=200, transition_begin=500),
    'learning_rate': 1e-3,
    # 'dshift' : 1e-4,
    'dshift': linear_schedule(init_value=1e-4, end_value=1e-5, transition_steps=100, transition_begin=300),
    'n_iter' : 800,
    'n_samples' : 2**12,
    'chunk_size' : 2**10,
    'holom' : True,

}

In [4]:
Stopper1 = InvalidLossStopping(monitor = 'mean', patience = 20)
Stopper2 = LateConvergenceStopping(target = 0.001, monitor = 'variance', patience = 20, start_from_step=100)

log_curr = nk.logging.RuntimeLog()
DataDir = '/scratch/samiz/GPU_ViT_Calcs/ViT_1d_Calcs/'

In [45]:
nls = [1]

for j, nl in enumerate(nls):
    m_vit_100 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_100['p'], embed_dim=pvit_100['d'], heads=pvit_100['h'], nl=nl,
                                                     translations=pvit_100['translations'])


    gs, vs = VMC_SR(hamiltonian=Ha100.to_jax_operator(), sampler=sampler_100, model = m_vit_100, learning_rate=p_opt['learning_rate'], diag_shift=p_opt['dshift'],
                n_samples=p_opt['n_samples'], chunk_size=p_opt['chunk_size'], holomorph=p_opt['holom'], discards=8)

    StateLogger = PickledJsonLog(output_prefix=DataDir + 'log_vit_100S_nl_{}'.format(nl), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger), n_iter=p_opt['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    log_curr.serialize(DataDir + 'log_vit_100S_nl_{}'.format(nl))


number of parameters:  1384
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

# Do the same for 64 spins (run on 4 gpus)

In [21]:
Ha64, hi64 = H_xyz_1d(L = 64, J1 = p_Ha['J'], Dxy = p_Ha['Dxy'], d = p_Ha['d'], parity= p_Ha['parity'], 
                            make_rotation = p_Ha['make_rot'], exchange_XY = p_Ha['exchange_XY'], return_space= p_Ha['return_hi'])

sampler_64 = nk.sampler.MetropolisHamiltonian(hilbert=hi64, hamiltonian=Ha64.to_jax_operator(), n_chains=32)


In [28]:
pvit_64 = {
    'p' : 4,
    'd' : 32,
    'h' : 8,
    'nl' : 1, 
}

transl_arr = get_translations(number_nodes=64, patch_size=pvit_64['p'])
pvit_64['translations'] = transl_arr

# m_vit_64 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_64['h'], nl=pvit_64['nl'],
                                                #  translations=pvit_64['translations'])

m_vit_xavier_64 = xavier.Simplified_ViT_TranslationSymmetric(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_64['h'], nl=pvit_64['nl'],
                                                 translations=pvit_64['translations'])

# vs_vit100_trasl = nk.vqs.MCState(sampler=sampler_100, model=m_vit_100, n_samples=2**10)
vs_vit64_trasl = nk.vqs.MCState(sampler=sampler_64, model=m_vit_xavier_64, n_samples=2**10)

In [29]:
transl_arr.shape

(4, 64)

In [38]:
p_opt_64 = {
    # 'learning_rate' : linear_schedule(init_value=1e-3, end_value=1e-4, transition_steps=150, transition_begin=400),
    'learning_rate': 1e-3,

    # 'dshift' : 1e-4,
    'dshift' : linear_schedule(init_value=1e-4, end_value=1e-5, transition_steps=100, transition_begin=300),

    'n_iter' : 800,
    'n_samples' : 2**12,
    'chunk_size' : 2**10,
    'holom' : True,

}

In [39]:
nls = [1]

for j, nl in enumerate(nls):
    m_vit_64 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_64['h'], nl=nl,
                                                     translations=pvit_64['translations'])


    gs, vs = VMC_SR(hamiltonian=Ha64.to_jax_operator(), sampler=sampler_64, model = m_vit_64, learning_rate=p_opt_64['learning_rate'], diag_shift=p_opt_64['dshift'],
                n_samples=p_opt_64['n_samples'], chunk_size=p_opt_64['chunk_size'], holomorph=p_opt_64['holom'], discards=8)

    StateLogger64 = PickledJsonLog(output_prefix=DataDir + 'log_vit_64S_nl_{}'.format(nl), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger64), n_iter=p_opt_64['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    log_curr.serialize(DataDir + 'log_vit_64S_nl_{}'.format(nl))

number of parameters:  1312
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

In [None]:
# for i, nl in enumerate(nls):
#     for j, d in enumerate(ds):
#         for k, p in enumerate(ps):
#             print('Starting training for p = ', p, ' d = ', d, ' nl = ', nl)
#             transl_arr = get_translations(number_nodes=100, patch_size=p)
#             pvit_100['translations'] = transl_arr
#             pvit_100['p'] = p
#             pvit_100['d'] = d
#             pvit_100['nl'] = nl

#             m_vit_100 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_100['p'], embed_dim=pvit_100['d'], heads=pvit_100['h'], nl=pvit_100['nl'], translations=pvit_100['translations'])
#             vs100 = nk.vqs.MCState(sampler=sampler_100, model=m_vit_100, n_samples=2**10, chunk_size=2**9)

#             log_curr = nk.logging.RuntimeLog()
#             gs100 = nk.driver.VMC(H=Ha100_SS, sampler=sampler_100, optimizer=sgd, n_samples=2**10, preconditioner=sr)

#             gs100.run(n_iter=600, out=log_curr)

#             log_curr.serialize(dataDir + '/Log_XYZ_S100_vit_transl_p{}_d{}_h{}_nl{}_SS'.format(pvit_100['p'], pvit_100['d'], pvit_100['h'], pvit_100['nl']))

# Do the same for 16 spins (run on 4 gpus)

In [5]:
p_Ha16 = {
    'L' : 16,
    'J' : 1.0,
    'Dxy' : 0.75,
    'd' : 0.1,
    'parity': 0.,
    'make_rot' : False,
    'exchange_XY' : False,
    'return_hi' : True
}

Ha16, hi16 = H_xyz_1d(L = p_Ha16['L'], J1 = p_Ha16['J'], Dxy = p_Ha16['Dxy'], d = p_Ha16['d'], parity= p_Ha16['parity'], 
                            make_rotation = p_Ha16['make_rot'], exchange_XY = p_Ha16['exchange_XY'], return_space= p_Ha16['return_hi'])

sampler_16 = nk.sampler.MetropolisHamiltonian(hilbert=hi16, hamiltonian=Ha16.to_jax_operator(), n_chains=32)



In [6]:
pvit_16 = {
    'p' : 4,
    'd' : 16,
    'h' : 4,
    'nl' : 1, 
}

transl_arr = get_translations(number_nodes=p_Ha16['L'], patch_size=pvit_16['p'])
pvit_16['translations'] = transl_arr

# m_vit_64 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_64['p'], embed_dim=pvit_64['d'], heads=pvit_64['h'], nl=pvit_64['nl'],
                                                #  translations=pvit_64['translations'])

m_vit_xavier_16 = xavier.Simplified_ViT_TranslationSymmetric(patch_size=pvit_16['p'], embed_dim=pvit_16['d'], heads=pvit_16['h'], nl=pvit_16['nl'],
                                                 translations=pvit_16['translations'])

# vs_vit100_trasl = nk.vqs.MCState(sampler=sampler_100, model=m_vit_100, n_samples=2**10)
# vs_vit64_trasl = nk.vqs.MCState(sampler=sampler_64, model=m_vit_xavier_64, n_samples=2**10)

In [7]:
p_opt_16 = {
    # 'learning_rate' : linear_schedule(init_value=1e-3, end_value=1e-4, transition_steps=150, transition_begin=400),
    'learning_rate': 1e-3,

    # 'dshift' : 1e-4,
    'dshift' : linear_schedule(init_value=1e-4, end_value=1e-5, transition_steps=100, transition_begin=300),

    'n_iter' : 800,
    'n_samples' : 2**12,
    'chunk_size' : 2**10,
    'holom' : True,

}

In [11]:
nls = [1]

for j, nl in enumerate(nls):
    m_vit_16 = Simplified_ViT_TranslationSymmetric(patch_size=pvit_16['p'], embed_dim=pvit_16['d'], heads=pvit_16['h'], nl=nl,
                                                     translations=pvit_16['translations'])


    gs, vs = VMC_SR(hamiltonian=Ha16.to_jax_operator(), sampler=sampler_16, model = m_vit_16, learning_rate=p_opt_16['learning_rate'], diag_shift=p_opt_16['dshift'],
                n_samples=p_opt_16['n_samples'], chunk_size=p_opt_16['chunk_size'], holomorph=p_opt_16['holom'], discards=8)

    StateLogger64 = PickledJsonLog(output_prefix=DataDir + 'log_vit_16S_nl_{}_transl'.format(nl), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger64), n_iter=p_opt_16['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    log_curr.serialize(DataDir + 'log_vit_64S_nl_{}_transl'.format(nl))

number of parameters:  352
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

In [12]:
for j, nl in enumerate(nls):
    m_vit_16 = Simplified_ViT(patch_size=pvit_16['p'], embed_dim=pvit_16['d'], heads=pvit_16['h'], nl=nl)


    gs, vs = VMC_SR(hamiltonian=Ha16.to_jax_operator(), sampler=sampler_16, model = m_vit_16, learning_rate=p_opt_16['learning_rate'], diag_shift=p_opt_16['dshift'],
                n_samples=p_opt_16['n_samples'], chunk_size=p_opt_16['chunk_size'], holomorph=p_opt_16['holom'], discards=8)

    StateLogger64 = PickledJsonLog(output_prefix=DataDir + 'log_vit_16S_nl_{}'.format(nl), save_params_every=10, save_params=True)

    gs.run(out=(log_curr, StateLogger64), n_iter=p_opt_16['n_iter'], callback=[grad_norms_callback, Stopper1, Stopper2])

    log_curr.serialize(DataDir + 'log_vit_64S_nl_{}'.format(nl))

number of parameters:  352
using regular SR


  0%|          | 0/800 [00:00<?, ?it/s]

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
