In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import jaxtyping

import hodel.dismech as dismech

jax.config.update("jax_enable_x64", True)

geom = dismech.Geometry(
    rod_r0=1e-3,
    shell_h=1e-3,
)

mat = dismech.Material(
    density=1200,
    youngs_rod=2e6,
    youngs_shell=2e6,
    poisson_rod=0.5,
    poisson_shell=0.5,
)

mesh = dismech.Mesh.from_txt("plate.txt")
top, ini_state, mass, _, hinges = dismech.from_legacy(mesh, geom, mat)

assert type(hinges) is dismech.DESHinge

In [3]:
from dataclasses import dataclass
from jax.tree_util import register_dataclass

import hodel


def get_indices(
    q: jax.Array,
    top: dismech.Connectivity,
    fixed_nodes: jax.Array | None = None,
    fixed_edges: jax.Array | None = None,
) -> tuple[jax.Array, jax.Array]:
    nodes = (
        jnp.array([], dtype=jnp.int32) if fixed_nodes is None else dismech.map_node_to_dof(fixed_nodes)
    )
    edges = jnp.array([], dtype=jnp.int32) if fixed_edges is None else top.edge_dofs[0] + fixed_edges
    idx_b = jnp.union1d(nodes, edges)
    return idx_b, jnp.setdiff1d(jnp.arange(q.shape[0]), idx_b)


idx_b, idx_f = get_indices(ini_state.q, top, jnp.array([0, 1, 2]))


@register_dataclass
@dataclass(frozen=True)
class TripletAux:
    """parametrized external force."""

    top: dismech.Connectivity  # for state.update()
    idx_f: jax.Array
    idx_b: jax.Array


def get_gravity(mass: jax.Array, g: jax.Array = jnp.array([0.0, 0.0, -9.81])):
    return mass * jnp.concat(
        [
            jnp.tile(g, mesh.nodes.shape[0]),
            jnp.zeros(mesh.rod_edges.shape[0]),
        ]
    )


def get_W(lambda_: jax.Array, aux: TripletAux):
    """Just gravity."""
    return get_gravity(mass)[aux.idx_f] * lambda_


def fixed_0(lambda_: jax.Array, aux: TripletAux) -> jax.Array:
    return ini_state.q[aux.idx_b]


def get_q(xf: jax.Array, xb: jax.Array, aux: TripletAux) -> jax.Array:
    """Helper to construct q from xf and xb."""
    q = jnp.empty((aux.idx_f.shape[0] + aux.idx_b.shape[0]), xf.dtype)
    return q.at[aux.idx_f].set(xf).at[aux.idx_b].set(xb)


def get_energy_fn(stencils):
    def get_batch_energy(
        xf: jax.Array,
        xb: jax.Array,
        Theta: jaxtyping.PyTree,
        aux: TripletAux,
        carry: None,
    ) -> jax.Array:
        q = get_q(xf, xb, aux)
        state = ini_state.update(q, aux.top)
        return jnp.sum(jax.vmap(lambda t: t.get_energy(state, Theta))(stencils))

    return get_batch_energy


xf0 = ini_state.q[idx_f]
aux = TripletAux(top, idx_f, idx_b)
sim = hodel.HODEL(get_energy_fn(hinges), get_W_fn=get_W, get_xb_fn=fixed_0)

In [6]:
lambdas = jnp.linspace(0, 1.0, 100)
xf_stars = sim.solve(lambdas, xf0, None, aux, nsteps=5)
xb_stars = jax.vmap(fixed_0, (0, None))(lambdas, aux)
qs = jax.vmap(get_q, (0, 0, None))(xf_stars, xb_stars, aux)
dismech.animate(lambdas, qs, top)