In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
import dismech_jax

# jax.config.update("jax_enable_x64", True)

geom = dismech_jax.Geometry(0.1, 1e-3)
mat = dismech_jax.Material(1000, 1e7, 0.5)
bc = dismech_jax.BC(
    jnp.array([0, 1, 2, 3, -4, -3, -2, -1]),
    jnp.array([0.0, 0.0, 0.0, 0.0, 1.0, -1e-2, 0.0, 0.0]),
    jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0, 0.0]),
)

base_rod, q0, aux = dismech_jax.Rod.from_geometry(geom, mat, N=11)
der = base_rod.get_DER(geom, mat)
rod = base_rod.with_bc(bc)

lambdas = jnp.linspace(0.0, 1.0, 10)
all_qs = dismech_jax.solve(der, lambdas, q0, aux, rod)
truth = all_qs + jax.random.normal(jax.random.PRNGKey(42), all_qs.shape)


@eqx.filter_jit
def loss(model: eqx.Module):
    pred = dismech_jax.solve(model, lambdas, q0, aux, rod)
    return jnp.linalg.norm(pred - truth)
