In [10]:
from test_model import *
from model import *
import jax

In [3]:
from model import Model  # Assume the refactored code is saved in model.py

model = Model(2, 2)
spin = model.get_random_spins()

In [11]:
@jit
def _flip_spin_at(spin, i, j):
    def flip_if_one(spin, i, j):
        spin = spin.at[i, j].set(jnp.array([1, 0]))
        return spin

    def flip_if_zero(spin, i, j):
        spin = spin.at[i, j].set(jnp.array([0, 1]))
        return spin
    
    condition = jnp.array_equal(spin[i, j], jnp.array([0, 1]))
    spin = jax.lax.cond(condition, flip_if_one, flip_if_zero, spin, i, j)
    
    return spin

In [20]:
from model import Model  # Assume the refactored code is saved in model.py

model = Model(4, 4)
spin = model.get_random_spins()

In [21]:
%timeit _flip_spin_at(spin, 0, 0).block_until_ready()

15.7 µs ± 705 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [22]:
%timeit model.flip_spin_at(spin, 0, 0)

893 µs ± 21.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [5]:
from model import Model  # Assume the refactored code is saved in model.py

model = Model(2, 2)
spin = model.get_random_spins()
assert spin.shape == (2, 2, 2), "Random spins have incorrect shape."
projected_spin = model.project_spin(spin)
assert projected_spin.shape == (4,), "Random spins have incorrect shape."
unprojected_spin = model.unproject_spin(projected_spin)
assert unprojected_spin.shape == (2, 2, 2), "Random spins have incorrect shape."
assert jnp.array_equal(spin, unprojected_spin), "Unprojected spin does not match original spin."

In [6]:
%timeit model.get_random_spins()

101 µs ± 8.54 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
%timeit model.project_spin(spin)

8.75 µs ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [10]:
%timeit model.project_spin(spin).block_until_ready()

9.96 µs ± 654 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [12]:
%timeit model.unproject_spin(projected_spin)

12.4 µs ± 891 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [13]:
%timeit model.unproject_spin(projected_spin).block_until_ready()

11.7 µs ± 1.58 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [7]:
help(jax.random.split)

Help on function split in module jax._src.random:

split(key: 'KeyArrayLike', num: 'int | tuple[int, ...]' = 2) -> 'KeyArray'
    Splits a PRNG key into `num` new keys by adding a leading axis.
    
    Args:
      key: a PRNG key (from ``key``, ``split``, ``fold_in``).
      num: optional, a positive integer (or tuple of integers) indicating
        the number (or shape) of keys to produce. Defaults to 2.
    
    Returns:
      An array-like object of `num` new PRNG keys.

