In [1]:
import netket as nk
import json
from qutip import *
import numpy as np
import time
import multiprocessing as mp
from collections import OrderedDict
from pickle import dump
import os
import matplotlib.pyplot as plt
import scipy
from matplotlib import gridspec
from functools import reduce
plt.style.use('seaborn')
from scipy.stats import norm
import sys
sys.path.append("/Users/victorwei/Research projects/Neural Network Quantum State/penalty excited states")
import expect_grad_ex
import vmc_ex
import jax
import optax

from typing import Callable, Optional
from functools import partial

import jax
import jax.numpy as jnp

from netket.utils.types import PyTree, Array
import netket.jax as nkjax

from netket.vqs.mc import (
    kernels,
    check_hilbert,
    get_local_kernel_arguments,
    get_local_kernel,
)

from netket.vqs.mc.mc_state.state import MCState

from netket.operator import (
    AbstractOperator,
    DiscreteOperator,
    Squared,
    ContinuousOperator,
)

from netket.vqs.mc import (
    get_local_kernel_arguments,
    get_local_kernel,
)
from netket.stats import Stats, statistics

In [2]:
#we define necessary functions here
def batch_discrete_kernel(kernel):
    """
    Batch a kernel that only works with 1 sample so that it works with a
    batch of samples.
    Works only for discrete-kernels who take two args as inputs
    """

    def vmapped_kernel(logpsi_1, logpsi_2, pars1, pars2, σ, args):
        """
        local_value kernel for MCState and generic operators
        """
        σp, mels = args

        if jnp.ndim(σp) != 3:
            σp = σp.reshape((σ.shape[0], -1, σ.shape[-1]))
            mels = mels.reshape(σp.shape[:-1])

        vkernel = jax.vmap(kernel, in_axes=(None, None,None,None, 0, (0, 0)), out_axes=0)
        return vkernel(logpsi_1, logpsi_2, pars1, pars2, σ, (σp, mels))

    return vmapped_kernel

#we first need to estabalish the kernel functions for evaluating expectation values
#@partial(jax.jit, static_argnums=(0, 1))
@batch_discrete_kernel
def corr_kernel(logpsi_1: Callable, logpsi_2: Callable, pars1: PyTree, pars2: PyTree, σ: Array, args: PyTree):
    σp, mel = args


    return jnp.sum(mel * jnp.exp(logpsi_1(pars1, σp) - logpsi_2(pars2, σ)))


def get_local_kernel_arguments(vstate: MCState, Ô: DiscreteOperator):  # noqa: F811
    check_hilbert(vstate.hilbert, Ô.hilbert)

    σ = vstate.samples
    σp, mels = Ô.get_conn_padded(σ)
    
    return σ, (σp, mels)

In [3]:
#to be able to jit things, we need to not use variational state objects
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
def expect(local_kernel, operator, left_apply_fun, right_apply_fun, left_para, right_para, left_model_state, right_model_state,
          σ, args):
    para1 = right_para
    para2 = left_para
    
    logpsi_1 = right_apply_fun
    logpsi_2 = left_apply_fun
    
    model_state1 = right_model_state
    model_state2 = left_model_state
    
    #σ, args = get_local_kernel_arguments(left_psi, operator)

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    psi_loc = local_kernel(
            logpsi_1,
            logpsi_2,
            {"params": para1, **model_state1},
            {"params": para2, **model_state2},
            σ,
            args,
        )

    psi = statistics(psi_loc)
    print("expectation done once")
    return psi.mean

In [None]:
def evo_corr(t_obver, obver, sta_list, eng_list, t_end, t_step, n_samples):
    num_step = int(np.around(t_end / t_step))
    time = []
    correlation = []
    
    #re-set the number of samples
    for ii in range(len(eng_list)):
        sta_list[ii].n_samples = n_samples
    
    for i in range(num_step+1):
        time.append(t_step*(i))
        corr_temp = complex(0,0)
        for jj in range(len(eng_list)):
            sta_list[jj].reset()
        
        for j in range(len(eng_list)):
            op = np.exp(- 1j * (eng_list[j]-eng_list[0]) * time[i])
            corr_temp += op * expect(t_obver, sta_list[0], sta_list[j]) * expect(obver, sta_list[j], sta_list[0])
        
        correlation.append(corr_temp)
            
    return time, correlation