In [1]:
import json
from time import time as timer
import jax
import openferro as of
from openferro.engine.elastic import *
from openferro.engine.ferroelectric import *
from openferro.units import Constants
from openferro.parallelism import DeviceMesh
from IPython.display import clear_output

##########################################################################################
## Define the lattice  (256X256X256 Simple Cubic Lattice)
##########################################################################################
L = 256
hydropres =  -4.8e4   ## apply a hydrostatic pressure to correct for error of DFT approximation used to parameterize the lattice model. 
config = json.load(open('BaTiO3.json'))
latt_vecs = jnp.eye(3) * config['lattice']['a1']
latt = of.SimpleCubic3D(L, L, L, latt_vecs[0], latt_vecs[1], latt_vecs[2])


In [None]:
bto = of.System(latt)
##########################################################################################
## Define the fields
##########################################################################################
dipole_field = bto.add_field(ID="dipole", ftype="Rn", dim=3, value=0.0, mass = 200 * Constants.amu)
lstrain_field = bto.add_field(ID="lstrain", ftype="LocalStrain3D", value=0.0, mass = 200 * Constants.amu)
gstrain  = bto.add_global_strain(value=jnp.array([0.01,0.01,0.01,0,0,0]), mass = 200 * Constants.amu * L**3)

##########################################################################################
## Define the Hamiltonian
##########################################################################################
bto.add_dipole_onsite_interaction('self_onsite', field_ID="dipole", K2=config["onsite"]["k2"], alpha=config["onsite"]["alpha"], gamma=config["onsite"]["gamma"])
bto.add_dipole_interaction_1st_shell('short_range_1', field_ID="dipole", j1=config["short_range"]["j1"], j2=config["short_range"]["j2"])
bto.add_dipole_interaction_2nd_shell('short_range_2', field_ID="dipole", j3=config["short_range"]["j3"], j4=config["short_range"]["j4"], j5=config["short_range"]["j5"])
bto.add_dipole_interaction_3rd_shell('short_range_3', field_ID="dipole", j6=config["short_range"]["j6"], j7=config["short_range"]["j7"])
bto.add_dipole_dipole_interaction('dipole_ewald', field_ID="dipole", prefactor = config["born"]["Z_star"]**2 / config["born"]["epsilon_inf"] )
bto.add_homo_elastic_interaction('homo_elastic', field_ID="gstrain", B11=config["elastic"]["B11"], B12=config["elastic"]["B12"], B44=config["elastic"]["B44"])
bto.add_homo_strain_dipole_interaction('homo_strain_dipole', field_1_ID="gstrain", field_2_ID="dipole", B1xx=config["elastic_dipole"]["B1xx"], B1yy=config["elastic_dipole"]["B1yy"], B4yz=config["elastic_dipole"]["B4yz"])
bto.add_inhomo_elastic_interaction('inhomo_elastic', field_ID="lstrain", B11=config["elastic"]["B11"], B12=config["elastic"]["B12"], B44=config["elastic"]["B44"])
bto.add_inhomo_strain_dipole_interaction('inhomo_strain_dipole', field_1_ID="lstrain", field_2_ID="dipole", B1xx=config["elastic_dipole"]["B1xx"], B1yy=config["elastic_dipole"]["B1yy"], B4yz=config["elastic_dipole"]["B4yz"])

##########################################################################################
## NPT simulation setup
##########################################################################################
dt = 0.002
temperature = 300
dipole_field.set_integrator('isothermal', dt=dt, temp=temperature, tau=0.1)
lstrain_field.set_integrator('isothermal', dt=dt, temp=temperature, tau=1)
gstrain.set_integrator('isothermal', dt=dt, temp=temperature, tau=1)
simulation = of.SimulationNPTLangevin(bto, pressure=hydropres)
simulation.init_velocity(mode='gaussian', temp=temperature)

##########################################################################################
## Run
##########################################################################################
t0_cpu_init = timer()
jax.block_until_ready(simulation.run(1, profile=False))
t1_cpu_init = timer()

t0_cpu_run = timer()
jax.block_until_ready(simulation.run(500, profile=False))
t1_cpu_run = timer()

##########################################################################################
## Report
##########################################################################################
clear_output()   # Clean the "constant folding warning" from JAX. It is expected for large lattices.

print(f"initialization takes: {t1_cpu_init-t0_cpu_init} seconds")
print(f"500 steps takes: {t1_cpu_run-t0_cpu_run} seconds")

2025-01-15 22:58:09.368440: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %slice.71 = f32[256,256,256,1]{3,2,1,0} slice(f32[256,256,256,6]{3,2,1,0} %constant.17), slice={[0:256], [0:256], [0:256], [3:4]}, metadata={op_name="jit(energy_engine)/jit(main)/jvp(jit(energy_engine))/jit(_ewald_ksum)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1, 2), collapsed_slice_dims=(3,), start_index_map=(3,)) slice_sizes=(256, 256, 256, 1) unique_indices=True indices_are_sorted=True mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/scratch/gpfs/pinchenx/OpenFerro/openferro/engine/ewald.py" source_line=124}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If 

In [None]:
import os
os.system("cat /proc/cpuinfo  | grep 'name'| uniq")