In [10]:
import jax
import jax.numpy as jnp
from model import Model

In [11]:
@jax.jit
def vdot_computation(spin1, spin2):
    """
    Calculate the overlap expectation value <spin1|spin2>
    where spin1 and spin2 are 2D arrays (L1, L2) generated by get_random_spins or equivalent
    """
    # Element-wise multiplication and summation
    dot_product = jnp.sum(spin1 * spin2, axis=-1)
    result = jnp.prod(dot_product)
    return result


In [12]:
def test_generate_local_spins():
    #print("Testing generate_local_spins...")
    model = Model(4, 4)
    spin = model.get_random_spins()
    local_spins = model.generate_local_spins(spin, change=1)
    #assert len(local_spins) == 10, "Incorrect number of local spins for change=1."

    local_spins_change_2 = model.generate_local_spins(spin, change=2)
    #assert len(local_spins_change_2) == 46, "Incorrect number of local spins for change=2."

def test_vdot():
    #print("Testing vdot...")
    model = Model(4, 4)
    spin1 = model.get_random_spins()
    spin2 = spin1.copy()
    v_value = model.vdot(spin1, spin2)
    assert v_value == 1, "vdot of identical spins should be 1."
    spin2 = model.flip_random_spin(spin1)
    v_value = model.vdot(spin1, spin2)
    assert v_value == 0, "vdot of orthogonal spins should be 0."

In [13]:
%timeit test_vdot()

4.09 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
def test_vdot_jit():
    #print("Testing vdot...")
    model = Model(4, 4)
    spin1 = model.get_random_spins()
    spin2 = spin1.copy()
    v_value = vdot_computation(spin1, spin2).block_until_ready()
    assert v_value == 1, "vdot of identical spins should be 1."
    spin2 = model.flip_random_spin(spin1)
    v_value = vdot_computation(spin1, spin2).block_until_ready()
    assert v_value == 0, "vdot of orthogonal spins should be 0."

In [25]:
%timeit test_vdot_jit()

3.16 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
